diff --git a/.gitignore b/.gitignore index 828bbe9bd3363853ae3f58f54a8d5f60cefad837..b5306b8b79c37166e5496cf17a3e39b86b9a6314 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ __pycache__ cmake_build/ .idea/** /build/ +[Bb]uild/ /tensorflow/core/util/version_info.cc /tensorflow/python/framework/fast_tensor_util.cpp Pods diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8669c25c452b53da48239bc20c9a2d3528e75422..f598999f351c10f8bd01dfbd3ad8897f19d570e8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g., Changes to TensorFlow C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). -Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: ```bash apt-get install -y clang-tidy @@ -107,7 +107,7 @@ diff /tmp/my_cc_file.cc #### Python coding style Changes to TensorFlow Python code should conform to -[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) +[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) Use `pylint` to check your Python changes. To install `pylint` and retrieve TensorFlow's custom style definition: diff --git a/README.md b/README.md index 6fb4486d0de9ff476b5cf1dbd63d66879637df84..05fcb23f7edd657f2ea495d848fadc226e56b524 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), a data visualization toolkit. +code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -56,6 +56,7 @@ $ python 42 >>> sess.close() ``` +Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines @@ -95,6 +96,8 @@ The TensorFlow project strives to abide by generally accepted best practices in | --- | --- | --- | | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | +| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | +| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | ## For more information diff --git a/RELEASE.md b/RELEASE.md index 84d9d52868ecd55d38d6073315749d11c2340e8c..4b0339442768afbd97ac21323bb0351eea13a6ca 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,82 @@ +# Release 1.9.0 + +## Major Features And Improvements +* Updated docs for `tf.keras`: New Keras-based [get started](http://tensorflow.org/versions/r1.9/get_started), + and [programmers guide page](http://tensorflow.org/versions/r1.9/programmers_guide/keras). +* Update `tf.keras` to the Keras 2.1.6 API. +* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). +* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). +* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite) + for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md) + has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again + included in the standard `pip` installation. +* Improved data-loading and text processing with: + * [`tf.decode_compressed`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/decode_compressed) + * [`tf.string_strip`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/string_strip) + * [`tf.strings.regex_full_match`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/strings/regex_full_match) +* Added experimental support for new pre-made Estimators: + * [`tf.contrib.estimator.BaselineEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/BaselineEstimator) + * [`tf.contrib.estimator.RNNClassifier`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNEstimator) + * [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier) +* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) + API supports broadcasting for Bijectors with new API changes. + +## Breaking Chances + * If you're opening empty variable scopes; replace `variable_scope('', ...)` by + `variable_scope(tf.get_variable_scope(), ...)`. + * Headers used for building custom ops have been moved from site-packages/external into site-packages/tensorflow/include/external. + +## Bug Fixes and Other Changes + +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See + [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details +* `tf.data`: + * The `DatasetBase::DebugString()` method is now `const`. + * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. +* Eager Execution: +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. +* Accelerated Linear Algebra (XLA): +* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* `tf.contrib`: + * Add `tf.contrib.data.choose_from_datasets()`. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Add optional `args` argument to `Dataset.from_generator()`. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang + # Release 1.8.0 ## Major Features And Improvements @@ -406,15 +485,7 @@ answered questions, and were part of inspiring discussions. ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of - the core TensorFlow API. - * The API is now subject to backwards compatibility guarantees. - -# Release 1.4.0 - -## Major Features And Improvements -* `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of +* [`tf.data`](http://tensorflow.org/guide/datasets) is now part of the core TensorFlow API. * The API is now subject to backwards compatibility guarantees. * For a guide to migrating from the `tf.contrib.data` API, see the @@ -434,7 +505,7 @@ answered questions, and were part of inspiring discussions. * TensorFlow Debugger (tfdbg): * Add `eval` command to allow evaluation of arbitrary Python/numpy expressions in tfdbg command-line interface. See - [Debugging TensorFlow Programs](https://www.tensorflow.org/programmers_guide/debugger) + [Debugging TensorFlow Programs](https://www.tensorflow.org/guide/debugger) for more details. * Usability improvement: The frequently used tensor filter `has_inf_or_nan` is now added to `Session` wrappers and hooks by default. So there is no need @@ -721,7 +792,7 @@ answered questions, and were part of inspiring discussions. * Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters. * TensorFlow C library now available for Windows. * We released a new open-source version of TensorBoard. -* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel +* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel * Android releases of TensorFlow are now pushed to jcenter for easier integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md diff --git a/SECURITY.md b/SECURITY.md index 01886b613e5d93793953124331b57f075fe7a373..0b52fdc7ab84b7bd5bce5d247ede81b40699005c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -168,7 +168,7 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. In addition, please include the following information along with your report: @@ -242,9 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= -----END PGP PUBLIC KEY BLOCK----- ``` -### Known vulnerabilities - -| Type | Versions affected | Reported by | Additional Information | -|--------------------|:-----------------:|-----------------------|-----------------------------| -| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +### Known Vulnerabilities +For a list of known vulnerabilities and security advisories for TensorFlow, +[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md). diff --git a/WORKSPACE b/WORKSPACE index 4ddfb9a3832ea1ea639ace887e1d601bdd857086..fd7570a80ae2ee0087f7d2fd771fcce5b9690028 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0") load("//tensorflow:workspace.bzl", "tf_workspace") -# Uncomment and update the paths in these entries to build the Android demo. -#android_sdk_repository( -# name = "androidsdk", -# api_level = 23, -# # Ensure that you have the build_tools_version below installed in the -# # SDK manager as it updates periodically. -# build_tools_version = "26.0.1", -# # Replace with path to Android SDK on your system -# path = "", -#) -# -#android_ndk_repository( -# name="androidndk", -# path="", -# # This needs to be 14 or higher to compile TensorFlow. -# # Please specify API level to >= 21 to build for 64-bit -# # archtectures or the Android NDK will automatically select biggest -# # API level that it supports without notice. -# # Note that the NDK version is not the API level. -# api_level=14) +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() diff --git a/configure.py b/configure.py index 6d9aba61bbc73ba1b80321d6859877c371dc5427..31a83b4a1589b7f038bcdde5cec9007cd16b261c 100644 --- a/configure.py +++ b/configure.py @@ -498,10 +498,6 @@ def set_cc_opt_flags(environ_cp): if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') - # TODO(mikecase): Remove these default defines once we are able to get - # TF Lite targets building without them. - write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -674,8 +670,9 @@ def create_android_ndk_rule(environ_cp): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.') ) - - write_android_ndk_workspace_rule(android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', + check_ndk_level(android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -737,41 +734,12 @@ def create_android_sdk_rule(environ_cp): error_msg=('The selected SDK does not have build-tools version %s ' 'available.')) - write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level) - - -def write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level): - print('Writing android_sdk_workspace rule.\n') - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_sdk_repository( - name="androidsdk", - api_level=%s, - path="%s", - build_tools_version="%s")\n -""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) - - -def write_android_ndk_workspace_rule(android_ndk_home_path): - print('Writing android_ndk_workspace rule.') - ndk_api_level = check_ndk_level(android_ndk_home_path) - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_ndk_repository( - name="androidndk", - path="%s", - api_level=%s)\n -""" % (android_ndk_home_path, ndk_api_level)) + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', + android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', + android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -784,18 +752,16 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - return revision.group(1) - return None - - -def workspace_has_any_android_rule(): - """Check the WORKSPACE for existing android_*_repository rules.""" - with open(_TF_WORKSPACE, 'r') as f: - workspace = f.read() - has_any_rule = re.search(r'^android_[ns]dk_repository', - workspace, - re.MULTILINE) - return has_any_rule + ndk_api_level = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + return ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -977,6 +943,35 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def is_cuda_compatible(lib, cuda_ver, cudnn_ver): + """Check compatibility between given library and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, lib], True) + ldd_out = ldd_out.split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + cudnn_ok = True # assume no cudnn dependency by default + cuda_ok = True # assume no cuda dependency by default + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + cudnn_ok = False + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + cuda_ok = False + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + if cudnn is not None: + cudnn_ok = (cudnn == cudnn_ver) + if cudart is not None: + cuda_ok = (cudart == cuda_ver) + return cudnn_ok and cuda_ok + + def set_tf_tensorrt_install_path(environ_cp): """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. @@ -993,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp): raise ValueError('Currently TensorRT is only supported on Linux platform.') # Ask user whether to add TensorRT support. - if str(int(get_var( - environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', + False))) != '1': return for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): @@ -1007,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp): # Result returned from "read" will be used unexpanded. That make "~" # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath( - os.path.expanduser(trt_install_path)) + trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path)) def find_libs(search_path): """Search for libnvinfer.so in "search_path".""" fl = set() if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([os.path.realpath(os.path.join(search_path, x)) - for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + fl.update([ + os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) + if 'libnvinfer.so' in x + ]) return fl possible_files = find_libs(trt_install_path) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) - - def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): - """Check the compatibility between tensorrt and cudnn/cudart libraries.""" - ldd_bin = which('ldd') or '/usr/bin/ldd' - ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) - cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') - cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') - cudnn = None - cudart = None - for line in ldd_out: - if 'libcudnn.so' in line: - cudnn = cudnn_pattern.search(line) - elif 'libcudart.so' in line: - cudart = cuda_pattern.search(line) - if cudnn and len(cudnn.group(1)): - cudnn = convert_version_to_int(cudnn.group(1)) - if cudart and len(cudart.group(1)): - cudart = convert_version_to_int(cudart.group(1)) - return (cudnn == cudnn_ver) and (cudart == cuda_ver) - cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') highest_ver = [0, None, None] for lib_file in possible_files: - if is_compatible(lib_file, cuda_ver, cudnn_ver): + if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) if len(matches.groups()) == 0: continue @@ -1063,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp): # Try another alternative from ldconfig. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' ldconfig_output = run_shell([ldconfig_bin, '-p']) - search_result = re.search( - '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)', + ldconfig_output) if search_result: libnvinfer_path_from_ldconfig = search_result.group(2) if os.path.exists(libnvinfer_path_from_ldconfig): - if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver, + cudnn_ver): trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) tf_tensorrt_version = search_result.group(1) break @@ -1156,7 +1134,9 @@ def set_tf_nccl_install_path(environ_cp): nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt') + if os.path.exists(nccl_lib_path) and os.path.exists( + nccl_hdr_path) and os.path.exists(nccl_license_path): # Set NCCL_INSTALL_PATH environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) @@ -1227,7 +1207,7 @@ def set_tf_cuda_compute_capabilities(environ_cp): # Check whether all capabilities from the input is valid all_valid = True # Remove all whitespace characters before splitting the string - # that users may insert by accident, as this will result in error + # that users may insert by accident, as this will result in error tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) @@ -1431,6 +1411,10 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_build_strip_flag(): + write_to_bazelrc('build --strip=always') + + def set_windows_build_flags(): if is_windows(): # The non-monolithic build is not supported yet @@ -1465,7 +1449,7 @@ def main(): setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_S3'] = '0' + environ_cp['TF_NEED_AWS'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1489,8 +1473,8 @@ def main(): 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', 'with_hdfs_support', True, 'hdfs') - set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', - 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform', + 'with_aws_support', True, 'aws') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', @@ -1553,23 +1537,18 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) + set_build_strip_flag() set_windows_build_flags() - if workspace_has_any_android_rule(): - print('The WORKSPACE file has at least one of ["android_sdk_repository", ' - '"android_ndk_repository"] already set. Will not ask to help ' - 'configure the WORKSPACE. Please delete the existing rules to ' - 'activate the helper.\n') - else: - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): - create_android_ndk_rule(environ_cp) - create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See tools/bazel.rc for ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2ad16fa04f5beb6616c58c28d0f0c460c3e3a17..51eea94847e47ac3ffee89ed6bbae269b7b92c77 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -19,6 +19,10 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/tools/api/generator:api_gen.bzl", + "gen_api_init_files", # @unused +) # Config setting for determining if we are building for Android. config_setting( @@ -150,6 +154,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_s390x", + values = {"cpu": "s390x"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -206,8 +216,8 @@ config_setting( ) config_setting( - name = "with_s3_support", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support", + define_values = {"with_aws_support": "true"}, visibility = ["//visibility:public"], ) @@ -234,8 +244,8 @@ config_setting( ) config_setting( - name = "with_s3_support_windows_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_windows_override", + define_values = {"with_aws_support": "true"}, values = {"cpu": "x64_windows"}, visibility = ["//visibility:public"], ) @@ -247,6 +257,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_cuda_support_windows_override", + define_values = {"using_cuda_nvcc": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -262,8 +279,8 @@ config_setting( ) config_setting( - name = "with_s3_support_android_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_android_override", + define_values = {"with_aws_support": "true"}, values = {"crosstool_top": "//external:android/crosstool"}, visibility = ["//visibility:public"], ) @@ -283,8 +300,8 @@ config_setting( ) config_setting( - name = "with_s3_support_ios_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_ios_override", + define_values = {"with_aws_support": "true"}, values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) @@ -394,6 +411,7 @@ config_setting( package_group( name = "internal", packages = [ + "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", @@ -420,6 +438,22 @@ filegroup( data = glob(["docs_src/**/*.md"]), ) +cc_library( + name = "grpc", + deps = select({ + ":linux_s390x": ["@grpc//:grpc_unsecure"], + "//conditions:default": ["@grpc"], + }), +) + +cc_library( + name = "grpc++", + deps = select({ + ":linux_s390x": ["@grpc//:grpc++_unsecure"], + "//conditions:default": ["@grpc//:grpc++"], + }), +) + # A shared object which includes registration mechanisms for ops and # kernels. Does not include the implementations of any ops or kernels. Instead, # the library which loads libtensorflow_framework.so @@ -447,6 +481,15 @@ filegroup( tf_cc_shared_object( name = "libtensorflow_framework.so", framework_so = [], + linkopts = select({ + "//tensorflow:darwin": [], + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//conditions:default": [ + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow:tf_framework_version_script.lds)", + ], + }), linkstatic = 1, visibility = ["//visibility:public"], deps = [ @@ -456,6 +499,7 @@ tf_cc_shared_object( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", "//tensorflow/stream_executor:stream_executor_impl", + "//tensorflow:tf_framework_version_script.lds", ] + tf_additional_binary_deps(), ) @@ -471,7 +515,7 @@ tf_cc_shared_object( # excludes all but a subset of function names. # On MacOS, the linker does not support version_script, but has an # an "-exported_symbols_list" command. -z defs disallows undefined -# symbols in object files and -s strips the output. +# symbols in object files. tf_cc_shared_object( name = "libtensorflow.so", @@ -485,7 +529,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow/c:version_script.lds)", ], @@ -511,7 +554,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_version_script.lds)", ], @@ -536,13 +578,28 @@ exports_files( ], ) +gen_api_init_files( + name = "tensorflow_python_api_gen", + srcs = ["api_template.__init__.py"], + root_init_template = "api_template.__init__.py", +) + py_library( name = "tensorflow_py", - srcs = ["__init__.py"], + srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python", - "//tensorflow/tools/api/generator:python_api", + ":tensorflow_py_no_contrib", + "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", ], ) + +py_library( + name = "tensorflow_py_no_contrib", + srcs = [":tensorflow_python_api_gen"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//tensorflow/python:no_contrib"], +) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index c8683e3976c90add3f1f54d8e575c798327e9273..440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -22,9 +22,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# pylint: disable=wildcard-import -from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin -# pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..779f65d5b17c350833f67f07985b00e8eb561e72 --- /dev/null +++ b/tensorflow/api_template.__init__.py @@ -0,0 +1,59 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import + +try: + import os # pylint: disable=g-import-not-at-top + # Add `estimator` attribute to allow access to estimator APIs via + # "tf.estimator..." + from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top + + # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." + # style imports. + from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top + __path__ += [os.path.dirname(estimator_api.__file__)] + del estimator_api + del os +except (ImportError, AttributeError): + print('tf.estimator package not installed.') + +# API IMPORTS PLACEHOLDER + +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b86b277ac3200b88ae03490a6c1b64d464e81950..5c218d3f25e01f0e78916d4a5a8b1d2751f9dc25 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -390,64 +391,6 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } -// This traverses the specified nodes in topological order to verify there are -// no cycles. Starting with inputless nodes, it visits nodes whose inputs have -// all been visited, and counts the total number of visited nodes. If there is a -// cycle, nodes in the cycle will never be visited, and the visited count will -// be less than the total node count. -Status ValidateNoCycles(const Graph& g) { - // TODO(nolivia): check this on a subset of the graph instead of all of it. - // A node is ready when all of its inputs have been visited. - std::vector ready; - std::vector pending_count(g.num_node_ids(), 0); - - for (int i = 0; i < g.num_node_ids(); ++i) { - const Node* n = g.FindNodeId(i); - if (n == nullptr) continue; - pending_count[i] = n->in_edges().size(); - if (n->IsMerge()) { - // While-loop cycles are legal cycles so we manually adjust the - // pending_count to make sure that the loop is visited. - for (const Edge* e : n->in_edges()) { - if (!e->IsControlEdge() && e->src()->IsNextIteration()) { - pending_count[i]--; - } - } - } - if (pending_count[i] == 0) { - ready.push_back(n); - } - } - - int processed = 0; - while (!ready.empty()) { - const Node* node = ready.back(); - ready.pop_back(); - ++processed; - - for (const Edge* out : node->out_edges()) { - const int output_id = out->dst()->id(); - pending_count[output_id]--; - if (pending_count[output_id] == 0) { - ready.push_back(out->dst()); - } - } - } - - if (processed < g.num_nodes()) { - std::vector nodes_in_cycle; - for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; - ++i) { - if (pending_count[i] != 0) { - nodes_in_cycle.push_back(g.FindNodeId(i)->name()); - } - } - return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, - " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); - } - return Status::OK(); -} } // namespace } // namespace tensorflow @@ -631,7 +574,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, "Failed to allocate memory to serialize message of type '", in.GetTypeName(), "' and size ", proto_size); } - in.SerializeToArray(buf, proto_size); + // SerializeToArray takes size as an int. + // This next 'if' is a workaround till we update to depend on a version + // of protocol buffers that includes + // https://github.com/google/protobuf/pull/4739 + if (proto_size > std::numeric_limits::max()) { + return InvalidArgument("Cannot serialize protocol buffer of type ", + in.GetTypeName(), " as the serialized size (", + proto_size, + "bytes) would be larger than the limit (", + std::numeric_limits::max(), " bytes)"); + } + if (!in.SerializeToArray(buf, proto_size)) { + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { @@ -731,7 +689,9 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); if (!status->status.ok()) { session->graph->mu.unlock(); return false; @@ -2108,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2138,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } @@ -2454,7 +2416,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; - g->name_map[n->name()] = n; + // We have a convoluted scheme here: Using the C++ graph construction API + // to add potentially many nodes to the graph without running the checks + // (such as uniqueness of the names of nodes) we run with other functions + // that add a node to the graph (like TF_FinishOperation). + if (!g->name_map.insert(std::make_pair(n->name(), n)).second) { + status->status = tensorflow::errors::Internal( + "BUG: The API allowed construction of a graph with duplicate node " + "names (", + n->name(), + "). This is a bug. Please file an issue at " + "https://github.com/tensorflow/tensorflow/issues."); + } } } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c8594347451dffd465d7fa926cc53818dc9e38d4..1eb75ef11ff337dfcb2e016e09804fc04662fcda 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -894,7 +894,8 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); // Set the prefix to be prepended to the names of nodes in `graph_def` that will -// be imported into `graph`. +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); @@ -915,6 +916,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst); @@ -922,7 +924,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( // Set any imported nodes with control input `src_name` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references an operation already existing in the graph being imported -// into. +// into. `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); @@ -934,6 +936,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( // Add an output in `graph_def` to be returned via the `return_outputs` output // parameter of TF_GraphImportGraphDef(). If the output is remapped via an input // mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_ImportGraphDefOptions* opts, const char* oper_name, int index); @@ -943,7 +946,8 @@ TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); // Add an operation in `graph_def` to be returned via the `return_opers` output -// parameter of TF_GraphImportGraphDef(). +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( TF_ImportGraphDefOptions* opts, const char* oper_name); diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 95b04f9058afdfaadbc24f0238860279fcd3e800..170046c8024dc85c899108b254cd3a95a3be4096 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -57,6 +57,33 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } } +TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, + unsigned char gpu_memory_allow_growth) { + tensorflow::ConfigProto config; + auto* optimizer_options = + config.mutable_graph_options()->mutable_optimizer_options(); + if (enable_xla_compilation) { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + + // These XLA flags are needed to trigger XLA properly from C (more generally + // non-Python) clients. If this API is called again with `enable` set to + // false, it is safe to keep these flag values as is. + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = true; + flags->tf_xla_min_cluster_size = 1; + } else { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); + } + + auto* gpu_options = config.mutable_gpu_options(); + gpu_options->set_allow_growth(gpu_memory_allow_growth); + + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(config, ret)); + return ret; +} + const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { tensorflow::mutex_lock c(graph->mu); const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 20bdace40f1272ded06e710034053a7610326e7f..2d81c01e0dd056e9beb3b45f24809381554a7924 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -55,11 +55,21 @@ extern "C" { // set XLA flag values to prepare for XLA compilation. Otherwise set // global_jit_level to OFF. // -// This API is syntax sugar over TF_SetConfig(), and is used by clients that -// cannot read/write the tensorflow.ConfigProto proto. +// This and the next API are syntax sugar over TF_SetConfig(), and is used by +// clients that cannot read/write the tensorflow.ConfigProto proto. +// TODO: Migrate to TF_CreateConfig() below. TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Create a serialized tensorflow.ConfigProto proto, where: +// +// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if +// `enable_xla_compilation` is non-zero, and OFF otherwise. +// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( + unsigned char enable_xla_compilation, + unsigned char gpu_memory_allow_growth); + // Returns the graph content in a human-readable format, with length set in // `len`. The format is subject to change in the future. // The returned string is heap-allocated, and caller should call free() on it. diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 577f10c5e69ea9ecbe8ce821c6bd5167e98bef25..bc04b53fbb7fa9ba46228ae5a4ec8ee96df5f3dc 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1160,7 +1160,7 @@ TEST(CAPI, GetOpDef) { } void StringVectorToArrays(const std::vector& v, - std::unique_ptr* ptrs, + std::unique_ptr* ptrs, std::unique_ptr* lens) { ptrs->reset(new const void*[v.size()]); lens->reset(new size_t[v.size()]); @@ -1196,7 +1196,7 @@ class CApiColocationTest : public ::testing::Test { void SetViaStringList(TF_OperationDescription* desc, const std::vector& list) { - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(), @@ -1700,6 +1700,61 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { + ASSERT_TRUE(t != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); + ASSERT_EQ(0, TF_NumDims(t)); + ASSERT_EQ(4, TF_TensorByteSize(t)); + float* p = static_cast(TF_TensorData(t)); + *f = *p; +} + +TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) { + const float X = 3.0f, Y = 7.0f; + TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT); + TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT); + TF_Operation* xy = Mul(x, y, graph_, s_, "xy"); + TF_Output dxy_dx, dxy_dy; + + TF_Output outputs[1] = {{xy, 0}}; + TF_Output inputs[1] = {{x, 0}}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + inputs[0] = {y, 0}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph_, opts, s_); + TF_DeleteSessionOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Output feeds[] = {{x, 0}, {y, 0}}; + TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)}; + TF_Output fetches[] = {dxy_dx, dxy_dy}; + TF_Tensor* fetchValues[] = {nullptr, nullptr}; + + TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches, + fetchValues, 2, nullptr /* target_opers */, 0, + nullptr /* run_metadata */, s_); + TF_DeleteTensor(feedValues[0]); + TF_DeleteTensor(feedValues[1]); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_DeleteSession(sess, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f; + ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue); + EXPECT_EQ(Y, dxy_dxValue); + + ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue); + EXPECT_EQ(X, dxy_dyValue); + + TF_DeleteTensor(fetchValues[0]); + TF_DeleteTensor(fetchValues[1]); +} + // REGISTER_OP for CApiAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other @@ -1784,7 +1839,7 @@ TEST_F(CApiAttributesTest, String) { TEST_F(CApiAttributesTest, StringList) { std::vector list = {"bugs", "bunny", "duck"}; - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); int list_total_size = 0; @@ -1800,7 +1855,7 @@ TEST_F(CApiAttributesTest, StringList) { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size); - std::unique_ptr values(new void*[list.size()]); + std::unique_ptr values(new void*[list.size()]); std::unique_ptr lens(new size_t[list.size()]); std::unique_ptr storage(new char[list_total_size]); TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(), @@ -2025,7 +2080,7 @@ TEST_F(CApiAttributesTest, TensorShapeProtoList) { tensorflow::PartialTensorShape(pts2).AsProto(&proto); proto.SerializeToString(&bytes2); - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; const std::vector list = {bytes1, bytes2}; StringVectorToArrays(list, &list_ptrs, &list_lens); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index f3b28c1708129d39e451d927a89c0d10e2193b63..24eb6c069b21349fce288db3e79fbf14e824ad11 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -216,6 +216,13 @@ TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); } +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true); + return op; +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index c16aba666ee6974fed5351c2d9ac291dcbcdecab..38313d647ca93d4779bb1325f8ed7bde4b743879 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -80,6 +80,9 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "min"); +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "mul"); + // If `op_device` is non-empty, set the created op on that device. TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, const string& op_device, TF_Status* s, diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 28f974c5d41327daa4565c62caf834b3c9519273..37be52f57d865c1e59611540d5dab04b59e89444 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -14,6 +14,7 @@ tf_cuda_library( name = "c_api", srcs = [ "c_api.cc", + "c_api_debug.cc", "c_api_internal.h", ], hdrs = ["c_api.h"], @@ -24,10 +25,10 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", @@ -45,6 +46,7 @@ tf_cuda_library( "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", ], "//conditions:default": [], }) + [ @@ -52,7 +54,6 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", @@ -70,7 +71,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":c_api", - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -80,6 +80,7 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:eager_operation", @@ -91,71 +92,54 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) -tf_cuda_cc_test( - name = "c_api_test", - srcs = ["c_api_test.cc"], - extra_copts = tfe_xla_copts(), - tags = [ - "guitar", - "multi_gpu", +tf_cuda_library( + name = "c_api_test_util", + testonly = 1, + srcs = ["c_api_test_util.cc"], + hdrs = ["c_api_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", ], deps = [ ":c_api", "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", ], ) -tf_cuda_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - copts = tf_copts(), - visibility = ["//tensorflow:internal"], - deps = select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/c:c_api", - "//tensorflow/core:core_cpu", - "//tensorflow/core/common_runtime/eager:kernel_and_device", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], - }), -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], +tf_cuda_cc_test( + name = "c_api_test", + size = "small", + srcs = [ + "c_api_debug_test.cc", + "c_api_test.cc", + ], + extra_copts = tfe_xla_copts(), + tags = [ + "guitar", + "multi_gpu", + ], deps = [ - ":runtime", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", + ":c_api", + ":c_api_test_util", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 1c1020f812bfa4c95ebfa17aae7cc7a96d48588e..82ca2be2cff885967dd798a1cb84b164a9df399e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -32,13 +31,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" @@ -46,10 +46,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -73,10 +75,6 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#ifdef TENSORFLOW_EAGER_USE_XLA -std::atomic_int_fast64_t func_id_generator(0); -#endif // TENSORFLOW_EAGER_USE_XLA - tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, @@ -111,7 +109,8 @@ tensorflow::Status GetAllRemoteDevices( } tensorflow::Status CreateRemoteContexts( - const std::vector& remote_workers, + const std::vector& remote_workers, int64 rendezvous_id, + const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -119,12 +118,14 @@ tensorflow::Status CreateRemoteContexts( tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextResponse response; + request.set_rendezvous_id(rendezvous_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) { return tensorflow::errors::InvalidArgument( "Unable to parse ", remote_worker, " as a device name"); } + *request.mutable_server_def() = server_def; request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); @@ -151,46 +152,82 @@ tensorflow::Status CreateRemoteContexts( tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TFE_Context** ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + string worker_name = tensorflow::strings::StrCat( "/job:", opts->server_def.job_name(), "/replica:0/task:", opts->server_def.task_index()); - std::unique_ptr server; - TF_RETURN_IF_ERROR( - tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); - TF_RETURN_IF_ERROR(server->Start()); + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + int64 rendezvous_id = tensorflow::random::New64(); std::vector remote_workers; - server->master_env()->worker_cache->ListWorkers(&remote_workers); + grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); remote_workers.erase( std::remove(remote_workers.begin(), remote_workers.end(), worker_name), remote_workers.end()); std::unique_ptr remote_device_mgr; - TF_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, grpc_server->master_env()->worker_cache, + &remote_device_mgr)); std::shared_ptr channel_cache = - server->channel_cache(); + grpc_server->channel_cache(); std::unique_ptr remote_eager_workers( tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; - TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, - remote_eager_workers.get(), - opts->async, &remote_contexts)); + LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + remote_workers, rendezvous_id, opts->server_def, + remote_eager_workers.get(), opts->async, &remote_contexts)); tensorflow::RemoteRendezvous* r = - server->worker_env()->rendezvous_mgr->Find(0); + grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); + + auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); + TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( + session_name, opts->server_def, true)); + + std::shared_ptr worker_session; + TF_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); - auto* device_mgr = server->worker_env()->device_mgr; + // Initialize remote tensor communication based on worker session. + TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + + auto* device_mgr = grpc_server->worker_env()->device_mgr; *ctx = new TFE_Context(opts->session_options.options, opts->policy, opts->async, device_mgr, r, std::move(server), std::move(remote_eager_workers), std::move(remote_device_mgr), remote_contexts); return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR } } // namespace @@ -311,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dims(); + int result; + status->status = h->handle->NumDims(&result); + return result; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dim_size(dim_index); + tensorflow::int64 result; + status->status = h->handle->Dim(dim_index, &result); + return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { @@ -425,8 +462,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, return ret; } -void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->operation.MutableAttrs()->Set(attr_name, value); +void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, + size_t length) { + op->operation.MutableAttrs()->Set( + attr_name, + tensorflow::StringPiece(static_cast(value), length)); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { @@ -477,16 +517,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } -#define TFE_OP_SET_ATTR_LIST(fn, type) \ - void fn(TFE_Op* op, const char* attr_name, const type* values, \ - int num_values) { \ - op->operation.MutableAttrs()->Set( \ - attr_name, \ - tensorflow::gtl::ArraySlice(values, num_values)); \ +void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = tensorflow::StringPiece(static_cast(values[i]), + lengths[i]); } -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) -#undef TFE_OP_SET_ATTR_LIST + op->operation.MutableAttrs()->Set(attr_name, v); +} + +void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values) { + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice(values, num_values)); +} void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { @@ -659,9 +705,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, const char* attr_name, TF_Status* status) { switch (default_value.value_case()) { - case tensorflow::AttrValue::kS: - TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + case tensorflow::AttrValue::kS: { + const string& v = default_value.s(); + TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); break; + } case tensorflow::AttrValue::kI: TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); break; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 574a097e0d6f5d6e7acd77cae246678b6675129b..fdbd5374b2afe815c3a81b453930eb8f1fa351d3 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -191,6 +191,45 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); +// Debugging/Profiling information for TFE_TensorHandle +// +// TFE_TensorDebugInfo contains information useful for debugging and +// profiling tensors. +typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo; + +// Retrieves TFE_TensorDebugInfo for `handle`. +// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller +// is responsible for deleting returned TFE_TensorDebugInfo. +// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate +// error and nullptr is returned. This function can block till the operation +// that produces `handle` has completed. +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status); + +// Deletes `debug_info`. +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of dimensions used to represent the tensor on its device. +// The number of dimensions used to reprensent the tensor on device can be +// different from the number returned by TFE_TensorHandleNumDims. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of elements in dimension `dim_index`. +// Tensor representation on device can be transposed from its representation +// on host. The data contained in dimension `dim_index` on device +// can correspond to the data contained in another dimension in on-host +// representation. The dimensions are indexed using the standard TensorFlow +// major-to-minor order (slowest varying dimension first), +// not the XLA's minor-to-major order. +// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns +// the number of elements in a dimension after padding. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index); + // Description of the TensorFlow op to execute. // // Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., @@ -239,7 +278,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, - const char* value); + const void* value, + size_t length); TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, @@ -266,7 +306,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, + const void* const* values, + const size_t* lengths, int num_values); TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc new file mode 100644 index 0000000000000000000000000000000000000000..5006b76f1981d068e99a2c081115ebb3a66d8c7f --- /dev/null +++ b/tensorflow/c/eager/c_api_debug.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/jit/xla_device.h" +#endif // TENSORFLOW_EAGER_USE_XLA + +using tensorflow::int64; +using tensorflow::string; + +namespace { + +std::vector TensorShapeAsVector(TFE_TensorHandle* handle, + TF_Status* status) { + std::vector shape; + int rank = TFE_TensorHandleNumDims(handle, status); + if (!status->status.ok()) { + return shape; + } + shape.reserve(rank); + for (int i = 0; i < rank; ++i) { + shape.push_back(TFE_TensorHandleDim(handle, i, status)); + if (!status->status.ok()) { + return shape; + } + } + return shape; +} + +} // namespace + +extern "C" { + +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status) { + const tensorflow::Tensor* tensor; + status->status = handle->handle->Tensor(&tensor); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::Device* device; + status->status = handle->handle->Device(&device); + if (!status->status.ok()) { + return nullptr; + } + +#ifdef TENSORFLOW_EAGER_USE_XLA + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. + tensorflow::XlaDevice* xla_device = + dynamic_cast(device); + if (xla_device != nullptr) { + tensorflow::XlaDevice::PaddedShapeFn shape_fn = + xla_device->metadata().padded_shape_fn(); + xla::Shape padded_shape; + status->status = shape_fn(*tensor, &padded_shape); + if (!status->status.ok()) { + return nullptr; + } + if (VLOG_IS_ON(3)) { + std::vector shape_to_log = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + // Ignore the status here as we are simply logging. + status->status = tensorflow::Status::OK(); + } else { + VLOG(3) << "Fully padded shape of [" + << tensorflow::str_util::Join(shape_to_log, ", ") << "] is " + << padded_shape.DebugString(); + } + } + + if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { + // Currently, the only case of XlaTensor containing a tuple shape is to + // represent 64 bit ints, doubles, and complex numbers (we don't support + // 64bit complex numbers). + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should only contain tuples of size 2. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // shape0 is not a const& because we will assign it to padded_shape below. + // It is illegal to assign a part of a message to itself. + xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); + const xla::Shape& shape1 = + xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); + if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should not contain nested tuples. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + if (!xla::ShapeUtil::Equal(shape0, shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "Subshapes of XlaTensors should be the same. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // Since the only case we handle here are two equal subshapes, we + // simply return one of them. The caller will interpret it as this + // shape directly storing the 64bit types. This approximation is good + // enough for this API's debugging use case. + padded_shape = shape0; + } + + int rank = padded_shape.dimensions_size(); + std::vector dev_dims; + dev_dims.reserve(rank); + if (rank == 1) { + // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, + dev_dims.push_back(padded_shape.dimensions(0)); + } else { + for (int i = rank - 1; i >= 0; --i) { + int64 dim_index = padded_shape.layout().minor_to_major(i); + dev_dims.push_back(padded_shape.dimensions(dim_index)); + } + } + status->status = tensorflow::Status::OK(); + return new TFE_TensorDebugInfo(dev_dims); + } +#endif // TENSORFLOW_EAGER_USE_XLA + + // If the tensor is not an XLA tensor, the device shape is + // the same as regular tensor shape. + std::vector dev_dims = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + return nullptr; + } + return new TFE_TensorDebugInfo(dev_dims); +} + +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info) { + delete debug_info; +} + +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info) { + return debug_info->dev_dims.size(); +} + +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index) { + return debug_info->dev_dims[dim_index]; +} + +} // extern "C" diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cddb9f6e00e9d639026f4bbe061d58f76771c0a9 --- /dev/null +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +TEST(CApiDebug, ScalarCPU) { + TFE_TensorHandle* h = TestScalarTensorHandle(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} + +TEST(CApiDebug, 2DCPU) { + TFE_TensorHandle* h = TestMatrixTensorHandle3X2(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + // Shape is the same for CPU tensors. + EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0)); + EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index f506ede0871e1b345a27bab32e1d4342de9ba6f4..4c5077023d5bb3b83808bf3908e7110dd026e3ad 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" @@ -78,7 +78,7 @@ struct TFE_Context { TFE_ContextDevicePlacementPolicy default_policy, bool async, tensorflow::DeviceMgr* local_device_mgr, tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, + std::unique_ptr server, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_mgr, const tensorflow::gtl::FlatMap& @@ -107,6 +107,14 @@ struct TFE_TensorHandle { tensorflow::TensorHandle* handle; }; +struct TFE_TensorDebugInfo { + TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; +}; + struct TFE_Op { // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a // primitive operation. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 49646bb73599d96fce2df90f918e692df7972aeb..3504a8b5e78480732d3454097c1b2197ac2b2e17 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -32,122 +33,6 @@ using tensorflow::string; namespace { -TFE_TensorHandle* DoubleTestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle3X2() { - int64_t dims[] = {3, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, a, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, b, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); - - return op; -} - -TFE_TensorHandle* TestAxisTensorHandle() { - int64_t dims[] = {1}; - int data[] = {1}; - TF_Tensor* t = TF_AllocateTensor( - TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, - TFE_TensorHandle* axis) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "Min", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, input, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, axis, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpSetAttrBool(op, "keep_dims", 1); - TFE_OpSetAttrType(op, "Tidx", TF_INT32); - TF_DeleteStatus(status); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); - - return op; -} - -// If there is a GPU device, returns true and sets 'gpu_device_name' -// accordingly. -bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - const int num_devices = TF_DeviceListCount(devices); - for (int i = 0; i < num_devices; ++i) { - const string device_type(TF_DeviceListType(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - const string device_name(TF_DeviceListName(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - if (device_type == "GPU") { - *gpu_device_name = device_name; - LOG(INFO) << "Found GPU device " << device_name; - TF_DeleteDeviceList(devices); - return true; - } - } - TF_DeleteDeviceList(devices); - return false; -} - void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -247,18 +132,20 @@ void TestRemoteExecute(bool async) { server_def.set_task_index(1); - std::unique_ptr worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); ASSERT_TRUE(worker_server->Start().ok()); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -320,6 +207,95 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } +void TestRemoteExecuteSilentCopies(bool async) { + tensorflow::ServerDef server_def = GetServerDef(3); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + auto* h1_task2 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Handles are on task0 (local), and task2, but op is on task1. + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); + TFE_OpSetDevice(matmul, task1_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h1_task2); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopiesAsync) { + TestRemoteExecuteSilentCopies(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -536,7 +512,7 @@ void TensorHandleSilentCopy(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -583,7 +559,7 @@ void TensorHandleSilentCopyLocal(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -624,7 +600,7 @@ void SetAndGetOpDevices(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_OpSetDevice(matmul, "GPU:0", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); const char* device_name = TFE_OpGetDevice(matmul, status); @@ -688,7 +664,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m1 = TestMatrixTensorHandle(); - TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(); TFE_Op* matmul = MatMulOp(ctx, m1, m2); TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", status); @@ -1198,8 +1174,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", ""); - TFE_OpSetAttrString(op, "shared_name", ""); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); if (TF_GetCode(status) != TF_OK) return nullptr; TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..5607c9dcb0bbec72b2f86def3dd4e6590d73197b --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_test_util.h" + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +using tensorflow::string; + +TFE_TensorHandle* TestScalarTensorHandle() { + float data[] = {1.0f}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrBool(op, "transpose_a", 0); + TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +bool GetDeviceName(TFE_Context* ctx, string* device_name, + const char* device_type) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string dev_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string dev_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (dev_type == device_type) { + *device_name = dev_name; + LOG(INFO) << "Found " << device_type << " device " << *device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..474cae67c89249af3a62707f0db00ba458ca8f31 --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include "tensorflow/core/platform/types.h" + +// Return a tensor handle containing a float scalar +TFE_TensorHandle* TestScalarTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle(); + +// Return a tensor handle containing a 3x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); + +// Return a tensor handle containing a 3x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle3X2(); + +// Return a matmul op multiplying `a` by `b`. +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return an 1-D INT32 tensor containing a single value 1. +TFE_TensorHandle* TestAxisTensorHandle(); + +// Return an op taking minimum of `input` long `axis` dimension. +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis); + +// If there is a device of type `device_type`, returns true +// and sets 'device_name' accordingly. +// `device_type` must be either "GPU" or "TPU". +bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, + const char* device_type); + +#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index dcc2357b71a68ba39d1c376242fb35e287f9d033..734e712daa39c03f0177eb199b1acb1b19e5d845 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -48,7 +48,7 @@ struct OpTapeEntry { // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. - std::function backward_function_deleter; + std::function backward_function_deleter; }; // Map from tensor_id to internally-defined operation-id of the operation which @@ -104,14 +104,12 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; - - // Lets this VSpace know that it can release resources held by the - // `backward_function`, It will not be called again. - // `backward_function` must not be null. - virtual void ReleaseBackwardFunction( - BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -126,7 +124,7 @@ class GradientTape { GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } @@ -135,12 +133,12 @@ class GradientTape { void Watch(int64 tensor_id); - void RecordOperation(const string& op_type, - gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, - const std::function& backward_function_deleter); + void RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -214,9 +212,9 @@ void GradientTape::RecordOperation( gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, - const std::function& backward_function_deleter) { + const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(); + backward_function_deleter(backward_function); return; } std::vector ids; @@ -271,7 +269,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { for (int64 id : op_it->second.input_tensor_id) { DeleteTrace(id); } - op_it->second.backward_function_deleter(); + op_it->second.backward_function_deleter(op_it->second.backward_function); op_tape_.erase(op_it); } @@ -356,8 +354,7 @@ BackpropInitialState PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -378,7 +375,8 @@ BackpropInitialState PrepareBackprop( // backward functions that will be used for gradient computation // has been transferred to `result`. for (const auto& op_pair : *op_tape) { - op_pair.second.backward_function_deleter(); + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); } op_tape->clear(); } @@ -470,7 +468,7 @@ Status GradientTape::ComputeGradient( if (!persistent_) { // Release all backprop functions for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } }; @@ -522,10 +520,15 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector in_gradients; @@ -533,7 +536,7 @@ Status GradientTape::ComputeGradient( Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } if (!s.ok()) { cleanup(); @@ -542,7 +545,7 @@ Status GradientTape::ComputeGradient( } else { in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } for (Gradient* grad : out_gradients) { if (grad != nullptr) { diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 02a6a58b6153bb78c684f9290ef95900f96e9357..7184ad68fb79f2598067d68d5ab5ba8f2c7a22c8 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -15,10 +15,12 @@ # ============================================================================== TF_PREFIX='/usr/local' +LIBDIR='lib' usage() { echo "Usage: $0 OPTIONS" echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-l, --libdir\tset lib directory (default: lib)" echo -e "-v, --version\tset TensorFlow version" echo -e "-h, --help\tdisplay this message" } @@ -26,7 +28,7 @@ usage() { [ $# == 0 ] && usage && exit 0 # read the options -ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@") eval set -- "$ARGS" # extract options and their arguments into variables. @@ -38,6 +40,11 @@ while true ; do "") shift 2 ;; *) TF_PREFIX=$2 ; shift 2 ;; esac ;; + -l|--libdir) + case "$2" in + "") shift 2 ;; + *) LIBDIR=$2 ; shift 2 ;; + esac ;; -v|--version) case "$2" in "") shift 2 ;; @@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} -libdir=\${exec_prefix}/lib +libdir=\${exec_prefix}/${LIBDIR} includedir=\${prefix}/include Name: TensorFlow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 079e063d3e3fbdaf833e9031f5f9438853c14099..a98f0b00b2c70055f697ed4f15cb14708384b62f 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -530,7 +530,7 @@ cc_library_with_android_deps( "//tensorflow/core/api_def:base_api_def", ], deps = [ - "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d6a4f141b6bb8ccadb77f1fa83b5fb742d78f70f..dfdef88945deca376368edd6f7aa322b1e1cbf94 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return ""; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -297,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { - static const std::unordered_map, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, @@ -317,14 +323,34 @@ std::pair AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const { "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 62a889181e787f2e181135ab0563c45e1bab8812..8c886f31711eb014fb9e9d600c9c78cf22073f71 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -37,6 +37,11 @@ Scope& Scope::operator=(const Scope& other) { return *this; } +namespace { +const char kScopeSeparator[] = "/"; +const char kSuffixSeparator[] = "_"; +} // namespace + Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), @@ -308,19 +313,23 @@ string Scope::Impl::GetUniqueName(const string& prefix, return prefix; } auto entry = name_map_->find(prefix); - string unique_name = prefix; if (entry == name_map_->end()) { name_map_->insert({prefix, 0}); - } else { - unique_name = strings::StrCat(unique_name, "_", ++entry->second); + return prefix; } + string unique_name; + do { + unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second); + } while (name_map_->find(unique_name) != name_map_->end()); + name_map_->insert({unique_name, 0}); return unique_name; } string Scope::Impl::GetNameForOp(const string& default_name) const { const string unique_name = GetUniqueName(default_name, true /* check_single_use */); - const string sep = name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return strings::StrCat(name_, sep, unique_name); } @@ -345,7 +354,8 @@ Scope Scope::NewSubScope(const string& child_scope_name) const { } const string unique_name = impl()->GetUniqueName(child_scope_name, false /* check_single_use */); - const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return Scope(new Impl(*this, Impl::Tags::ScopeName(), strings::StrCat(impl()->name_, sep, unique_name), false /* copy_names */)); @@ -412,7 +422,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes( if (!impl()->single_use_scope()) { Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name : impl()->op_name_); - const string child_op_sep = impl()->name_.empty() ? "" : "_"; + const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator; const string child_name = strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_); return {child, @@ -435,7 +445,13 @@ class InternalScope { static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; for (const Node* node : graph->nodes()) { - (*name_map)[node->name()] = 0; + const string& name = node->name(); + (*name_map)[name] = 0; + // Add all name prefixes ('/' separated). + size_t idx = -1; + while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) { + (*name_map)[name.substr(0, idx)] = 0; + } } // We provide null destructors for these shared ptrs (except for name_map) // since the caller owns them and doesn't want the scope to destroy them. diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 8efcfed20d0b86d86d8c20a3d8630c7c6bc909c3..58adaef2e942a7fa6b0ce8d5534ac3e2fd380580 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -34,8 +34,7 @@ class Scope::Impl { // name that has not been used so far in a scope will get no suffix. Later // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. + // WithControlDependencies() would share the same NameMap with the parent. typedef std::unordered_map NameMap; Impl(const std::shared_ptr& graph, diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc index 9eca9d3face34319413e1acbc2f5ac0b2ba85374..b40b345eb84237c34ea593021bea022ad28095f7 100644 --- a/tensorflow/cc/framework/scope_test.cc +++ b/tensorflow/cc/framework/scope_test.cc @@ -26,6 +26,16 @@ TEST(ScopeTest, BasicNames) { EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul"); } +TEST(ScopeTest, OpAndScopeNameCollision) { + Scope root = Scope::NewRootScope(); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_1"), "foo_1_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_3"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2_1"); +} + TEST(ScopeTest, HierarchicalNames) { Scope root = Scope::NewRootScope(); Scope child = root.NewSubScope("child"); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index ff348fadb24e29a83bd6c8853aa67931f6df4182..b353accddcb6db9a07c112de03ead2f02c4ee6a6 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -421,6 +421,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper); +Status SliceGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // Propagate the incoming gradient along all the selected values, + // and zero everywhere else. Use the Pad operator for this. + // + // First create an Nx2 padding where N is the number of input + // dimensions. The first column is the number of prepended zeros + // for each dimension, and the second column is the number of + // appended zeros. + // + // The first column is just the begin vector. + // The second column is the shape of the input element-wise + // subtracted by begin+size + + // Running example: + // input.shape = [3, 5, 3] + // begin = [1, 2, 1], size = [1, 3, 2] + Input input = op.input(0); + Input begin = op.input(1); + // input_rank = 3 + auto input_rank = Rank(scope, input); + // slice_size = [1, 3, 2] + auto slice_size = Shape(scope, op.output(0)); + // padding_shape = [3, 1] + auto padding_shape = Stack(scope, {input_rank, 1}); + // before_padding = [[1] + // [2] + // [1]] + Input before_padding = Reshape(scope, begin, padding_shape); + // after_padding_sizes = shape(input) - slice_size - begin + // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1] + // = [1, 0, 0] + auto after_padding_sizes = + Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin); + // after_padding = [[1] + // [0] + // [0]] + Input after_padding = Reshape(scope, after_padding_sizes, padding_shape); + // paddings = [[1 1] + // [2 0] + // [1 0]] + auto paddings = + Concat(scope, {before_padding, after_padding}, Const(scope, 1)); + grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings)); + // Nothing propagated for "begin" and "size" inputs + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Slice", SliceGrad); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index de3bd0fc9e2493f8ff76163f5be6bd4327c58c5a..d09275b6487b4212aa35a0476002f2bb587fa210 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -378,5 +378,12 @@ TEST_F(ArrayGradTest, StridedSliceGrad) { RunTest(x, x_shape, y, {1, 2, 2, 2}); } +TEST_F(ArrayGradTest, SliceGrad) { + TensorShape x_shape({3, 5, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Slice(scope_, x, {1, 2, 1}, {1, 3, 2}); + RunTest(x, x_shape, y, {1, 3, 2}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a8c88f1857defcc38de4a01ac47dab0..35a01e0341cb08c9b314908b6dcd76fd99c1e68b 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual"); REGISTER_NO_GRADIENT_OP("LogicalAnd"); REGISTER_NO_GRADIENT_OP("LogicalOr"); REGISTER_NO_GRADIENT_OP("LogicalNot"); +REGISTER_NO_GRADIENT_OP("Floor"); // Conjugate helper function returns the conjugate of an Output if it // is complex valued. diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0cb3132e94e381f672d69aefe4a199d2b590830c..c73482d5f4d13ade0dc0412941251d1651371b6e 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -255,6 +255,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb017fe4628140d75a63bc7f0f09deb7f..b4d457a9d14eb79232cda9412fa0050f6a9968cc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 0025842aead53973befc794378a26fa8db2ae1cb..28070d60dbbe6dd8f930b8e6509cedcf09f94e11 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -287,7 +287,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); - if (result_index < 0 || result_index > temp_sizes.size()) { + if (result_index < 0 || result_index >= temp_sizes.size()) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", temp_sizes.size(), ")"); diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 63d22de1ca4aa0872b6fad3e0ac0182306d7cb8c..4e27aafec7747655d8e4ea3ddd1788d495ca0710 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -82,7 +82,8 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, llvm::legacy::PassManager codegen_passes; if (target_machine->addPassesToEmitFile( - codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + codegen_passes, ostream, nullptr, + llvm::TargetMachine::CGFT_ObjectFile)) { return xla::InternalError( "Could not create pass pipeline to generate object file"); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index ebfe4806c203e901358d5c5096c10c03d4c738c3..4e194a6aba9a9efcad27c47c42e148d8e537ae68 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -71,7 +71,7 @@ struct ProtobufToEmbed { const ::tensorflow::protobuf::MessageLite* message; }; -// Embeds a a sequence of protocol buffers into an object file. +// Embeds a sequence of protocol buffers into an object file. // // `target_triple` is the target triple for the target architecture for the // generated object file. diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index fd2cf2b67d4618dd626b8eef78eed044d7fde0a4..0ecc3feeb6fef1dd691ab2785b3221075a79ba88 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -7,6 +7,10 @@ package( load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# We disable some tfcompile tests in the open source build with the +# "manual" tag to avoid making our OSS users build LLVM twice +# (once for host and once for target). + test_suite( name = "all_tests", tags = ["manual"], diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 980e0eec9e23b15a97b826067bac08053a437712..c2245b8eae8fd27d96feaf58e26418b92e646910 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -175,10 +176,14 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:queue_op", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", ], ) @@ -311,9 +316,9 @@ cc_library( ":common", ":shape_inference_helpers", ":union_find", + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", - "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", @@ -331,6 +336,19 @@ cc_library( ], ) +cc_library( + name = "xla_cluster_util", + srcs = ["xla_cluster_util.cc"], + hdrs = ["xla_cluster_util.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", + ], +) + cc_library( name = "union_find", hdrs = ["union_find.h"], @@ -382,6 +400,32 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_cluster_util_test", + size = "small", + srcs = [ + "xla_cluster_util_test.cc", + ], + deps = [ + ":common", + ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "xla_launch_util_test", size = "small", @@ -407,6 +451,38 @@ tf_cc_test( ], ) +cc_library( + name = "xla_fusion_optimizer", + srcs = ["xla_fusion_optimizer.cc"], + hdrs = ["xla_fusion_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ], +) + +tf_cuda_cc_test( + name = "xla_fusion_optimizer_test", + srcs = ["xla_fusion_optimizer_test.cc"], + deps = [ + ":common", + ":xla_cluster_util", + ":xla_fusion_optimizer", + "//tensorflow/core:graph", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb0007527557f79b70ad2b9c9576af2ab10ea..b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 731b8ebfdc6262500940274c94a03ae7c0376096..a2e6285339f9ed0bde8d72f5b4752b1ecc22f426 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -66,8 +66,28 @@ class SinglePassSearch { Status CompilationRequested(const FunctionLibraryRuntime& flr, const NodeDef& node_def) { + const FunctionDef* function_def = + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); + if (function_def == nullptr) { + // The node def is not calling a function. Individual ops can be + // run directly using on-demand mode, no need to create XlaLaunch + // kernel for them. + // TODO(b/110359382): Make custom kernel creation return a bool instead of + // status. + // We don't set error messages here to avoid unnecessary string copy. + // Similarly below. + return Status(error::INVALID_ARGUMENT, ""); + } + + // If kXlaCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaCompileAttr); + if (it != node_def.attr().end()) { + return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, ""); + } + + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. bool xla_compile = false; - // Check if op is marked _XlaCompile=true. Status status = flr.GetFunctionLibraryDefinition()->GetAttr( node_def, kXlaCompileAttr, &xla_compile); if (!status.ok() || !xla_compile) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 6d1e3325ebd35b9608ea273fb7de39bad381e60d..b3a1c19c9e555161ec64aae46bfd4deb6b05e9ff 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -107,41 +106,11 @@ void MarkGuaranteedConstants( } } -// A node/slot pair. -// TODO(phawkins): is there a common definition of this? -struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot) - : node(node), slot(slot), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot, DataType dtype) - : node(node), slot(slot), dtype(dtype) {} - - const Node* node; - int slot; - - // Optional: used to record the destination type of a source NodeSlot in case - // the source output is a Ref type that is cast to a Tensor at the - // destination. - DataType dtype; - - bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot && dtype == other.dtype; - } - - // Leave dtype out of the hash since there are never two NodeSlots with the - // same node and slot and different dtypes. - struct Hasher { - uint64 operator()(NodeSlot const& s) const { - return Hash64Combine(std::hash()(s.node), - std::hash()(s.slot)); - } - }; - - struct PairHasher { - uint64 operator()(std::pair const& s) const { - return Hash64Combine(Hasher()(s.first), Hasher()(s.second)); - } - }; +struct OutputInputTensorPairHasher { + uint64 operator()(std::pair const& s) const { + return Hash64Combine(OutputTensor::Hash()(s.first), + InputTensor::Hash()(s.second)); + } }; // TODO(phawkins) add a canonical copy of these operator names and refactor @@ -182,8 +151,7 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, - FunctionLibraryDefinition* library); + Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -271,7 +239,7 @@ class Encapsulator { // Adds the function call node to graph_out. Status AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( @@ -284,11 +252,9 @@ class Encapsulator { // Subgraph. void GetOutsideCompilationSubgraphNames(std::vector* names) const; - // Returns the Node that inputs to the function should be wired up to. - Node* GetCallNodeForInputs() const; - - // Returns the Node that outputs to the function should be wired up to. - Node* GetCallNodeForOutputs() const; + // Returns the Node that the inputs and outputs of the function should be + // wired up to. + Node* GetCallNode() const; // Returns the index of the arg that the dst of edge should connect to. int GetArgIndexForEdge(const Edge* edge) const; @@ -380,7 +346,7 @@ class Encapsulator { // Map from source (producer node/slot) tensors in the original graph to // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. - std::unordered_map inputs; + std::unordered_map inputs; // Set of nodes in the original graph that are the source of control edges // that cross from the containing compiled subgraph into the @@ -396,8 +362,15 @@ class Encapsulator { // node/slot) tensors in the original graph to output index (slot number // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. - std::unordered_map outputs_by_src; - std::unordered_map outputs_by_dst; + struct ArgNumAndType { + int index; + DataType dtype; + + ArgNumAndType(int i, DataType t) : index(i), dtype(t) {} + }; + std::unordered_map + outputs_by_src; + std::unordered_map outputs_by_dst; // Set of nodes in the original graph that are the destination of control // edges that cross from the outside_compilation subgraph into the @@ -425,12 +398,6 @@ class Encapsulator { OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( const string& outside_compilation_id); - // Builds a ParallelCheck op that compares the output of the original - // subgraph with the encapsulated subgraph. - Status BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out); - // Builds a placeholder node used to provide the key input to a RecvAtHost // or SendFromHost node. This placeholder node will be removed by a later // pass. @@ -482,26 +449,21 @@ class Encapsulator { // Not owned. Node* host_compute_key_placeholder_ = nullptr; - // Function call node(s) in the output graph. Not owned. - // If parallel_checking is enabled, 'call_node_inputs' is the function call - // node to which inputs should be fed, and 'call_node_outputs' is the - // parallel check op from which outputs should be read. If parallel checking - // is disabled, both point to the function call node. - Node* call_node_inputs_; - Node* call_node_outputs_; + // Function call node in the output graph. Not owned. + Node* call_node_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src_; - std::unordered_map args_by_dst_; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; - // The _Arg nodes in the subgraph, in order by argument number. + // The arguments to the subgraph, in order. std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results_; + std::unordered_map results_; // The outside_compilation clusters in this subgraph. std::unordered_map @@ -541,13 +503,12 @@ class Encapsulator { // Copies all nodes that aren't in a compiled subgraph to the output graph. Status CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images); + Graph* graph_out, std::unordered_map* node_images); // Adds function call nodes for each compiled subgraph. Status AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all // outside_compilation subgraphs. @@ -598,9 +559,9 @@ class Encapsulator { const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added); + Graph* graph_out, + std::unordered_set, + OutputInputTensorPairHasher>* edges_added); // Adds control dependencies between subgraph call nodes that have // dependencies via outside_compilation edges. @@ -609,7 +570,7 @@ class Encapsulator { // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Constructs a minimal shape inference graph that can be used to determine // the shape of send_node at the time that the subgraph is compiled. @@ -729,20 +690,14 @@ void TopologicalClusterSort( } // namespace -Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { - return call_node_inputs_; -} - -Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { - return call_node_outputs_; -} +Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; } int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { - return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); + return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input())); } int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { - return results_.at(NodeSlot(edge->src(), edge->src_output())); + return results_.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetRecvAtHostNode( @@ -754,7 +709,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode( int Encapsulator::Subgraph::GetRecvAtHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .inputs.at(NodeSlot(edge->src(), edge->src_output())); + .inputs.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetSendFromHostNode( @@ -766,7 +721,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode( int Encapsulator::Subgraph::GetSendFromHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); + .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input())); } Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { @@ -791,10 +746,10 @@ Status Encapsulator::Subgraph::RecordArg( std::vector>* src_arg_pairs) { Node* src_node = edge->src(); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; - std::tie(iter, inserted) = - args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + std::tie(iter, inserted) = args_by_src_.emplace( + OutputTensor(src_node, src_slot), args_by_src_.size()); int arg_index = iter->second; if (inserted) { NodeDef arg_def; @@ -815,7 +770,7 @@ Status Encapsulator::Subgraph::RecordArg( Node* dst_node = edge->dst(); Node* dst_image = node_images.at(dst_node); int dst_slot = edge->dst_input(); - args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index; graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); return Status::OK(); } @@ -826,10 +781,10 @@ Status Encapsulator::Subgraph::RecordResult( Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; std::tie(iter, inserted) = - results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + results_.emplace(OutputTensor(src_node, src_slot), results_.size()); int ret_index = iter->second; if (inserted) { NodeDef ret_def; @@ -867,8 +822,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( outside_subgraph->control_inputs.insert(edge->src()); } else { int input_index = outside_subgraph->inputs.size(); - outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + outside_subgraph->inputs.emplace( + OutputTensor(edge->src(), edge->src_output()), input_index); } } @@ -882,11 +837,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = outside_subgraph->outputs_by_src - .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph->outputs_by_src.size()) + .emplace(OutputTensor(edge->src(), edge->src_output()), + OutsideCompilationSubgraph::ArgNumAndType( + outside_subgraph->outputs_by_src.size(), dtype)) .first; - int output_index = output_iter->second; - outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + const int output_index = output_iter->second.index; + outside_subgraph + ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] = output_index; } } @@ -968,7 +925,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); @@ -976,8 +933,8 @@ Status Encapsulator::Subgraph::AddHostComputes( input_dtypes[input_index] = dtype; } for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; + DataType dtype = output.second.dtype; + int output_index = output.second.index; output_dtypes[output_index] = dtype; } @@ -1015,7 +972,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; graph_->AddEdge(src_image, src_slot, host_compute, input_index); } @@ -1037,7 +994,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); - int dst_slot = output.first.slot; + int dst_slot = output.first.index; int output_index = output.second; graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); @@ -1075,7 +1032,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_inputs_); + graph_out->AddControlEdge(sequencer_, call_node_); } } @@ -1090,14 +1047,19 @@ Status Encapsulator::Subgraph::BuildFunctionDef( call_node_def_.set_device(device_); if (rewrite_subgraph_fn) { + std::vector arg_source_tensors(args_by_src_.size()); + for (const auto& arg : args_by_src_) { + arg_source_tensors.at(arg.second) = arg.first; + } // Initialize the input and output permutations to the identity. std::vector input_permutation(args_by_src_.size()); std::iota(input_permutation.begin(), input_permutation.end(), 0); std::vector output_permutation(results_.size()); std::iota(output_permutation.begin(), output_permutation.end(), 0); - TF_RETURN_IF_ERROR(rewrite_subgraph_fn( - &graph_, &input_permutation, &output_permutation, &call_node_def_)); + TF_RETURN_IF_ERROR( + rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation, + &output_permutation, &call_node_def_)); // Apply the input/output permutations to the 'args_by_...' and 'results_' // mappings, so when we build edges in BuildOutputGraph() we @@ -1174,7 +1136,10 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + // TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator. + if (library->Find(inference_graph_name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } } return Status::OK(); } @@ -1200,83 +1165,16 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( return Status::OK(); } -Status Encapsulator::Subgraph::BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out) { - // Build an index mapping output positions to node/slot pairs in the - // original graph. - std::vector results_by_num(results_.size()); - for (const auto& entry : results_) { - results_by_num[entry.second] = entry.first; - } - - // Build a parallel check NodeDef. - int num_results = results_by_num.size(); - std::vector result_dtypes(num_results); - std::vector expected_outputs(num_results); - std::vector actual_outputs(num_results); - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - result_dtypes[i] = node_slot.node->output_type(node_slot.slot); - expected_outputs[i] = - NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), - node_slot.slot, result_dtypes[i]); - actual_outputs[i] = - NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); - } - // Assign the parallel check op to a CPU on the same task as the cluster it is - // checking. - string device, dummy; - if (!DeviceNameUtils::SplitDeviceName( - call_node_inputs_->assigned_device_name(), &device, &dummy)) { - return errors::InvalidArgument("Could not parse device name"); - } - strings::StrAppend(&device, "/cpu:0"); - - NodeDef check_def; - TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), - "_parallel_check")), - "ParallelCheck") - .Device(device) - .Attr("T", result_dtypes) - .Input(expected_outputs) - .Input(actual_outputs) - .Finalize(&check_def)); - - Status s; - Node* check_op = graph_out->AddNode(check_def, &s); - if (!s.ok()) return s; - check_op->set_assigned_device_name(device); - - // TODO(phawkins): it seems redundant to call AddEdge as well as - // pass Inputs to the NodeDefBuilder, but I have been unable to find a - // way to avoid it. - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, - i); - graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); - } - - call_node_outputs_ = check_op; - return Status::OK(); -} - Status Encapsulator::Subgraph::AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { Status s; - call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + call_node_ = graph_out->AddNode(call_node_def_, &s); if (!s.ok()) return s; // Copy the assigned device and the key_annotation over. - call_node_inputs_->set_assigned_device_name(device_); - call_node_outputs_ = call_node_inputs_; + call_node_->set_assigned_device_name(device_); - if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); - } return Status::OK(); } @@ -1315,7 +1213,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( for (const auto& input : oc_subgraph->inputs) { const Node* src_node = input.first.node; - int src_slot = input.first.slot; + int src_slot = input.first.index; int input_index = input.second; DataType dtype = src_node->output_type(src_slot); @@ -1369,8 +1267,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( for (const auto& output : oc_subgraph->outputs_by_src) { const Node* src_node = output.first.node; Node* src_image = node_images.at(src_node); - int src_slot = output.first.slot; - int output_index = output.second; + int src_slot = output.first.index; + int output_index = output.second.index; DataType dtype = src_node->output_type(src_slot); dtypes[output_index] = dtype; @@ -1609,6 +1507,9 @@ Status Encapsulator::SplitIntoSubgraphs() { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); + // Verify that the graph has well-formed control flow structure. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } return s; @@ -1627,27 +1528,17 @@ Status Encapsulator::BuildFunctionDefs( } Status Encapsulator::CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images) { + Graph* graph_out, std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; string outside_compilation_id; TF_RETURN_IF_ERROR( GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - // Don't copy nodes that going to be encapsulated, unless parallel checking - // is enabled. - if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) - continue; + // Don't copy nodes that are going to be encapsulated. + if (IsInSubgraph(func_id, outside_compilation_id)) continue; Node* image = graph_out->CopyNode(node); - if (!outside_compilation_id.empty()) { - if (parallel_checking) { - return errors::InvalidArgument( - "Parallel checking is not supported when outside_compilation " - "clusters are present."); - } - } (*node_images)[node] = image; } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); @@ -1657,10 +1548,10 @@ Status Encapsulator::CopyNodesToOutputGraph( Status Encapsulator::AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { - TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( - node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + subgraph_entry.second.AddFunctionCallNode(node_images, graph_out)); } return Status::OK(); } @@ -1694,7 +1585,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( } else { // The edge is from a subgraph to a regular node in the output graph so // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + *src_image = subgraphs_.at(src_func_id).GetCallNode(); } } else { // The source of the edge is in the output graph so use the node image in @@ -1742,7 +1633,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst( } else { // The edge is to a subgraph from a regular node in the output graph so // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); } } else { // The destination of the edge is in the output graph so use the node image @@ -1778,10 +1669,9 @@ Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, - const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added) { + const std::unordered_map& node_images, Graph* graph_out, + std::unordered_set, + OutputInputTensorPairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( src_func_id, src_outside_compilation_id, dst_func_id, @@ -1796,16 +1686,12 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edge->IsControlEdge()) { // Add the control edge, if we have not already added it, using the images // determined above (potentially call operators or RecvAtHost/SendFromHost). - if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + if (edges_added + ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { graph_out->AddControlEdge(src_image, dst_image); } - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } return Status::OK(); } @@ -1817,18 +1703,10 @@ Status Encapsulator::CopyEdgeToOutputGraph( FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, dst_func_id, dst_outside_compilation_id, edge); - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && - parallel_checking) { - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - // Add the edge, if we have not already added it. if (edges_added - ->emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) + ->emplace(OutputTensor(src_image, src_output), + InputTensor(dst_image, dst_input)) .second) { graph_out->AddEdge(src_image, src_output, dst_image, dst_input); } @@ -1839,8 +1717,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { for (const auto& ancestors : subgraph_ancestors_) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { - graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(), - subgraphs_[subgraph].GetCallNodeForInputs()); + graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), + subgraphs_[subgraph].GetCallNode()); } } return Status::OK(); @@ -1848,11 +1726,12 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. - std::unordered_set, NodeSlot::PairHasher> + std::unordered_set, + OutputInputTensorPairHasher> edges_added; for (const Edge* edge : graph_in_->edges()) { @@ -1870,16 +1749,6 @@ Status Encapsulator::AddEdgesToOutputGraph( if (IsInSubgraph(src_func_id, src_outside_compilation_id) && IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { - if (parallel_checking) { - Node* src_image = node_images.at(edge->src()); - Node* dst_image = node_images.at(edge->dst()); - if (edge->IsControlEdge()) { - graph_out->AddControlEdge(src_image, dst_image); - } else { - graph_out->AddEdge(src_image, edge->src_output(), dst_image, - edge->dst_input()); - } - } continue; } @@ -1887,8 +1756,7 @@ Status Encapsulator::AddEdgesToOutputGraph( // unclustered graph. TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( edge, src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, parallel_checking, graph_out, - &edges_added)); + dst_outside_compilation_id, node_images, graph_out, &edges_added)); } for (auto& subgraph_entry : subgraphs_) { @@ -2504,18 +2372,15 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, +Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; - TF_RETURN_IF_ERROR( - CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); - TF_RETURN_IF_ERROR( - AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); + TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); - TF_RETURN_IF_ERROR( - AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); TF_RETURN_IF_ERROR( GetShapeInfoForOutsideCompilationSends(graph_out, library)); @@ -2528,8 +2393,8 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, Status EncapsulateSubgraphsInFunctions( string group_attribute, string outside_compilation_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, - bool parallel_checking, bool reuse_existing_functions, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), @@ -2543,8 +2408,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); + TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library)); *graph_out = std::move(out); return Status::OK(); @@ -2585,8 +2449,6 @@ static Status RenumberArguments(Graph* graph, Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; - legacy_flags::EncapsulateSubgraphsPassFlags* flags = - legacy_flags::GetEncapsulateSubgraphsPassFlags(); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, options.flib_def); @@ -2602,69 +2464,73 @@ Status EncapsulateSubgraphsPass::Run( FunctionLibraryRuntime* flr = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - auto rewrite_subgraph = [flr](std::unique_ptr* subgraph, - std::vector* input_permutation, - std::vector* output_permutation, - NodeDef* node) { - // Optimize the subgraph. - OptimizeGraph(flr, subgraph); - - const int num_args = input_permutation->size(); - std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); - - DataTypeVector arg_types(num_args); - TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); - - // Compute a permutation of the arguments such that the constant arguments - // are first. - const int num_consts = - std::count(const_args.begin(), const_args.end(), true); - - const int num_resources = - std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); - const int num_nonconsts = num_args - num_resources - num_consts; - if (num_nonconsts < 0) { - return errors::Internal("num_nonconsts should be >= 0, was ", - num_nonconsts); - } - - int const_pos = 0; - int arg_pos = num_consts; - int resource_pos = num_consts + num_nonconsts; - for (int i = 0; i < num_args; ++i) { - if (const_args[i]) { - if (arg_types[i] == DT_RESOURCE) { - return errors::Internal( - "Resource arguments cannot be constant (argument ", i, ")"); + auto rewrite_subgraph = + [flr](const std::vector& arg_source_tensors, + std::unique_ptr* subgraph, + std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node) { + // Optimize the subgraph. + OptimizeGraph(flr, subgraph); + + const int num_args = input_permutation->size(); + std::vector const_args(num_args); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + + DataTypeVector arg_types(num_args); + TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); + + // Compute a permutation of the arguments such that the constant + // arguments are first. + const int num_consts = + std::count(const_args.begin(), const_args.end(), true); + + const int num_resources = + std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); + const int num_nonconsts = num_args - num_resources - num_consts; + if (num_nonconsts < 0) { + return errors::Internal("num_nonconsts should be >= 0, was ", + num_nonconsts); } - (*input_permutation)[i] = const_pos; - ++const_pos; - } else if (arg_types[i] == DT_RESOURCE) { - (*input_permutation)[i] = resource_pos; - ++resource_pos; - } else { - (*input_permutation)[i] = arg_pos; - ++arg_pos; - } - } - // Renumber argument nodes in the graph. - TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation)); - - // TODO(phawkins): add a forward is-constant analysis, similarly split - // outputs into host-memory constants and device-memory non-constants. - - AddNodeAttr(kXlaCompiledKernelAttr, true, node); - AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); - AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); - return Status::OK(); - }; + int const_pos = 0; + int arg_pos = num_consts; + int resource_pos = num_consts + num_nonconsts; + for (int i = 0; i < num_args; ++i) { + if (const_args[i]) { + if (arg_types[i] == DT_RESOURCE) { + return errors::Internal( + "Resource arguments cannot be constant (argument ", i, ")"); + } + (*input_permutation)[i] = const_pos; + ++const_pos; + } else if (arg_types[i] == DT_RESOURCE) { + (*input_permutation)[i] = resource_pos; + ++resource_pos; + } else { + (*input_permutation)[i] = arg_pos; + ++arg_pos; + } + } - TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, flags->tf_xla_parallel_checking, - /*reuse_existing_functions=*/false, &graph_out, library)); + // Renumber argument nodes in the graph. + TF_RETURN_IF_ERROR( + RenumberArguments(subgraph->get(), *input_permutation)); + + // TODO(phawkins): add a forward is-constant analysis, similarly split + // outputs into host-memory constants and device-memory non-constants. + + AddNodeAttr(kXlaCompiledKernelAttr, true, node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); + return Status::OK(); + }; + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, + library), + "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be4409a381197d2191e083727aa8d48ab8cd63..926589546fec72048485d30966f31b24e44b1245 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -28,6 +28,9 @@ limitations under the License. namespace tensorflow { // A rewriting function to apply to each subgraph during encapsulation. +// 'arg_source_tensors' are the tensors corresponding to the arguments in the +// original source graph (*not* 'graph'). +// // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; // 'input_permutation' is a mapping from old argument numbers to new argument // numbers, whereas 'output_permutation' is the same for outputs. Both @@ -37,6 +40,7 @@ namespace tensorflow { // The rewrite may also change the NodeDef's operator name, and that // name will be used as the name of the generated function. typedef std::function& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> RewriteSubgraphFn; @@ -61,10 +65,6 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 5ec24d39a2c40a766dbb0ec51ebe798de620e24b..4eb389e0c653f2d32c17f448687f865a44a11b96 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -511,7 +511,6 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -560,8 +559,9 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* b = Input(b1.opts().WithName("B")); // Give nodes 'c' and 'd' names that collide after lowercasing. Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); - Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( - "_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } @@ -614,8 +614,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Node* c = Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( "_encapsulate", "F1")); - Node* d = - Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( + Node* d = Binary(b, c, + b1.opts().WithName("D").WithControlInput(control).WithAttr( "_encapsulate", "F2")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -707,7 +707,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; @@ -721,47 +721,6 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -TEST(EncapsulateSubgraphsTest, ParallelChecking) { - Scope root = Scope::NewRootScope().ExitOnError().WithDevice( - "/job:localhost/replica:0/task:0/cpu:0"); - auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); - auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); - add1.node()->AddAttr("_cluster", "cluster1"); - auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); - add2.node()->AddAttr("_cluster", "cluster1"); - auto out = ops::Mul(root.WithOpName("mul"), x1, add2); - - Graph graph_before_encapsulation(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); - - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - std::unique_ptr graph; - TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, - /*reuse_existing_functions=*/false, &graph, &library)); - - std::vector expected_nodes = { - "add1", "add2", "cluster1", "cluster1_parallel_check/_0", - "mul", "x1", "x2"}; - EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - - std::vector> expected_edges = { - {"add1:0", "add2:0"}, - {"add2:0", "cluster1_parallel_check/_0:0"}, - {"cluster1:0", "cluster1_parallel_check/_0:1"}, - {"cluster1_parallel_check/_0:0", "mul:1"}, - {"x1:0", "add1:0"}, - {"x1:0", "cluster1:0"}, - {"x1:0", "mul:0"}, - {"x2:0", "add1:1"}, - {"x2:0", "add2:1"}, - {"x2:0", "cluster1:1"}, - }; - EXPECT_EQ(expected_edges, GraphEdges(*graph)); -} - const Node* FindNodeByName(const Graph& graph, const string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; @@ -798,7 +757,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -814,7 +774,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); EXPECT_EQ(2, guaranteed_consts); } @@ -843,7 +802,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -859,7 +819,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const // and another non-const, so overall non-const. @@ -1050,7 +1009,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_outside", "O1")); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, shape2.opts()); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") @@ -1075,7 +1034,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", - {"D:o:0", "F:o:0"}, + {"F:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"ancestors", @@ -1123,13 +1082,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, b2.opts()); - Node* g = Binary(e, ops::NodeOut(recv2, 1), + Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") .WithControlInputs({recv2, e}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 9d856346eca06b1f2ed8bf450a4265a6a589b818..251a07304eaeb21f1313d7a6ef6af668f99d8551 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); @@ -166,6 +166,14 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; + // Optimization: don't resolve constants. If we resolve constants we never + // emit them on the device, meaning that if they are needed by a following + // computation the host has to transfer them. + compile_options.resolve_compile_time_constants = false; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + OP_REQUIRES_OK( ctx, cache->Compile(options, function_, constant_args, variables, ctx, &kernel, &executable, &compile_options)); @@ -256,10 +264,9 @@ XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5d211f4d733d8d807426e62dd116092799184f35..5b6692f523658749f7ef48f9d7d89e97d4ce8b09 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -16,18 +16,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -cc_library( - name = "encapsulate_subgraphs_pass_flags", - srcs = ["encapsulate_subgraphs_pass_flags.cc"], - hdrs = ["encapsulate_subgraphs_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "mark_for_compilation_pass_flags", srcs = ["mark_for_compilation_pass_flags.cc"], diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc deleted file mode 100644 index 856475f12c8a411cd80c1c1859323304ca4029e0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static EncapsulateSubgraphsPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new EncapsulateSubgraphsPassFlags; - flags->tf_xla_parallel_checking = false; - flag_list = new std::vector({ - Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking, - "Debug tool. Runs both JIT-compiled and interpreted graphs in " - "parallel and verifies they produce the same outputs."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h deleted file mode 100644 index d371bd269dbdfbf737d81490fb877fcf88661a8f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags( - std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -typedef struct { - bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and - // interpreted graphs in parallel and verifies - // they produce the same outputs. -} EncapsulateSubgraphsPassFlags; - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8e2ee0f1d71bc17b4c12c792c38002af4f9eb5eb..8c3882116dd4f048ea3e32c037bf4139c67a3eb9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -41,9 +42,6 @@ limitations under the License. namespace tensorflow { -const char* const kXlaClusterAttr = "_XlaCluster"; -const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; - namespace { bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { @@ -60,6 +58,14 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return false; } } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (HasForwardedRefInput(node)) { + return false; + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } @@ -165,16 +171,6 @@ bool IsCompilableCall(const NodeDef& call_def, return true; } -// Returns the DeviceType corresponding to 'device'. -Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - // Tests whether `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), @@ -183,18 +179,11 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } -struct NodeCompare { - bool operator()(const Node* a, const Node* b) const { - return a->id() < b->id(); - } -}; -using OrderedNodeSet = std::set; - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // -// TODO(hpucha): Consider a black list instead of a white list as -// implemented below. +// TODO(hpucha): Remove this code since this functionality is subsumed by +// Grappler XlaFusionOptimizer. bool IsXlaFusable(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( @@ -364,7 +353,7 @@ Status FindCompilationCandidates( for (Node* node : graph.op_nodes()) { sorted_nodes.push_back(node); } - std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); for (Node* node : sorted_nodes) { VLOG(2) << "Fuel: " << fuel; @@ -379,9 +368,13 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + DeviceToDeviceType(node->assigned_device_name(), &device_type)); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + if (is_compilable_fn && !is_compilable_fn(node, device_type)) { + VLOG(2) << "Compilation rejected node: not compilable " << node->name() + << ": " << node->type_string(); + continue; + } const XlaOpRegistry::DeviceRegistration* registration; CHECK( @@ -430,46 +423,6 @@ struct Cluster { int representative = -1; }; -// Returns a string describing how an edge from src to dst would -// create a cycle. -string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, - int dst) { - int32 max_path_size = graph.num_node_ids() + 1; - std::vector path(max_path_size); - int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); - if (path_size == 0) { - return ""; - } - - auto node_name = [&cycles, &graph](int node_id) { - if (!FastBoundsCheck(node_id, graph.num_node_ids())) { - return string("(null)"); - } - auto* node = graph.FindNodeId(node_id); - if (node == nullptr) { - return string("(null)"); - } - return node->name(); - }; - - string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); - path.resize(path_size); - for (int32 node_id : path) { - string ascii_art; - if (node_id == dst) { - ascii_art = "+-> "; - } else if (node_id != src) { - ascii_art = "| "; - } else { - ascii_art = "+-- "; - } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); - } - return description; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -575,84 +528,13 @@ Status MarkForCompilationPass::RunImpl( : Env::Default(), is_compilable_fn, &compilation_candidates)); - GraphCycles cycles; - for (int i = 0; i < graph->num_node_ids(); ++i) { - // We rely on the node IDs in the cycle detection graph being consecutive - // integers starting from 0. - CHECK_EQ(i, cycles.NewNode()); + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - // Compute the loop structure of the graph. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); - - // The clustering code must avoid adding cycles to the graph to prevent - // deadlock. However, the graph may contain loops, which would trigger the - // cycle detection code. To handle loops, we alter the structure of the cycle - // detection graph, disconnecting each loop from the enclosing graph. - // Specifically, we: - // * add a new "frame" node for each loop. - // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges - // to/from the corresponding frame node. In essence, we collapse the loop - // into a single node for the purpose of cycle detection in the enclosing - // graph. - // * the body of the loop should now be disconnected from the rest of the - // graph; we make it acyclic by breaking loop backedges (edges outgoing from - // "NextIteration" nodes. - - // Map from frame name strings to node IDs in the cycle detection graph. - std::unordered_map frame_nodes; - - // Get the cycle graph node ID for frame 'frame_name', or add one if none - // exists. - auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { - int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; - if (frame_id < 0) { - // The emplace succeeded; we have not allocated a frame node yet. - frame_id = cycles.NewNode(); - } - return frame_id; - }; - - for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); - } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(src, edge->dst()->id())) { - return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); - } - // Drop the original edge. - continue; - } - if (edge->src()->IsNextIteration()) { - // Break loop back-edges. - continue; - } - if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { - // This should never happen. All cycles in the graph should contain - // a control flow operator. - return errors::Internal( - "Found cycle in graph without control flow operator during XLA " - "compilation: ", - DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); - } - } + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -670,6 +552,9 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). while (!worklist.empty()) { int from = worklist.front()->Get().representative; worklist.pop_front(); @@ -778,7 +663,7 @@ Status MarkForCompilationPass::RunImpl( // compilation. DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + DeviceToDeviceType(n->assigned_device_name(), &device_type)); const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 703d8825d74ced8d4d69c31ccd730adc89a8bffe..772c92d369e67f431b5d030d1d5cdc5ae2700d39 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) { } } +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b43dab790e6cda5e85688bdacf48a35adc4..f2473d98ffd5dae55983e601b8d2d65af6a6d54c 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5628b12a27c9ed052e22c784517a07f2c1c059a --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -0,0 +1,188 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; + +namespace { +// Returns a string describing how an edge from src to dst would +// create a cycle. +string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, + int dst) { + int32 max_path_size = graph.num_node_ids() + 1; + std::vector path(max_path_size); + int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data()); + if (path_size == 0) { + return ""; + } + + auto node_name = [cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + + string description; + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); + path.resize(path_size); + for (int32 node_id : path) { + string ascii_art; + if (node_id == dst) { + ascii_art = "+-> "; + } else if (node_id != src) { + ascii_art = "| "; + } else { + ascii_art = "+-- "; + } + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + } + return description; +} + +bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } + +} // namespace + +Status DeviceToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +bool HasForwardedRefInput(const Node& node) { + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input " + << incoming_node->name() << " " << incoming_node->type_string(); + return true; + } + } + } + return false; +} + +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles->NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles->NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter() || edge->src()->IsExit()) { + const char* src_type = "pre-enter"; + const char* dst_type = "post-exit"; + int src = edge->src()->id(); + int dst = edge->dst()->id(); + + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + dst = GetOrAddFrameNodeId(frame_name); + dst_type = "frame"; + } + + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + src = GetOrAddFrameNodeId(frame_name); + src_type = "frame"; + } + + if (!cycles->InsertEdge(src, dst)) { + return errors::Internal( + "Cycle detected when adding ", src_type, "->", dst_type, + " edge: ", DescribeCycle(cycles, *graph, src, dst)); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation: ", + DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bcce082aaf6044ff0654efa4d78c0f493a350d00 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + +using OrderedNodeSet = std::set; + +// Returns the DeviceType corresponding to 'device'. +Status DeviceToDeviceType(const string& device, DeviceType* device_type); + +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cb351e1ecdb4523a8652886af156540e4736b18 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter = + ops::internal::Enter(root.WithOpName("enter"), a, "only_frame"); + Output exit = ops::internal::Exit(root.WithOpName("exit"), enter); + Output b = ops::Add(root.WithOpName("b"), a, exit); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); +} + +TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + Output b = ops::Add(root.WithOpName("b"), a, exit_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6430975335f5eef5b53c80213e6090ffd6166a91..54a41a4daa790401c797277e7eaab531dd34ac80 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -40,7 +40,23 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} -XlaCompilationCache::~XlaCompilationCache() = default; +XlaCompilationCache::~XlaCompilationCache() { + // Ensure any use of our programs have completed by waiting for all stream + // executors to complete. + for (auto* executor : client_->backend().stream_executors()) { + bool ok = executor->SynchronizeAllActivity(); + if (!ok) { + LOG(ERROR) << "Error synchronizing activity while waiting for all " + "programs to complete"; + } + } + // TODO(b/110813685): Think about the program ownership model. Programs are + // currently owned by the compilation cache which means we must wait for + // program completion in the destructor. There are multiple compilation caches + // around, which complicates things a little. Perhaps having programs be + // shared_ptrs (an invasive change) would make the model easier to reason + // about? +} string XlaCompilationCache::DebugString() { return "XLA JIT compilation cache"; @@ -122,8 +138,7 @@ Status XlaCompilationCache::BuildSignature( namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index ab644ff5a61c407b246b97af5328bf5cd8c1893b..baccea2d6a793df8c5cf8c8941706d41d2c044ca 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -61,14 +61,18 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; TF_RET_CHECK(stream); - VLOG(2) << "Executing computation."; + VLOG(2) << "Executing computation: " << name(); + for (const xla::ShapedBuffer* arg : launch_context.arguments()) { + VLOG(2) << name() << ": " << *arg; + } xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(ctx->step_id()); - auto run_result = executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result = + executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); @@ -151,8 +155,7 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - DeviceType device_type = metadata.jit_device_type(); - options.device_type = &device_type; + options.device_type = metadata.jit_device_type(); options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); @@ -160,6 +163,13 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; + // Optimization: don't resolve constants. If we resolve constants we never + // emit them on the device, meaning that if they are needed by a following + // computation the host has to transfer them. + compile_options.resolve_compile_time_constants = false; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 23c6f3903f841a6c39104983c6f7f409757a7319..7cc3d0e007ba2974fbfbe6fbabc4aa08f9fa910f 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { // An OpKernel that compiles an op to an XLA computation and runs it. Unlike -// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a // vanilla TensorFlow op as long as the bridge supports it. -// -// Importantly _XlaLaunch assumes all input and output tensors are on the host, -// whereas XlacompileOnDemandOp works with tensors in device memory. class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index ea9e0366043a4a64bfe43703c55d4470693bbac8..43648402f65c656b6b4eb2e83e61ce45f1c73669 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -50,11 +50,12 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR( - XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, - name_prefix, registration, - /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index f13b46c532e6008477849f2e06887901c90038ab..ed007d603ea1b3d27dd25f00726261cdd029c20c 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" @@ -105,6 +106,25 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( return alloc_ptr; } +namespace { + +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + +} // namespace + /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, @@ -112,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const XlaOpRegistry::DeviceRegistration& registration, bool transfer_as_literal, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr* device) { + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -133,17 +153,20 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( device->reset(new XlaDevice( options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), - shape_representation_fn_(std::move(shape_representation_fn)) {} + shape_representation_fn_(std::move(shape_representation_fn)), + padded_shape_fn_(std::move(padded_shape_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -178,10 +201,11 @@ XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn), + shape_representation_fn, padded_shape_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d5d345d43b16c43c7a202791b2604b39d29e8cdb..02e88ee6793e984a7b782790f8011cbcbc5a5026 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -45,13 +45,19 @@ namespace tensorflow { class XlaDevice : public LocalDevice { public: + // Given a tensor, sets `xla::Shape*` the shape of tensor's representation + // on device, fully padded. On error, the contents of `xla::Shape*` + // are undefined. + typedef std::function PaddedShapeFn; + // Wrapper class to store metadata about the XlaDevice, where it can be // retrieved e.g., when lazily creating the XlaCompilationCache device. class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn); // The index of the device on this host. int device_ordinal() const; @@ -62,12 +68,14 @@ class XlaDevice : public LocalDevice { const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { return shape_representation_fn_; } + const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + PaddedShapeFn padded_shape_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -81,6 +89,8 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. + // If padded_shape_fn is empty, a default implementation that returns + // the on-host shape is used. static Status Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, @@ -88,12 +98,16 @@ class XlaDevice : public LocalDevice { const XlaOpRegistry::DeviceRegistration& registration, bool transfer_as_literal, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr* device); + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); + // Creates a new XLA Device. + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn); + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -110,6 +124,7 @@ class XlaDevice : public LocalDevice { Tensor* tensor) override; xla::LocalClient* client() const; + const Metadata& metadata() { return xla_metadata_; } xla::StatusOr GetStream(); // If not already set, create and set GpuDeviceInfo. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ff30b62bad782f281bcd25275521ed8b0c4c0bfd..3bbf97afadd2c8a70add16b748a35832a2ef8538 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -54,38 +54,66 @@ XlaTransferManager::XlaTransferManager( client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) {} + shape_representation_fn_(std::move(shape_representation_fn)) { + if (!shape_representation_fn_) { + shape_representation_fn_ = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { return shape; }; + } +} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { - xla::Literal literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal)); - VLOG(1) << "Transfer to device as literal: " << literal.ToString(); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + // Create a reference to hold onto host_tensor until after the literal has + // been transferred. Also make sure the literal exists until the function + // asynchronously completes, as it will be wrapped in an xla::LiteralSlice. + TensorReference ref(host_tensor); + auto literal = std::make_shared( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(device_tensor)->shaped_buffer(); - return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, - shaped_buffer); + VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " + << shaped_buffer.ToString(); + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( + stream_, *literal, shaped_buffer)); + // Unref the host tensor, and capture the literal shared_ptr too so it goes + // out of scope when the lambda completes. + stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } -Status XlaTransferManager::TransferLiteralFromDevice( - Tensor* host_tensor, const Tensor& device_tensor) const { +void XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor, + const StatusCallback& done) const { const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - transfer_manager_->TransferLiteralFromDevice( - stream_->parent(), shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - return errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } - return Status::OK(); + TensorReference ref(device_tensor); + transfer_manager_->TransferLiteralFromDevice( + stream_, shaped_buffer, + [=, &shaped_buffer]( + xla::StatusOr > literal_or) { + ref.Unref(); + done([&]() -> Status { + TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() + << " " << shaped_buffer.ToString(); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + Status status; + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + status = errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return status; + }()); + }); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -98,7 +126,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << " " << reinterpret_cast( device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements(); + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); @@ -106,24 +136,23 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); - TensorShape shape; - if (shape_representation_fn_) { - shape = shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype()); - } else { - shape = device_tensor->shape(); + Status status; + xla::StatusOr shape_or_status = shape_representation_fn_( + device_tensor->shape(), device_tensor->dtype()); + if (!shape_or_status.ok()) { + done(shape_or_status.status()); + return; } + TensorShape shape = shape_or_status.ValueOrDie(); if (!xla_tensor->has_shaped_buffer()) { - Status s = xla_tensor->AllocateShapedBuffer( + status = xla_tensor->AllocateShapedBuffer( device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; + if (!status.ok()) { + return done(status); } } - Status status; if (transfer_as_literal_) { Tensor reshaped_cpu_tensor; if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { @@ -165,7 +194,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_tensor->tensor_data().data()) << " " << reinterpret_cast(cpu_tensor->tensor_data().data()) - << device_tensor->NumElements(); + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); const int64 total_bytes = cpu_tensor->TotalBytes(); se::DeviceMemoryBase dev_src_ptr = @@ -174,7 +205,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status status; if (transfer_as_literal_) { - status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); + TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); + return; } else { stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. @@ -184,9 +216,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, "Failed to complete data transfer on stream %p: %s", stream_, block_status.error_message().c_str()); } + done(status); } - - done(status); return; } @@ -194,6 +225,42 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } +void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + // Perform memory allocation now, and enqueue the device-to-device transfer. + Status status = [&]() -> Status { + if (src_tensor.NumElements() == 0) { + return Status::OK(); + } + XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); + XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); + CHECK(xla_src && xla_dst) + << "Missing destination tensor for device-to-device copy"; + if (!xla_dst->has_shaped_buffer()) { + TF_ASSIGN_OR_RETURN( + TensorShape shape, + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype())); + TF_RETURN_IF_ERROR( + xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, + stream_->parent()->device_ordinal())); + } + auto from_iter = xla_src->shaped_buffer().buffers().begin(); + auto to_iter = xla_dst->shaped_buffer().buffers().begin(); + for (auto end_iter = xla_src->shaped_buffer().buffers().end(); + from_iter != end_iter; ++from_iter, ++to_iter) { + stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second, + to_iter->second.size()); + } + return Status::OK(); + }(); + if (!status.ok()) { + return done(status); + } else { + stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + } +} + XlaDeviceContext::XlaDeviceContext( se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn) @@ -215,4 +282,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, done); } +void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 9af9655868448ce5116db3611c5f88339135947e..c5c81d65fe0f4a2774aab9f742454467e052071e 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -55,13 +55,18 @@ class XlaTransferManager { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); + + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const { return stream_; } private: Status TransferLiteralToDevice(const Tensor& host_tensor, Tensor* device_tensor) const; - Status TransferLiteralFromDevice(Tensor* host_tensor, - const Tensor& device_tensor) const; + void TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor, + const StatusCallback& done) const; // Stream obtained from a Device, used to transfer tensors between // CPU and device. @@ -72,7 +77,7 @@ class XlaTransferManager { xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; - const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -90,6 +95,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const override { return manager_.stream(); } private: diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index f68dba6b6a26c0c289fd8457ad143d62e5fb9a69..5ecb1afa7bcec910ca843ccd3a782745f2bb6ca8 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_ops.h" +#include + #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_tensor.h" namespace tensorflow { @@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { << type_string() << " on an XLA device. This should never happen."; } +XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); +} + +void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(), + errors::InvalidArgument( + "Variable and value dtypes don't match; respectively, ", + dtype_, " and ", context->input(1).dtype()), + done); + Var* variable = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + LookupOrCreateResource( + context, HandleFromInput(context, 0), &variable, + [this, context](Var** ptr) { + *ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + TF_RETURN_IF_ERROR(context->allocate_persistent( + dtype_, context->input(1).shape(), &unused, &tmp, attr)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + }), + done); + core::ScopedUnref s(variable); + + OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_)), + done); + + const Tensor& value = context->input(1); + AllocatorAttributes attr; + + // Copying is unnecessary if we are the last user of the value tensor, we can + // just adopt the input tensor's buffer instead. + std::unique_ptr input_alias = context->forward_input( + 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_, + value.shape(), DEVICE_MEMORY, attr); + mutex_lock ml(*variable->mu()); + variable->is_initialized = true; + if (input_alias) { + *variable->tensor() = *input_alias; + done(); + return; + } + + // Need to copy, but maybe we can re-use variable's buffer? + if (!XlaTensor::RefCountIsOne(*variable->tensor()) || + !variable->tensor()->shape().IsSameSize(value.shape())) { + // Copy to new buffer + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_persistent(dtype_, value.shape(), + &unused, &tmp, attr), + done); + *variable->tensor() = *tmp; + } + + XlaDeviceContext* device_context = + static_cast(context->op_device_context()); + + variable->Ref(); + device_context->CopyDeviceTensorToDevice( + value, variable->tensor(), [context, variable, done](Status status) { + variable->Unref(); + context->SetStatus(status); + done(); + }); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 65c0e8577f1d0324df9edbf6a805721436c04669..a605335a94f8687e0af4566f912b38dca9b5ac26 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,17 +23,21 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/queue_op.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -41,8 +45,17 @@ class XlaDeviceDummyOp : public OpKernel { void Compute(OpKernelContext* ctx) override; }; +class XlaAssignVariableOp : public AsyncOpKernel { + public: + explicit XlaAssignVariableOp(OpKernelConstruction* c); + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + DataType dtype_; +}; + #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ @@ -73,7 +86,93 @@ class XlaDeviceDummyOp : public OpKernel { \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ - ResourceHandleOp); + ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ + ReadVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ + XlaAssignVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); \ + REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ + SwitchOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ + REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ + .Device(DEVICE) \ + .HostMemory("size") \ + .HostMemory("handle"), \ + QueueSizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ + QueueIsClosedOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + +// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..74257b09a808a39454eace3b1a9bf57a2e071360 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -0,0 +1,328 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +namespace tensorflow { + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "ShapeN" || + node.type_string() == "Rank" || node.type_string() == "Size"; +} + +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", + "StopGradient", /*"Snapshot",*/ + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + +Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) { + VLOG(2) << "Here at fusion optimizer"; + + // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. + // Once that happens, the expected interaction between this optimizer and when + // the global_jit_level is set is as follows: Fusion optimizer will replace + // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can + // be further compiled where possible via mark_for_compilation_pass. Note that + // this might lead to inefficient clustering, and it is best to use either the + // fusion optimizer or the global_jit flag, and not combine the two. + + // Create a Graph out of GraphDef. This is required currently because the + // helpers around clustering, encapsulation etc work on graphs. + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + Graph graph(function_library); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + shape_refiner.set_disable_constant_propagation(true); + ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, when + // colocation constraints are already validated previously and the device + // placement of nodes has also completed, so there is no need to validate + // colocation constraints again. + options.validate_colocation_constraints = false; + options.validate_shape = false; + TF_RETURN_IF_ERROR( + ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + + // Collect nodes that can be fused via XLA, while ignoring those that + // explicitly ask for XLA: (*) nodes that are marked to be compiled + // explicitly. (*) nodes assigned to XLA device. + OrderedNodeSet compilation_candidates; + for (Node* node : graph.op_nodes()) { + // If there is a _XlaCompile annotation, ignore the node if it is + // true. Nodes are marked with this attr via experimental_jit_scope, and + // will be handled by the mark_for_compilation pass. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok() && compile) { + continue; + } + // If there is already a _XlaCluster annotation, ignore the node. Nodes are + // marked with this attr to indicate they are already part of a cluster and + // hence ignored. + status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); + if (status.ok()) { + continue; + } + + // If there is an explicit XLA device placement, ignore the node. + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); + if (device_type.type_string().find("XLA") != string::npos) continue; + + // Assume all fusable ops are registered. + // TODO(hpucha): Check for registration if possible. + if (!IsXlaFusable(node->def())) { + continue; + } + + // XLA does not offer guaranteed aliasing between the input and output of + // the XLA cluster so it can't implement the forward-tensor-ref semantic. + // Leave such nodes out of XLA clusters. + if (HasForwardedRefInput(*node)) { + continue; + } + + compilation_candidates.insert(node); + } + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + *output = item.graph; + return Status::OK(); + } + + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + + // TODO(hpucha): Make clustering more robust. There are two known issues that + // we need to mitigate: (a) Non-resource variables can cause deadlocks + // when clustering changes order of execution. See b/77263461 for a specific + // example. (b) Queue operations can also cause deadlocks. See b/77261498 for + // example. + + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + }; + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters(graph.num_node_ids()); + std::deque*> worklist; + for (Node* node : compilation_candidates) { + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); + worklist.push_back(&clusters[node->id()]); + } + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. This is a simplified + // version of the clustering in mark_for_compilation_pass that also deals with + // nodes that are explicitly tagged to be compiled/clustered. + while (!worklist.empty()) { + int from = worklist.front()->Get().representative; + worklist.pop_front(); + + Node* node_from = graph.FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + for (int to : cycles.Successors(from)) { + if (to >= graph.num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph.FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + // Do not cluster across devices. + if (node_from->def().device() != node_to->def().device()) { + VLOG(2) << "Devices " << node_from->def().device() << " " + << node_to->def().device(); + VLOG(2) << "Device names " << node_from->assigned_device_name() << " " + << node_to->assigned_device_name(); + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph.num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } + } + + const int min_cluster_size = 2; + int num_clusters = 0; + for (auto size : effective_cluster_sizes) { + if (size >= min_cluster_size) { + VLOG(3) << "Cluster " << num_clusters << " " << size; + num_clusters++; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + // Sequence number generator to ensure clusters have unique names. + static std::atomic cluster_sequence_num; + + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + + // Compile if this is a cluster of >= min_cluster_size compilable operators. + if (effective_cluster_sizes[cluster] >= min_cluster_size) { + string& name = cluster_names[cluster]; + + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + graph.ToGraphDef(output); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2309e782d38725f8db025fbfda0bf0f63d18be --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { + +// Optimizes graphs by fusing ops where possible, resulting in more efficient +// execution. +class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { + public: + XlaFusionOptimizer() {} + ~XlaFusionOptimizer() override {} + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { + return Status::OK(); + } + + string name() const override { return "xla-fusion"; }; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) override; + + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, + const GraphDef& optimize_output, double result) override { + // Nothing to do for XlaFusionOptimizer. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5736760a878dc857a8558093054d0adc0f727398 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +class XlaFusionOptimizerTest : public grappler::GrapplerTest { + protected: + std::unordered_map GetClusters(const GraphDef& graph) { + std::unordered_map ids; + for (const NodeDef& node : graph.node()) { + string cluster; + if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node.name()] = cluster; + } + } + return ids; + } +}; + +TEST_F(XlaFusionOptimizerTest, Chains) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, FusableOps) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(clusters["C"], clusters["E"]); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp( + "Add", a, b, + builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + ops::UnaryOp("Cos", e, + builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, CompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 26842fbe5cc110fa9ce7a2767d245484fd67556d..c0d86a28c7698c302e28bab972bb2f847cc00ca4 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -49,7 +49,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device); + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 9e098c46f422b436c722bb909dc58930ab7c0ef6..661187f4a873b03b8d013aa74cb6b6315bb4e2eb 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -51,7 +51,9 @@ Status XlaInterpreterDeviceFactory::CreateDevices( TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device)); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d0c7a9365125708b2af43f87c7617d8d84050a61..5ceccc769fa2e95d4cf4d2b4ebd8dbf312ebdfd0 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -176,6 +176,21 @@ void XlaComputationLaunchContext::PopulateOutputs( } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + // If the on-host-shape isn't a tuple, create a new single-element tuple + // buffer with a nullptr root index table. This allows the code below to treat + // output as a tuple unconditionally. + if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + ShapedBuffer nontuple_buffer = output.release(); + ShapedBuffer buffer( + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), + output.platform(), output.device_ordinal()); + buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), + /*source_base_index=*/{}, + /*target_base_index=*/{0}); + output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -230,9 +245,14 @@ void XlaComputationLaunchContext::PopulateOutputs( Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); + } } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index a7211c9c7e281a8141d5671b345c628441b2359d..3c44c4ae6df7f3e2d60d8933561c0c71888e8c3f 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { -/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) { +/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { if (tensor->NumElements() == 0) { return nullptr; } @@ -27,8 +27,8 @@ namespace tensorflow { return xla_tensor; } -/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { - return FromTensor(const_cast(tensor)); +/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) { + return tensor.RefCountIsOne(); } /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( @@ -67,6 +67,8 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, index_to_buffer.second = buffer.Forget(); } + VLOG(4) << shaped_buffer.ToString(); + set_shaped_buffer(std::move(shaped_buffer)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6b29c82ec11e39ad525663991e179443c2b6dca7..c54001a999998f45c0cdacd752ca4036f0792857 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -34,10 +34,9 @@ class XlaTensor { public: // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast // fails. - static XlaTensor* FromTensor(Tensor* tensor); - // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast - // fails. - static const XlaTensor* FromTensor(const Tensor* tensor); + static XlaTensor* FromTensor(const Tensor* tensor); + + static bool RefCountIsOne(const Tensor& tensor); // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in // which case the returned value is shaped_buffer()->root_buffer(), or a @@ -62,6 +61,10 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } + xla::ShapedBuffer& shaped_buffer() { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 2a88743c80559a30a6f27eafebdb838c9c3d9949..273641f1978f2aa13265d491f15f0994c08bb0e7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -51,6 +51,38 @@ py_library( ], ) +py_library( + name = "test_utils", + testonly = 1, + srcs = ["test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//third_party/py/numpy", + ], +) + +py_test( + name = "xla_test_test", + size = "small", + srcs = ["xla_test_test.py"], + deps = [ + ":xla_test", + ], +) + +tf_xla_py_test( + name = "adadelta_test", + size = "medium", + srcs = ["adadelta_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "adagrad_test", size = "small", @@ -120,6 +152,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "categorical_op_test", size = "small", @@ -196,9 +241,11 @@ tf_xla_py_test( name = "oom_test", size = "medium", srcs = ["oom_test.py"], + # TODO(b/80081500): Re-enable on GPU. Disabled on 2018-05-21. disabled_backends = [ "cpu", "cpu_ondemand", + "gpu", ], tags = [ # Allocates very large amounts of memory and does not work under TSAN. @@ -223,6 +270,7 @@ tf_xla_py_test( srcs = ["conv2d_test.py"], shard_count = 10, deps = [ + ":test_utils", ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", @@ -230,6 +278,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -335,6 +384,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "fifo_queue_test", + size = "medium", + srcs = ["fifo_queue_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -520,17 +583,48 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "proximal_adagrad_test", + size = "medium", + srcs = ["proximal_adagrad_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "proximal_gradient_descent_test", + size = "medium", + srcs = ["proximal_gradient_descent_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], - # TODO(b/31361304): enable RNG ops on GPU when parallelized. disabled_backends = [ + # TODO(b/110300529): RngNormal doesn't return values with the expected variance + "cpu", + "cpu_ondemand", + # TODO(b/31361304): enable RNG ops on GPU when parallelized. "gpu", ], deps = [ ":xla_test", + "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -647,6 +741,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sparse_to_dense_op_test", + size = "small", + srcs = ["sparse_to_dense_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_ops", + ], +) + tf_xla_py_test( name = "stack_ops_test", size = "small", @@ -726,9 +833,10 @@ tf_xla_py_test( tf_xla_py_test( name = "fused_batchnorm_test", - size = "small", + size = "medium", srcs = ["fused_batchnorm_test.py"], deps = [ + ":test_utils", ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:math_ops", @@ -738,6 +846,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", ], ) @@ -813,6 +922,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sort_ops_test", + size = "small", + srcs = ["sort_ops_test.py"], + # Times out in fastbuild mode. + tags = ["optonly"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3c09c66e72c4de141b64cea3c4693fabb7b2a2 --- /dev/null +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adadelta Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adadelta + + +class AdadeltaOptimizerTest(xla_test.XLATestCase): + + def testBasic(self): + num_updates = 4 # number of ADADELTA steps to perform + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + for grad in [0.2, 0.1, 0.01]: + for lr in [1.0, 0.5, 0.1]: + var0_init = [1.0, 2.0] + var1_init = [3.0, 4.0] + var0 = resource_variable_ops.ResourceVariable( + var0_init, dtype=dtype) + var1 = resource_variable_ops.ResourceVariable( + var1_init, dtype=dtype) + + grads = constant_op.constant([grad, grad], dtype=dtype) + + accum = 0.0 + accum_update = 0.0 + + # ADADELTA gradient optimizer + rho = 0.95 + epsilon = 1e-8 + adadelta_opt = adadelta.AdadeltaOptimizer( + learning_rate=lr, rho=rho, epsilon=epsilon) + adadelta_update = adadelta_opt.apply_gradients( + zip([grads, grads], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + opt_vars = adadelta_opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var0._shared_name) + self.assertStartsWith(opt_vars[2].name, var1._shared_name) + self.assertStartsWith(opt_vars[3].name, var1._shared_name) + self.assertEqual(4, len(opt_vars)) + # Assign slots + slot = [None] * 2 + slot_update = [None] * 2 + self.assertEqual(["accum", "accum_update"], + adadelta_opt.get_slot_names()) + slot[0] = adadelta_opt.get_slot(var0, "accum") + self.assertEquals(slot[0].get_shape(), var0.get_shape()) + self.assertFalse(slot[0] in variables.trainable_variables()) + + slot_update[0] = adadelta_opt.get_slot(var0, "accum_update") + self.assertEquals(slot_update[0].get_shape(), var0.get_shape()) + self.assertFalse(slot_update[0] in variables.trainable_variables()) + + slot[1] = adadelta_opt.get_slot(var1, "accum") + self.assertEquals(slot[1].get_shape(), var1.get_shape()) + self.assertFalse(slot[1] in variables.trainable_variables()) + + slot_update[1] = adadelta_opt.get_slot(var1, "accum_update") + self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) + self.assertFalse(slot_update[1] in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose(var0_init, self.evaluate(var0)) + self.assertAllClose(var1_init, self.evaluate(var1)) + + update = [None] * num_updates + tot_update = 0 + for step in range(num_updates): + # Run adadelta update for comparison + self.evaluate(adadelta_update) + + # Perform initial update without previous accum values + accum = accum * rho + (grad**2) * (1 - rho) + update[step] = ( + np.sqrt(accum_update + epsilon) * + (1. / np.sqrt(accum + epsilon)) * grad) + accum_update = ( + accum_update * rho + (update[step]**2) * (1.0 - rho)) + tot_update += update[step] * lr + + # Check that the accumulators have been updated + for slot_idx in range(2): + self.assertAllCloseAccordingToType( + np.array([accum, accum], dtype=dtype), + self.evaluate(slot[slot_idx]), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array([accum_update, accum_update], dtype=dtype), + self.evaluate(slot_update[slot_idx]), + rtol=1e-5) + + # Check that the parameters have been updated + self.assertAllCloseAccordingToType( + np.array( + [var0_init[0] - tot_update, var0_init[1] - tot_update], + dtype=dtype), + self.evaluate(var0), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array( + [var1_init[0] - tot_update, var1_init[1] - tot_update], + dtype=dtype), + self.evaluate(var1), + rtol=1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 9a93b3216404d8ed21fd6c57757bec1730c119b4..d775850a80e9f83f7b2c9f1cf8997dd50e229635 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import adagrad -class AdagradOptimizerTest(XLATestCase): +class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 3215dc36e5b2d517aa951db1b0d41188185ef93a..03554d6933aca39b428c6af4be0c78e2c7ccb0c9 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -48,7 +48,7 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(XLATestCase): +class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 1e4dd32916c3a40282735fb8f75670b0e9ef0dc9..9cb3d0454608c37e669d5b4360bc39bf1bf7e68c 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class BinaryOpsTest(XLATestCase): +class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): @@ -226,6 +226,11 @@ class BinaryOpsTest(XLATestCase): np.array([0b1, 0b101, 0b1000], dtype=dtype), np.array([0b0, 0b101, 0b1001], dtype=dtype), expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_xor, + np.array([0b1, 0b111, 0b1100], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b010, 0b0101], dtype=dtype)) lhs = np.array([0, 5, 3, 14], dtype=dtype) rhs = np.array([5, 0, 7, 11], dtype=dtype) @@ -1216,6 +1221,24 @@ class BinaryOpsTest(XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + def testConjugateTranspose(self): + for dtype in self.complex_types: + self._testBinary( + array_ops.conjugate_transpose, + np.zeros(shape=[1, 0, 4], dtype=dtype), + np.array([1, 2, 0], dtype=np.int32), + expected=np.zeros(shape=[0, 4, 1], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([0, 1], dtype=np.int32), + expected=np.array([[1 + 1j, 2 - 2j], [3 + 3j, 4 - 4j]], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([1, 0], dtype=np.int32), + expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype)) + def testCross(self): for dtype in self.float_types: self._testBinary( diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ef4d5f6322b7ae79b051795b5af7e6f7f1e55550 --- /dev/null +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bucketize_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BucketizationOpTest(xla_test.XLATestCase): + + def testInt(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual(expected_out, + sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) + + def testFloat(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual( + expected_out, + sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) + + def test2DInput(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + self.assertAllEqual( + expected_out, sess.run(op, + {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + + def testInvalidBoundariesOrder(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Expected sorted boundaries"): + sess.run(op, {p: [-5, 0]}) + + def testBoundariesNotList(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Expected list.*"): + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + math_ops._bucketize(p, boundaries=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 035cdea1786d39f3d21bb63be5c8ccffe1608bdf..a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -22,7 +22,7 @@ import collections import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest # TODO(srvasude): Merge this with # third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. -class CategoricalTest(XLATestCase): +class CategoricalTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def output_dtypes(self): diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 1a8989d7c2f617525c301f30fd899a01362310bf..d2867278af93812eae804b66a7a6b706f98fa600 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CholeskyOpTest(XLATestCase): +class CholeskyOpTest(xla_test.XLATestCase): # Cholesky defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/cholesky) diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 574f82fc717818334ac5d72ebef2191f1c18e669..e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" -class ClusteringTest(XLATestCase): +class ClusteringTest(xla_test.XLATestCase): def testAdd(self): val1 = np.array([4, 3, 2, 1], dtype=np.float32) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index f10973e19f1945515b776cf86349445ed7334629..d9ad4281477e87f79f2ecb52989ae86a5030d0cc 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ConcatTest(XLATestCase): +class ConcatTest(xla_test.XLATestCase): def testHStack(self): with self.test_session(): @@ -292,7 +292,7 @@ class ConcatTest(XLATestCase): array_ops.concat([scalar, scalar, scalar], dim) -class ConcatOffsetTest(XLATestCase): +class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: @@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase): self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) -class PackTest(XLATestCase): +class PackTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 62577b70ce96e220d79978f01614b2d9a3647680..98d41ba7edd52eedbf035097a48a1ce2ac7d5e9e 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -22,9 +22,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import test_utils +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops @@ -32,7 +34,15 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class Conv2DTest(XLATestCase): +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), +) + + +class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -40,6 +50,8 @@ class Conv2DTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that tf.nn.conv2d produces the expected value. @@ -51,8 +63,12 @@ class Conv2DTest(XLATestCase): strides: Strides. dilations: RHS dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ + total_size_1 = np.prod(input_sizes) total_size_2 = np.prod(filter_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) @@ -62,6 +78,18 @@ class Conv2DTest(XLATestCase): dilations = [1, 1] dilations = [1] + dilations + [1] + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src, + data_format_dst) + x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) @@ -71,12 +99,14 @@ class Conv2DTest(XLATestCase): t2, strides=strides, padding=padding, - data_format="NHWC", + data_format=data_format_dst, dilations=dilations) + value = sess.run(out, {t1: x1, t2: x2}) self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = np.reshape([ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 @@ -86,9 +116,12 @@ class Conv2DTest(XLATestCase): filter_sizes=[1, 1, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3]) self._VerifyValues( @@ -96,9 +129,12 @@ class Conv2DTest(XLATestCase): filter_sizes=[2, 2, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter2x1Dilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter2x1Dilation(self, data_format): expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]]) self._VerifyValues( input_sizes=[1, 4, 4, 1], @@ -106,9 +142,12 @@ class Conv2DTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = np.reshape([ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0 @@ -118,18 +157,24 @@ class Conv2DTest(XLATestCase): filter_sizes=[1, 2, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3]) self._VerifyValues( @@ -137,47 +182,61 @@ class Conv2DTest(XLATestCase): filter_sizes=[2, 2, 3, 3], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2DEmptyDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyDilation(self, data_format): self._VerifyValues( input_sizes=[0, 2, 3, 3], filter_sizes=[1, 1, 3, 3], strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0, 2, 3, 3])) - def testConv2D2x2FilterDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterDilation(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3])) - def testConv2D1x2FilterDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterDilation(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 2, 3, 3], strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.array([[[[231, 252, 273], [384, 423, 462]], [[690, 765, 840], [843, 936, 1029]]]])) - def testConv2DKernelSizeMatchesInputSizeDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.reshape([108, 128], [1, 1, 1, 2])) -class Conv2DBackpropInputTest(XLATestCase): +class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -186,6 +245,8 @@ class Conv2DBackpropInputTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output. @@ -198,8 +259,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides: Strides. dilations: Dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ + total_size_1 = np.prod(filter_sizes) total_size_2 = np.prod(out_backprop_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes) @@ -209,6 +274,23 @@ class Conv2DBackpropInputTest(XLATestCase): if dilations is not None: dilations = [1] + dilations + [1] + expected = np.reshape(expected, input_sizes) + + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src, + data_format_dst) + x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats( + out_backprop_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + if dilations is not None: + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) @@ -220,12 +302,14 @@ class Conv2DBackpropInputTest(XLATestCase): strides=strides, dilations=dilations, padding=padding, - data_format="NHWC") + data_format=data_format_dst) + value = sess.run(out, {t1: x1, t2: x2}) self.assertAllEqual(input_sizes, value.shape) - self.assertAllClose(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = [ 5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127, 41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71, @@ -237,9 +321,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 4, 4, 2], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width5(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width5(self, data_format): expected_output = [1, 2, 0, 2, 4] self._VerifyValues( input_sizes=[1, 1, 5, 1], @@ -247,9 +334,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width6(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width6(self, data_format): expected_output = [1, 2, 0, 2, 4, 0] self._VerifyValues( input_sizes=[1, 1, 6, 1], @@ -257,9 +347,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width7(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width7(self, data_format): expected_output = [1, 2, 0, 2, 4, 0, 0] self._VerifyValues( input_sizes=[1, 1, 7, 1], @@ -267,9 +360,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterC1Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterC1Same(self, data_format): expected_output = [1, 4, 7, 7, 23, 33] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -277,9 +373,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = [ 14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604, 437, 482, 527 @@ -290,9 +389,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterSame(self, data_format): expected_output = [ 14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217, 1505, 1487, 1883, 2279 @@ -303,9 +405,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 3], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -313,9 +418,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 3, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterSame(self, data_format): expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -323,9 +431,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 3, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12] self._VerifyValues( input_sizes=[1, 3, 5, 1], @@ -333,9 +444,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 2, 1], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = [1, 2, 2, 3, 4, 6] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -343,9 +457,13 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 6, 1], filter_sizes=[2, 2, 1, 1], @@ -353,9 +471,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 4, 7, 10, 13, 10, 0, 0, 0, 0, 0, 0, 3, 10, 17, 24, 31, 20]) - def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -363,9 +484,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 0, 2, 3, 0, 4]) - def testConv2DEmptyBackpropInputDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyBackpropInputDilation1x2(self, data_format): self._VerifyValues( input_sizes=[0, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -373,9 +497,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0])) - def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self, data_format): # The GPU version of this test is not very stable. So adjusting the # error threshold to 1e-4. self._VerifyValues( @@ -385,12 +512,16 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[ 14, 32, 50, 68, 86, 104, 0, 0, 0, 0, 0, 0, 122, 140, 158, 176, 194, 212 ]) - def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -398,10 +529,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) -class Conv2DBackpropFilterTest(XLATestCase): +class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -410,6 +543,8 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output. @@ -422,6 +557,9 @@ class Conv2DBackpropFilterTest(XLATestCase): strides: Stride. dilations: Dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ @@ -434,6 +572,23 @@ class Conv2DBackpropFilterTest(XLATestCase): if dilations is not None: dilations = [1] + dilations + [1] + expected = np.reshape(expected, filter_sizes) + + # Convert between data formats. + x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src, + data_format_dst) + x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats( + out_backprop_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + if dilations is not None: + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) @@ -445,13 +600,14 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=strides, dilations=dilations, padding=padding, - data_format="NHWC") + data_format=data_format_dst) value = sess.run(tensor, {t1: x1, t2: x2}) self.assertAllEqual(filter_sizes, value.shape) - self.assertAllClose(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = [8056, 8432, 8312, 8704, 8568, 8976] self._VerifyValues( input_sizes=[1, 4, 4, 3], @@ -459,9 +615,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 4, 4, 2], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = [120, 141] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -469,9 +628,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 3, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterDepth1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterDepth1(self, data_format): expected_output = [5, 8, 14, 17] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -479,9 +641,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = [ 17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72, 62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87, @@ -493,9 +658,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width5(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width5(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 5, 1], @@ -503,9 +671,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width6(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width6(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 6, 1], @@ -513,9 +684,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width7(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width7(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 7, 1], @@ -523,9 +697,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3Filter(self, data_format): expected_output = [5, 8, 11] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -533,9 +710,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3FilterSame(self, data_format): expected_output = [20, 30, 20] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -543,9 +723,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 4, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3FilterSameOutbackprop2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3FilterSameOutbackprop2(self, data_format): expected_output = [7, 10, 3] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -553,9 +736,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterC1Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterC1Same(self, data_format): expected_output = [91, 58, 32, 17] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -563,9 +749,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = [92, 102, 112] self._VerifyValues( input_sizes=[1, 3, 5, 1], @@ -573,9 +762,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 2, 2, 1], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = [7, 2, 16, 5] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -583,9 +775,13 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 6, 1], filter_sizes=[2, 2, 1, 1], @@ -593,9 +789,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[55, 70, 235, 250]) - def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -603,9 +802,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 3, 4, 6]) - def testConv2DEmptyBackpropFilterDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyBackpropFilterDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 0], @@ -613,9 +815,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0])) - def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self, data_format): self._VerifyValues( input_sizes=[1, 3, 4, 3], filter_sizes=[2, 2, 3, 3], @@ -623,13 +828,17 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[ 17, 22, 27, 22, 29, 36, 27, 36, 45, 47, 64, 81, 52, 71, 90, 57, 78, 99, 137, 190, 243, 142, 197, 252, 147, 204, 261, 167, 232, 297, 172, 239, 306, 177, 246, 315 ]) - def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -637,6 +846,8 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 2, 3, 6, 7, 14, 9, 18]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 3bebf46511cbc471d3fbbbe92d28511fcc717387..31ee41f04f27d387415e9fa2c4fa70b33cab7b04 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest # Test cloned from # tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py -class Conv3DBackpropFilterV2GradTest(XLATestCase): +class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): with self.test_session(), self.test_scope(): @@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase): # Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py -class Conv3DTransposeTest(XLATestCase): +class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index b0bf1b79d6c8be3170db3079e1c83ceead0584de..865f60ccab46ec6829e49409508303052944e13b 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -46,8 +46,8 @@ def InLabels(labels, substr): def XlaLaunchOpCount(labels): - """Count how many _XlaLaunch labels are present.""" - return sum("_XlaLaunch(" in x for x in labels) + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) class DenseLayerTest(test.TestCase): @@ -55,7 +55,7 @@ class DenseLayerTest(test.TestCase): def testDenseLayerAutoJit(self): """Tests dense layer compilation in auto-jit mode. - Dense layer should be compiled into a single _XlaLaunch op in auto-jit mode. + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. """ os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") @@ -83,7 +83,7 @@ class DenseLayerTest(test.TestCase): """Tests that the dense layer node is properly compiled in jit scope. Dense layer with static shape input tensor should be compiled into a single - _XlaLaunch op by XLA. + XlaLaunch op by XLA. """ with self.test_session() as sess: @@ -110,7 +110,7 @@ class DenseLayerTest(test.TestCase): Dense layer uses shape op to get shape of input tensor if its shape is not fully defined. XLA does not cluster shape op with other operators. But in experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO _XlaLaunch ops. + cluster, causing dense layer to be split into TWO XlaLaunch ops. """ with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 0a0d335ca76dd7ec7ca3b12f9e8a83b596daa07e..98dc73e189f99b7b811487756659d89dacb97d8a 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -114,7 +114,7 @@ def CheckGradConfigsToTest(): yield i, f, o, s, p -class DepthwiseConv2DTest(XLATestCase): +class DepthwiseConv2DTest(xla_test.XLATestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native # produce the same results. It also tests that NCHW and NWHC @@ -153,7 +153,7 @@ class DepthwiseConv2DTest(XLATestCase): dtype=data_type).reshape(filter_in_sizes) with self.test_session() as sess: if data_type == np.float32: - tolerance = 1e-5 + tolerance = 1e-4 else: self.assertEqual(data_type, np.float64) tolerance = 1e-8 @@ -339,7 +339,7 @@ class DepthwiseConv2DTest(XLATestCase): gpu_value = _GetVal(use_xla=True) cpu_value = _GetVal(use_xla=False) - self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4) + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-3) def testDepthwiseConv2DInputGradCompare(self): for index, (input_size, filter_size, output_size, stride, diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 6a46d2ec3e7aee3a4ecfbf1ab9f622d8eb659e3c..154e36b10e6da409606ae6022aaf53e34c8e37cc 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DynamicUpdateSliceOpsTest(XLATestCase): +class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index c109c27abe2f145685f83251e1d21ec8ddad563a..edd78153b56bb5bf1c268936fb82a60581389733 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import googletest -class DynamicStitchTest(XLATestCase): +class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 311f2ada15a68d860f2f3a89f9ee90cea1f7fd95..3524666499cbb2ef3eae2bb3b314dda0a9be64c8 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -31,14 +31,16 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest +from tensorflow.python.training import adam -class EagerTest(XLATestCase): +class EagerTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -47,6 +49,21 @@ class EagerTest(XLATestCase): product = three * five self.assertAllEqual(15, product) + def testGradientTape(self): + with self.test_scope(): + + x = constant_op.constant(1.0) + y = constant_op.constant(10.0) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + tape.watch(y) + a = x + y + x * y + da_dx = tape.gradient(a, x) + da_dy = tape.gradient(a, y) + + self.assertEqual(11.0, da_dx.numpy()) + self.assertEqual(2.0, da_dy.numpy()) + def testExecuteListOutputLen0(self): with self.test_scope(): empty = constant_op.constant([], dtype=dtypes.float32) @@ -117,6 +134,15 @@ class EagerTest(XLATestCase): v.assign_add(2.0) self.assertEqual(3.0, v.numpy()) + def testReadAssignRead(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + val1 = v.read_value() + v.assign_add(2.0) + val2 = v.read_value() + self.assertEqual(1.0, val1.numpy()) + self.assertEqual(3.0, val2.numpy()) + def testGradient(self): def f(x): return x @@ -136,12 +162,135 @@ class EagerTest(XLATestCase): grads = backprop.implicit_grad(f)() self.assertEqual(2., grads[0][0].numpy()) + def testMultipleVariableReads(self): + # This test makes sure consecutive variable reads don't copy + # the underlying memory. + with self.test_scope(): + # Create 128MiB variables + var = resource_variable_ops.ResourceVariable( + array_ops.ones([32, 1024, 1024])) + + # Read the same variable 100 times. If the underlying tensor + # is not copied, this is a trivial operation. If it is copied, + # this will eat over 13GB and OOM. + values = [] + for _ in range(100): + values.append(var.value()) + + # The shape, shape_n, size, and rank are tested here because their + # execution kernels (as opposed to compilation only tf2xla kernels) + # are distincts from tf2xla kernels. + + def testShape(self): + def const(value): + return array_ops.shape( + constant_op.constant(value)).numpy() + + def ones(value): + return array_ops.shape( + array_ops.ones(value)).numpy() -class EagerFunctionTest(XLATestCase): + with self.test_scope(): + # Shapes of directly constructed tensors + self.assertAllEqual([], const(3)) + self.assertAllEqual([3], const([1.0, 2.0, 3.0])) + self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) + self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) + + # Shapes of tensors created by op running on device + # We make this distinction because directly constructed tensors + # are treated differently in a few places that can influence shape: + # - they always have on_host_tensor + # - they and their shapes can be cached + # - they end up on device via a copy, instead of as program output + self.assertAllEqual([], ones([])) + self.assertAllEqual([3], ones([3])) + self.assertAllEqual([2, 2], ones([2, 2])) + self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) + + def testShapeN(self): + with self.test_scope(): + # Shapes of directly constructed tensors + shapes = array_ops.shape_n([ + constant_op.constant(1.0), + constant_op.constant([1.0, 2.0, 3.0]), + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + # Shapes of tensors created by op running on device + shapes = array_ops.shape_n([ + array_ops.ones([]), + array_ops.ones([3]), + array_ops.ones([2, 2])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + def testSize(self): + with self.test_scope(): + self.assertEqual( + 1, array_ops.size(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 4, array_ops.size( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testRank(self): + with self.test_scope(): + self.assertEqual( + 0, array_ops.rank(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 2, array_ops.rank( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testAdam(self): + with self.test_scope(): + optimizer = adam.AdamOptimizer(0.1) + x = resource_variable_ops.ResourceVariable(10.0) + with backprop.GradientTape() as tape: + y = x * x + dy_dx = tape.gradient(y, x) + optimizer.apply_gradients([(dy_dx, x)]) + self.assertAlmostEqual(9.9, x.numpy(), places=3) + + def testAdamSparse(self): + with ops.device('/cpu:0'): + # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates + # are not implemented on TPU. + embedding_matrix = resource_variable_ops.ResourceVariable( + array_ops.ones([3, 2])) + + with self.test_scope(): + with backprop.GradientTape() as tape: + embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) + y = math_ops.reduce_sum(embedding) + dy_dx = tape.gradient(y, embedding_matrix) + self.assertIsInstance(dy_dx, ops.IndexedSlices) + optimizer = adam.AdamOptimizer(0.1) + # The gradient application operations will run on CPU because optimizer + # updates are always collocated with the variable. + optimizer.apply_gradients([(dy_dx, embedding_matrix)]) + + # This assign_add will run on CPU because when an input to an + # operation is a resource, this operation is placed on the resource's + # device by the eager runtime. + embedding_matrix.assign_add(array_ops.ones([3, 2])) + + self.assertAllClose([[2.0, 2.0], + [1.9, 1.9], + [2.0, 2.0]], embedding_matrix.numpy()) + + +class EagerFunctionTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): - matmul = function.defun(math_ops.matmul, compiled=True) + matmul = function.defun(math_ops.matmul) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq = matmul(t, t, transpose_a=True) self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) @@ -163,7 +312,7 @@ class EagerFunctionTest(XLATestCase): def model(x): x = conv(x) return pool(x) - model = function.defun(model, compiled=True) + model = function.defun(model) x = array_ops.ones([1, 4, 4, 1]) y = model(x) @@ -173,7 +322,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) - @function.defun(compiled=True) + @function.defun def f(): return v.read_value() @@ -188,7 +337,7 @@ class EagerFunctionTest(XLATestCase): v.assign_add(1.0) return v - f = function.defun(f, compiled=True) + f = function.defun(f) var = f(v) self.assertEqual(2.0, var.numpy()) @@ -216,7 +365,7 @@ class EagerFunctionTest(XLATestCase): d = r2 * v2 return a, b, c, d - foo = function.defun(foo, compiled=True) + foo = function.defun(foo) c1 = [0, 0] c2 = array_ops.ones([2], dtype=dtypes.int32) @@ -238,7 +387,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v0 = resource_variable_ops.ResourceVariable(5.0) - @function.defun(compiled=True) + @function.defun def f(x): x = v0 * v0 * x return x @@ -251,6 +400,75 @@ class EagerFunctionTest(XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testSliceInDefun(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x, y): + return x[0::2, y:, ...] + + x = array_ops.ones([2, 3, 4]) + y = array_ops.ones([], dtype=dtypes.int32) + with backprop.GradientTape() as tape: + tape.watch(x) + tape.watch(y) + z = f(x, y) + dz = tape.gradient(z, x) + + self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) + self.assertAllEqual((2, 3, 4), dz.shape.as_list()) + + +class ExcessivePaddingTest(xla_test.XLATestCase): + """Test that eager execution works with TPU flattened tensors. + + Tensors that would normally be excessively padded when written + to TPU memory are reshaped to 1-D flat tensors. + + This test case verifies that such tensors work with eager execution. + + The flattening currently only happens on TPU, but tests should work + fine with all backends as flattening is transparent. + """ + + def testFromConstant(self): + with self.test_scope(): + # Create constant of shape [100, 2, 1]. This tensor would be + # excessively padded on TPU. + tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) + # Use reduce_sum since it requires correctly working with + # a particular dimension. + reduced = math_ops.reduce_sum(tensor, axis=1) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testFromOperation(self): + with self.test_scope(): + tensor = array_ops.ones([3, 100, 2, 2]) + reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) + self.assertAllEqual(100 * [12.0], reduced) + + def testAsFunctionInput(self): + with self.test_scope(): + + @function.defun + def f(x): + return math_ops.reduce_sum(x, axis=2) + + tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) + reduced = f(tensor) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testAsFunctionOutput(self): + with self.test_scope(): + + @function.defun + def f(x): + return x * constant_op.constant(100 * [[[10.0, 2.0]]]) + + y = f(3) + reduced = math_ops.reduce_sum(y, axis=2) + self.assertAllEqual(100 * [[36.0]], reduced) + if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 0361702e7af778176daed941d64e61198090daf2..5529fdbb090315e1d7f47589777d8a538c90db2b 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ExtractImagePatches(XLATestCase): +class ExtractImagePatches(xla_test.XLATestCase): """Functional tests for ExtractImagePatches op.""" def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index dfe9400ef0f55ca011d4e23ba5d735899ca2e054..c48ab178bf53558084fb500b2811c6f0b77a7943 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -17,14 +17,14 @@ from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import googletest -class FakeQuantWithMinMaxArgsTest(XLATestCase): +class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgs operation.""" # 8 bits, wide range. @@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): +class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" # 8 bits, wide range. @@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsTest(XLATestCase): +class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVars operation.""" # 8 bits, wide range. @@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): +class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" # 8 bits, wide range. diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index afb5fa4bb4fefe5bc2ecded826143ffc83c2b559..c64ea249ecb97991952a960a6d16e1bb3be35b17 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -23,10 +23,11 @@ import itertools import numpy as np import scipy.signal as sps -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import spectral_ops from tensorflow.python.platform import googletest @@ -57,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) -class FFTTest(XLATestCase): +class FFTTest(xla_test.XLATestCase): def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, tf_method): @@ -97,8 +98,11 @@ class FFTTest(XLATestCase): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) out = signal.stft(ph, ws, hs) + grad = gradients_impl.gradients(out, ph, + grad_ys=array_ops.ones_like(out)) - value = sess.run(out, {ph: data}) + # For gradients, we simply verify that they compile & execute. + value, _ = sess.run([out, grad], {ph: data}) self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) def testFFT(self): diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0f64cc87cde77fbbef6c4e570879e992bc34bafa --- /dev/null +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class FIFOQueueTest(xla_test.XLATestCase): + + def testEnqueue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + enqueue_op.run() + + def testEnqueueWithShape(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) + enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) + enqueue_correct_op.run() + with self.assertRaises(ValueError): + q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) + self.assertEqual(1, q.size().eval()) + + def testMultipleDequeues(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue([1])) + self.evaluate(q.enqueue([2])) + self.evaluate(q.enqueue([3])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + def testQueuesDontShare(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) + + def testEnqueueDictWithoutNames(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + with self.assertRaisesRegexp(ValueError, "must have names"): + q.enqueue({"a": 12.0}) + + def testParallelEnqueue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Run one producer thread for each element in elems. + def enqueue(enqueue_op): + sess.run(enqueue_op) + + threads = [ + self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Dequeue every element using a single thread. + results = [] + for _ in xrange(len(elems)): + results.append(dequeued_t.eval()) + self.assertItemsEqual(elems, results) + + def testParallelDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Enqueue every element using a single thread. + for enqueue_op in enqueue_ops: + enqueue_op.run() + + # Run one consumer thread for each element in elems. + results = [] + + def dequeue(): + results.append(sess.run(dequeued_t)) + + threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, results) + + def testDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + vals = dequeued_t.eval() + self.assertEqual([elems[i]], vals) + + def testEnqueueAndBlockingDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + def enqueue(): + # The enqueue_ops should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + for enqueue_op in enqueue_ops: + sess.run(enqueue_op) + + results = [] + + def dequeue(): + for _ in xrange(len(elems)): + results.append(sess.run(dequeued_t)) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + for elem, result in zip(elems, results): + self.assertEqual([elem], result) + + def testMultiEnqueueAndDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) + elems = [(5, 10.0), (10, 20.0), (15, 30.0)] + enqueue_ops = [q.enqueue((x, y)) for x, y in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + x_val, y_val = sess.run(dequeued_t) + x, y = elems[i] + self.assertEqual([x], x_val) + self.assertEqual([y], y_val) + + def testQueueSizeEmpty(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + self.assertEqual([0], q.size().eval()) + + def testQueueSizeAfterEnqueueAndDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue() + size = q.size() + self.assertEqual([], size.get_shape()) + + enqueue_op.run() + self.assertEqual(1, size.eval()) + dequeued_t.op.run() + self.assertEqual(0, size.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 8e6407dffdac3adbcda8cbca2109ef9196defa8c..1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent -class FtrlOptimizerTest(XLATestCase): +class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7..04fba444460e714ce96205361ac02ed492206b04 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class FunctionTest(XLATestCase): +class FunctionTest(xla_test.XLATestCase): def testFunction(self): """Executes a simple TensorFlow function.""" diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index a80d69fa5f5099b8a8b67df0da9c92b957e9d194..132e42ac7a28d0769b0de12ea0cee6eae752b245 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import test_utils +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -28,7 +30,7 @@ from tensorflow.python.ops import nn from tensorflow.python.platform import test -class FusedBatchNormTest(XLATestCase): +class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): def _reference_training(self, x, scale, offset, epsilon, data_format): if data_format != "NHWC": @@ -63,24 +65,36 @@ class FusedBatchNormTest(XLATestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - def testInference(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) - offset_val = np.random.random_sample(scale_shape).astype(np.float32) - data_format = "NHWC" + epsilon = 0.001 + data_format_src = "NHWC" + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format_src) + with self.test_session() as sess, self.test_scope(): # To avoid constant folding - t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + y_ref_converted = test_utils.ConvertBetweenDataFormats( + y_ref, data_format_src, data_format) + + t_val = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") - epsilon = 0.001 - y_ref, mean_ref, var_ref = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format) y, mean, variance = nn.fused_batch_norm( t_val, scale, @@ -91,31 +105,39 @@ class FusedBatchNormTest(XLATestCase): data_format=data_format, is_training=False) - y_val, _, _ = sess.run( - [y, mean, - variance], {t_val: x_val, - scale: scale_val, - offset: offset_val}) - self.assertAllClose(y_val, y_ref, atol=1e-3) + y_val, _, _ = sess.run([y, mean, variance], { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val + }) + self.assertAllClose(y_val, y_ref_converted, atol=1e-3) - def _testLearning(self, use_gradient_checker): + def _testLearning(self, use_gradient_checker, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) - offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) - data_format = "NHWC" + epsilon = 0.001 + data_format_src = "NHWC" + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format_src) + with self.test_session() as sess, self.test_scope(): # To avoid constant folding - t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + y_ref_converted = test_utils.ConvertBetweenDataFormats( + y_ref, data_format_src, data_format) + + t_val = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") - epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, scale, @@ -129,33 +151,50 @@ class FusedBatchNormTest(XLATestCase): if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, - x_shape, + x_val_converted.shape, y, - x_shape, + x_val_converted.shape, extra_feed_dict={ - t_val: x_val, + t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) - y_val, mean_val, var_val = sess.run( - [y, mean, var], {t_val: x_val, - scale: scale_val, - offset: offset_val}) - y_ref, mean_ref, var_ref = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format) + y_val, mean_val, var_val = sess.run([y, mean, var], { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val + }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) - self.assertAllClose(y_val, y_ref, atol=1e-3) + self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - def testLearning(self): - self._testLearning(False) + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testLearning(self, data_format): + self._testLearning(False, data_format) - def testLearningWithGradientChecker(self): - self._testLearning(True) + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testLearningWithGradientChecker(self, data_format): + self._testLearning(True, data_format) - def testGradientTraining(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -167,33 +206,48 @@ class FusedBatchNormTest(XLATestCase): mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + data_format_src = "NHWC" + grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( + x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) with self.test_session() as sess, self.test_scope(): - grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") - x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + grad_val_converted = test_utils.ConvertBetweenDataFormats( + grad_val, data_format_src, data_format) + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + grad_x_ref_converted = test_utils.ConvertBetweenDataFormats( + grad_x_ref, data_format_src, data_format) + + grad = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="grad") + x = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=True) + grad, x, scale, mean, var, data_format=data_format, is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val }) - grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( - x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC") - - self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2) self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - def testGradientInference(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -204,33 +258,47 @@ class FusedBatchNormTest(XLATestCase): scale_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) + data_format_src = "NHWC" with self.test_session() as sess, self.test_scope(): - grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") - x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + grad_val_converted = test_utils.ConvertBetweenDataFormats( + grad_val, data_format_src, data_format) + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + + grad = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="grad") + x = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") with self.test_scope(): out = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad, + x, + scale, + mean, + var, + data_format=data_format, + is_training=False) grad_x, grad_scale, grad_offset, _, _ = out ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad, x, scale, mean, var, data_format=data_format, is_training=False) grad_x_val, grad_scale_val, grad_offset_val, = sess.run( [grad_x, grad_scale, grad_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val }) grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( [ref_x, ref_scale, ref_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 9378b1db7245c0da3e8298e7dcd972491616b0cd..23b0aed34fb460f50c241e5a920cb4f6f613b947 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GatherNdTest(XLATestCase): +class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): with self.test_session(): diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 1a8c4519118f69ce51ca9a5eb95a9d706c7766cc..e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual( [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + def testGatherPrecision(self): + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], + [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) + indices = np.array([1, 2, 3, 1]) + dtype = dtypes.float32 + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices_tf = constant_op.constant(indices) + gather_t = array_ops.gather(params, indices_tf) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + np_val = params_np[indices] + self.assertAllEqual(np_val, gather_val) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 42e637734c578fcc70473060cb156e172a0a1995..8b01ef96db3e8ab58850df234c2e05b764be52ba 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -25,7 +25,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape): return np.random.randint(0, 256, shape) / 256. -class RGBToHSVTest(XLATestCase): +class RGBToHSVTest(xla_test.XLATestCase): def testBatch(self): # Build an arbitrary RGB image @@ -65,9 +65,7 @@ class RGBToHSVTest(XLATestCase): join1 = array_ops.stack(split1) join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], - { - batch0: inp - }) + {batch0: inp}) # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) @@ -106,7 +104,7 @@ class RGBToHSVTest(XLATestCase): self.assertAllCloseAccordingToType(hsv_tf, hsv_np) -class AdjustContrastTest(XLATestCase): +class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): with self.test_session(): @@ -170,7 +168,7 @@ class AdjustContrastTest(XLATestCase): self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) -class AdjustHueTest(XLATestCase): +class AdjustHueTest(xla_test.XLATestCase): def testAdjustNegativeHue(self): x_shape = [2, 2, 3] @@ -305,7 +303,7 @@ class AdjustHueTest(XLATestCase): self._adjustHueTf(x_np, delta_h) -class AdjustSaturationTest(XLATestCase): +class AdjustSaturationTest(xla_test.XLATestCase): def _adjust_saturation(self, image, saturation_factor): image = ops.convert_to_tensor(image, name="image") @@ -401,18 +399,17 @@ class AdjustSaturationTest(XLATestCase): x = array_ops.placeholder(dtypes.float32, shape=x_shape) with self.test_scope(): y_fused = self._adjust_saturation(x, - scale).eval(feed_dict={ - x: x_np - }) + scale).eval(feed_dict={x: x_np}) self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) -class ResizeBilinearTest(XLATestCase): +class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, target_shape, - expected=None): + expected=None, + large_tolerance=False): if expected is None: self.fail("expected must be specified") with self.test_session() as sess, self.test_scope(): @@ -420,7 +417,11 @@ class ResizeBilinearTest(XLATestCase): resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=True) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) def _assertBackwardOpMatchesExpected(self, grads_np, @@ -555,6 +556,28 @@ class ResizeBilinearTest(XLATestCase): [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], dtype=np.float32)) + def testAlignCorners4x4To8x8(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array( + [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8], + expected=3 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)), + large_tolerance=True) + + def testAlignCorners8x8To16x16(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, + [16, 16], + expected=7 * (np.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), + large_tolerance=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 0310cdde660c912d593fe034fbfbd749f258fc1f..6e0db54b7a74b284dc7d18bcbb07c178c664c1e5 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -78,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: @@ -125,7 +125,7 @@ class JitLaunchTest(test.TestCase): for (x, y) in zip(compiled, direct): self.assertAllClose(x, y, rtol=1e-1) else: - self.assertAllClose(compiled, direct) + self.assertAllClose(compiled, direct, rtol=1e-2) def testNoOutputs(self): with session_lib.Session() as sess: @@ -441,14 +441,14 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaLaunch")) class ElementWiseFusionTest(test.TestCase): @@ -482,7 +482,7 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("_XlaLaunch(" in x for x in labels) + count = sum("XlaLaunch(" in x for x in labels) return output, count diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 69bd8f7230d4394c45764d02a88fb0ec097c5756..253b45902fba2df64e5234f135b373cd2a0a7e2a 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -22,7 +22,7 @@ import copy import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" # Local response normalization tests. The forward tests are copied from # tensorflow/python/kernel_tests/lrn_op_test.py -class LRNTest(XLATestCase): +class LRNTest(xla_test.XLATestCase): def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 29394f9ea5139b30f88f53de0469b27e37d79195..0d9f99f8a6803ecae5f9233518a1768109161ac0 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -19,14 +19,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): with self.test_session(): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 5819b2bf2b55b9213a039c0ba82dd0bf1c738b00..2bb8a97bdaf5836a05501ab9754433e29ae34675 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -22,7 +22,7 @@ import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -35,7 +35,7 @@ def MakePlaceholder(x): return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) -class MatrixTriangularSolveOpTest(XLATestCase): +class MatrixTriangularSolveOpTest(xla_test.XLATestCase): # MatrixTriangularSolve defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index af9394e7d7dc9cf7dd009420ff9c845aec8785bd..c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib -class MomentumOptimizerTest(XLATestCase): +class MomentumOptimizerTest(xla_test.XLATestCase): def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): var += accum * lr * momentum diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index e4843b169b943b63346b783ddc50039030988ca5..da08225e9fc0d5a8ec21ee9961c4758fa38628b4 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -22,14 +22,14 @@ import unittest import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class NAryOpsTest(XLATestCase): +class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 6f588d8ab562cb24f33c4c2987df22264aede027..2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import googletest -class NullaryOpsTest(XLATestCase): +class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 5e6d1313bd0336eba71fcf3658d949bd3342ae11..a75d99189b5b673261c9e48f1c5998ea0c575594 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class PlaceholderTest(XLATestCase): +class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): with self.test_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 4eed903963a34a253ea5c409782d9a89a97a4fdf..17f860db61aeda98326a6820771d67ee948b6dda 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): padding=padding) -class Pooling3DTest(XLATestCase): +class Pooling3DTest(xla_test.XLATestCase): def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, expected): @@ -187,8 +187,14 @@ class Pooling3DTest(XLATestCase): padding="VALID", expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5]) - def _VerifyGradient(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + def _VerifyGradient(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -198,6 +204,7 @@ class Pooling3DTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ ksize = [1] + ksize + [1] strides = [1] + strides + [1] @@ -218,6 +225,8 @@ class Pooling3DTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -236,6 +245,22 @@ class Pooling3DTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NDHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) @@ -246,6 +271,16 @@ class Pooling3DTest(XLATestCase): ksize=ksize, strides=strides, padding=padding) + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NDHWC") + actual = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, @@ -260,6 +295,22 @@ class Pooling3DTest(XLATestCase): atol=1e-6) self.assertShapeEqual(actual, inputs) + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -267,7 +318,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[1, 3, 3, 3, 1], ksize=[1, 1, 1], strides=[1, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( @@ -276,9 +328,13 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 3, 6, 3], ksize=[2, 2, 2], strides=[1, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradValidPadding2_1_7_3d(self): + # TODO(b/73062247): the bfloat16 implementation of MaxPool3DGradGrad does + # not have enough precision for this test case to pass if + # pool_grad_grad_func is passed. self._VerifyGradient( nn_ops.max_pool3d, gen_nn_ops.max_pool3d_grad, @@ -294,7 +350,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 2, 2, 2, 3], ksize=[2, 2, 2], strides=[2, 2, 2], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( @@ -303,7 +360,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 2, 4, 1], ksize=[1, 1, 1], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( @@ -312,7 +370,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 2, 4, 1], ksize=[2, 2, 2], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( @@ -321,7 +380,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 5, 2, 4, 3], ksize=[2, 2, 2], strides=[2, 2, 2], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( @@ -330,7 +390,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[1, 3, 3, 7, 1], ksize=[3, 3, 3], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testAvgPoolGradValidPadding1_1_3d(self): self._VerifyGradient( diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index fe270af3d636c0824621f36360ce9e7d14d8fc91..9fc94752ea660f7fb8b2c792180f01485ad04419 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -69,7 +69,7 @@ def GetTestConfigs(): return test_configs -class PoolingTest(XLATestCase): +class PoolingTest(xla_test.XLATestCase): def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, data_format, expected): @@ -288,7 +288,7 @@ class PoolingTest(XLATestCase): expected=expected_output) -class PoolGradTest(XLATestCase): +class PoolGradTest(xla_test.XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cde87db63dbfd7c8d823c6fd0e41eee8b23735bb --- /dev/null +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -0,0 +1,172 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Proximal Adagrad optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad +from tensorflow.python.training import proximal_adagrad + + +class ProximalAdagradOptimizerTest(xla_test.XLATestCase): + + def testResourceProximalAdagradwithoutRegularization(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps Proximal Adagrad. + for _ in range(3): + update.run() + + self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) + self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + opt_vars = opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var1._shared_name) + self.assertEqual(2, len(opt_vars)) + + def testProximalAdagradwithoutRegularization2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps Proximal Adagrad. + for _ in range(3): + update.run() + self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) + self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + + def testProximalAdagradWithL1(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Adagrad + for _ in range(10): + update.run() + self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) + self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + + def testProximalAdagradWithL1_L2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Adagrad. + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) + self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + + def applyOptimizer(self, opt, steps=5): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run ProximalAdagrad for a few steps + for _ in range(steps): + update.run() + + return var0.eval(), var1.eval() + + def testEquivAdagradwithoutRegularization(self): + with self.test_session(), self.test_scope(): + val0, val1 = self.applyOptimizer( + proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0)) + + with self.test_session(), self.test_scope(): + val2, val3 = self.applyOptimizer( + adagrad.AdagradOptimizer( + 3.0, initial_accumulator_value=0.1)) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py new file mode 100644 index 0000000000000000000000000000000000000000..11eb76871133eba8fcd24621afb03e16614fb005 --- /dev/null +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -0,0 +1,156 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Proximal Gradient Descent optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import proximal_gradient_descent + + +class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): + + def testResourceProximalGradientDescentwithoutRegularization(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps Proximal Gradient Descent. + for _ in range(3): + update.run() + + self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) + self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + + def testProximalGradientDescentwithoutRegularization2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps Proximal Gradient Descent + for _ in range(3): + update.run() + + self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) + self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + + def testProximalGradientDescentWithL1(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.001, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps proximal gradient descent. + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) + self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + + def testProximalGradientDescentWithL1_L2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.001, l2_regularization_strength=2.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Gradient Descent + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) + self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + + def applyOptimizer(self, opt, steps=5): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run ProximalAdagrad for a few steps + for _ in range(steps): + update.run() + + return var0.eval(), var1.eval() + + def testEquivGradientDescentwithoutRegularization(self): + with self.test_session(), self.test_scope(): + val0, val1 = self.applyOptimizer( + proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0)) + + with self.test_session(), self.test_scope(): + val2, val3 = self.applyOptimizer( + gradient_descent.GradientDescentOptimizer(3.0)) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index d6c93088d4efff7d8306e262a79ae49d3d8ac722..b880b2a3fea3ee72af96396bc2d61b2887e6e9b8 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -18,15 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest -class RandomOpsTest(XLATestCase): +class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): @@ -47,18 +52,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,22 +75,90 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsNotConstant(self): + + def rng(dtype): + return random_ops.truncated_normal(shape=[2], dtype=dtype) + + # TODO(b/34339814): implement inverse erf support for non-F32 types. + self._testRngIsNotConstant(rng, dtypes.float32) + def testTruncatedNormalIsInRange(self): - count = 10000 + count = 10000000 # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: with self.test_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) y = sess.run(x) - self.assertTrue((y >= -2).sum() == count) - self.assertTrue((y <= 2).sum() == count) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == count) + self.assertTrue((y <= b).sum() == count) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) + + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e53efc3091d8935e745122af29abd7b8063b1d01..16f293891d56d78885dd515bb7b9899faf0690f7 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -619,8 +619,8 @@ std::vector OpTest::ImageDims(TensorFormat format, int batch, dims.push_back(dim); } break; - case FORMAT_NCHW_VECT_C: - LOG(FATAL) << "FORMAT_NCHW_VECT_C not supported."; + default: + LOG(FATAL) << "Tensor format " << ToString(format) << " not supported."; } return dims; } diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 7420724bdbeab63b39542ada59328621febad895..cea2ec816f85e88b11e6e80c91c14fca9015f45c 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -22,7 +22,7 @@ import functools import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(XLATestCase): +class ReduceOpsTest(xla_test.XLATestCase): def _testReduction(self, tf_reduce_fn, @@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) -class ReduceOpPrecisionTest(XLATestCase): +class ReduceOpPrecisionTest(xla_test.XLATestCase): def _testReduceSum(self, expected_result, diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index e78a63465b80644d8810d9fa7433653bc4639fed..c69b6837b0f88ced844faf3713a29a1c14c8790d 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReduceWindowTest(XLATestCase): +class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 18fabca28c9817fc8517595fa1694a18399f54b0..d01c676e7c2fe705344f26818350c46c30451c67 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReverseOpsTest(XLATestCase): +class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 1a5d05094e53cfecd9476d7d87f023e8a02d7458..ccfa63001653537c4d1b7140e3d745c126f9034b 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ReverseSequenceTest(XLATestCase): +class ReverseSequenceTest(xla_test.XLATestCase): def _testReverseSequence(self, x, diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ecdce4f052bbe3eeae8697c02c891105103f4f69..9489fded32a7b6aada0543721a8bfe5f2d74575e 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import rmsprop -class RmspropTest(XLATestCase): +class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 3260e63b23226d736a7ddc0f21a94a8c791e0442..4292352e76ebcef7dbf41df7b857d2604a468117 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse): return x -class CumsumTest(XLATestCase): +class CumsumTest(xla_test.XLATestCase): valid_dtypes = [np.float32] @@ -147,7 +147,7 @@ class CumsumTest(XLATestCase): math_ops.cumsum(input_tensor, [0]).eval() -class CumprodTest(XLATestCase): +class CumprodTest(xla_test.XLATestCase): valid_dtypes = [np.float32] diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 638946e234daf28dc4a34e6c33fc0f78b8e8699b..f606f88545d0b6f0b52cee9b93083a6bd91169bc 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -22,7 +22,7 @@ import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape): return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) -class ScatterNdTest(XLATestCase): +class ScatterNdTest(xla_test.XLATestCase): def _VariableRankTest(self, np_scatter, diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 4a9c0e7471f9cdb2a47b54705495d2dda9748890..772c20fd424577c3e06eeae409f424b77b52aa8a 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -21,26 +21,40 @@ from __future__ import print_function import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class SegmentReductionOpsTest(XLATestCase): +class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" - def UnsortedSegmentSum(self, data, indices, num_segments): + def _segmentReduction(self, op, data, indices, num_segments): with self.test_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) else: i = array_ops.placeholder(indices.dtype, shape=indices.shape) - return sess.run( - math_ops.unsorted_segment_sum(d, i, num_segments), - {d: data, - i: indices}) + return sess.run(op(d, i, num_segments), {d: data, i: indices}) + + def _unsortedSegmentSum(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices, + num_segments) + + def _unsortedSegmentProd(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices, + num_segments) + + def _unsortedSegmentMin(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_min, data, indices, + num_segments) + + def _unsortedSegmentMax(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, + num_segments) def testUnsortedSegmentSum0DIndices1DData(self): for dtype in self.numeric_types: @@ -49,14 +63,14 @@ class SegmentReductionOpsTest(XLATestCase): [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 0]], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) def testUnsortedSegmentSum1DIndices1DData(self): for dtype in self.numeric_types: self.assertAllClose( np.array([1, 3, 2, 9], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) @@ -64,7 +78,7 @@ class SegmentReductionOpsTest(XLATestCase): for dtype in self.numeric_types: self.assertAllClose( np.array([6, 3, 0, 6], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) @@ -76,7 +90,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) num_segments = 10 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], @@ -92,7 +106,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) num_segments = 4 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], @@ -102,30 +116,30 @@ class SegmentReductionOpsTest(XLATestCase): def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) num_segments = 8 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( - [[210, 211, 212], [110, 111, 112], [310, 311, 312], - [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, - 302], [0, 0, 0]], + [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ + 100, 102, 104 + ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], dtype=dtype), y) def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([3, 0, 2, 5], dtype=np.int32) num_segments = 6 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], @@ -138,10 +152,40 @@ class SegmentReductionOpsTest(XLATestCase): data = np.ones((4, 8, 7), dtype=dtype) indices = np.ones((3, 2), dtype=np.int32) num_segments = 4 - self.assertRaises(ValueError, - functools.partial(self.UnsortedSegmentSum, data, - indices, num_segments)) + self.assertRaises( + ValueError, + functools.partial(self._segmentReduction, + math_ops.unsorted_segment_sum, data, indices, + num_segments)) + + def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self): + """Tests for min, max, and prod ops. + + These share most of their implementation with sum, so we only test basic + functionality. + """ + for dtype in self.numeric_types: + self.assertAllClose( + np.array([8, 3, 1, 0], dtype=dtype), + self._unsortedSegmentProd( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + + for dtype in self.int_types | self.float_types: + minval = dtypes.as_dtype(dtype).min + maxval = dtypes.as_dtype(dtype).max + + self.assertAllClose( + np.array([2, 3, maxval, 0], dtype=dtype), + self._unsortedSegmentMin( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + self.assertAllClose( + np.array([4, 3, minval, 6], dtype=dtype), + self._unsortedSegmentMax( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) -if __name__ == '__main__': +if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 305ca0c6b78d3ef985deb38816f9388e7983906b..6c4890565d2083a9493abc59bd563c4dd9fdb186 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class SliceTest(XLATestCase): +class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: @@ -110,7 +110,7 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result) -class StridedSliceTest(XLATestCase): +class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2ef964a1ff00a861a874135b7dfa1358a7020e --- /dev/null +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -0,0 +1,140 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sorting operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaSortOpTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + if isinstance(output, ops.Tensor): + output = [output] + + results = session.run(output, feeds) + for result, v in zip(results, expected): + self.assertAllClose(v, result, rtol=1e-3) + + def testSort(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=dtype) + np.random.shuffle(x) + self._assertOpOutputMatchesExpected( + xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + + def testTopK(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 20 + k_options = [0, 1, 2, 10, 20] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + for x in [np.arange(array_size)]: + np.random.shuffle(x) + for k in k_options: + indices = x.argsort()[::-1][:k] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[x[indices].astype(dtype), indices]) + + def testTopKZeros(self): + """Tests that positive and negative zeros sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=4) + results = sess.run( + topk, + {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) + self.assertAllEqual( + np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) + + def testTopKInfinities(self): + """Tests that positive and negative infinity sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: np.array( + [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) + }) + self.assertAllEqual( + np.array( + [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=bfloat16), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index f37c34156f96761632247be4bc1b62fca54f666e..c685bc548f9f6f8f7723c6f94dfd45f5420b4a67 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings): return permuted_reshaped_padded.reshape(output_shape) -class SpaceToBatchTest(XLATestCase): +class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): @@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase): self._testOne(x_np, block_size, x_out) -class SpaceToBatchNDTest(XLATestCase): +class SpaceToBatchNDTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" def _testPad(self, inputs, block_shape, paddings, outputs): diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3db8101c4bfbb1b53c7318a36519612984d6f179 --- /dev/null +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -0,0 +1,118 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.kernels.sparse_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +def _SparseToDense(sparse_indices, + output_size, + sparse_values, + default_value, + validate_indices=True): + feed_sparse_indices = array_ops.placeholder(dtypes.int32) + feed_dict = {feed_sparse_indices: sparse_indices} + return sparse_ops.sparse_to_dense( + feed_sparse_indices, + output_size, + sparse_values, + default_value=default_value, + validate_indices=validate_indices).eval(feed_dict=feed_dict) + + +class SparseToDenseTest(xla_test.XLATestCase): + + def testInt(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1, 0) + np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testFloat(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) + np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) + self.assertAllClose(np_ans, tf_ans) + + def testSetValue(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) + np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testSetSingleValue(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1, -1) + np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def test2d(self): + # pylint: disable=bad-whitespace + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) + np_ans = np.array([[-1, -1, -1, -1], + [-1, -1, -1, 1], + [ 1, -1, -1, -1]]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testZeroDefault(self): + with self.test_session(): + x = sparse_ops.sparse_to_dense(2, [4], 7).eval() + self.assertAllEqual(x, [0, 0, 7, 0]) + + def test3d(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) + np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 + np_ans[1, 3, 0] = 1 + np_ans[2, 0, 1] = 1 + self.assertAllClose(np_ans, tf_ans) + + def testBadShape(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): + _SparseToDense([1, 3], [[5], [3]], 1, -1) + + def testBadValue(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError( + r"sparse_values has incorrect shape \[2,1\], " + r"should be \[\] or \[2\]"): + _SparseToDense([1, 3], [5], [[5], [3]], -1) + + def testBadNumValues(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError( + r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): + _SparseToDense([1, 3], [5], [1, 2, 3], -1) + + def testBadDefault(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError("default_value should be a scalar"): + _SparseToDense([1, 3], [5], [1, 2], [0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 94342f9567ca71274609e63b0482d55637c98d51..b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.platform import test -class StackOpTest(XLATestCase): +class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index b6f8390a45d43bf7666b90e14cc6ff2f3f61947e..d162675ef840131485128414b4a29e3cd89c8761 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -22,14 +22,15 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test -class StatelessRandomOpsTest(XLATestCase): +class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): @@ -122,6 +123,56 @@ class StatelessRandomOpsTest(XLATestCase): # so to avoid flakiness the seed is fixed. self.assertTrue(self._anderson_darling(y) < 2.492) + def testTruncatedNormalIsInRange(self): + # TODO(b/34339814): implement inverse erf support for non-F32 types. + for dtype in [dtypes.float32]: + with self.test_session() as sess, self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 10000000 + x = stateless.stateless_truncated_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == n) + self.assertTrue((y <= b).sum() == n) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ef047005b60bd156a677050368ef67ae030d6c3a..effa5a59fee7dda543b2c409dfaa27a972a55808 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class TernaryOpsTest(XLATestCase): +class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6abde18ea91f16d153a154b94effab037a911c6c --- /dev/null +++ b/tensorflow/compiler/tests/test_utils.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for helping test ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): + """Converts 4D tensor between data formats.""" + + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + if data_format_src not in valid_data_formats: + raise ValueError("data_format_src must be of %s, got %s." % + (valid_data_formats, data_format_src)) + if data_format_dst not in valid_data_formats: + raise ValueError("data_format_dst must be of %s, got %s." % + (valid_data_formats, data_format_dst)) + if len(x.shape) != 4: + raise ValueError("x must be 4D, got shape %s." % x.shape) + + if data_format_src == data_format_dst: + return x + + dim_map = {d: i for i, d in enumerate(data_format_src)} + transpose_dims = [dim_map[d] for d in data_format_dst] + return np.transpose(x, transpose_dims) + + +def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): + """Get new shape for converting between data formats.""" + + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + if data_format_src not in valid_data_formats: + raise ValueError("data_format_src must be of %s, got %s." % + (valid_data_formats, data_format_src)) + if data_format_dst not in valid_data_formats: + raise ValueError("data_format_dst must be of %s, got %s." % + (valid_data_formats, data_format_dst)) + if len(dims) != 4: + raise ValueError("dims must be of length 4, got %s." % dims) + + if data_format_src == data_format_dst: + return dims + + dim_map = {d: i for i, d in enumerate(data_format_src)} + permuted_dims = [dims[dim_map[d]] for d in data_format_dst] + return permuted_dims diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 52633f619db5c4bdfed064a624fa8f74f87c3487..6a7011aea6cc3f942fecf27a640b998bfc10c0de 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops @@ -44,11 +44,16 @@ def nhwc_to_format(x, data_format): raise ValueError("Unknown format {}".format(data_format)) -class UnaryOpsTest(XLATestCase): +class UnaryOpsTest(xla_test.XLATestCase): """Test cases for unary operators.""" - def _assertOpOutputMatchesExpected(self, op, inp, expected, - equality_test=None, rtol=1e-3, atol=1e-5): + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): """Verifies that 'op' produces 'expected' when fed input 'inp' . Args: @@ -81,10 +86,10 @@ class UnaryOpsTest(XLATestCase): def testAllTypeOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( - array_ops.diag, - np.array([1, 2, 3, 4], dtype=dtype), - np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], - dtype=dtype)) + array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), @@ -102,8 +107,7 @@ class UnaryOpsTest(XLATestCase): expected=np.array([[-1, 1]], dtype=dtype)) self._assertOpOutputMatchesExpected( - array_ops.matrix_diag, - np.array([[1, 2], [3, 4]], dtype=dtype), + array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), @@ -115,10 +119,10 @@ class UnaryOpsTest(XLATestCase): np.array( [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), np.array( - [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], - [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], - [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], - [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [ + 0, 0, 6 + ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0], + [0, 0, 12]]]], dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, @@ -159,36 +163,30 @@ class UnaryOpsTest(XLATestCase): continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( - math_ops.acos, - x.astype(dtype), - expected=np.arccos(x).astype(dtype)) + math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) self._assertOpOutputMatchesExpected( - math_ops.asin, - x.astype(dtype), - expected=np.arcsin(x).astype(dtype)) + math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype)) x = np.arange(-3, 3).reshape(1, 3, 2) self._assertOpOutputMatchesExpected( - math_ops.atan, - x.astype(dtype), - expected=np.arctan(x).astype(dtype)) + math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype)) self._assertOpOutputMatchesExpected( math_ops.acosh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([0, 1.3169579, 1.76274717, 2.06343707], - dtype=dtype)) + expected=np.array( + [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.asinh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([0.88137359, 1.44363548, 1.81844646, 2.09471255], - dtype=dtype)) + expected=np.array( + [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), - expected=np.array([0.10033535, 0.20273255, 0.3095196, 0.42364893], - dtype=dtype)) + expected=np.array( + [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.ceil, @@ -198,8 +196,18 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.cosh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284], - dtype=dtype)) + expected=np.array( + [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype)) + + # Disable float16 testing for now + if dtype != np.float16: + x = np.arange(-10, 10, 1).astype(dtype) + with self.test_session() as session: + erf_x = session.run(math_ops.erf(x)) + erfc_x = session.run(math_ops.erfc(x)) + + self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x) + self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x) self._assertOpOutputMatchesExpected( math_ops.exp, @@ -210,8 +218,7 @@ class UnaryOpsTest(XLATestCase): math_ops.expm1, np.array([[-1, 1]], dtype=dtype), expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), - rtol=1e-5, - atol=1e-6) + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -220,8 +227,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.is_finite, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool)) # Tests for tf.nn ops. @@ -262,16 +269,20 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.rint, - np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5]], dtype=dtype), - expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], - dtype=dtype)) + np.array( + [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], + dtype=dtype), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.round, - np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5]], dtype=dtype), - expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], - dtype=dtype)) + np.array( + [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], + dtype=dtype), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.rsqrt, @@ -280,10 +291,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sigmoid, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.7310586, 0.7310586, 0.7310586, 0.7310586], [0.7310586, 0.880797, 0.95257413, 0.98201376]], @@ -297,8 +305,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sinh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.17520119, 3.62686041, 10.01787493, 27.2899172], - dtype=dtype)) + expected=np.array( + [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.sqrt, @@ -308,15 +316,12 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.55740772, -2.18503986, -0.14254654, 1.15782128], - dtype=dtype)) + expected=np.array( + [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.tanh, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.76159418, 0.76159418, 0.76159418, 0.76159418], [0.76159418, 0.96402758, 0.99505478, 0.99932933]], @@ -324,10 +329,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.log_softmax, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[-1.3862944, -1.3862944, -1.3862944, -1.3862944], [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], @@ -335,13 +337,19 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, @@ -355,10 +363,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softmax, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.25, 0.25, 0.25, 0.25], [0.032058604, 0.087144323, 0.23688284, 0.64391428]], @@ -367,8 +372,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]], - dtype=dtype)) + expected=np.array( + [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.is_finite, @@ -377,10 +382,23 @@ class UnaryOpsTest(XLATestCase): expected=np.array( [[True, False, True], [False, True, True]], dtype=np.bool)) + def quantize_and_dequantize_v2(x): + return array_ops.quantize_and_dequantize_v2( + x, -127, 127, signed_input=True, num_bits=8) + self._assertOpOutputMatchesExpected( - lambda x: array_ops.quantize_and_dequantize_v2(x, -127, 127, True, 8), + quantize_and_dequantize_v2, np.array([-1, -0.5, 0, 0.3], dtype=dtype), - expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + + def quantize_and_dequantize_v3(x): + return array_ops.quantize_and_dequantize_v3( + x, -127, 127, num_bits=8, signed_input=True, range_given=False) + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v3, + np.array([-1, -0.5, 0, 0.3], dtype=dtype), + expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) def testComplexOps(self): for dtype in self.complex_types: @@ -561,13 +579,13 @@ class UnaryOpsTest(XLATestCase): for dtype in self.float_types: self._assertOpOutputMatchesExpected( math_ops.is_inf, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool)) self._assertOpOutputMatchesExpected( math_ops.is_nan, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) def testLogicalOps(self): @@ -584,14 +602,15 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), - np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], - dtype=np.float32), + np.array( + [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), expected=np.array([10., 26.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) | - self.complex_tf_types) + types = ( + set([dtypes.bool, dtypes.int32, dtypes.float32]) + | self.complex_tf_types) for shape in shapes: for src_type in types: for dst_type in types: @@ -633,14 +652,11 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( rank_op, dtype(7), expected=np.int32(0)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [[], []], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [-1, 1], dtype=dtype), expected=np.int32(1)) + rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [[-1, 1]], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( rank_op, np.array([[-1], [1], [4]], dtype=dtype), @@ -705,97 +721,97 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): - return array_ops.depth_to_space(x, block_size=2, - data_format=data_format) + return array_ops.depth_to_space( + x, block_size=2, data_format=data_format) + return op for dtype in self.numeric_types: for data_format in ["NCHW", "NHWC"]: self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - data_format), - expected=nhwc_to_format(np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - data_format)) + nhwc_to_format( + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format), + expected=nhwc_to_format( + np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), nhwc_to_format( - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype), + np.array( + [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], - dtype=dtype), - data_format)) + np.array( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), nhwc_to_format( - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - data_format), + np.array( + [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], + [13, 14, 15, 16]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - data_format)) + np.array( + [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], + [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], + dtype=dtype), data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): - return array_ops.space_to_depth(x, block_size=2, - data_format=data_format) + return array_ops.space_to_depth( + x, block_size=2, data_format=data_format) + return op for dtype in self.numeric_types: for data_format in ["NCHW", "NHWC"]: self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - data_format), - expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - data_format)) + nhwc_to_format( + np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - data_format), + nhwc_to_format( + np.array( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype), + np.array( + [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - data_format), + nhwc_to_format( + np.array( + [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], + [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - data_format)) + np.array( + [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], + [13, 14, 15, 16]]]], + dtype=dtype), data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected, - rtol=1e-6, - atol=9.1e-6) + nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: @@ -809,9 +825,10 @@ class UnaryOpsTest(XLATestCase): one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ - log_eps, log_eps - one, log_eps + one, log_eps - ten, - log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, - -log_eps - ten, -log_eps + ten], dtype) + log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten, + -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten, + -log_eps + ten + ], dtype) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 8ecad00f6e23b3a7746bbb473102ac847bf4cbfd..dd2c252d383bca9c59033ac07e442b487e4975a6 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -20,12 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -36,7 +37,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training.gradient_descent import GradientDescentOptimizer -class VariableOpsTest(XLATestCase): +class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" def testOneWriteOneOutput(self): @@ -52,9 +53,7 @@ class VariableOpsTest(XLATestCase): with ops.control_dependencies([x]): y = v.read_value() self.assertAllClose( - np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, { - p: 1 - })) + np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {p: 1})) def testSparseRead0DIndices(self): for dtype in self.numeric_types: @@ -103,9 +102,9 @@ class VariableOpsTest(XLATestCase): x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], - [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], - ).astype(dtype), sess.run(x)) + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]] + ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] + ],).astype(dtype), sess.run(x)) def testShape(self): for dtype in self.numeric_types: @@ -187,6 +186,225 @@ class VariableOpsTest(XLATestCase): rtol=1e-4) self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) + def testWriteOfAliasedTensor(self): + for dtype in self.numeric_types: + init = np.array([[1, 2j], [3, 4]]).astype(dtype) + update = np.array([[7, 1j], [2, 11]]).astype(dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + q = array_ops.identity(p) + x = v.read_value() + # Writes the value of 'p' to 'v', but keeps a reference to the original + # value of 'v' so the variable update cannot reuse its buffer. + with ops.control_dependencies([x]): + y = v.assign(q) + result = sess.run([x, y, q], {p: update}) + self.assertAllClose(init, result[0]) + self.assertAllClose(update, result[1]) + self.assertAllClose(update, result[2]) + + def testScatterAdd(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[2, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1], [7]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[3], [7]]) + + def testScatterSub(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[2, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_sub( + handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[4], [-1]]) + + def testScatterMul(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_mul( + handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[5]]) + + def testScatterDiv(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_div( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[2]]) + + def testScatterMin(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_min( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterMax(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_max( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[6]]) + + def testScatterUpdate(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_update( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterAddScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant(2, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterSubScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_sub( + handle, [0], constant_op.constant(2, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[-1]]) + + def testScatterMulScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_mul( + handle, [0], constant_op.constant(5, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[5]]) + + def testScatterDivScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_div( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[2]]) + + def testScatterMinScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_min( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterMaxScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_max( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[6]]) + + def testScatterNdAddOps(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.float32, shape=[8]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) + sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) + read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.float32) + self.assertAllClose(expected, sess.run(read)) + + def testScatterNdUpdateAddOps(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.float32, shape=[8]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 11, 1, 10, 9, 1, 1, 12]) + sess.run( + gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) + read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.float32) + self.assertAllClose(expected, sess.run(read)) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" @@ -217,12 +435,12 @@ class StridedSliceAssignChecker(object): self.test.assertAllEqual(val, valnp) -class SliceAssignTest(XLATestCase): +class SliceAssignTest(xla_test.XLATestCase): def testSliceAssign(self): for dtype in self.numeric_types: - checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], - dtype=dtype) + checker = StridedSliceAssignChecker( + self, [[1, 2, 3], [4, 5, 6]], dtype=dtype) # No-op assignment checker[:] = [[10, 20, 30], [40, 50, 60]] # Checks trivial (1,1) shape tensor diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index f79eb27435cc954cebde4357c1d946a320f4ed75..b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class WhileTest(XLATestCase): +class WhileTest(xla_test.XLATestCase): def testSingletonLoopHandrolled(self): # Define a function for the loop body diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index b707bd0963d71d7c4b43b8d42752b4c50e9bbf7c..06d977b93c28792704b910c688af510bc650d2a4 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -20,13 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(XLATestCase): +class XlaDeviceTest(xla_test.XLATestCase): def testCopies(self): """Tests that copies onto and off XLA devices work.""" @@ -46,6 +47,12 @@ class XlaDeviceTest(XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testControlTrigger(self): + with self.test_session() as sess: + with self.test_scope(): + x = gen_control_flow_ops.control_trigger() + sess.run(x) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index e924fe1e61454aefda622a5a46a0e483d26db5c1..88827cb53bee7bb809d0163d6badcef17e59aa78 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -49,6 +49,32 @@ flags.DEFINE_string('tf_xla_flags', None, 'Value to set the TF_XLA_FLAGS environment variable to') +def parse_disabled_manifest(manifest_content): + comments_re = re.compile('#.*$') + disabled_tests = [] + disabled_method_types = [] + for l in manifest_content.splitlines(): + stripped = comments_re.sub('', l).strip() + if not stripped: + continue + entry = stripped.split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append((entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + disabled_regex = '|'.join(disabled_tests) + method_types_filter = dict() + for method, types in disabled_method_types: + method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types + ]) + return disabled_regex, method_types_filter + + class XLATestCase(test.TestCase): """XLA test cases are parameterized test cases.""" @@ -85,38 +111,21 @@ class XLATestCase(test.TestCase): # Parse the manifest file, if any, into a regex identifying tests to # disable - self.disabled_regex = None - self._method_types_filter = dict() # TODO(xpan): Make it text proto if it doesn't scale. # Each line of the manifest file specifies an entry. The entry can be # 1) TestNameRegex // E.g. CumprodTest.* Or # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 # The 1) disables the entire test. While 2) only filter some numeric types # so that they are not used in those tests. + self.disabled_regex = None + self._method_types_filter = {} if FLAGS.disabled_manifest is not None: - comments_re = re.compile('#.*$') - manifest_file = open(FLAGS.disabled_manifest, 'r') - disabled_tests = [] - disabled_method_types = [] - for l in manifest_file.read().splitlines(): - if not l: - continue - entry = comments_re.sub('', l).strip().split(' ') - if len(entry) == 1: - disabled_tests.append(entry[0]) - elif len(entry) == 2: - disabled_method_types.append( - (entry[0], entry[1].strip().split(','))) - else: - raise ValueError('Bad entry in manifest file.') - - self.disabled_regex = re.compile('|'.join(disabled_tests)) - for method, types in disabled_method_types: - self._method_types_filter[method] = set([ - dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype - for name in types]) - manifest_file.close() + with open(FLAGS.disabled_manifest, 'r') as manifest_file: + disabled_regex, self._method_types_filter = ( + parse_disabled_manifest(manifest_file.read())) + if disabled_regex: + self.disabled_regex = re.compile(disabled_regex) if FLAGS.tf_xla_flags is not None: os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags diff --git a/tensorflow/compiler/tests/xla_test_test.py b/tensorflow/compiler/tests/xla_test_test.py new file mode 100644 index 0000000000000000000000000000000000000000..24664451579445edaadb335c30d253ee55f003da --- /dev/null +++ b/tensorflow/compiler/tests/xla_test_test.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the XLATestCase test fixture base class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.platform import test + + +class XlaTestCaseTestCase(test.TestCase): + + def testManifestEmptyLineDoesNotCatchAll(self): + manifest = """ +testCaseOne +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne") + + def testManifestWholeLineCommentDoesNotCatchAll(self): + manifest = """# I am a comment +testCaseOne +testCaseTwo +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd57452302fcbde37d79ce760a80615a76d7ad8c..40e32f2e757c96de86414b5699b67935f4d92776 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -164,11 +164,15 @@ cc_library( "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", @@ -462,3 +466,13 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_cc_test( + name = "xla_op_registry_test", + srcs = ["xla_op_registry_test.cc"], + deps = [ + ":xla_compiler", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 4f8bb8ad743afe69a6544c2ae0dc7309891b2df3..ea8d1b3d14939d4f4fba598318200f71c2eb0270 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -27,3 +27,25 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_gen_op_wrapper_cc( + name = "xla_jit_op_gen", + out_ops_file = "ops/xla_jit_op", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) + +cc_library( + name = "xla_jit_ops", + srcs = ["ops/xla_jit_op.cc"], + hdrs = ["ops/xla_jit_op.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit/ops:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 42585ad4d8a17d71146e48b69f9fa56f9ff24c3e..6cc95149a16a59fce8486c5d103ad09e3e262765 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -166,6 +166,27 @@ StatusOr AddNode(const NodeDef& node_def, Graph* graph) { return inserted_node; } +// Check that the graph has no cycle containing the given node. +Status CheckNoCycleContains(const Node* node, const int num_nodes) { + std::vector ready; + ready.push_back(node); + std::vector visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(", + node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); @@ -1407,6 +1428,10 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( TF_RETURN_IF_ERROR( AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(if_node, graph_->num_node_ids()), + "ConvertToXlaIf failed."); return if_node; } @@ -1438,7 +1463,15 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + std::vector unreachable_nodes; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes), + "FunctionalizeControlFlow failed"); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + tensorflow::str_util::Join(unreachable_nodes, ", ")); + } // Builds Frames, indexed by name. std::unordered_map frames; @@ -1458,10 +1491,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, frame.parent = parent; frame.name = cf.frame_name; ++parent->num_children; - } else if (frame.parent != parent) { - return errors::InvalidArgument("Mismatched parent frames for ", - cf.frame->id(), ": ", parent->name, " vs ", - frame.parent->name); } if (IsEnter(node)) { @@ -1471,12 +1500,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, &arg.is_loop_invariant)); frame.args.push_back(arg); } else if (IsLoopCond(node)) { - if (frame.loop_cond) { - return errors::InvalidArgument( - "Loop ", cf.frame_name, - " has more than one LoopCond node: ", node->name(), " and ", - frame.loop_cond->name()); - } frame.loop_cond = node; } frame.nodes.insert(node); @@ -1508,6 +1531,16 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, worklist.push_back(frame->parent); } } + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added XlaWhile nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "XlaWhile") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(node, graph->num_node_ids()), + "FunctionalizeLoop failed."); + } + } // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 14977a908ae2b0ff7e13b634c41b6d331b4b8a36..aae2f8ee5acd6249f8b6002d94c877f18064f936 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -1012,5 +1013,60 @@ TEST(FunctionalizeControlFlow, Complex) { } } +TEST(FunctionalizeControlFlow, Cycle) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + // ----------------------------------------------------- + // | | + // | v + // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 + // | ^ | + // | | v + // --------> one -------------------------> add_2 ---> merge_2 + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto two = + ops::Const(scope.WithOpName("cond/two") + .WithControlDependencies(switch_1.output_true), + 2); + auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), + switch_1.output_true, two); + auto one = + ops::Const(scope.WithOpName("cond/one") + .WithControlDependencies(switch_1.output_false), + 1); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_1.output_false, one); + + auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list{add, mul}); + auto identity = + ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); + auto switch_2 = + ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); + auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), + switch_2.output_false, one); + auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), + switch_2.output_true, two); + auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), + std::initializer_list{add_2, mul_2}); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + // No cycle before functionalize control flow. + TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // switch_1 and switch_2 have the same switch depth. They are replaced by a + // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: + // less -> XlaIf <--> identity. + Status status = FunctionalizeControlFlow(graph.get(), &library); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 212f6f3966149ca0b2d2e012b19300e1f488f996..4900af6df17f360630abb1e64b7f144ccd4a0289 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -87,6 +89,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, } } // namespace Status GraphCompiler::Compile() { + // Check that the graph has no illegal cycles. + TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_)); // Maintain a mapping from node id to node outputs. using NodeOutputs = std::vector; std::vector output_registry(graph_->num_node_ids()); @@ -227,7 +231,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, XlaContext& context = XlaContext::Get(op_context); auto* b = context.builder(); - auto output_handle = b->Call(*result.computation, handles); + auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. int computation_output = 0; @@ -236,7 +240,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { xla_op_context.SetOutput( - i, b->GetTupleElement(output_handle, computation_output)); + i, xla::GetTupleElement(output_handle, computation_output)); ++computation_output; } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index e6da157c111ad9167bf7b1e743d9afbb8fb2ad03..a8eb7d942dfbabff3c53e2b5225c1018b01eb315 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,6 +18,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", @@ -78,14 +79,17 @@ tf_kernel_library( "shape_util.cc", "slice_op.cc", "softmax_op.cc", + "sort_ops.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", + "sparse_to_dense_op.cc", "split_op.cc", "stack_ops.cc", "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", + "topk_op.cc", "training_ops.cc", "transpose_op.cc", "unary_ops.cc", @@ -103,6 +107,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", @@ -116,6 +121,9 @@ tf_kernel_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 1e59868621475cf72f4cc8b14dafec2dd8cd5c95..e33532828040123243f839ab1aa655b4bbc72520 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -31,7 +32,7 @@ class AddNOp : public XlaOpKernel { xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { - sum = ctx->builder()->Add(sum, ctx->Input(i)); + sum = xla::Add(sum, ctx->Input(i)); } ctx->SetOutput(0, sum); diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index b0ba25b9983c3a9af26728ce4b1c263c844327db..4cfe946b2e6146f034867c06e996ffae42b90705 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + auto result = BatchDot(ctx->Input(0), ctx->Input(1), /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 15e1815a4cf07ff50dd1431b6790d14781da590f..c4af79281d2162b1dbfb0a7881720892f4bc49d2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -34,10 +35,11 @@ class FusedBatchNormOp : public XlaOpKernel { ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || + data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), errors::InvalidArgument( "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC and NCHW")); + "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -48,8 +50,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::XlaBuilder* builder = ctx->builder(); - xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); @@ -59,30 +59,30 @@ class FusedBatchNormOp : public XlaOpKernel { // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. As a workaround, cast everything to the statistics type (which // may be more precise than the input type). - input = builder->ConvertElementType(input, scale_type); + input = xla::ConvertElementType(input, scale_type); if (is_training_) { - xla::XlaOp output = builder->BatchNormTraining( + xla::XlaOp output = xla::BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - ctx->SetOutput(0, builder->ConvertElementType( - builder->GetTupleElement(output, 0), input_type)); - ctx->SetOutput(1, builder->GetTupleElement(output, 1)); - ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), + input_type)); + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); + ctx->SetOutput(2, xla::GetTupleElement(output, 2)); // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, builder->GetTupleElement(output, 1)); - ctx->SetOutput(4, builder->GetTupleElement(output, 2)); + ctx->SetOutput(3, xla::GetTupleElement(output, 1)); + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); } else { - xla::XlaOp output = builder->BatchNormInference( + xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); - ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); + ctx->SetOutput(0, xla::ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -111,10 +111,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || + data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), errors::InvalidArgument( "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC and NCHW")); + "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -142,12 +143,12 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp offset_backprop; if (is_training_) { xla::XlaOp output = - b->BatchNormGrad(activations, scale, mean, var, grad_backprop, - epsilon_, feature_index); + xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); - x_backprop = b->GetTupleElement(output, 0); - scale_backprop = b->GetTupleElement(output, 1); - offset_backprop = b->GetTupleElement(output, 2); + x_backprop = xla::GetTupleElement(output, 0); + scale_backprop = xla::GetTupleElement(output, 1); + offset_backprop = xla::GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. std::vector reduction_dims(input_dims - 1); @@ -164,35 +165,35 @@ class FusedBatchNormGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = - b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + auto scratch1 = xla::Pow( + xla::Add(var, xla::ConstantR0(b, epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) auto mul = - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = - b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); - scale_backprop = b->Mul(scratch1, scratch2); + xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); + scale_backprop = xla::Mul(scratch1, scratch2); } ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); - ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); - ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(3, Tensor()); + ctx->SetConstantOutput(4, Tensor()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 642278ab994bf3cc84396f093ed56b009a1435c1..26130fd9e7fce75c6d2a5a53cfc85842cf762b35 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -45,7 +46,6 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -72,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -90,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::XlaOp permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = xla::Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -110,7 +110,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = + xla::Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -138,7 +139,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } xla::XlaOp output = - b->Slice(reshaped_permuted, start_indices, end_indices, strides); + xla::Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 9d677f426650ea17a49e5ab1401078f04623fe97..e9b2c0b16d39cb3b747c0316621fb01de709b12e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" @@ -60,8 +61,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::XlaOp result = - ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); + xla::XlaOp result = xla::Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -109,8 +109,8 @@ class BiasAddGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index f04cde878e98002d9442e0f3ec251c5197ef7969..d6d4ae89376b67c14af8ef4f3a608fcc83b6fb59 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -41,18 +41,19 @@ namespace { const BCast& broadcast_helper, \ const std::vector& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ return HLO; \ } \ }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) -XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { @@ -67,13 +68,13 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); - auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero)); - auto abs_x = b->Abs(x); - auto abs_y = b->Abs(y); - auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); - auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); + auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); + auto abs_x = xla::Abs(x); + auto abs_y = xla::Abs(y); + auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one)); + auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); if (DataTypeIsFloating(dtype)) { - result = b->Floor(result); + result = xla::Floor(result); } return result; } @@ -87,75 +88,78 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); - auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); - auto trunc_mod = b->Rem(x, y); - return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y)); + auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); + auto trunc_mod = xla::Rem(x, y); + return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y)); } XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(RightShift, (DataTypeIsUnsigned(ctx->input_type(0)) - ? b->ShiftRightLogical(lhs, rhs, extend_dimensions) - : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions))); - -XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs)))); + ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions) + : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, - b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), - b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), - extend_dimensions)); -XLA_MAKE_BINARY(SqrtGrad, - b->Div(b->Mul(rhs, - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - lhs, extend_dimensions)); + xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), + extend_dimensions)); +XLA_MAKE_BINARY( + SqrtGrad, + xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + lhs, extend_dimensions)); static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return builder->Mul(x, x); + return xla::Mul(x, x); } XLA_MAKE_BINARY(SquaredDifference, - Square(b, b->Sub(lhs, rhs, extend_dimensions))); + Square(b, xla::Sub(lhs, rhs, extend_dimensions))); -XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); // Comparison ops -XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions)); // Non-linear ops XLA_MAKE_BINARY(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); + xla::Mul(xla::Mul(rhs, lhs), + xla::Sub(XlaHelpers::One(b, input_type(0)), lhs))); XLA_MAKE_BINARY(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); + xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, - b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)), - b->Abs(rhs))))); + xla::Div(lhs, + Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); -XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); +XLA_MAKE_BINARY(TanhGrad, + xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(lhs, lhs)))); -XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); #undef XLA_MAKE_BINARY @@ -168,12 +172,13 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); + auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); auto abs_type = abs_shape.ValueOrDie().element_type(); - auto result = b->Lt( - abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); + auto result = + xla::Lt(abs, xla::ConvertElementType( + xla::ConstantR0(b, tolerance_), abs_type)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..efbdb76eaaf78904fe783a018940b1b096ec39bd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BucketizeOp : public XlaOpKernel { + public: + explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); + OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), + errors::InvalidArgument("Expected sorted boundaries")); + } + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + const DataType dtype = context->input_type(0); + xla::XlaOp input = context->Input(0); + + xla::XlaOp boundaries = xla::ConstantR1(builder, boundaries_); + // TODO(phawkins): the following behavior matches the behavior of the core + // Bucketize kernel. However, comparing an int32 or int64 against float may + // lead to inaccurate bucketing due to rounding. + if (dtype == DT_DOUBLE) { + input = xla::ConvertElementType(input, xla::F64); + boundaries = xla::ConvertElementType(boundaries, xla::F64); + } else { + input = xla::ConvertElementType(input, xla::F32); + } + xla::XlaOp comparison = + xla::ConvertElementType(xla::Ge(xla::Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = xla::Reduce( + comparison, /*init_value=*/xla::ConstantR0(builder, 0), + /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + context->SetOutput(0, buckets); + } + + private: + std::vector boundaries_; +}; + +REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index e9d98c768572c52825fa5192ecec834889f040fe..62eebf762b3e063da8ec456cc4726d3cc9b77d1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -40,14 +41,14 @@ class CastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else if (dst_dtype_ == DT_BOOL) { - output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_)); } else if (xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_)) { // As in cast_op.h, we replicate the numpy behavior of truncating the // imaginary part. - output = builder->ConvertElementType(builder->Real(input), dst_type_); + output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { - output = builder->ConvertElementType(input, dst_type_); + output = xla::ConvertElementType(input, dst_type_); } ctx->SetOutput(0, output); @@ -72,7 +73,6 @@ class BitcastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); xla::XlaOp input = ctx->Input(0); xla::XlaOp output; @@ -92,7 +92,7 @@ class BitcastOp : public XlaOpKernel { xla::primitive_util::BitWidth(dst_type_), errors::Unimplemented( "Only bitcasts between equally sized types supported.")); - output = builder->BitcastConvertType(input, dst_type_); + output = xla::BitcastConvertType(input, dst_type_); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 835a7f568945f0bee86fe2b39491c3326726e1aa..1784e712b56145bbdff5f1daa2e031b65d0774b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -65,24 +66,22 @@ class CategoricalOp : public XlaOpKernel { DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); xla::Shape uniform_shape = xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); - auto uniforms = builder->RngUniform( - XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. - auto softmax_entries = - builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))), - /*broadcast_dimensions=*/{0, 2}); - - TensorShape softmax_shape(uniform_shape_array); - xla::XlaOp argmax; - OP_REQUIRES_OK( - ctx, - XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, - input_type(0), output_type(0), /*axis=*/2, &argmax)); + auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + /*broadcast_dimensions=*/{0, 2}); + + xla::PrimitiveType xla_output_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(output_type(0), &xla_output_type)); + xla::XlaOp argmax = + XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2); ctx->SetOutput(0, argmax); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index fe6651793dc763d13f4a4b0ac294ec3ecf64af8f..9fcbc86adc0967cbb7fb73da8bdabc58b60953da 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - auto result = Cholesky(ctx->builder(), ctx->Input(0)); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index a00bc912f9f40052565446c6bf9390629af9a4cd..4e6d33304c4ae08a0fd1e0a8373267a527087528 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -29,7 +30,6 @@ class ClipByValueOp : public XlaOpKernel { const TensorShape min_shape = ctx->InputShape(1); const TensorShape max_shape = ctx->InputShape(2); - xla::XlaBuilder* builder = ctx->builder(); auto input = ctx->Input(0); auto min = ctx->Input(1); auto max = ctx->Input(2); @@ -45,13 +45,13 @@ class ClipByValueOp : public XlaOpKernel { if (shape != min_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); - min = builder->Broadcast(min, shape.dim_sizes()); + min = xla::Broadcast(min, shape.dim_sizes()); } if (shape != max_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); - max = builder->Broadcast(max, shape.dim_sizes()); + max = xla::Broadcast(max, shape.dim_sizes()); } - ctx->SetOutput(0, builder->Clamp(min, input, max)); + ctx->SetOutput(0, xla::Clamp(min, input, max)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 78285affa1c399ae107a9172fb85cf257457c368..e3a32a5c0e2f93237c8c7ebeea3668b5d1ab6c23 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -88,7 +89,7 @@ class ConcatBaseOp : public XlaOpKernel { "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. - input_data.push_back(ctx->builder()->Reshape(handle, {1})); + input_data.push_back(xla::Reshape(handle, {1})); } else { input_data.push_back(handle); } @@ -96,7 +97,7 @@ class ConcatBaseOp : public XlaOpKernel { } VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 59d06c654de18c9003fe0bdc706d0c2443de6d7b..f4360d8c3f6fc4007c31fdcfd7f7634de15c76d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -53,41 +54,41 @@ class ConstOp : public XlaOpKernel { switch (proto_.dtype()) { case DT_BOOL: if (proto_.bool_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0(proto_.bool_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0(b, proto_.bool_val(0)), + shape.dim_sizes())); return; } break; case DT_FLOAT: if (proto_.float_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.float_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.float_val(0)), + shape.dim_sizes())); return; } break; case DT_DOUBLE: if (proto_.double_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.double_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.double_val(0)), + shape.dim_sizes())); return; } break; case DT_INT32: if (proto_.int_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0(proto_.int_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0(b, proto_.int_val(0)), + shape.dim_sizes())); return; } break; case DT_INT64: if (proto_.int64_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.int64_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.int64_val(0)), + shape.dim_sizes())); return; } break; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 627bad12f33c82e91bc3c6f3323f562bc8174056..48ac4867edcef97be001a24f42f6a35225d466c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -51,8 +53,8 @@ xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); + return xla::Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); } // Create a mask for depthwise convolution that will make a normal convolution @@ -95,32 +97,27 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, - &input_feature_iota)); - xla::XlaOp expanded_feature_iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - input_feature * depthwise_multiplier, - &expanded_feature_iota)); + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); // Divide the M*N sized linspace by the depthwise_multiplier to create // [0 0 1 1 2 2] in the example in the function comment. expanded_feature_iota = - builder->Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); // Broadcast the N*M linspace to [H, W, ..., M, M*N]. auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = builder->Broadcast( - expanded_feature_iota, expanded_feature_broadcast_dims); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); // Compare the broadcasted linspace to the input feature linspace in the // input feature dimension to create a diagonal predicate. - return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); } // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding @@ -142,16 +139,16 @@ xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, implicit_broadcast_filter_shape.dims() - 1, depthwise_multiplier * input_feature); auto implicit_broadcast_filter = - builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); // Broadcast the filter to [H, W, ..., M, M*N]. auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); // If the filter mask is set, choose the broadcasted filter, othwerwise, // choose zero. - return builder->Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); + return xla::Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. @@ -162,17 +159,17 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - auto masked_expanded_filter = builder->Select( + auto masked_expanded_filter = xla::Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); - return builder->Reshape( + return xla::Reshape( // This reduce does not need inputs to be converted with // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with // ExpandedZero guarantees that only one element is non zero, so there // cannot be accumulated precision error. - builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), + xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), filter_shape.dim_sizes()); } @@ -289,8 +286,8 @@ class ConvOp : public XlaOpKernel { } xla::XlaOp conv = - b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } @@ -435,11 +432,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights - xla::XlaOp in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = xla::ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -638,8 +635,8 @@ class ConvBackpropFilterOp : public XlaOpKernel { // This is done by specifying the window dilation factors in the // convolution HLO below. auto filter_backprop = - b->ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); if (depthwise_) { filter_backprop = ContractFilterForDepthwiseBackprop( diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 7fcd4170fb79a574663c1abffe873d4b53f471d3..500a564f3f0489a42dbc9d5b70ae7708a7a43973 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -58,21 +59,21 @@ class CrossOp : public XlaOpKernel { auto in1 = ctx->Input(1); starts.back() = 0; limits.back() = 1; - auto u1 = b->Slice(in0, starts, limits, strides); - auto v1 = b->Slice(in1, starts, limits, strides); + auto u1 = xla::Slice(in0, starts, limits, strides); + auto v1 = xla::Slice(in1, starts, limits, strides); starts.back() = 1; limits.back() = 2; - auto u2 = b->Slice(in0, starts, limits, strides); - auto v2 = b->Slice(in1, starts, limits, strides); + auto u2 = xla::Slice(in0, starts, limits, strides); + auto v2 = xla::Slice(in1, starts, limits, strides); starts.back() = 2; limits.back() = 3; - auto u3 = b->Slice(in0, starts, limits, strides); - auto v3 = b->Slice(in1, starts, limits, strides); + auto u3 = xla::Slice(in0, starts, limits, strides); + auto v3 = xla::Slice(in1, starts, limits, strides); - auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2)); - auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3)); - auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1)); - auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1); + auto s1 = xla::Sub(xla::Mul(u2, v3), xla::Mul(u3, v2)); + auto s2 = xla::Sub(xla::Mul(u3, v1), xla::Mul(u1, v3)); + auto s3 = xla::Sub(xla::Mul(u1, v2), xla::Mul(u2, v1)); + auto output = xla::ConcatInDim(b, {s1, s2, s3}, in0_shape.dims() - 1); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 01aa1a83e7967921f1583b3ef18ec57e452dcfea..9ff3e0222831cb4339943966810eeae451e47a2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -96,18 +96,16 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // First reshape the inputs, which should be a metadata-only // operation since we are flattening the dimensions in order. - auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape()); + auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); + auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); // Next broadcast the necessary input dimensions. We rely on the // XLA optimizer to be smart about the fact that we are asking // it to broadcast size 1 on some of these dimensions, to avoid // adding complexity to this code. - auto lhs_broadcast = - builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast()); + auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = - builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast()); + auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); int rhs_size = broadcast_helper.y_bcast().size(); // Now reshape them to the correct output shape. After the @@ -122,15 +120,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { lhs_reorder.push_back(i); lhs_reorder.push_back(i + lhs_size); } - auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder, - broadcast_helper.output_shape()); + auto lhs_output = + xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); std::vector rhs_reorder; for (int i = 0; i < rhs_size; ++i) { rhs_reorder.push_back(i); rhs_reorder.push_back(i + rhs_size); } - auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder, - broadcast_helper.output_shape()); + auto rhs_output = + xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); return {lhs_output, rhs_output}; } diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 23243f62462c6315e359d9621823b19fc98c6218..f3149200250935629a6e4bf67bff0c048135ce3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,7 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -151,7 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 931705ba837153e1175cd9a209876ef5ec93f0fc..6dec414c53bee6b0102e229c86cfafb4072a35f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,6 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -25,10 +28,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr CreateDiagonal( - const xla::XlaOp& input, int64 last_dim_size, - tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, - xla::XlaBuilder* builder) { +xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, + gtl::ArraySlice other_dims, + xla::PrimitiveType element_type) { + xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -38,16 +41,14 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::XlaOp iota; - TF_RETURN_IF_ERROR( - XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); - xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size); + xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. if (!other_dims.empty()) { - mask = builder->Broadcast(mask, other_dims); + mask = xla::Broadcast(mask, other_dims); } // Broadcast the input, and then use the mask computed above to select the @@ -64,18 +65,15 @@ xla::StatusOr CreateDiagonal( std::vector broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; - xla::PrimitiveType element_type; - TF_RETURN_IF_ERROR( - DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::XlaOp zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape); - input_broadcast = builder->Add(input_broadcast, zeros); - return builder->Select(mask, input_broadcast, zeros); + input_broadcast = xla::Add(input_broadcast, zeros); + return xla::Select(mask, input_broadcast, zeros); } class DiagOp : public XlaOpKernel { @@ -83,8 +81,6 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -104,19 +100,17 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - input = builder->Reshape(input, {size}); + input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. - auto diag_or_status = - CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::XlaOp diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0)); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); std::copy(dims.begin(), dims.end(), new_dims.begin()); std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); - diag = builder->Reshape(diag, new_dims); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -170,21 +164,21 @@ class DiagPartOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + diag = xla::Reshape(diag, {size}); // Adds padding after the last element of 'new_size'. xla::PaddingConfig config; auto* dim = config.add_dimensions(); dim->set_edge_padding_high(new_size); auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); + diag = xla::Pad(diag, zero, config); // Reshapes so the diagonal is now in the first column. - diag = builder->Reshape(diag, {new_size, new_size + 1}); + diag = xla::Reshape(diag, {new_size, new_size + 1}); // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); - diag = builder->Reshape(diag, new_dims); + diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -197,8 +191,6 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -208,17 +200,15 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); tensorflow::gtl::ArraySlice other_dims(dims); other_dims.pop_back(); - auto diag_or_status = - CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - diag = diag_or_status.ValueOrDie(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, + ctx->input_xla_type(0)); ctx->SetOutput(0, diag); } }; @@ -265,7 +255,7 @@ class MatrixDiagPartOp : public XlaOpKernel { // Collapses the last two dimensions. std::vector flattened_dims(dims.begin(), dims.end() - 1); flattened_dims.back() *= dims.back(); - diag = builder->Reshape(diag, flattened_dims); + diag = xla::Reshape(diag, flattened_dims); // Slices or pads the last dimension to 'target_size'. int64 actual_size = flattened_dims.back(); @@ -276,13 +266,13 @@ class MatrixDiagPartOp : public XlaOpKernel { auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); dim->set_edge_padding_high(target_size - actual_size); auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); + diag = xla::Pad(diag, zero, config); } else if (actual_size > target_size) { std::vector start(flattened_dims.size(), 0); std::vector limits(flattened_dims.begin(), flattened_dims.end()); std::vector strides(flattened_dims.size(), 1); limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits, strides); + diag = xla::Slice(diag, start, limits, strides); } // Reshape so the target values are in the first position of the last @@ -290,18 +280,18 @@ class MatrixDiagPartOp : public XlaOpKernel { std::vector unflattened_dims(dims.begin(), dims.end()); dims[last_dim - 1] = smaller_dim_size; dims[last_dim] = last_dim_size + 1; - diag = builder->Reshape(diag, dims); + diag = xla::Reshape(diag, dims); // Slices out the first column and reshapes to the final shape. std::vector start(dims.size(), 0); std::vector limits(dims.begin(), dims.end()); std::vector strides(dims.size(), 1); limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits, strides); + diag = xla::Slice(diag, start, limits, strides); // Collapses away the last dimension. dims.pop_back(); - diag = builder->Reshape(diag, dims); + diag = xla::Reshape(diag, dims); ctx->SetOutput(0, diag); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 0419de78b2ee83fd395e8bf23444fde84f30bba2..3b86ea34c9e7d943eb9c7de222e0a2be049ebc68 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -57,8 +57,8 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( - ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = + xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dd4a16908779508380b36f43ce2306ff2f5fb8c4..958231505b50431b9bb267b0a3cc5ed56e3aeb21 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -150,8 +151,7 @@ class DynamicStitchOp : public XlaOpKernel { if (new_shape == data_shapes[input_num]) { input[input_num] = handle; } else { - input[input_num] = - ctx->builder()->Reshape(handle, new_shape.dim_sizes()); + input[input_num] = xla::Reshape(handle, new_shape.dim_sizes()); } } @@ -175,10 +175,10 @@ class DynamicStitchOp : public XlaOpKernel { // And place it in the concat list in the place indicated by // the index. to_concat[index_num] = - ctx->builder()->Slice(expression, slice_start, slice_limit, stride); + xla::Slice(expression, slice_start, slice_limit, stride); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), to_concat, 0)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index ed7462c16615f7f63a174e29843c2a1675c17058..2c76bcee2593b820eafe09af3a52736ed8a92f86 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -34,10 +34,9 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), expm1)); } }; @@ -52,9 +51,9 @@ class EluGradOp : public XlaOpKernel { const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto exp_grad = b->Mul(grad, b->Add(activation, one)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, one)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, grad, exp_grad)); } }; @@ -68,15 +67,14 @@ class SeluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); - ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), - b->Mul(scale_alpha, expm1))); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, xla::Mul(scale, ctx->Input(0)), + xla::Mul(scale_alpha, expm1))); } }; @@ -94,10 +92,10 @@ class SeluGradOp : public XlaOpKernel { 1.7580993408473768599402175208123); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto lin_grad = b->Mul(grad, scale); - const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + const auto lin_grad = xla::Mul(grad, scale); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, scale_alpha)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, lin_grad, exp_grad)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 6df01cabbf1d98c0299bfd808bcc6db6223c4777..65d42a302fca48c7b5f88813f80e975823f63ddf 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -110,13 +112,11 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::XlaOp iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - kernel_size * depth, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth); - auto lhs = builder->Reshape(iota, lhs_shape); - auto filter = builder->ConvertElementType( - builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + auto lhs = xla::Reshape(iota, lhs_shape); + auto filter = xla::ConvertElementType( + xla::Eq(lhs, iota, {num_spatial_dims + 1}), type); xla::ConvolutionDimensionNumbers dims; std::vector window_strides(num_spatial_dims); @@ -148,8 +148,8 @@ class ExtractImagePatchesOp : public XlaOpKernel { } xla::XlaOp conv = - builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, - padding, lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 8f0de0a524c908b598c1a2165a462275346ad137..2fd1a34741e1c7235397f9a69dd8444b4679fa22 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -49,20 +50,20 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type, const float quant_min_value, const float quant_max_value, xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, xla::XlaOp* scale) { - *scale = b->Div(b->Sub(max, min), - XlaHelpers::FloatLiteral(b, data_type, - quant_max_value - quant_min_value)); + *scale = xla::Div(xla::Sub(max, min), + XlaHelpers::FloatLiteral( + b, data_type, quant_max_value - quant_min_value)); xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); xla::XlaOp nudged_zero_point = - b->Select(b->Le(zero_point_from_min, quant_min), quant_min, - b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, - b->Round(zero_point_from_min))); - *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); - *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); + xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min, + xla::Select(xla::Ge(zero_point_from_min, quant_max), + quant_max, xla::Round(zero_point_from_min))); + *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale); } xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, @@ -71,14 +72,14 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, const xla::XlaOp& nudged_input_max, const xla::XlaOp& input_scale) { xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp inv_scale = xla::Div(one, input_scale); xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); - xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); - xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min); xla::XlaOp rounded = - b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); - return b->Add(b->Mul(rounded, input_scale), nudged_input_min); + xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half)); + return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min); } class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { @@ -163,11 +164,11 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), - gradient_shape.dim_sizes()); - xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); + xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -249,25 +250,25 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::XlaOp below_min = b->Lt(input, nudged_input_min); - xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); - xla::XlaOp reduce1 = b->ReduceAll( + xla::XlaOp below_min = xla::Lt(input, nudged_input_min); + xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::XlaOp above_max = b->Gt(input, nudged_input_max); - xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); - xla::XlaOp reduce2 = b->ReduceAll( + xla::XlaOp above_max = xla::Gt(input, nudged_input_max); + xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 933924cad1c7cac2879bd4720cb21ffc33c23f50..b2b00e51e3b00fa93c258af489cf0f4a3e6e764b 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,8 +63,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e4467a0fb138ed7919af62ed032c0f5abee3e4f6..95faa1d058f4c0d3fa802b157c6daba1e1adaf41 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -59,11 +60,11 @@ class FillOp : public XlaOpKernel { xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); - data = ctx->builder()->Reshape(data, {}); + data = xla::Reshape(data, {}); } // Emit the actual computation, which broadcasts the scalar to the // desired shape. - auto result = ctx->builder()->Broadcast(data, broadcast); + auto result = xla::Broadcast(data, broadcast); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index d13e25bcddae16d0cd630403219657121b80868d..5f041be5df226ed996b21844c0cf92b6dfac005c 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -75,8 +76,8 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, out_shape.AppendShape(indices_shape_no_index_vectors); out_shape.AppendShape(input_shape_post_axis); - *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = + xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); return Status::OK(); } @@ -142,7 +143,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, dim_numbers.add_gather_dims_to_operand_dims(i); } - *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 8b9b026643cf35216a2082dfcce9270c017bd14f..f5fcf3cacdbff8297bc42fcb0cf79c2bc83a4e11 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -48,11 +49,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); @@ -60,7 +61,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = resource->kind(); - OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); arg.type = resource->type(); arg.shape = resource->shape(); @@ -79,7 +79,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); - inputs[i] = ctx->Input(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString(); } @@ -100,6 +99,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, arguments, &else_result)); + bool has_tensor_array_gradients = false; for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { XlaResource* resource; @@ -121,9 +121,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; } } + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + then_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + else_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + } + // Check that both branches have identical input shapes. OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); @@ -175,13 +187,26 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } - xla::XlaOp outputs = - b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, - b->Tuple(inputs), *else_result.computation); + int num_inputs = then_result.input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = then_result.input_mapping[i] + 1; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + + xla::XlaOp outputs = xla::Conditional( + ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, + xla::Tuple(b, inputs), *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - xla::XlaOp output_handle = b->GetTupleElement(outputs, i); + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); if (VLOG_IS_ON(2)) { LOG(INFO) << "Setting output " << i; auto shape_or = b->GetShape(output_handle); @@ -209,7 +234,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - b->GetTupleElement(outputs, pos), b)); + xla::GetTupleElement(outputs, pos), b)); } VLOG(2) << "If variable: pos: " << update.input_index << " name: " << resource->name() diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 1568b33679963c1a6630525f60560180d40b8d53..cb4caf7bcb4caaa1bf7e0e79e52bb966a8838db3 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -32,23 +33,26 @@ std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, auto red = rgb[0]; auto green = rgb[1]; auto blue = rgb[2]; - auto value = b->Max(b->Max(red, green), blue); - auto minimum = b->Min(b->Min(red, green), blue); - auto range = b->Sub(value, minimum); - - auto zeros = b->Broadcast(zero, shape.dim_sizes()); - auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); - - auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); - - auto hue = b->Select(b->Eq(green, value), - b->Add(b->Mul(norm, b->Sub(blue, red)), - XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), - b->Add(b->Mul(norm, b->Sub(red, green)), - XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); - hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); - hue = b->Select(b->Gt(range, zero), hue, zeros); - hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); + auto value = xla::Max(xla::Max(red, green), blue); + auto minimum = xla::Min(xla::Min(red, green), blue); + auto range = xla::Sub(value, minimum); + + auto zeros = xla::Broadcast(zero, shape.dim_sizes()); + auto saturation = + xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); + + auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); + + auto hue = + xla::Select(xla::Eq(green, value), + xla::Add(xla::Mul(norm, xla::Sub(blue, red)), + XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), + xla::Add(xla::Mul(norm, xla::Sub(red, green)), + XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); + hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)), + hue); + hue = xla::Select(xla::Gt(range, zero), hue, zeros); + hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue); return {hue, saturation, value}; } @@ -66,15 +70,15 @@ std::array HSVToRGB(xla::XlaBuilder* b, auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); - auto dh = b->Mul(hue, six); - auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); - auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); - auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); - auto one_minus_s = b->Sub(one, saturation); + auto dh = xla::Mul(hue, six); + auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one); + auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one); + auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one); + auto one_minus_s = xla::Sub(one, saturation); - auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); - auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); - auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); + auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value); + auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value); + auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value); return {red, green, blue}; } @@ -97,21 +101,21 @@ class RGBToHSVOp : public XlaOpKernel { xla::XlaBuilder* b = context->builder(); xla::XlaOp input = context->Input(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim)); } }; REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); @@ -134,20 +138,20 @@ class HSVToRGBOp : public XlaOpKernel { xla::XlaBuilder* b = context->builder(); xla::XlaOp input = context->Input(0); - xla::XlaOp hue = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp saturation = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp value = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); @@ -182,18 +186,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, input, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *context->GetOrCreateAdd(accumulation_type), - {height_dim, width_dim}); + auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), + {height_dim, width_dim}); auto output = XlaHelpers::ConvertElementType(b, reduce, type); - output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + output = + xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims.back() = channel_dim; - output = b->Add(b->Mul(input, factor), - b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), - broadcast_dims); + output = + xla::Add(xla::Mul(input, factor), + xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)), + broadcast_dims); context->SetOutput(0, output); } }; @@ -226,26 +232,26 @@ class AdjustSaturationOp : public XlaOpKernel { DataType type = context->input_type(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, type)); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -276,15 +282,15 @@ class AdjustHueOp : public XlaOpKernel { DataType type = context->input_type(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), @@ -294,12 +300,13 @@ class AdjustHueOp : public XlaOpKernel { auto one = XlaHelpers::One(b, type); auto& hue = hsv[0]; - hue = b->Rem(b->Add(hsv[0], delta), one); - hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); + hue = xla::Rem(xla::Add(hsv[0], delta), one); + hue = + xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 9058cbc74762576c7e6f8ec1b2b0f6b247ac0502..d6bf92fb3df8d38909df99e11c85ede4fac2bf81 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" @@ -99,46 +101,71 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } +// Form a 2D convolution kernel like: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 +// by multiplying two 1D kernels of the form: +// 1/3 * [1 2 3 2 1] +// If the 2D kernel would be very large, the 1D kernel can be applied once in +// each dimension due to the symmetry of the kernel along all axis to reduce the +// computational intensity. +std::vector Make1DKernel(int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; +} + +// Kernels with more than 16 spatial elements are considered intense and the +// kernel should applied to each dimension independently. +const int64 kMax2DKernelSize = 16; + xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - // Form a 2D convolution kernel like: - // 1 2 3 2 1 - // 2 4 6 4 2 - // 1/9 * 3 6 9 6 3 - // 2 4 6 4 2 - // 1 2 3 2 1 - // by multiplying two 1D kernels of the form: - // 1/3 * [1 2 3 2 1] - auto make_1d_kernel = [](int64 n) { - std::vector kernel(n * 2 - 1); - for (int64 i = 0; i < n; ++i) { - float v = (i + 1.0f) / n; - kernel[i] = v; - kernel[n * 2 - 2 - i] = v; - } - return kernel; - }; - - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); - auto diag = builder->ConvertElementType( - builder->Eq( - builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + auto diag = xla::ConvertElementType( + xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); - return builder->Mul( - builder->Mul(diag, - builder->ConstantR1(make_1d_kernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}), - builder->ConstantR1(make_1d_kernel(kernel_size[0])), + return xla::Mul( + xla::Mul(diag, + xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}), + xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } +xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels, int64 dim) { + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + + auto diag = xla::ConvertElementType( + xla::Eq( + xla::Broadcast(channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + if (dim == 1) { + return xla::Mul( + diag, xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}); + } + return xla::Mul(diag, + xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const xla::XlaOp& input, const int num_spatial_dims, @@ -165,27 +192,49 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.add_output_spatial_dimensions(1 + i); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::XlaOp output = builder->ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp output; + // Split convolutions into independent dimensions if they wmuld be a very + // large kernel. + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + output = xla::ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + output = xla::ConvGeneralDilated( + input, kernel0, {dims.stride[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.kernel_size[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + output = xla::ConvGeneralDilated( + output, kernel1, {1, dims.stride[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.kernel_size[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && out_size[i] > 1) { - output = builder->Add(output, builder->ConstantR1(out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); + output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); } } return output; @@ -214,26 +263,63 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp output; + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = + xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + output = xla::ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + if (in_size[0] == 1 && grad_size[0] > 1) { + kernel0 = + xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), + /*broadcast_dimensions=*/{0}); + } + if (in_size[1] == 1 && grad_size[1] > 1) { + kernel1 = + xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - } - xla::XlaOp output = builder->ConvGeneralDilated( - grad, kernel, /*window_strides=*/dims.kernel_size, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = xla::ConvGeneralDilated( + grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.stride[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + output = xla::ConvGeneralDilated( + output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.stride[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. // Opposite of the slice performed by the forward op. @@ -246,7 +332,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } } if (pad_output) { - output = builder->Pad(output, builder->ConstantR0(0.0f), padding); + output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); } return output; } @@ -302,13 +388,13 @@ class ResizeBilinearOp : public XlaOpKernel { } } if (slice_input) { - input = b->Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); } // Output is always type float. - input = b->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -438,7 +524,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } } - output = b->ConvertElementType(output, output_type_); + output = xla::ConvertElementType(output, output_type_); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 36eb4c75454ed82804c40b82e5dbaec2eef0a719..f3964748587c1b31cf8b1b76643ff19a9044bf44 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -60,19 +60,15 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { input_shape.DebugString())); DataType index_type = output_type(0); + xla::PrimitiveType index_xla_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &index_xla_type)); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); - xla::XlaOp output; if (is_min_) { - OP_REQUIRES_OK(ctx, - XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), - index_type, axis, &output)); + output = XlaHelpers::ArgMin(input, index_xla_type, axis); } else { - OP_REQUIRES_OK(ctx, - XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0), - index_type, axis, &output)); + output = XlaHelpers::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 2c2d88486fda99d2380382a3e2f633f5bdc7478c..a020ebc729e4c07d1b182cc0585ba0f2bca46403 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // XLA passes to the function, so it is not included here. std::vector args; args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(input_shape.dim_sizes()))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1(output_shape.dim_sizes()))); + args.push_back( + xla::ConstantLiteral(&b, *xla::Literal::CreateR0(dim))); } xla::Shape xla_shape = @@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); break; case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 1decf7d72d72bb697477e7f841ced2a1a0d5fbe9..9e64711051d31107db1bf6f1966f9ed6f5630c34 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -39,12 +39,12 @@ class L2LossOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto t = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); - auto square = b->Mul(t, t); - auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), dims); + auto square = xla::Mul(t, t); + auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - ctx->SetOutput(0, b->Div(deconverted, two)); + ctx->SetOutput(0, xla::Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 0388b4c830702ea00ec69fc42c6468326c88cf38..2fb072f827906d40dcf410f0312394c4f568a28d 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" @@ -90,8 +91,10 @@ class ListDiffOp : public XlaOpKernel { idx_output.push_back(i); } - context->SetOutput(0, context->builder()->ConstantR1(val_output)); - context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + context->SetOutput(0, + xla::ConstantR1(context->builder(), val_output)); + context->SetOutput(1, + xla::ConstantR1(context->builder(), idx_output)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 39fbf98a6274918840e9e351470f04c2d80c5d01..dc934543cb2f94fbe1e8f1f865156eb082d6a127 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -50,8 +51,8 @@ class LRNOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, input, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -59,12 +60,12 @@ class LRNOp : public XlaOpKernel { auto sqr_sum = XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); - auto scale = builder->Pow( - builder->Add(builder->ConstantR0(bias_), - builder->Mul(builder->ConstantR0(alpha_), sqr_sum)), - builder->ConstantR0(-beta_)); + auto scale = xla::Pow( + xla::Add(xla::ConstantR0(builder, bias_), + xla::Mul(xla::ConstantR0(builder, alpha_), sqr_sum)), + xla::ConstantR0(builder, -beta_)); - ctx->SetOutput(0, builder->Mul(input, scale)); + ctx->SetOutput(0, xla::Mul(input, scale)); } private: @@ -138,8 +139,8 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -148,17 +149,17 @@ class LRNGradOp : public XlaOpKernel { XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = - builder->Add(builder->ConstantR0(bias_), - builder->Mul(builder->ConstantR0(alpha_), sqr_sum)); + xla::Add(xla::ConstantR0(builder, bias_), + xla::Mul(xla::ConstantR0(builder, alpha_), sqr_sum)); - auto dy = builder->Mul( - builder->Mul(builder->ConstantR0(-2.0f * alpha_ * beta_), - builder->Div(out_image, norm)), + auto dy = xla::Mul( + xla::Mul(xla::ConstantR0(builder, -2.0f * alpha_ * beta_), + xla::Div(out_image, norm)), in_grads); auto converted_dy = XlaHelpers::ConvertElementType(builder, dy, accumulation_type); - auto dy_reduce = builder->ReduceWindow( + auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -166,10 +167,10 @@ class LRNGradOp : public XlaOpKernel { auto dy_reduced = XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::XlaOp gradients = builder->Add( - builder->Mul(in_image, dy_reduced), - builder->Mul(in_grads, - builder->Pow(norm, builder->ConstantR0(-beta_)))); + xla::XlaOp gradients = xla::Add( + xla::Mul(in_image, dy_reduced), + xla::Mul(in_grads, + xla::Pow(norm, xla::ConstantR0(builder, -beta_)))); ctx->SetOutput(0, gradients); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6949b296f4b9afe4a0c9152c763a9ad233b9f595..844080b8cf5462da201ce7671e4f9d02fa52c861 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -70,15 +71,15 @@ class MatMulOp : public XlaOpKernel { xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { - a = ctx->builder()->ConvertElementType(a, xla::F32); + a = xla::ConvertElementType(a, xla::F32); } if (b_type_ == DT_BFLOAT16) { - b = ctx->builder()->ConvertElementType(b, xla::F32); + b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); + auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; + auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; + ctx->SetOutput(0, xla::Dot(lhs, rhs)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index fbd5dc0fdad4483aadbe9bc263cc1f7a034cee09..e06c87db7adb1840606208fe15cd68a3ca4d137a 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -50,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel { xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); + xla::PrimitiveType index_xla_type = context->input_xla_type(1); TensorShape batch_shape = input_shape; batch_shape.RemoveLastDims(2); @@ -58,33 +61,29 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); + xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); + xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); - - auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); // If num_lower or num_upper are negative, include all lower/upper // diagonals. auto zero_index = XlaHelpers::Zero(builder, index_type); - num_lower = builder->Select( - builder->Lt(num_lower, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); - num_upper = builder->Select( - builder->Lt(num_upper, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + num_lower = xla::Select(xla::Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), + num_lower); + num_upper = xla::Select(xla::Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), + num_upper); - auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), - builder->Le(offset, num_upper)); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset), + xla::Le(offset, num_upper)); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); auto zero_input = XlaHelpers::Zero(builder, input_type); - auto output = builder->Select( - indicator, input, - builder->Broadcast(zero_input, input_shape.dim_sizes())); + auto output = xla::Select( + indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes())); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index db53f6fef8d6bf901c8281f50791ca6766c46efd..e2ab4b83cfb45b2f9a7f3aba2d2a927d10ad8b85 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -61,14 +63,11 @@ class MatrixSetDiagOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); - auto indicator = builder->Eq(iota_m, - builder->Broadcast(iota_n, {m}), - /*broadcast_dimensions=*/{0}); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); + xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); + auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); // Broadcast diag up to the input shape. Use an implicit broadcast (Add) // because we need to broadcast on the right. @@ -77,10 +76,10 @@ class MatrixSetDiagOp : public XlaOpKernel { if (min_dim != m) { diag_broadcast_dims.back() = rank - 1; } - diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); + diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); - auto output = builder->Select(indicator, diag, input); + auto output = xla::Select(indicator, diag, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index eaed93146460de5a6e8328432302cc75bf36a534..f4def11d08c31513aec5aad15187016a7294c2fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = TriangularSolve( - ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 7e9de3ef9b245c113cc143128fe58e7e017a361c..529959dbd90b05f8860360f70e087ef225150600 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { @@ -27,21 +28,21 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, + const xla::LiteralSlice& pad_literal, xla::XlaBuilder* b) { xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { - auto t_rev = b->Rev(accum, {dimno}); + auto t_rev = xla::Rev(accum, {dimno}); TF_ASSIGN_OR_RETURN(int64 lhs_padding, pad_literal.GetIntegralAsS64({dimno, 0})); TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); - accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno); + auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, + dim_size - 1, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; } diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 8c8a9bbe787f3224e7444b62dcf8ad99130cf37f..65ab9da8d7ca0509a4a69c43727a0e6c0435908a 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -24,8 +24,7 @@ namespace tensorflow { REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); // We register ControlTrigger as a no-op. This is correct since nodes seen -// by the XLA compiler are never dead. This may need rethinking when we add -// support for conditionals to XLA. -REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); +// by the XLA compiler are never dead. +REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index aecaabb6dcf46bdd6ae3da929448d6370acb989b..3aed47de2603f3e187ad515d4db3f884da4c6cc8 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,11 +77,10 @@ class PackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. - reshaped_inputs[i] = - ctx->builder()->Reshape(values[i], child_shape.dim_sizes()); + reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes()); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 7c95475e7b1f02183e44f73f116a4aeb25f05c09..89fd610bc63349d008836c3c4e6ec8927c232a54 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -63,8 +64,8 @@ class PadOp : public XlaOpKernel { int before = pad_literal.Get({i, 0}); int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, - errors::InvalidArgument("Paddings must be non-negative: ", - before, " ", after)); + errors::InvalidArgument( + "Paddings must be non-negative: ", before, " ", after)); dim->set_edge_padding_low(before); dim->set_edge_padding_high(after); } @@ -74,11 +75,10 @@ class PadOp : public XlaOpKernel { if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, - ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index f8e7b48a0fd94835964aea033ad33523150067b4..a81f5fddf69523619d03ea2041c40222de46174e 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -61,6 +63,9 @@ class PoolingOp : public XlaOpKernel { Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -113,8 +118,8 @@ class PoolingOp : public XlaOpKernel { xla::XlaBuilder* const b = ctx->builder(); auto input = XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = ctx->builder()->ReduceWindow( - input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, + stride, padding_); auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); ctx->SetOutput(0, PostProcessOutput(ctx, pooled, input_type(0), input_shape)); @@ -127,6 +132,7 @@ class PoolingOp : public XlaOpKernel { xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; class MaxPoolOp : public PoolingOp { @@ -136,7 +142,7 @@ class MaxPoolOp : public PoolingOp { /*reduction_type=*/ctx->input_type(0)) {} xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::MinValue(b, reduction_type_); + return xla::MinValue(b, xla_reduction_type_); } const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { @@ -190,7 +196,7 @@ static xla::XlaOp AvgPoolDivideByCount( auto divisor = XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return ctx->builder()->Div(output, divisor); + return xla::Div(output, divisor); } else { // For SAME padding, the padding shouldn't be included in the // counts. We use another ReduceWindow to find the right counts. @@ -212,18 +218,18 @@ static xla::XlaOp AvgPoolDivideByCount( // Build a matrix of all 1s, with the same width/height as the input. const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = ctx->builder()->Broadcast( + auto ones = xla::Broadcast( XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto reduce = ctx->builder()->ReduceWindow( + auto reduce = xla::ReduceWindow( ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - return ctx->builder()->Div(output, counts, window_dims); + return xla::Div(output, counts, window_dims); } } @@ -235,7 +241,7 @@ class AvgPoolOp : public PoolingOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::Zero(b, reduction_type_); + return xla::Zero(b, xla_reduction_type_); } const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { @@ -347,9 +353,9 @@ class MaxPoolGradOp : public XlaOpKernel { xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::XlaOp gradients = ctx->builder()->SelectAndScatter( - input, select, ksize_, stride_, xla_padding, out_backprop, init_value, - scatter); + xla::XlaOp gradients = + xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding, + out_backprop, init_value, scatter); ctx->SetOutput(0, gradients); } @@ -485,12 +491,12 @@ class AvgPoolGradOp : public XlaOpKernel { } auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); + auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = b->ReduceWindow( + auto in_backprop = xla::ReduceWindow( XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), ksize_, @@ -614,58 +620,61 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); - auto sixteen = b->ConstantR0(16); + auto sixteen = xla::ConstantR0(b, 16); // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 - auto in_hi = b->BitcastConvertType( - b->ConvertElementType(b->ConvertElementType(input, xla::BF16), - xla::F32), + auto in_hi = xla::BitcastConvertType( + xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16), + xla::F32), xla::U32); - auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); - auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); - auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); - auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. - auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. - - auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); + auto bp_lo = + xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. + + auto init_value = xla::MinValue(b, xla::F32); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo // 16 bits of packed-in hi/lo backprop value). auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); { // F32 parameters to satisfy lowering type restriction for reduce opcode. const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto lhs = rb->Parameter(0, scalar, "lhs"); - auto rhs = rb->Parameter(1, scalar, "rhs"); - auto sixteen = rb->ConstantR0(16); - auto lhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), - sixteen); - auto rhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), - sixteen); + auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); + auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); + auto sixteen = xla::ConstantR0(rb.get(), 16); + auto lhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); // Must use a F32 comparison, because S32 would not work for negatives. - rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), - rb->BitcastConvertType(rhs_criteria, xla::F32)), - lhs, rhs); + xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32), + xla::BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); } auto reduce = rb->BuildAndNoteError(); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; auto pooled_hi = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto pooled_lo = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto grads_hi = - b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); - auto grads_lo = b->ShiftRightLogical( - b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = xla::ShiftRightLogical( + xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen), sixteen); - auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add. xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type)); } protected: @@ -694,5 +703,18 @@ REGISTER_XLA_OP(Name("MaxPoolGradGradV2") .CompileTimeConstInput("strides"), MaxPool2DGradGradOp); +class MaxPool3DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool3DGradGradOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 661cd5923e1023eaf89a6bc4f56fcc362c8bcfb6..e88221e4f400abeec59d85c1539d4f70bf515d3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -28,82 +31,115 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); - OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), - errors::InvalidArgument("num_bits is out of range: ", num_bits_, - " with signed_input_ ", signed_input_)); } void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - // Comments taken from semantics description at - // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize - // - // ... we find m such that - // - // m = max(abs(input_min), abs(input_max)) if range_given is true, - // m = max(abs(min_elem(input)), - // abs(max_elem(input))) otherwise. + xla::PrimitiveType xla_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(data_type, &xla_type)); + xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp input_min, input_max; + + // The implementation follows + // tensorflow/core/kernels/quantize_and_dequantize_op.h closely. + xla::XlaOp min_range, max_range; if (range_given_) { - double input_min_value, input_max_value; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value)); - input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); - input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); + min_range = ctx->Input(1); + max_range = ctx->Input(2); } else { const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); - input_min = - b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); - input_max = - b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); + min_range = ReduceAll(input, xla::MaxValue(b, xla_type), *fmin); + max_range = ReduceAll(input, xla::MinValue(b, xla_type), *fmax); } - xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); - - // Next, we choose our fixed-point quantization buckets, [min_fixed, - // max_fixed]. If signed_input is true, this is - // - // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1), - // (1 << (num_bits - 1)) - 1]. - // - // Otherwise, if signed_input is false, the fixed-point range is - // - // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. - int64 min_fixed, max_fixed; + + xla::XlaOp num_bits; + if (num_bits_ < 0) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 4, + errors::Internal("Expected 4 inputs to QuantizeAndDequantize")); + num_bits = ctx->Input(3); + } else { + num_bits = xla::ConstantR0(b, num_bits_); + } + + const xla::XlaOp zero = XlaHelpers::Zero(b, data_type); + const xla::XlaOp one = XlaHelpers::One(b, data_type); + const xla::XlaOp two = XlaHelpers::FloatLiteral(b, data_type, 2.0); + const xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5); + + // Calculate the range for the simulated integer quantization: + // e.g. [-128,127] for signed = true, num_bits = 8, + // or [0, 255] for signed = false, num_bits = 8. + // We do this in floating point for hardware that does not have 64-bit + // integer support. + xla::XlaOp min_quantized, max_quantized; if (signed_input_) { - min_fixed = -((1LL << (num_bits_ - 1)) - 1); - max_fixed = (1LL << (num_bits_ - 1)) - 1; + min_quantized = + -Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + xla_type)); + max_quantized = + Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + xla_type)) - + one; } else { - min_fixed = 0; - max_fixed = (1LL << num_bits_) - 1; + min_quantized = zero; + max_quantized = Pow(two, ConvertElementType(num_bits, xla_type)) - one; } - // From this we compute our scaling factor, s: - // - // s = (max_fixed - min_fixed) / (2 * m). - xla::XlaOp s = - b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), - b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); + // Determine the maximum scaling factor that would scale + // [min_range, max_range] to not exceed [min_quantized, max_quantized], + // while keeping 0 unchanged. + xla::XlaOp scale_from_min_side = + Select(Gt(min_quantized * min_range, zero), min_quantized / min_range, + xla::MaxFiniteValue(b, xla_type)); + xla::XlaOp scale_from_max_side = + Select(Gt(max_quantized * max_range, zero), max_quantized / max_range, + xla::MaxFiniteValue(b, xla_type)); - // Now we can quantize and dequantize the elements of our tensor. An element - // e is transformed into e': - // - // e' = (e * s).round_to_nearest() / s. - xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); + // Note: Avoids changing the side of the range that determines scale. + xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side); + xla::XlaOp scale = Select(cond, scale_from_min_side, scale_from_max_side); + xla::XlaOp inverse_scale = + Select(cond, min_range / min_quantized, max_range / max_quantized); + min_range = Select(cond, min_range, min_quantized * inverse_scale); + max_range = Select(cond, max_quantized * inverse_scale, max_range); + if (range_given_) { + // Note: The clamping here is to avoid overflow in the quantized type. + // The semantics of the op does not guarantee to clamp to the specified + // min_range and max_range - because we may have changed either min_range + // or max_range. + // No need to clamp to min_range and max_range if range_given_ == false as + // in that case they were measured from the tensor. + input = Clamp(min_range, input, max_range); + } + xla::XlaOp result = + Floor((input - min_range) * scale + half) * inverse_scale + min_range; ctx->SetOutput(0, result); } - int64 num_bits_; + protected: + int64 num_bits_ = -1; bool signed_input_; bool range_given_; }; -REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp); +class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { + public: + explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx) + : QuantizeAndDequantizeOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); + OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), + errors::InvalidArgument("num_bits is out of range: ", num_bits_, + " with signed_input_ ", signed_input_)); + } +}; + +REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeV2Op); +REGISTER_XLA_OP(Name("QuantizeAndDequantizeV3"), QuantizeAndDequantizeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 5f5bd586376ab368e443671ac8a5de23a5fd604b..9a0a7f9b9004f210adac44ed8b6e32cff131d23b 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,11 +17,17 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -42,8 +48,8 @@ class RandomUniformOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -55,6 +61,77 @@ class RandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), RandomUniformOp); +class RandomShuffleOp : public XlaOpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + const int64 n = input_shape.dim_size(0); + int64 num_elements = 1; + for (tensorflow::TensorShapeDim dimension : input_shape) { + num_elements *= dimension.size; + } + if (num_elements <= 1 || n <= 1) { + // No shuffling is required, so copy input directly to output + ctx->SetOutput(0, input); + } else { + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + xla::RngUniform(xla::ConstantR0(builder, 0), + xla::ConstantR0(builder, n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices = xla::Iota(builder, xla::S32, n); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = xla::Reshape(i, {1}); + // temp = indices[i] + auto temp = xla::DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = xla::DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = xla::DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); +}; + +REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); + class RandomUniformIntOp : public XlaOpKernel { public: explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -77,7 +154,7 @@ class RandomUniformIntOp : public XlaOpKernel { auto minval = ctx->Input(1); auto maxval = ctx->Input(2); - ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape)); + ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); } private: @@ -103,8 +180,8 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -127,62 +204,20 @@ class TruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::Shape xla_element_shape = - xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape); - - auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { - return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); - }; - auto out_of_range_mask = [two_sd](xla::XlaOp candidate, - xla::XlaBuilder* b) { - xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); - xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); - return b->Or(too_large, too_small); - }; - - // The algorithm we're using is roughly: - // - // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { - // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd - // candidate = select(out_of_range_mask, rng_normal(), candidate) - // } - std::unique_ptr test_builder = - b->CreateSubBuilder("truncated_normal_test"); - { - auto* b = test_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - out_of_range_mask(candidate, b); - OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); - } - std::unique_ptr body_builder = - b->CreateSubBuilder("truncated_normal_body"); - { - auto* b = body_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - xla::XlaOp to_resample = out_of_range_mask(candidate, b); - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); - } - - xla::StatusOr test_computation = test_builder->Build(); - OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); - OP_REQUIRES_OK(ctx, body_computation.status()); - xla::XlaOp result = b->While(test_computation.ValueOrDie(), - body_computation.ValueOrDie(), candidate); - - ctx->SetOutput(0, result); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp min_positive = + XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + auto uniform = xla::RngUniform(min_positive, one, xla_shape); + ctx->SetOutput(0, TruncatedNormal(uniform)); } }; -REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("TruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 08894489ac77bbbe4ddb067c06a6d031a537697d..76bd1e62aa1efd85d6ed489b9a6d22a2bacf2a8b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -98,10 +99,10 @@ class ReduceWindowOp : public XlaOpKernel { { std::unique_ptr cb = builder->CreateSubBuilder("wrapper"); - auto x = cb->Parameter(0, scalar_shape, "x"); - auto y = cb->Parameter(1, scalar_shape, "y"); - auto outputs = cb->Call(*reducer.computation, {x, y}); - cb->GetTupleElement(outputs, 0); + auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); + auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); + auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); + xla::GetTupleElement(outputs, 0); xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(context, result.status()); wrapper = std::move(result.ValueOrDie()); @@ -112,7 +113,7 @@ class ReduceWindowOp : public XlaOpKernel { padding[i] = {padding_low_[i], padding_high_[i]}; } - xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( + xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), wrapper, window_dimensions_, window_strides_, padding); context->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 0f425637795e9633a8e36f921000ee2f5e25813a..46fae59ad4fa30b57946671518251a7e53ac4c8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -31,11 +33,11 @@ class SumOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, reduction_type_); + return xla::Zero(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } }; @@ -48,12 +50,12 @@ class ProdOp : public XlaReductionOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::One(builder, reduction_type_); + return xla::One(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Mul(scalar_lhs, scalar_rhs); + xla::Mul(scalar_lhs, scalar_rhs); } }; @@ -66,12 +68,12 @@ class MinOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MaxValue(builder, reduction_type_); + return xla::MaxValue(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Min(scalar_lhs, scalar_rhs); + xla::Min(scalar_lhs, scalar_rhs); } }; @@ -83,12 +85,12 @@ class MaxOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MinValue(builder, reduction_type_); + return xla::MinValue(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Max(scalar_lhs, scalar_rhs); + xla::Max(scalar_lhs, scalar_rhs); } }; @@ -101,11 +103,11 @@ class MeanOp : public XlaReductionOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, reduction_type_); + return xla::Zero(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, @@ -113,7 +115,7 @@ class MeanOp : public XlaReductionOp { int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - return builder->Div(reduce_output, divisor); + return reduce_output / divisor; } }; @@ -126,12 +128,12 @@ class AllOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0(true); + return xla::ConstantR0(builder, true); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->And(scalar_lhs, scalar_rhs); + xla::And(scalar_lhs, scalar_rhs); } }; @@ -143,12 +145,12 @@ class AnyOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0(false); + return xla::ConstantR0(builder, false); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Or(scalar_lhs, scalar_rhs); + xla::Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 2ecfb854a1c8625524d4f1199af3927edd204926..8333f9b288e27efe9497306f031980c9eec7c99c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -64,6 +64,7 @@ class XlaReductionOp : public XlaOpKernel { protected: DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4fd5bfd03999a7f8b7bb081cc4b03aa1434d4c3d..909783ecb3c2a866136e1a09767144c91c46525c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -31,6 +32,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } // Unless BuildFinalizer is overridden the reduction has no @@ -56,9 +59,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { // Evaluate the constant, reshaping to a 1-vector if it is a scalar. xla::Literal axes_literal; - OP_REQUIRES_OK(ctx, - ctx->ConstantInputReshaped( - 1, {axes_tensor_shape.num_elements()}, &axes_literal)); + OP_REQUIRES_OK( + ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, + &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << axes_literal.ToString(); @@ -101,20 +104,20 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); - auto data = b->ConvertElementType(ctx->Input(0), type); + auto data = xla::ConvertElementType(ctx->Input(0), type); // Call virtual method to get the initial value. - auto initial = b->ConvertElementType(InitialValue(b), type); + auto initial = xla::ConvertElementType(InitialValue(b), type); // Make two scalar parameters of the desired type for the lambda. - auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); + auto rx = xla::Parameter(&r, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = xla::Parameter(&r, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); - auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); - auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; + auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index ba7d484d53d7258edaa5bc42fa116cf16e94835b..a4ba6c748a73f161ea252e2adf4050eb5dda7df5 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -34,7 +34,7 @@ class ReluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); - ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); + ctx->SetOutput(0, xla::Max(zero, ctx->Input(0))); } }; @@ -46,7 +46,7 @@ class Relu6Op : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); - ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); + ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six)); } }; @@ -59,9 +59,9 @@ class ReluGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto pred = b->Gt(ctx->Input(1), zero); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = xla::Gt(ctx->Input(1), zero); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero)); } }; @@ -74,12 +74,12 @@ class Relu6GradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto six = b->Broadcast( + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = xla::Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = - b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = xla::Select( + xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index af4d64b159c09ed7e01017f25a2b23e58542dc3c..e0ca8dd8e27914ad60d0b97e8ac5f0b91a4fd9a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -90,8 +91,7 @@ class ReshapeOp : public XlaOpKernel { VLOG(1) << "Reshape " << input_shape.DebugString() << " " << shape.DebugString(); - ctx->SetOutput(0, - ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes())); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index a711278638444be01fb865561957702368b75114..5be70a4ded31a988cb77cdabe3fc8a041bc3ad16 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,15 +63,24 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { TensorShape shape = ctx->InputShape(0); - TensorShape representation_shape = - tc.is_entry_computation() - ? tc.RepresentationShape(shape, ctx->input_type(0)) - : shape; + ctx->SetStatus(is_constant.status()); + TensorShape representation_shape; + if (tc.is_entry_computation()) { + xla::StatusOr shape_or_status = + tc.RepresentationShape(shape, ctx->input_type(0)); + if (!shape_or_status.ok()) { + ctx->SetStatus(shape_or_status.status()); + return; + } else { + representation_shape = shape_or_status.ValueOrDie(); + } + } else { + representation_shape = shape; + } xla::XlaOp output = input; if (tc.is_entry_computation()) { - output = - ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + output = xla::Reshape(input, representation_shape.dim_sizes()); } else { // The core from which a return value is returned depends on the // device assignment of the input to the retval. Since we can't change @@ -78,8 +88,8 @@ class RetvalOp : public XlaOpKernel { // introduce an operator here, even if the shape does not change. // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. - output = ctx->builder()->GetTupleElement( - ctx->builder()->Tuple({output}), 0); + output = + xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); } tc.AddRetval(index_, dtype_, shape, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 2872a3c4d49d0d269aa3d216887a5c32cd51f1c3..037c422258555289711b8754f2277d077d0cd6a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -62,7 +63,7 @@ class ReverseOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), dimensions)); } }; @@ -100,7 +101,7 @@ class ReverseV2Op : public XlaOpKernel { x_shape.dims(), ").")); } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 0ed4c4707df71cf5f56ccfe0af506916f04bcdb5..c810456f94322acfccae18d78efa861eede4648c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -85,89 +87,96 @@ class ReverseSequenceOp : public XlaOpKernel { auto condition_builder = builder->CreateSubBuilder("reverse_sequence_condition"); { - auto param = condition_builder->Parameter(0, tuple_shape, "param"); - auto i = condition_builder->GetTupleElement(param, 0); - condition_builder->Lt( - i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, - batch_size)); + auto param = + xla::Parameter(condition_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(), + seq_lens_type, batch_size)); } auto condition = condition_builder->Build(); OP_REQUIRES_OK(context, condition.status()); auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); { - auto param = body_builder->Parameter(0, tuple_shape, "param"); - auto i = body_builder->GetTupleElement(param, 0); - auto seq_lens = body_builder->GetTupleElement(param, 1); - auto output = body_builder->GetTupleElement(param, 2); + auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + auto seq_lens = xla::GetTupleElement(param, 1); + auto output = xla::GetTupleElement(param, 2); // seq_len is the sequence length of the current batch element (rank 1) - auto seq_len = body_builder->DynamicSlice( - seq_lens, body_builder->Reshape(i, {1}), {1}); + auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - batch_dim_), - {1})); - - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, - body_builder->Sub(XlaHelpers::IntegerLiteral( - body_builder.get(), seq_lens_type, max_seq_len), - seq_len), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - seq_dim_), - {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. + auto batch_element_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + batch_element_indices = xla::DynamicUpdateSlice( + batch_element_indices, xla::Reshape(i, {1}), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, batch_dim_), + {1})); + + // Slice out the current batch element and pad it out in the sequence + // dimension. TensorShape slice_shape = input_shape; slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = xla::DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = xla::DynamicUpdateSlice( + sequence_start_indices, + xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + max_seq_len), + seq_len), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, seq_dim_), + {1})); + slice = xla::DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = xla::DynamicUpdateSlice(output, slice, batch_element_indices); - body_builder->Tuple( - {body_builder->Add( - i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + xla::Tuple( + body_builder.get(), + {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)), seq_lens, output}); } auto body = body_builder->Build(); OP_REQUIRES_OK(context, body.status()); - auto loop_output = builder->While( + auto loop_output = xla::While( condition.ValueOrDie(), body.ValueOrDie(), - builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, - builder->Rev(input, {seq_dim_})})); - auto output = builder->GetTupleElement(loop_output, 2); + xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + xla::Rev(input, {seq_dim_})})); + auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::XlaOp iota; - OP_REQUIRES_OK( - context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + xla::XlaOp iota = + xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); std::vector dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; - auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); // Broadcast the mask up to the input shape. - mask = - builder->Or(mask, builder->Broadcast(builder->ConstantR0(false), - input_shape.dim_sizes())); + mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0(builder, false), + input_shape.dim_sizes())); - output = builder->Select(mask, output, input); + output = xla::Select(mask, output, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 1819fb543317eed15b2fe0518d74aba5c564697d..76924c6a01a44e7a723b8c8895e8decbdd466c79 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -100,7 +101,7 @@ class ScanOp : public XlaOpKernel { init = XlaHelpers::One(builder, dtype); reducer = ctx->GetOrCreateMul(dtype); } - auto output = builder->ReduceWindowWithGeneralPadding( + auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, *reducer, window_dims, window_strides, padding); output = @@ -110,12 +111,12 @@ class ScanOp : public XlaOpKernel { // of all the input elements. Slice off this extra "last" element. if (exclusive_) { if (reverse_) { - output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, - 1, axis); + output = + xla::SliceInDim(output, 1, input_shape.dim_size(axis) + 1, 1, axis); } else { output = - builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + xla::SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); } } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index f2c63b4f9083ad3c7dd7cf318dc22def1e99fa9f..14709bb6cbce4b3ae0f7ff859b0fa622c6eda293 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -103,8 +104,8 @@ class ScatterNdOp : public XlaOpKernel { updates_shape)); xla::XlaBuilder* builder = context->builder(); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - buffer_shape.dim_sizes()); + auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); auto indices = context->Input(0); auto updates = context->Input(1); auto result = diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 664078ca16c6d5d4b57c4a8c661ad0848f30dd7d..e2ac7da2c2630725efe3dbcc51c3f3d30e7aca2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -14,20 +14,30 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { -class UnsortedSegmentSum : public XlaOpKernel { +class UnsortedSegmentReduce : public XlaOpKernel { public: - explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_)); } + // The initial value to initialize elements of the output to. + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; + + // A function to combine two scalars with the same index (e.g., sum). + virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0; + void Compile(XlaOpKernelContext* ctx) override { // output = unsorted_segment_sum(data, indices, num_segments) // Compute a tensor such that: @@ -50,28 +60,28 @@ class UnsortedSegmentSum : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), - errors::InvalidArgument( - "UnsortedSegmentSum requires that indices' rank be" - " less than or equal to data's rank.")); + errors::InvalidArgument(type_string(), + " requires that indices' rank be" + " less than or equal to data's rank.")); // Validate that indices.shape is a prefix of data.shape. for (int d = 0; d < indices_shape.dims(); ++d) { - OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), - errors::InvalidArgument( - "UnsortedSegmentSum requires indices shape to be prefix" - " of data_shape, but dimension ", - d, " differs ", data_shape.dim_size(d), " vs. ", - indices_shape.dim_size(d))); + OP_REQUIRES( + ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), + errors::InvalidArgument(type_string(), + " requires indices shape to be prefix" + " of data_shape, but dimension ", + d, " differs ", data_shape.dim_size(d), + " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), - buffer_shape.dim_sizes()); + auto buffer = + xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); - auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { - return builder->Add(a, b); - }; + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) { return Combine(a, b); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); @@ -79,13 +89,73 @@ class UnsortedSegmentSum : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie()); } - private: - DataType dtype_; + protected: + xla::PrimitiveType type_; +}; + +class UnsortedSegmentSum : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentSum(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::Zero(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; }; REGISTER_XLA_OP( Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), UnsortedSegmentSum); +class UnsortedSegmentProd : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentProd(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::One(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"), + UnsortedSegmentProd); + +class UnsortedSegmentMin : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMin(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::MaxFiniteValue(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { + return xla::Min(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"), + UnsortedSegmentMin); + +class UnsortedSegmentMax : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMax(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::MinFiniteValue(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { + return xla::Max(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"), + UnsortedSegmentMax); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index f9f48164d63492b057d4950abfc2ca6153e44870..5c010c9df23ba6c7732d87fa014879d93ff586ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -40,8 +41,6 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::XlaBuilder* builder = ctx->builder(); - auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); auto else_handle = ctx->Input(2); @@ -69,14 +68,14 @@ class SelectOp : public XlaOpKernel { const auto dim_sizes = then_shape.dim_sizes(); gtl::ArraySlice bdims = dim_sizes; bdims.pop_front(); - cond_handle = builder->Broadcast(cond_handle, bdims); + cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); dim_order[0] = then_shape.dims() - 1; std::iota(dim_order.begin() + 1, dim_order.end(), 0); - cond_handle = builder->Transpose(cond_handle, dim_order); + cond_handle = xla::Transpose(cond_handle, dim_order); } - ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle)); + ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 9ce01d0d44509bbcbea18afdb4210a675834bb6d..6281d6c6533f7f49a269f5c7e52226ba0f1d29f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -45,7 +45,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->builder()->Send(ctx->Input(0), channel); + xla::Send(ctx->Input(0), channel); } REGISTER_XLA_OP(Name("XlaSend"), SendOp); @@ -76,7 +76,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel)); + ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); } REGISTER_XLA_OP(Name("XlaRecv"), RecvOp); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 2c31f8d90891924f6f86a54ccf548de4df87f3bd..bc3d0bf5dfe9e5af8e50a25e27db7148e05e0cfd 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::Literal& start_literal, - const xla::Literal& limit_literal, - const xla::Literal& delta_literal, Tensor* output) { +Status CreateRangeTensor(const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, + Tensor* output) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal, } if (delta > 0) { if (start > limit) { - return errors::InvalidArgument("Requires start <= limit when delta > 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start <= limit when delta > 0: ", start, "/", limit); } } else { if (start < limit) { - return errors::InvalidArgument("Requires start >= limit when delta < 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start >= limit when delta < 0: ", start, "/", limit); } } int64 size = diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 05354bca5bb089703fdcceb6f44648bbb98d004b..5798823cd54c66dd179e3611c0041f7c5a1ff2b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -43,7 +44,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape"), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -65,7 +66,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -81,7 +82,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank"), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -100,7 +101,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size"), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -147,7 +148,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); @@ -189,10 +190,9 @@ class SqueezeOp : public XlaOpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze dimension ", i, + " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); @@ -205,7 +205,7 @@ class SqueezeOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } private: @@ -222,7 +222,7 @@ class ZerosLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); } }; @@ -236,7 +236,7 @@ class OnesLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto one = XlaHelpers::One(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index be1e97bf26fa4cde1b741c8d0b843a85ce33a59c..1864584adee357ce35a3e8a38a4e3c58c356bfca 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,8 +93,7 @@ class SliceOp : public XlaOpKernel { limits.push_back(begin[i] + size[i]); } std::vector strides(begin.size(), 1); - ctx->SetOutput( - 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides)); + ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides)); } else { // `begin` is not a compile-time constant. for (int i = 0; i < input_dims; ++i) { @@ -106,8 +106,7 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput( - 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index bbf5ee8b12186a582666121b1df5d8b7d881863e..a71fbcd901e8919949db5873675a7e3e785bdf4e 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,9 +15,12 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -41,6 +44,7 @@ class SoftmaxOp : public XlaOpKernel { const int kClassDim = 1; const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); xla::XlaBuilder* const b = ctx->builder(); @@ -48,24 +52,27 @@ class SoftmaxOp : public XlaOpKernel { // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - auto exp_shifted = b->Exp(shifted_logits); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); + auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + xla::PrimitiveType xla_accumulation_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, + &xla_accumulation_type)); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + xla::ConvertElementType(exp_shifted, xla_accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum = XlaHelpers::ConvertElementType(b, reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : b->Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -77,8 +84,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); std::pair CrossEntropyWithLogits( - XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, - const xla::XlaOp& labels) { + XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type, + xla::XlaOp logits, xla::XlaOp labels) { const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; @@ -87,43 +94,44 @@ std::pair CrossEntropyWithLogits( xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. // Broadcasts along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); // exp(logits - max_logits) - auto exp_shifted_logits = b->Exp(shifted_logits); + auto exp_shifted_logits = xla::Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto reduce = + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) - auto log_sum_exp = b->Log(sum_exp); + auto log_sum_exp = xla::Log(sum_exp); // sum(-labels * // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - auto mul = b->Mul(b->Neg(labels), sub); + auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = xla::Mul(xla::Neg(labels), sub); auto sum = - b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) xla::XlaOp backprop = - b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); + xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -146,12 +154,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { // check that "labels" is a matrix too. const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); auto labels = ctx->Input(1); xla::XlaOp loss, backprop; std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, type, logits, labels); + CrossEntropyWithLogits(ctx, type, xla_type, logits, labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } @@ -187,8 +196,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { int64 batch_size = logits_shape.dim_size(0); int64 depth = logits_shape.dim_size(1); - DataType logits_type = input_type(0); - DataType indices_type = input_type(1); + const DataType logits_type = input_type(0); + const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0); + const DataType indices_type = input_type(1); xla::XlaOp indices = ctx->Input(1); @@ -206,20 +216,18 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::XlaOp nan_or_zero = builder->Select( - builder->And( - builder->Le(XlaHelpers::Zero(builder, indices_type), indices), - builder->Lt(indices, XlaHelpers::IntegerLiteral( - builder, indices_type, depth))), - builder->Broadcast(XlaHelpers::Zero(builder, logits_type), - {batch_size}), - builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), - {batch_size})); - labels = builder->Add(labels, nan_or_zero, {0}); + xla::XlaOp nan_or_zero = xla::Select( + xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices), + xla::Lt(indices, XlaHelpers::IntegerLiteral( + builder, indices_type, depth))), + xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}), + xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), + {batch_size})); + labels = xla::Add(labels, nan_or_zero, {0}); xla::XlaOp loss, backprop; - std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); + std::tie(loss, backprop) = CrossEntropyWithLogits( + ctx, logits_type, xla_logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..faaf8964ff7c40d75a493b03e6b400632117cb45 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +namespace tensorflow { +namespace { + +class XlaSortOp : public XlaOpKernel { + public: + explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + context->SetOutput(0, xla::Sort(context->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index ec077924b5b5af4a573c86c8d9aeb8623bd7f801..8a8525efa186ed4aa02c494f7505f6245677e96e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -73,7 +74,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, "The product of the block dimensions must be positive")); xla::XlaOp padded = - b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: // @@ -100,7 +101,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -120,7 +121,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); xla::XlaOp permuted_reshaped_padded = - b->Transpose(reshaped_padded, permutation); + xla::Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -140,7 +141,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4c5886ee2a0f63d609f79fc690f457d93e284e3e..47d282fe9ec664bbc424793e93f778ebb13c6877 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,7 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -155,7 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e831dc30a9d3c27ec3b1494e7d8a6de836ff2a11 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -0,0 +1,88 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Operator to convert sparse representations to dense. +class SparseToDenseOp : public XlaOpKernel { + public: + explicit SparseToDenseOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + // sparse_indices + const TensorShape indices_shape = context->InputShape(0); + OP_REQUIRES(context, indices_shape.dims() <= 2, + errors::InvalidArgument( + "sparse_indices should be a scalar, vector, or matrix, " + "got shape ", + indices_shape.DebugString())); + const int64 num_elems = + indices_shape.dims() > 0 ? indices_shape.dim_size(0) : 1; + const int64 num_dims = + indices_shape.dims() > 1 ? indices_shape.dim_size(1) : 1; + + // output_shape + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + OP_REQUIRES(context, output_shape.dims() == num_dims, + errors::InvalidArgument( + "output_shape has incorrect number of elements: ", + output_shape.num_elements(), " should be: ", num_dims)); + + // sparse_values + const TensorShape sparse_values_shape = context->InputShape(2); + const int64 num_values = sparse_values_shape.num_elements(); + OP_REQUIRES( + context, + sparse_values_shape.dims() == 0 || + (sparse_values_shape.dims() == 1 && num_values == num_elems), + errors::InvalidArgument("sparse_values has incorrect shape ", + sparse_values_shape.DebugString(), + ", should be [] or [", num_elems, "]")); + + // default_value + const TensorShape default_value_shape = context->InputShape(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(default_value_shape), + errors::InvalidArgument("default_value should be a scalar.")); + + xla::XlaOp indices = context->Input(0); + xla::XlaOp sparse_values = context->Input(2); + xla::XlaOp default_value = context->Input(3); + + if (sparse_values_shape.dims() == 0 && num_elems != 1) { + sparse_values = Broadcast(sparse_values, {num_elems}); + } + xla::XlaBuilder* builder = context->builder(); + auto buffer = Broadcast(default_value, output_shape.dim_sizes()); + + auto result = XlaScatter(buffer, sparse_values, indices, + /*indices_are_vectors=*/num_dims > 1, + /*combiner=*/{}, builder); + context->SetOutput(0, builder->ReportErrorOrReturn(result)); + } +}; + +REGISTER_XLA_OP(Name("SparseToDense").CompileTimeConstInput("output_shape"), + SparseToDenseOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 8958b2e7701e62d802e37a895c14b662ecf9786a..ca74cf24507e1666070751a17fb940a3ad594695 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -98,7 +99,7 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); } } }; @@ -134,7 +135,7 @@ class SplitVOp : public XlaOpKernel { errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - // check that sizes are correct + // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; std::vector split_sizes_vec(num_split, -1); @@ -148,7 +149,7 @@ class SplitVOp : public XlaOpKernel { " number of elements as the output. Got ", split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); - // get the dimension of this split + // Get the dimension of this split. xla::Literal split_size_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); @@ -199,7 +200,7 @@ class SplitVOp : public XlaOpKernel { // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); begin[split_dim] = limits[split_dim]; } } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 0fb05a2be7b1034d6c2e864643b69647d622ede7..591e61b4c82836bc1995cd11c4c0314c9d854e50 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -144,24 +144,25 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); - xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0); + xla::XlaOp index = xla::GetTupleElement(resource->value(), 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( - {b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}))); + OP_REQUIRES_OK(ctx, + resource->SetValue(xla::Tuple( + b, {xla::DynamicUpdateSlice(ta, update, start_indices), + xla::Add(index, xla::ConstantR0(b, 1))}))); ctx->SetOutput(0, value); } @@ -197,27 +198,27 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); xla::XlaOp state = resource->value(); - xla::XlaOp ta = b->GetTupleElement(state, 0); - xla::XlaOp index = b->GetTupleElement(state, 1); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); - index = b->Sub(index, b->ConstantR0(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); + index = Sub(index, xla::ConstantR0(b, 1)); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f..a6f5769e7b7b1e550b7908caa35289cf3030120f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -15,11 +15,15 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -32,17 +36,9 @@ namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, int distance) { - return builder->Or( - builder->ShiftLeft(v, builder->ConstantR0(distance)), - builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); -} - -// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than -// building XOR out of other bitwise operators. -xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x, - const xla::XlaOp& y) { - return builder->Or(builder->And(x, builder->Not(y)), - builder->And(builder->Not(x), y)); + return xla::Or( + xla::ShiftLeft(v, xla::ConstantR0(builder, distance)), + xla::ShiftRightLogical(v, xla::ConstantR0(builder, 32 - distance))); } using ThreeFry2x32State = std::array; @@ -58,22 +54,22 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, std::array ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = builder->ConstantR0(0x1BD11BDA); + ks[2] = xla::ConstantR0(builder, 0x1BD11BDA); for (int i = 0; i < 2; ++i) { ks[i] = key[i]; x[i] = input[i]; - ks[2] = BitwiseXor(builder, ks[2], key[i]); + ks[2] = xla::Xor(ks[2], key[i]); } - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(x[1], ks[1]); + x[0] = xla::Add(x[0], ks[0]); + x[1] = xla::Add(x[1], ks[1]); // Performs a single round of the Threefry2x32 algorithm, with a rotation // amount 'rotation'. auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = builder->Add(v[0], v[1]); + v[0] = xla::Add(v[0], v[1]); v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = BitwiseXor(builder, v[0], v[1]); + v[1] = xla::Xor(v[0], v[1]); return v; }; @@ -83,36 +79,36 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(1)); + x[0] = xla::Add(x[0], ks[1]); + x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0(builder, 1)); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(2)); + x[0] = xla::Add(x[0], ks[2]); + x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0(builder, 2)); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0(3)); + x[0] = xla::Add(x[0], ks[0]); + x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0(builder, 3)); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(4)); + x[0] = xla::Add(x[0], ks[1]); + x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0(builder, 4)); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(5)); + x[0] = xla::Add(x[0], ks[2]); + x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0(builder, 5)); return x; } @@ -123,8 +119,8 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, const TensorShape& shape, double minval, double maxval) { // Split the seed into two 32-bit scalars to form a key. - auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); ThreeFry2x32State key = {seed0, seed1}; const int64 size = shape.num_elements(); @@ -133,81 +129,36 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, // Fill the generator inputs with unique counter values. ThreeFry2x32State inputs; - TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); - inputs[1] = builder->Add(inputs[0], builder->ConstantR0(half_size)); + inputs[0] = xla::Iota(builder, xla::S32, half_size); + inputs[1] = xla::Add(inputs[0], xla::ConstantR0(builder, half_size)); ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); if (size_is_odd) { - outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); + outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1}); } auto bits = - builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); + xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes()); // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit // forces the random bits into the mantissa. constexpr int kFloatBits = 32; constexpr int kMantissaBits = 23; - bits = builder->Or( - builder->ShiftRightLogical( - bits, builder->ConstantR0(kFloatBits - kMantissaBits)), - builder->ConstantR0(bit_cast(1.0f))); - auto floats = builder->BitcastConvertType(bits, xla::F32); + bits = xla::Or( + xla::ShiftRightLogical( + bits, xla::ConstantR0(builder, kFloatBits - kMantissaBits)), + xla::ConstantR0(builder, bit_cast(1.0f))); + auto floats = xla::BitcastConvertType(bits, xla::F32); // We have a floating point number in the range [1.0, 2.0). // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = builder->Sub(floats, builder->ConstantR0(1.0f)); + floats = xla::Sub(floats, xla::ConstantR0(builder, 1.0f)); // Multiply and add to shift to the range [minval, maxval). - floats = builder->Mul(floats, builder->ConstantR0(maxval - minval)); - floats = builder->Add(floats, builder->ConstantR0(minval)); + floats = xla::Mul(floats, xla::ConstantR0(builder, maxval - minval)); + floats = xla::Add(floats, xla::ConstantR0(builder, minval)); return floats; } -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, - const TensorShape& shape) { - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = b->ConstantR0(1.0); - auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); - - auto lt = b->Lt(w, b->ConstantR0(5.0)); - auto coefficient = [&](int i) { - return b->Select( - lt, - b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), - shape.dim_sizes()), - b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), - shape.dim_sizes())); - }; - w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), - b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b->Add(coefficient(i), b->Mul(p, w)); - } - return b->Mul(p, x); -} - } // namespace class StatelessRandomUniformOp : public XlaOpKernel { @@ -259,8 +210,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) - auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), - ErfInvF32(builder, uniform, shape)); + auto normal = + xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); ctx->SetOutput(0, normal); } @@ -275,4 +226,35 @@ REGISTER_XLA_OP(Name("StatelessRandomNormal") .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); +class StatelessTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::XlaOp seed = ctx->Input(1); + xla::XlaBuilder* b = ctx->builder(); + + auto uniform = + RandomUniform(b, seed, shape, std::numeric_limits::min(), 1.0); + ctx->SetOutput(0, TruncatedNormal(uniform)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessTruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessTruncatedNormalOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 55254c746e5ebaf6b468c24ab59b968bf0d6260b..c2165ccd86dfa1c119790beb20af0844fb1bbda8 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,12 +93,12 @@ class StridedSliceOp : public XlaOpKernel { xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { - slice = ctx->builder()->Rev(slice, dimensions_to_reverse); + slice = xla::Rev(slice, dimensions_to_reverse); } - slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); + slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); - slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } @@ -171,7 +172,7 @@ class StridedSliceGradOp : public XlaOpKernel { xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. - grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); + grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. gtl::InlinedVector dimensions_to_reverse; @@ -204,9 +205,9 @@ class StridedSliceGradOp : public XlaOpKernel { } } if (!dimensions_to_reverse.empty()) { - grad = ctx->builder()->Rev(grad, dimensions_to_reverse); + grad = xla::Rev(grad, dimensions_to_reverse); } - grad = ctx->builder()->Pad(grad, zero, padding_config); + grad = xla::Pad(grad, zero, padding_config); ctx->SetOutput(0, grad); } @@ -306,17 +307,17 @@ class StridedSliceAssignOp : public XlaOpKernel { } if (!dimensions_to_reverse.empty()) { - rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + rhs = xla::Rev(rhs, dimensions_to_reverse); } - rhs = ctx->builder()->Reshape(rhs, slice_dims); + rhs = xla::Reshape(rhs, slice_dims); if (lhs_shape.dims() == 0) { // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix // and remove this workaround. lhs = rhs; } else { - lhs = ctx->builder()->DynamicUpdateSlice( - lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); + lhs = xla::DynamicUpdateSlice( + lhs, rhs, xla::ConstantR1(ctx->builder(), slice_begin)); } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9adee78a1fd1fb9a12afae83197425c328b5fe7e..2f650ce3052ee4502912891cd3f60cfaec8b1d7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, const gtl::ArraySlice& update_dims, const xla::XlaOp& start_indices) { - xla::XlaOp current = - builder->DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = builder->Add(current, update); - return builder->DynamicUpdateSlice(operand, sum, start_indices); + xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); + xla::XlaOp sum = xla::Add(current, update); + return xla::DynamicUpdateSlice(operand, sum, start_indices); } class TensorArrayOp : public XlaOpKernel { @@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel { ta_shape.AddDim(size); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); - value = b->Broadcast(zero, ta_shape.dim_sizes()); + value = xla::Broadcast(zero, ta_shape.dim_sizes()); } XlaContext& xc = XlaContext::Get(ctx); @@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); @@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel { // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: @@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel { for (auto i = 1; i < ta_shape.dims(); i++) { end[i] = ta_shape.dim_size(i); } - ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + ctx->SetOutput(0, xla::Slice(ta, begin, end, strides)); return; } } @@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = b->Add(ta, value); + ta = xla::Add(ta, value); } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel { // Slice out part of the value. value_starts[0] = i; value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto index = xla::Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel { auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); shape[0] *= ta_shape.dim_size(0); - ctx->SetOutput(0, b->Reshape(ta, shape)); + ctx->SetOutput(0, xla::Reshape(ta, shape)); Tensor lengths(DT_INT64, {ta_dims[0]}); auto lengths_vec = lengths.vec(); @@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( - ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( + ta, xla::Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index e91075196bd8414939888e22b5483ad637487af6..c9e56942625a009fb3660f413a845547192460d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -93,9 +94,9 @@ class TileOp : public XlaOpKernel { if (one_dimension_is_broadcasted_without_multiple) { // Create a constant Zero the size of the output shape to leverage binary // operation broadcast semantics. - auto broadcasted_zero = ctx->builder()->Broadcast( + auto broadcasted_zero = xla::Broadcast( XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input)); + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); return; } @@ -103,7 +104,7 @@ class TileOp : public XlaOpKernel { // dimension. This prepends the broadcasted dimensions, so an // input of shape [2,3,1] broadcast with multiples [5,4,3] will // end up with shape [5,4,3,2,3,1]. - auto broadcasted = ctx->builder()->Broadcast(input, multiples_array); + auto broadcasted = xla::Broadcast(input, multiples_array); // Now flatten and reshape. The broadcasted dimensions are // paired with the original dimensions so in the above example // we flatten [0,3,1,4,2,5] then reshape to [10,12,3]. @@ -112,8 +113,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::XlaOp output = - ctx->builder()->Reshape(broadcasted, flattened, output_shape); + xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9962f1207d65edea5eba0083436fa380921bb4fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class TopKOp : public XlaOpKernel { + public: + explicit TopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); + } + + void Compile(XlaOpKernelContext* context) override { + int64 k; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(1, &k)); + OP_REQUIRES(context, k >= 0, + errors::InvalidArgument("Need k >= 0, got ", k)); + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be >= 1-D, got shape ", + input_shape.DebugString())); + OP_REQUIRES( + context, input_shape.dim_size(input_shape.dims() - 1) >= k, + errors::InvalidArgument("input must have at least k columns. Had ", + input_shape.dim_size(input_shape.dims() - 1), + ", needed ", k)); + + OP_REQUIRES( + context, input_shape.dims() == 1, + errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", + input_shape.DebugString())); + + xla::XlaBuilder* const b = context->builder(); + if (input_shape.dim_size(0) < k) { + k = input_shape.dim_size(0); + } + const xla::XlaOp input = context->Input(0); + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0)); + xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32); + xla::XlaOp values = + xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1})); + xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1}); + context->SetOutput(0, values); + context->SetOutput(1, indices); + } + + private: + bool sorted_; +}; + +REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint( + "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}), + TopKOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 34caefa050c0d58f5f7bad557286b6ed64b996ad..bef6161e8547dcc84d20b29aa74d6ef50045970b 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -31,7 +33,6 @@ class ResourceApplyGradientDescent : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp handle; - xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -48,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); + handle = handle - ctx->Input(1) * ctx->Input(2); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -56,6 +57,64 @@ REGISTER_XLA_OP( Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), ResourceApplyGradientDescent); +xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, + xla::XlaOp l1, xla::XlaOp l2, + xla::XlaOp grad) { + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp prox_var = var - grad * lr; + xla::XlaOp l1_gt_zero = xla::Sign(prox_var) * + xla::Max(xla::Abs(prox_var) - lr * l1, zero) / + (one + lr * l2); + xla::XlaOp l1_le_zero = prox_var / (one + lr * l2); + return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); +} + +class ResourceApplyProximalGradientDescent : public XlaOpKernel { + public: + explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp var; + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape delta_shape = ctx->InputShape(4); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + xla::XlaOp alpha = ctx->Input(1); + xla::XlaOp l1 = ctx->Input(2); + xla::XlaOp l2 = ctx->Input(3); + xla::XlaOp delta = ctx->Input(4); + var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent") + .TypeConstraint("T", kFloatTypes), + ResourceApplyProximalGradientDescent); + class ResourceApplyMomentum : public XlaOpKernel { public: explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -63,8 +122,6 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -97,14 +154,13 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = b->Add(b->Mul(accum, momentum), grad); + accum = accum * momentum + grad; if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = b->Sub( - var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr))); + var = var - (grad * lr + accum * momentum * lr); } else { - var = b->Sub(var, b->Mul(accum, lr)); + var = var - accum * lr; } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -121,8 +177,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -149,10 +203,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = b->Sub( - var, b->Mul(b->Mul(grad, lr), - b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = accum + xla::Square(grad); + var = var - grad * lr * xla::Rsqrt(accum); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -160,6 +212,62 @@ class ResourceApplyAdagrad : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), ResourceApplyAdagrad); +class ResourceApplyProximalAdagrad : public XlaOpKernel { + public: + explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(5); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape: ", + var_shape.DebugString(), " vs ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp l1 = ctx->Input(3); + xla::XlaOp l2 = ctx->Input(4); + xla::XlaOp grad = ctx->Input(5); + accum = accum + xla::Square(grad); + // Adagrad learning rate. + xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum); + var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes), + ResourceApplyProximalAdagrad); + class ResourceApplyAdam : public XlaOpKernel { public: explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -227,17 +335,12 @@ class ResourceApplyAdam : public XlaOpKernel { // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::XlaOp alpha = - b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), - b->Sub(one, beta1_power)); - m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); - v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); - var = - b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power); + m = m + (grad - m) * (one - beta1); + v = v + (xla::Square(grad) - v) * (one - beta2); + var = var - m * alpha / (xla::Sqrt(v) + epsilon); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -255,8 +358,6 @@ class ResourceApplyRMSProp : public XlaOpKernel { explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(3); TensorShape var_shape, ms_shape, mom_shape; @@ -320,16 +421,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = b->Add( - ms, - b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), - b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); + xla::XlaOp new_ms = + ms + (xla::Square(grad) - ms) * (xla::ScalarLike(ms, 1.0) - rho); xla::XlaOp new_mom = - b->Add(b->Mul(mom, momentum), - b->Mul(b->Mul(grad, lr), - b->Pow(b->Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = b->Sub(var, new_mom); + mom * momentum + grad * lr * xla::Rsqrt(new_ms + epsilon); + xla::XlaOp new_var = var - new_mom; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -424,21 +520,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); + grad_to_use = grad + two * l2_shrinkage * var; } else { grad_to_use = grad; } - xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); - xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); - linear = b->Add( - linear, - b->Sub(grad_to_use, - b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); - var = b->Div(b->Sub(linear_clipped, linear), quadratic); + xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); + xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); + linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; + xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); + xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; + var = (linear_clipped - linear) / quadratic; accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); @@ -478,5 +571,74 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), ResourceApplyFtrlV2); +class ResourceApplyAdadelta : public XlaOpKernel { + public: + explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, accum_update_shape; + xla::XlaOp var, accum, accum_update; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape, + &accum_update)); + + TensorShape lr_shape = ctx->InputShape(3); + TensorShape rho_shape = ctx->InputShape(4); + TensorShape epsilon_shape = ctx->InputShape(5); + TensorShape grad_shape = ctx->InputShape(6); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), + errors::InvalidArgument("rho is not a scalar: ", + rho_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp epsilon = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + accum = rho * accum + (one - rho) * xla::Pow(grad, two); + xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * + xla::Pow(accum + epsilon, neg_half) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + var = var - update * lr; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), + ResourceApplyAdadelta); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c167642174b328a968d7f7ce1f0ad6e0ab8a7a68..6c721c48fe3af45aff5cd0bd5e74e2693faf9f97 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -32,7 +33,8 @@ namespace { class TransposeOp : public XlaOpKernel { public: - explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false) + : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); @@ -78,19 +80,37 @@ class TransposeOp : public XlaOpKernel { errors::InvalidArgument(i, " is missing from 'perm' argument.")); } + xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - ctx->SetOutput(0, ctx->Input(0)); - return; + transposed = ctx->Input(0); + } else { + transposed = xla::Transpose(ctx->Input(0), transposed_order); } - ctx->SetOutput(0, - ctx->builder()->Transpose(ctx->Input(0), transposed_order)); + // Conjugate the transposed result if this is ConjugateTransposeOp. + if (conjugate_) { + ctx->SetOutput(0, xla::Conj(transposed)); + } else { + ctx->SetOutput(0, transposed); + } } + + private: + const bool conjugate_; +}; + +class ConjugateTransposeOp : public TransposeOp { + public: + explicit ConjugateTransposeOp(OpKernelConstruction* ctx) + : TransposeOp(ctx, /*conjugate=*/true) {} }; REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); +REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstInput("perm"), + ConjugateTransposeOp); + // InvertPermutation frequently forms part of the gradient of Transpose. // // inv = InvertPermutationOp(T p) takes a permutation of @@ -127,7 +147,7 @@ class InvertPermutationOp : public XlaOpKernel { output[d] = i; } - ctx->SetOutput(0, ctx->builder()->ConstantR1(output)); + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 71a9fd051bfc8db09738a4bfe8ddde447895ecf0..116a020437e263f1d3d82fee5c0ea0ca4f97e634 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -16,24 +16,26 @@ limitations under the License. // Native XLA implementations of simple unary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { namespace { -// A subclass of a TlaUnaryOp must build the lambda computation that -// describes the scalar->scalar function to apply to each element of -// the input. #define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \ class NAME##Op : public XlaOpKernel { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ xla::XlaOp x = ctx->Input(0); \ xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ @@ -41,122 +43,100 @@ namespace { }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op); -XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); +XLAJIT_MAKE_UNARY(ComplexAbs, xla::Abs(x)); -XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); +XLAJIT_MAKE_UNARY(Angle, xla::Atan2(xla::Imag(x), xla::Real(x))); -XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); +XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); // Return x if x>0, otherwise -x. -XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) -XLAJIT_MAKE_UNARY( - Acos, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - b->Add(XlaHelpers::One(b, input_type(0)), x)))); +XLAJIT_MAKE_UNARY(Acos, + xla::ScalarLike(x, 2.0) * + xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x), + xla::ScalarLike(x, 1.0) + x)); // acosh(x) = log(x + sqrt(x^2 - 1)) // = log(x + sqrt((x+1)*(x-1))) -XLAJIT_MAKE_UNARY( - Acosh, - b->Log(b->Add(x, - b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), - b->Sub(x, XlaHelpers::One(b, input_type(0)))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); +XLAJIT_MAKE_UNARY(Acosh, + xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) * + (x - xla::ScalarLike(x, 1.0))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( - Asin, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), - b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), - 0.5)))))); + Asin, xla::ScalarLike(x, 2.0) * + xla::Atan2(x, xla::ScalarLike(x, 1.0) + + xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x))); // asinh(x) = log(x + sqrt(x^2 + 1)) -XLAJIT_MAKE_UNARY( - Asinh, - b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); +XLAJIT_MAKE_UNARY(Asinh, + xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0)))); -XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0))); // atanh(x) = 0.5 * log((1 + x) / (1 - x)) +XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) / + (xla::ScalarLike(x, 1.0) - x)) * + xla::ScalarLike(x, 0.5)); +XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); +XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5)); +XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); +XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); + +XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); + +XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); +XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); XLAJIT_MAKE_UNARY( - Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), - b->Sub(XlaHelpers::One(b, input_type(0)), x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); -XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); -XLAJIT_MAKE_UNARY(Cosh, - b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); -XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); - -XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); - -XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); -XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); -XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x), - XlaHelpers::FloatLiteral( - b, input_type(0), - std::numeric_limits::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x)); + IsInf, + xla::Eq(xla::Abs(x), + xla::ScalarLike(x, std::numeric_limits::infinity()))); +XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Log, b->Log(x)); +XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); +XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); +XLAJIT_MAKE_UNARY(Log, xla::Log(x)); -XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); +XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); -XLAJIT_MAKE_UNARY(Invert, b->Not(x)); -XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); -XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); +XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); +XLAJIT_MAKE_UNARY(Neg, -x); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); - auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - - auto round_val = b->Floor(x); - auto fraction = b->Sub(x, round_val); - auto nearest_even_int = - b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); - auto is_odd = b->Eq(nearest_even_int, one); - return b->Select( - b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); +xla::XlaOp RoundToEven(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); + + auto round_val = xla::Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + round_val + one, round_val); } -XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Rsqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); +XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); +xla::XlaOp Sigmoid(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + return half + half * xla::Tanh(half * x); } -XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); -XLAJIT_MAKE_UNARY(Sinh, - b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); +XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5)); // softplus(x) = log(1 + exp(x)) // @@ -166,24 +146,48 @@ XLAJIT_MAKE_UNARY(Sinh, // // This is equivalent to: // max(x, 0) + log1p(exp(-abs(x))) -XLAJIT_MAKE_UNARY(Softplus, - b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), - b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); +XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) + + xla::Log1p(xla::Exp(-xla::Abs(x)))); // softsign(x) = x / (abs(x) + 1) -XLAJIT_MAKE_UNARY(Softsign, - b->Div(x, - b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0))))); -XLAJIT_MAKE_UNARY(Sqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); -XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); -XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); - -XLAJIT_MAKE_UNARY(Real, b->Real(x)); -XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); +XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0))); +XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x)); +XLAJIT_MAKE_UNARY(Square, x* x); +XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x)); +XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); + +XLAJIT_MAKE_UNARY(Real, xla::Real(x)); +XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); #undef XLAJIT_MAKE_UNARY +// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial +// is used outside of this range. +class ErfOp : public XlaOpKernel { + public: + explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp x = ctx->Input(0); + xla::XlaOp one = xla::ScalarLike(x, 1.0); + auto y = + xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erf"), ErfOp); + +class ErfcOp : public XlaOpKernel { + public: + explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp x = ctx->Input(0); + xla::XlaOp one = xla::ScalarLike(x, 1.0); + auto y = + xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erfc"), ErfcOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f87586ba578a6138e7fb921032e1a71f8c9ac80c..0e5d58ecbaeb13571f82a1311e29dc0ba91c11ac 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -74,10 +75,9 @@ class UnpackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { start_indices[axis] = i; limit_indices[axis] = i + 1; - auto slice = ctx->builder()->Slice(input, start_indices, limit_indices, - strides); + auto slice = xla::Slice(input, start_indices, limit_indices, strides); // Reshape to drop the 'axis' dimension. - auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); + auto result = xla::Reshape(slice, output_shape.dim_sizes()); ctx->SetOutput(i, result); } } diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 6109db8e89e5ee67e0635d26e258bfe7cb70a15d..febac8287350e32fccfd4cb5613f21b9a5fbcb95 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { @@ -35,12 +33,33 @@ class VarIsInitializedOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { XlaResource* variable; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); - ctx->SetOutput(0, - ctx->builder()->ConstantR0(variable->initialized())); + ctx->SetOutput( + 0, xla::ConstantR0(ctx->builder(), variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); + class ReadVariableOp : public XlaOpKernel { public: explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -57,7 +76,7 @@ class ReadVariableOp : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); +REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); class AssignVariableOp : public XlaOpKernel { public: @@ -67,7 +86,7 @@ class AssignVariableOp : public XlaOpKernel { ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); } }; -REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); +REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); class AssignAddVariableOp : public XlaOpKernel { public: @@ -77,7 +96,7 @@ class AssignAddVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Add(handle, ctx->Input(1)); + handle = xla::Add(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -93,7 +112,7 @@ class AssignSubVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Sub(handle, ctx->Input(1)); + handle = xla::Sub(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -125,29 +144,152 @@ class ResourceGatherOp : public XlaOpKernel { ctx->SetOutput(0, gather); } }; -REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), - ResourceGatherOp); +REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); -class VariableShapeOp : public XlaOpKernel { +class ResourceScatterOp : public XlaOpKernel { public: - explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + explicit ResourceScatterOp( + OpKernelConstruction* context, bool indices_are_vectors, + std::function + combiner) + : XlaOpKernel(context), + indices_are_vectors_(indices_are_vectors), + combiner_(std::move(combiner)) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + + DataType dtype = context->input_type(2); + TensorShape var_shape; + xla::XlaOp var_value; + OP_REQUIRES_OK( + context, context->ReadVariableInput(0, dtype, &var_shape, &var_value)); + + const xla::XlaOp indices = context->Input(1); + const xla::XlaOp updates = context->Input(2); + + auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_, + combiner_, builder); + OP_REQUIRES_OK(context, result.status()); + OP_REQUIRES_OK(context, + context->AssignVariable(0, dtype, result.ValueOrDie())); } - void Compile(XlaOpKernelContext* ctx) override { - DataType variable_dtype; - TensorShape shape; - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); - Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); - OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); - ctx->SetConstantOutput(0, shape_constant); + private: + const bool indices_are_vectors_; + const std::function + combiner_; +}; + +class ResourceScatterAddOp : public ResourceScatterOp { + public: + explicit ResourceScatterAddOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); } +}; +REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp); + +class ResourceScatterSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - DataType out_dtype_; + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } }; +REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp); + +class ResourceScatterMulOp : public ResourceScatterOp { + public: + explicit ResourceScatterMulOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Mul(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp); + +class ResourceScatterDivOp : public ResourceScatterOp { + public: + explicit ResourceScatterDivOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Div(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp); + +class ResourceScatterMinOp : public ResourceScatterOp { + public: + explicit ResourceScatterMinOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Min(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp); + +class ResourceScatterMaxOp : public ResourceScatterOp { + public: + explicit ResourceScatterMaxOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Max(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp); + +class ResourceScatterUpdateOp : public ResourceScatterOp { + public: + explicit ResourceScatterUpdateOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, + /*combiner=*/{}) {} +}; +REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp); + +class ResourceScatterNdUpdateOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/{}) {} +}; +REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp); + +class ResourceScatterNdAddOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdAddOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); -REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 5467c5d9946846ff9f14ce9c5aac9e2be4b9d6ab..340165bac6a2a214d8f84d5a116a4197b1df2c7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -246,7 +246,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp init = builder->Tuple(inputs); + xla::XlaOp init = xla::Tuple(builder, inputs); VLOG(1) << "Building while loop"; @@ -255,22 +255,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { { std::unique_ptr cb = builder->CreateSubBuilder("cond_wrapper"); - auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); - auto outputs = cb->Call(*cond.computation, {inputs}); - cb->GetTupleElement(outputs, 0); + auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs"); + auto outputs = xla::Call(cb.get(), *cond.computation, {inputs}); + xla::GetTupleElement(outputs, 0); xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::XlaOp while_result = - builder->While(cond_wrapper, *body.computation, init); + xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], - builder->GetTupleElement(while_result, i)); + xla::GetTupleElement(while_result, i)); } } @@ -284,7 +283,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), builder)); + xla::GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index ee7f5d510ab7a3ce7d3bbe843c5fefd362f79b7b..dfa3c0595acbfeb35f944209b4354b357b11bf3c 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -44,12 +44,28 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) +cc_library( + name = "random", + srcs = ["random.cc"], + hdrs = ["random.h"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "scatter", srcs = ["scatter.cc"], @@ -81,6 +97,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 526694d5a0c7124e1696f34b516f3b202462bc19..f9f3a8c8cfcbcd0a2ac853360c629d90c94db8b0 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -25,91 +26,94 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, - bool conjugate_y) { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, bool conjugate_y) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(x_shape), " vs. ", xla::ShapeUtil::HumanString(y_shape)); } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(x_shape) || - xla::ShapeUtil::HasZeroElements(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + const int ndims = xla::ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::Literal::Zero(x_shape.element_type())), + dimensions); + } + + if (x_shape.element_type() == xla::C64 && conjugate_x) { + x = xla::Conj(x); + } + if (y_shape.element_type() == xla::C64 && conjugate_y) { + y = xla::Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose + // HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; + return xla::Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = builder->Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = builder->Conj(y); - } - - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); - } - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - return builder->DotGeneral(x, y, dot_dnums); + return xla::DotGeneral(x, y, dot_dnums); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 1acc72033b05e73b0f5f88907df20cde5cfffbf0..d07a9486f18c0b8f26782123a8fba4ba228f71ee 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,10 +43,9 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x = false, - bool conjugate_y = false); +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 3f1384bc864abd882ebba2b90acbe0b1e664687a..cc840de393ebc2983ddf7659c6c18d8136de5dd6 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -47,179 +49,163 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, - const xla::XlaOp& a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); - - xla::XlaOp l = Zeros(builder, a_shape); - - // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = body_builder->ConstantR1(mask_vector); - auto mask_range_row = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); - auto mask_range_col = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - - // row = l[..., i, :i] - // select the whole i-th row, then mask out all columns past i-1 - auto zero = body_builder->ConstantR0(0); - TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, - {i, zero}, {1, n})); - auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), - mask_zeros_row, l_i); - // a[..., i, i] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, - {i, i}, {1, 1})); - // np.dot(row, np.swapaxes(row, -1, -2)) - xla::XlaOp diag_dot; - TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, - // np.swapaxes(row, -1, -2))) - auto l_ii = body_builder->Pow( - body_builder->Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); - - // a[..., i+1:, i] - auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); - // select the whole i-th column, then mask out all rows above i+1 +xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); + + xla::XlaOp l = xla::ZerosLike(a); + + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); + } + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range_row = + xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = + xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = xla::ConstantR0(body_builder, 0); + auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); + auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + // a[..., i, i] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + // np.dot(row, np.swapaxes(row, -1, -2)) + auto diag_dot = BatchDot(row, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = + xla::Pow(a_ii - diag_dot, + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + // select the whole i-th column, then mask out all rows above i+1 + auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); + auto a_ip1i = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + auto dot = BatchDot(body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + + body_l = + DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); + + return std::vector{body_a, body_l}; + }; + TF_ASSIGN_OR_RETURN( - auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); - auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, a_0i); - - // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / - // l[..., i, i] - // The columns in [i, n] are zeroed out in `row`, so we just have to - // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], - // r.T) - TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, dot); - - auto col_update = - body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, col_update, {i})); - // Assign the diagonal after the rest of the column because otherwise the - // column assign will wrap around and overwrite the diagonal assign. - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, l_ii, {i, i})); - - return std::vector{body_a, body_l}; - }; - - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); - - return cholesky_while[1]; + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; + }); } } // namespace -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); - } - - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); - } - - // Blocked left-looking Cholesky factorization. - // Algorithm 1 from - // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only - // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = Zeros(builder, a_shape); - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - if (i > 0) { - // TODO(phawkins): consider implementing SYRK for the diagonal part of - // the panel. - // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) - TF_ASSIGN_OR_RETURN(auto lhs, - SliceInMinorDims(builder, l, {i, 0}, {n, i})); - TF_ASSIGN_OR_RETURN(auto rhs, - SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); - TF_ASSIGN_OR_RETURN(auto delta, - BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto before, - SliceInMinorDims(builder, a, {i, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN( - a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), - {i, i})); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); } - // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - TF_ASSIGN_OR_RETURN(auto x, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, factorized, {i, i})); - - if (i + k < n) { - // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - TF_ASSIGN_OR_RETURN(auto panel, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto update, - TriangularSolve(builder, factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size)); - TF_ASSIGN_OR_RETURN( - l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for + // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::XlaOp l = xla::ZerosLike(a); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); + auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); + auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true); + auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); + a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto factorized = CholeskyUnblocked(x); + l = UpdateSliceInMinorDims(l, factorized, {i, i}); + + if (i + k < n) { + // l[i+k:, i:i+k] = + // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + auto update = TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, + /*block_size=*/block_size); + l = UpdateSliceInMinorDims(l, update, {i + k, i}); + } } - } - return l; + return l; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 20fca7969ece2729a44933fd3ef3f87230ab6cad..0f6e0e9d152ec5daedeb9c0e355bfb9731759094 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,8 +30,7 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ff10fbd3fbf9308140af84c752a5a50bec8fd32 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/random.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +xla::XlaOp TruncatedNormal(xla::XlaOp uniform) { + auto normal_cdf = [](double x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; + }; + + const double kA = -2.0; + const double kB = 2.0; + const double kMu = 0.0; + const double kSigma = 1.0; + const double kAlpha = (kA - kMu) / kSigma; + const double kBeta = (kB - kMu) / kSigma; + const double kAlphaNormalCdf = normal_cdf(kAlpha); + const double kBetaNormalCdf = normal_cdf(kBeta); + const double kZ = kBetaNormalCdf - kAlphaNormalCdf; + + xla::XlaOp one = xla::ScalarLike(uniform, 1.0); + xla::XlaOp two = xla::ScalarLike(uniform, 2.0); + xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0)); + xla::XlaOp z = xla::ScalarLike(uniform, kZ); + xla::XlaOp alpha_normal_cdf = xla::ScalarLike(uniform, kAlphaNormalCdf); + + auto p = alpha_normal_cdf + z * uniform; + // probit(p) = sqrt(2) * erfinv(2*p-1) + return sqrt_2 * xla::ErfInv(two * p - one); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h new file mode 100644 index 0000000000000000000000000000000000000000..2c573fd85b2783fdac13457cdb277cf988ac40c4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +// Builds an array filled with values sampled from a truncated normal +// distribution such that no values are greater than two or less than negative +// two. +// +// The "uniform" parameter must be an array of random numbers distributed in +// (0,1). +xla::XlaOp TruncatedNormal(xla::XlaOp uniform); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index d5a27abb2585f699ae2719cb8a6b9a829263389e..85e3d3ab85a89615cc5a01bdb4ec8f7fec30d58e 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -97,8 +98,8 @@ xla::StatusOr XlaScatter( buffer_shape_post_axes.end()); // Construct the initial values of the loop-carried Tensors. - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto flat_indices = xla::Reshape(indices, flat_indices_shape); + auto flat_updates = xla::Reshape(updates, flat_updates_shape); auto init = {flat_indices, flat_updates, buffer}; // Constructs the loop body. The implementation of scatter is essentially: @@ -112,46 +113,44 @@ xla::StatusOr XlaScatter( auto updates = loop_vars[1]; auto buffer = loop_vars[2]; - auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape.element_type())); + auto zero_index = xla::ConstantLiteral( + body_builder, xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; - auto indices_offset = body_builder->Reshape(i, {1}); + auto indices_offset = xla::Reshape(i, {1}); if (indices_are_vectors) { - indices_offset = body_builder->Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); + indices_offset = xla::Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); - index = body_builder->DynamicSlice(indices, indices_offset, - {1, num_index_dims}); - index = body_builder->Collapse(index, {0, 1}); + index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); + index = xla::Collapse(index, {0, 1}); } else { - index = body_builder->DynamicSlice(indices, indices_offset, {1}); + index = xla::DynamicSlice(indices, indices_offset, {1}); } // Discard updates with negative indices, since some users expect this. - auto index_in_range = - body_builder->ReduceAll(body_builder->Le(zero_index, index), - body_builder->ConstantR0(true), - xla::CreateScalarAndComputation(body_builder)); + auto index_in_range = xla::ReduceAll( + xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), + xla::CreateScalarAndComputation(body_builder)); // Make the index in bounds to prevent implementation defined behavior. - index = body_builder->Max(index, zero_index); - index = body_builder->Pad( + index = xla::Max(index, zero_index); + index = xla::Pad( index, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); // Slice the i-th index from the updates array. - auto updates_offset = body_builder->Reshape(i, {1}); - updates_offset = body_builder->Pad( + auto updates_offset = xla::Reshape(i, {1}); + updates_offset = xla::Pad( updates_offset, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); std::vector flat_updates_slice_shape({1}); flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - auto update = body_builder->DynamicSlice(updates, updates_offset, - flat_updates_slice_shape); + auto update = + xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); // Unflatten the major (iteration) dimensions of the slice to their // original shape. @@ -159,20 +158,19 @@ xla::StatusOr XlaScatter( updates_slice_shape.insert(updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - update = body_builder->Reshape(update, updates_slice_shape); + update = xla::Reshape(update, updates_slice_shape); // Apply the update to the buffer. If there is a combiner, use it to merge // the current values with the update. - auto current_value = - body_builder->DynamicSlice(buffer, index, updates_slice_shape); + auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); if (combiner) { update = combiner(current_value, update, body_builder); } // Use the current value instead of the update if the index is out of // bounds. - update = body_builder->Select(index_in_range, update, current_value); + update = xla::Select(index_in_range, update, current_value); // Apply the update. - buffer = body_builder->DynamicUpdateSlice(buffer, update, index); + buffer = xla::DynamicUpdateSlice(buffer, update, index); return std::vector{indices, updates, buffer}; }; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b4503601f94baa5a595a64c9fc81bc92d9980ac6..588afaac65122fbdc6fe9a399a7a50a3a49749cb 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -29,619 +31,564 @@ limitations under the License. namespace tensorflow { -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(a_shape), " vs. ", xla::ShapeUtil::HumanString(b_shape)); } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - std::map base_computations; - auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::XlaComputation& computation = base_computations[k]; - if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( - tensorflow::strings::StrCat("trsm_base_", k)); - - auto a_param = sub->Parameter( - 0, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, {k, k})), - "a"); - - std::array b_lastd; - if (left_side) { - b_lastd = {k, n}; - } else { - b_lastd = {m, k}; - } - auto b_param = sub->Parameter( - 1, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), - "b"); - - // We use a left-looking or right-looking subroutine on the block diagonal - // in the lower=true cases, while falling back to a recursive call in - // others. The left-looking and right-looking subroutines are written with - // a While loop and so yields much faster compile times. Moreover, they - // can give higher performance on smaller (sub)problems. - if (left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else if (!left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else { - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - left_side, lower, transpose_a, - conjugate_a, - /*block_size=*/1) - .status()); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } + batch_dimensions.push_back(a_size); + } - TF_ASSIGN_OR_RETURN(computation, sub->Build()); + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } - return &computation; - }; - - xla::XlaOp output = Zeros(builder, b_shape); - - // Right-looking blocked triangular solve. - // For an explanation of the algorithm, see the TRSM discussion in: - // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation - // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 - // (2008): 4. - - // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if - // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if - // conjugate_a is True. - - if (!left_side && lower == transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) - if (i + k < n) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); - } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); - } + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); } - } else if (left_side && lower != transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < m; i += block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) - if (i + k < m) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); + std::map base_computations; + auto get_base_triangular_solve = + [&](int k) -> xla::StatusOr { + xla::XlaComputation& computation = base_computations[k]; + if (computation.IsNull()) { + std::unique_ptr sub = builder->CreateSubBuilder( + tensorflow::strings::StrCat("trsm_base_", k)); + + auto a_param = xla::Parameter( + sub.get(), 0, + xla::ShapeUtil::MakeShape(b_shape.element_type(), + ConcatVectors(batch_dimensions, {k, k})), + "a"); + + std::array b_lastd; + if (left_side) { + b_lastd = {k, n}; + } else { + b_lastd = {m, k}; + } + auto b_param = xla::Parameter( + sub.get(), 1, + xla::ShapeUtil::MakeShape(b_shape.element_type(), + ConcatVectors(batch_dimensions, b_lastd)), + "b"); + + // We use a left-looking or right-looking subroutine on the block + // diagonal in the lower=true cases, while falling back to a recursive + // call in others. The left-looking and right-looking subroutines are + // written with a While loop and so yields much faster compile times. + // Moreover, they can give higher performance on smaller (sub)problems. + if (left_side && lower) { + TriangularSolveLeftLooking(a_param, b_param, transpose_a, + conjugate_a); + } else if (!left_side && lower) { + TriangularSolveRightLooking(a_param, b_param, transpose_a, + conjugate_a); } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); + TriangularSolve(a_param, b_param, left_side, lower, transpose_a, + conjugate_a, + /*block_size=*/1); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); + TF_ASSIGN_OR_RETURN(computation, sub->Build()); } - } - } else if (!left_side && lower != transpose_a) { - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + return &computation; + }; + + xla::XlaOp output = xla::ZerosLike(b); + + // Right-looking blocked triangular solve. + // For an explanation of the algorithm, see the TRSM discussion in: + // Goto, Kazushige, and Robert Van De Geijn. "High-performance + // implementation of the level-3 BLAS." ACM Transactions on Mathematical + // Software (TOMS) 35.1 (2008): 4. + + // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if + // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if + // conjugate_a is True. + + if (!left_side && lower == transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {0, i}); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) + if (i + k < n) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + } else { + a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, n}); + } + + auto b_update = BatchDot(update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + auto b_slice_2 = SliceInMinorDims(b, {0, i + k}, {m, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, i + k}); } + } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } else if (left_side && lower != transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < m; i += block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); + } else { + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {i, 0}); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) + if (i + k < m) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i + k, i}, {m, i + k}); + } else { + a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, m}); + } + + auto b_update = BatchDot(a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto b_slice_2 = SliceInMinorDims(b, {i + k, 0}, {m, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {i + k, 0}); + } } - } - } else { // left_side && lower == transpose_a - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); + } else if (!left_side && lower != transpose_a) { + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = + xla::RoundUpToNearest(n, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); + } else { + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {0, i}); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) + if (i - k >= 0) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i}); + } else { + a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k}); + } + + auto b_update = BatchDot(update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {m, i}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0}); + } } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { // left_side && lower == transpose_a + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = + xla::RoundUpToNearest(m, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {i, 0}); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) + if (i - k >= 0) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i}); + } else { + a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k}); + } + + auto b_update = BatchDot(a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {i, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0}); } - - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); } } - } - return output; + return output; + }); } -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - - // Allocate the output and set its first or last row, - // output = np.zeros_like(b) - // if transpose_a: - // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] - // else: - // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::XlaOp output = Zeros(builder, b_shape); - { - auto i = transpose_a ? m - 1 : 0; - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = builder->Div(b_slice, a_slice_conj); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - } - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (m-2, output, a, b) - // else: - // init = (1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); - auto init = builder->Tuple({init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); - { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), - 0); - if (transpose_a) { - condb->Ge(i, condb->ConstantR0(0)); - } else { - condb->Lt(i, condb->ConstantR0(m)); +xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) - // else: - // a_row = a[..., i:i+1, :i] - // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) - // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); - { - auto input_tuple = bodyb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); - // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0(0); + // The main computation is performed in a While loop. - // We'd like to implement this: - // if transpose_a: - // a_row = T(a[..., i+1:, i:i+1]) - // result_row = (b[..., i:i+1, :] - // - np.matmul(a_row, body_out[..., i+1:, :])) - // else: - // result_row = (b[..., i:i+1, :] - // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) - // But since we can't have intermediate array sizes depend on the loop - // counter, we instead exploit the fact that we initialized the output to - // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::XlaOp a_row; - if (transpose_a) { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {zero, i}, {m, 1})); - } else { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, zero}, {1, m})); + // Allocate the output and set its first or last row, + // output = np.zeros_like(b) + // if transpose_a: + // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] + // else: + // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] + xla::XlaOp output = xla::ZerosLike(b); + { + auto i = transpose_a ? m - 1 : 0; + auto a_slice = SliceInMinorDims(a, {i, i}, {i + 1, i + 1}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + 1, n}); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + auto update = b_slice / a_slice_conj; + output = UpdateSliceInMinorDims(output, update, {i, 0}); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN( - auto result_row_slice, - DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = bodyb->Sub(result_row_slice, b_update); - - // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_elt_conj, - MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_elt_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {i, zero})); + // Construct the initial loop carry tuple, // if transpose_a: - // return (i - 1, body_out, a, b) + // init = (m-2, output, a, b) // else: - // return (i + 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? -1 : 1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + // init = (1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = xla::ConstantR0(builder, transpose_a ? m - 2 : 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i >= 0 if transpose_a else i < m + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); + { + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), + 0); + if (transpose_a) { + xla::Ge(i, xla::ConstantR0(condb.get(), 0)); + } else { + xla::Lt(i, xla::ConstantR0(condb.get(), m)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) + // else: + // a_row = a[..., i:i+1, :i] + // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) + // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop + // counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); + { + auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0(bodyb.get(), 0); + + // We'd like to implement this: + // if transpose_a: + // a_row = T(a[..., i+1:, i:i+1]) + // result_row = (b[..., i:i+1, :] + // - np.matmul(a_row, body_out[..., i+1:, :])) + // else: + // result_row = (b[..., i:i+1, :] + // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) + // But since we can't have intermediate array sizes depend on the loop + // counter, we instead exploit the fact that we initialized the output to + // all zeros and use that as zero-padding (doing unnecessary FLOPs). + xla::XlaOp a_row; + if (transpose_a) { + a_row = DynamicSliceInMinorDims(body_a, {zero, i}, {m, 1}); + } else { + a_row = DynamicSliceInMinorDims(body_a, {i, zero}, {1, m}); + } + auto b_update = BatchDot(a_row, body_out, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto result_row_slice = + DynamicSliceInMinorDims(body_b, {i, zero}, {1, n}); + auto result_row = result_row_slice - b_update; + + // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + auto a_elt = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + auto a_elt_conj = MaybeConjugate(a_elt, conjugate_a); + auto div_result = xla::Div(result_row, a_elt_conj); + body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {i, zero}); + + // if transpose_a: + // return (i - 1, body_out, a, b) + // else: + // return (i + 1, body_out, a, b) + auto next_i = xla::Add( + i, xla::ConstantR0(bodyb.get(), transpose_a ? -1 : 1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + }); } -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - xla::XlaOp output = Zeros(builder, b_shape); - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (0, output, a, b) - // else: - // init = (n-1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); - auto init = builder->Tuple({init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i < n if transpose_a else i >= 0 - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); - { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), - 0); - if (transpose_a) { - condb->Lt(i, condb->ConstantR0(n)); - } else { - condb->Ge(i, condb->ConstantR0(0)); +xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) - // else: - // a_row = a[..., :, i:i+1] - // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) - // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); - { - auto input_tuple = bodyb->Parameter( - 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0(0); - - // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, - // i:i+1]) But since we can't have intermediate array sizes depend on the - // loop counter, we instead exploit the fact that we initialized the output - // to all zeros and use that as zero-padding (doing unnecessary FLOPs). - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - // result = b - np.matmul(output, a) - auto result = bodyb->Sub(body_b, b_update); - // result_row = result[..., :, i:i+1] - TF_ASSIGN_OR_RETURN( - auto result_row, - DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); - - // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_ii_conj, - MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_ii_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {zero, i})); + // The main computation is performed in a While loop. + xla::XlaOp output = xla::ZerosLike(b); + + // Construct the initial loop carry tuple, // if transpose_a: - // return (i + 1, body_out, a, b) + // init = (0, output, a, b) // else: - // return (i - 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + // init = (n-1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = xla::ConstantR0(builder, transpose_a ? 0 : n - 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + xla::Lt(i, xla::ConstantR0(condb.get(), n)); + } else { + xla::Ge(i, xla::ConstantR0(condb.get(), 0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1, -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop + // counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = xla::Parameter( + bodyb.get(), 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0(bodyb.get(), 0); + + // result = b - np.matmul(output, a) + // result_row = result[..., :, i:i+1] + auto body_b_slice = DynamicSliceInMinorDims(body_b, {zero, i}, {m, 1}); + xla::XlaOp a_slice; + if (transpose_a) { + a_slice = DynamicSliceInMinorDims(body_a, {i, zero}, {1, n}); + } else { + a_slice = DynamicSliceInMinorDims(body_a, {zero, i}, {n, 1}); + } + auto b_update = body_b_slice - BatchDot(body_out, a_slice, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + + // body_out[..., :, i:i+1] = b_update / a[..., i:i+1, i:i+1] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + auto a_ii_conj = MaybeConjugate(a_ii, conjugate_a); + body_out = DynamicUpdateSliceInMinorDims(body_out, b_update / a_ii_conj, + {zero, i}); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = xla::Add( + i, xla::ConstantR0(bodyb.get(), transpose_a ? 1 : -1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 540c26b2473df9e7885f4e549b3e516a3d8a0d43..80c2bc4c9c38ec101db419d48db26e67e25d169b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,23 +57,15 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size = 256); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a); -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 87ea4763f7c2357ae179b68ade3715b24c46432f..d5ffc1498e4b6dcfbc9f24f9b5dce58fddca8ab1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -196,11 +191,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -219,11 +213,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -242,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -267,11 +259,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); xla::Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -295,11 +286,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1., 1.5}, @@ -323,10 +313,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); + TriangularSolveLeftLooking(a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -345,10 +334,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); + TriangularSolveLeftLooking(a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); xla::Array2D expected({ {0.5, 1.0, 1.5}, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index d9ff7e6259f3fbab8957394bff5c5670a67dd0eb..fdc8bfca4932fe62a4d2a8db49f4104c3eb0cd3b 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -27,29 +28,23 @@ limitations under the License. namespace tensorflow { -xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), - xla::AsInt64Slice(shape.dimensions())); -} - xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { switch (type) { case xla::F16: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::BF16: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::F32: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::F64: - return builder->ConstantR0(value); + return xla::ConstantR0(builder, value); break; case xla::C64: - return builder->ConstantR0(value); + return xla::ConstantR0(builder, value); break; default: LOG(FATAL) << "unhandled element type " << type; @@ -107,134 +102,140 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, default: LOG(FATAL) << "unhandled element type " << type; } - return builder->ConstantLiteral(literal); + return xla::ConstantLiteral(builder, literal); } -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end) { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return builder->Slice(x, padded_start, padded_end, strides); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return xla::Slice(x, padded_start, padded_end, strides); + }); } -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices) { - std::vector output(indices.size() + major_dims.size()); - std::copy(major_dims.begin(), major_dims.end(), output.begin()); - std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); return output; } -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, - const gtl::ArraySlice& sizes) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); - return builder->DynamicSlice(x, padded_starts, padded_sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return xla::DynamicSlice(x, padded_starts, padded_sizes); + }); } -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = builder->ConstantR1(start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return builder->DynamicUpdateSlice(x, update, start_constant); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = xla::ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return xla::DynamicUpdateSlice(x, update, start_constant); + }); } -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(builder, x, update, padded_start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); } -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - return builder->DynamicUpdateSlice(x, update, padded_starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return xla::DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = builder->Reshape(builder->ConstantR0(0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = - builder->Reshape(starts[i], {1}); - } - return builder->ConcatInDim(padded_starts, 0); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); + } + return xla::ConcatInDim(builder, padded_starts, 0); + }); } -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return builder->Transpose(x, permutation); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return xla::Transpose(x, permutation); + }); } -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? builder->Conj(x) : x; +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? xla::Conj(x) : x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 3c120a2548576d6ad46870583ca65beea63507a3..6cb6c088e9d20af05193f0a3da6c2595966eb495 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -23,9 +23,6 @@ limitations under the License. namespace tensorflow { -// Returns a zero-filled tensor with shape `shape`. -xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); - // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, @@ -33,7 +30,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. @@ -41,54 +38,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last two values being +// Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end); -// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices); +// Returns the concatenation of `xs` and `ys`. +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, const gtl::ArraySlice& sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate); +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 265b39402c832f8c810a74f281563b05afdf2b1b..7d0f2222a9aa3ef09cb8be20c5f9b26431c6498c 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); - auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); - TF_ASSERT_OK(result.status()); + DynamicSliceInMinorDims(a, {x, y}, {1, 1}); ComputeAndCompareR2(&builder, {{10}}, {a_data.get(), x_data.get(), y_data.get()}, @@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0(0)}, {1, 4})); + DynamicSliceInMinorDims(a, {index, xla::ConstantR0(&builder, 0)}, + {1, 4}); ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); - auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); - TF_ASSERT_OK(result.status()); + DynamicUpdateSliceInMinorDims(a, b, {x, y}); xla::Array2D expected( {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); @@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) { // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0(0)}, {1, n})); - TF_ASSERT_OK_AND_ASSIGN( - auto dot, BatchDot(&builder, l_index, row, - /*transpose_x=*/false, /*transpose_y=*/true)); + auto l_index = DynamicSliceInMinorDims( + a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 09ce594930efc0af47306590d76b322ac730f80f..7cc88f34d291f25814fba9f802c93117973120e7 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,7 +40,7 @@ xla::StatusOr> XlaWhileLoop( xla::XlaBuilder* builder) { std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = builder->GetTupleElement(tuple, i); + elements[i] = xla::GetTupleElement(tuple, i); } return elements; }; @@ -48,7 +49,8 @@ xla::StatusOr> XlaWhileLoop( std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { - auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -61,7 +63,8 @@ xla::StatusOr> XlaWhileLoop( std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { - auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -69,11 +72,11 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - body_builder->Tuple(result); + xla::Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } @@ -86,9 +89,8 @@ xla::StatusOr> XlaForEachIndex( auto while_cond_fn = [&](gtl::ArraySlice values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return cond_builder->Lt( - values[0], - IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, + num_iterations)); }; auto while_body_fn = [&](gtl::ArraySlice values, xla::XlaBuilder* body_builder) @@ -97,9 +99,9 @@ xla::StatusOr> XlaForEachIndex( std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(body_builder->Add( - iteration, - body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + updated_values.push_back(xla::Add( + iteration, xla::ConstantLiteral( + body_builder, xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector body_outputs, @@ -112,7 +114,7 @@ xla::StatusOr> XlaForEachIndex( std::vector values; values.reserve(initial_values.size() + 1); values.push_back( - builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 43e1c1e9fecec1c71db1509757251cb5d903ca49..b43405a1a407b5fa98dd740c62af91e048cc9490 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -22,21 +22,34 @@ limitations under the License. namespace tensorflow { -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - host_tensor.dtype(), host_tensor.shape(), &literal_shape)); +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} - *literal = xla::Literal(literal_shape); +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = host_tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&host_tensor); - memcpy(dst_ptr, src_ptr, total_bytes); + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 220bec15538c36fa30abef9e729b64dbbb9f72b3..ab7e861f3336097d2ea52487092f16edb5c14531 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,12 +22,20 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { -// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an -// unsupported type. -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index bb9168fa358154f3db9dab87bacc9bf28dd16406..ace6fd1d8eeaf439509a7b75d8d986997c392e73 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( name = "xla_ops", - srcs = [ - "dynamic_slice_ops.cc", - "functional_ops.cc", - "reduce_window_op.cc", - "sendrecv_ops.cc", - ], + srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", ], diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc deleted file mode 100644 index d6c0edbb889b1751ac9d9d47d0c9534b543196ff..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("XlaDynamicUpdateSlice") - .Input("input: T") - .Input("update: T") - .Input("indices: Tindices") - .Output("output: T") - .Attr("T: type") - .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Wraps the XLA DynamicUpdateSlice operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice -. - -XlaDynamicUpdateSlice generates a result which is the value of the `input` -operand, with a slice update overwritten at `indices`. The shape of `update` -determines the shape of the sub-array of the result which is updated. The shape -of indices must be rank == 1, with dimension size equal to the rank of `input`. - -Handling of out-of-bounds slice indices is implementation-defined. - -input: A `Tensor` of type T. -indices: A vector of indices into `input`. Must have length equal to the rank of - `input`. -update: A `Tensor` of type T. Same rank as `input`. -output: A `Tensor` of type T. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc deleted file mode 100644 index 4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. -REGISTER_OP("XlaWhile") - .Input("input: T") - .Output("output: T") - .Attr("T: list(type) >= 0") - .Attr("cond: func") - .Attr("body: func") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = input; While (Cond(output)) { output = Body(output) } - -input: A list of input tensors whose types are T. -output: A list of output tensors whose types are T. -cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. -body: A function that takes a list of tensors and returns another - list of tensors. Both lists have the same types as specified by T. -)doc"); - -// TODO(b/37549631) setting the If Op to always be stateful is too -// conservative. -REGISTER_OP("XlaIf") - .Input("cond: Tcond") - .Input("inputs: Tin") - .Output("output: Tout") - .Attr("Tcond: type") - .Attr("then_branch: func") - .Attr("else_branch: func") - .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type) >= 0") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = cond ? then_branch(inputs) : else_branch(inputs). - -cond: A boolean scalar. -inputs: A list of input tensors. -output: A list of tensors returned by either then_branch(inputs) or - else_branch(inputs). The input shapes of the then_branch and - else_branch must match. -then_branch: A function takes 'inputs' and returns a list of tensors, - whose types are the same as what else_branch returns. -else_branch: A function takes 'inputs' and returns a list of tensors. - whose types are the same as what then_branch returns. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc deleted file mode 100644 index d9af982adc090ea78c711fd4656ba429c53b18c9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaReduceWindow") - .Input("input: T") - .Input("init_value: T") - .Attr("T: numbertype") - .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") - .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Wraps the XLA ReduceWindow operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . - -input: the input tensor -init_value: a scalar representing the initial value for the reduction -computation: a reducer function to apply -window_dimensions: the shape of the window -window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc deleted file mode 100644 index 7ec7b50e905a6cbdecea4543dcb87322b5a7e844..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaSend") - .Input("tensor: T") - .Attr("T: type") - .Attr("tensor_name: string") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Sends the named tensor to another XLA computation. Wraps the XLA Send operator -documented at - https://www.tensorflow.org/performance/xla/operation_semantics#send . - -tensor: The tensor to send. -tensor_name: A string key that identifies the channel. -)doc"); - -REGISTER_OP("XlaRecv") - .Output("tensor: dtype") - .Attr("dtype: type") - .Attr("tensor_name: string") - .Attr("shape: shape") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - TensorShape shape_attr; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); - shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); - c->set_output(0, s); - return Status::OK(); - }) - .Doc(R"doc( -Receives the named tensor from another XLA computation. Wraps the XLA Recv -operator documented at - https://www.tensorflow.org/performance/xla/operation_semantics#recv . - -tensor: The tensor to receive. -dtype: The type of the tensor. -tensor_name: A string key that identifies the channel. -shape: The shape of the tensor. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a59c77f5c3a309abe8f6fbab1e48455d54e8fae5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") + .Attr("tensor_name: string") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . + +tensor: The tensor to receive. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. +)doc"); + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +REGISTER_OP("XlaSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . + +tensor: The tensor to send. +tensor_name: A string key that identifies the channel. +)doc"); + +REGISTER_OP("XlaSort") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. + +input: A `Tensor` of type T. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index e5ce65bec950fdfd38c3ca5bc62ac745ef8ca4a7..2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -77,4 +77,6 @@ def reduce_window(operand, recv = gen_xla_ops.xla_recv send = gen_xla_ops.xla_send +sort = gen_xla_ops.xla_sort + while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3a08aa8cf4f5cea6210cc9470d57c3387445ea6e..ac768b206e2a8d163a4253432a1911152f89ce86 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f7098917b191058c53a1d6a5923e80e5e8319d72..319cbc74e96262881d32bdc9de2251b53f2b05d6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -83,12 +84,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -228,15 +226,18 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape) { + xla::Shape* xla_shape) const { switch (arg.kind) { case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape = - is_entry_computation - ? options_.shape_representation_fn(arg.shape, arg.type) - : arg.shape; + TensorShape shape; + if (is_entry_computation) { + TF_ASSIGN_OR_RETURN( + shape, options_.shape_representation_fn(arg.shape, arg.type)); + } else { + shape = arg.shape; + } return TensorShapeToXLAShape(arg.type, shape, xla_shape); } case XlaCompiler::Argument::kResource: { @@ -244,8 +245,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TensorShape representation_shape = - options_.shape_representation_fn(arg.shape, arg.type); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + options_.shape_representation_fn(arg.shape, arg.type)); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -341,9 +343,9 @@ Status BuildComputation( const std::vector& arg_cores, const std::vector& retvals, const std::vector>& resources, - bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, - xla::XlaComputation* computation, int* num_computation_outputs, - int* num_nonconst_outputs, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, std::vector* resource_updates) { std::vector elems; @@ -387,13 +389,14 @@ Status BuildComputation( const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = resource->value() != resource->initial_value(); + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value() != grad.second->initial_value() || - arg.tensor_array_gradients.count(grad.first) == 0; + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); @@ -418,7 +421,7 @@ Status BuildComputation( // create a tuple/get-tuple-element combination so that sharding // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); elems.push_back(handle); } @@ -427,7 +430,9 @@ Status BuildComputation( *num_computation_outputs = elems.size(); // Builds the XLA computation. - builder->Tuple(elems); + if (always_return_tuple || elems.size() != 1) { + xla::Tuple(builder, elems); + } builder->ClearOpMetadata(); xla::StatusOr computation_status = builder->Build(); @@ -554,16 +559,16 @@ Status XlaCompiler::BuildArguments( } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } else { - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = builder->GetTupleElement(tuple, i); + arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { @@ -571,8 +576,8 @@ Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = - builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); + arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], + strings::StrCat("arg", i)); } } @@ -603,7 +608,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression.set_handle( - builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); } else { arg_expression.set_handle(arg_handles[i]); } @@ -655,10 +660,65 @@ Status XlaCompiler::CompileSingleOp( .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } + FixupSourceAndSinkEdges(graph.get()); return CompileGraph(options, name, std::move(graph), args, result); } +namespace { + +// Check that the ops of all non-functional nodes have been registered. +Status ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); + } + return Status::OK(); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + auto maybe_error = [&](const string& op, const Status& s) -> Status { + if (!s.ok()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ": ", op, " (", s.error_message(), + ")")); + } + return Status::OK(); + }; + + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + Status s; + if (fdef) { + s = ValidateFunctionDef(fdef, flib_def); + TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + continue; + } + const OpDef* op_def; + s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); + TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + s = FindKernelDef(device_type, node->def(), nullptr, nullptr); + TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -681,6 +741,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), graph.get(), local_flib_def_.get())); + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); + xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( this, &builder, options_.allow_cpu_custom_calls, @@ -705,9 +770,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), - options.return_updated_values_for_all_resources, &builder, - result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->outputs, &result->resource_updates)); + options.return_updated_values_for_all_resources, + options.always_return_tuple, &builder, result->computation.get(), + &num_computation_outputs, &num_nonconst_outputs, &result->outputs, + &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 621fbc149a6216904dfec84580d7a0a3da553ca9..079c99797e1f1ec26205e33b3c7c16d3764f15ca 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -38,7 +40,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // @@ -51,13 +53,7 @@ class XlaContext; // (kind kResource). // // Only kParameter and initialized kResource arguments become runtime parameters -// to the generated XLA computation. The XLA computation will have run-time -// parameters in the following order: -// +---------------------+-----------------------------------------+ -// | kParameter values | Initial values of kResource arguments | -// +---------------------+-----------------------------------------+ -// Within each block, the arguments are arranged by the _Arg index from which -// they were derived. +// to the generated XLA computation. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -76,10 +72,10 @@ class XlaContext; // tensors with a different shape to their representation inside the XLA // computation. // -// In both inputs and outputs, kResource values are placed the end. When +// In computation outputs, updated kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has -// identical input and output signatures. By moving variable values -// to the end of the argument list and using the +// identical input and output signatures. By passing variable values +// at the end of the argument list and using the // `return_updated_values_for_all_variables` option, we can ensure that the // input and output values of resources appear at the same positions. // @@ -174,6 +170,11 @@ class XlaCompiler { // computation. bool resolve_compile_time_constants = true; + // If 'always_return_tuple' is true, then the output of a computation will + // always be a tuple. Otherwise, a single-element output will not be wrapped + // in a tuple. + bool always_return_tuple = true; + // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; @@ -233,7 +234,8 @@ class XlaCompiler { tf2xla::HostComputeMetadata host_compute_metadata; // Resources whose values were updated by the computation, ordered - // by return value position. Resource updates follow the non-constant + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant // results in the outputs of XLA computation. std::vector resource_updates; @@ -241,12 +243,13 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function + typedef std::function(const TensorShape&, + DataType)> ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; @@ -313,7 +316,7 @@ class XlaCompiler { // See the class comment for more details about the argument passing // convention. Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape); + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 55772ca324872f6d5fac008de7819b7fae64966a..07af8ef54b79b215e9e99faa161c8279488ebbf7 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -45,8 +46,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -58,7 +57,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -68,7 +67,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -979,5 +977,114 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); +} + +TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef no_op; + no_op.set_name("NoOp"); + no_op.set_op("NoOp"); + Status status; + graph->AddNode(no_op, &status); + TF_ASSERT_OK(status); + + std::vector args; + XlaCompiler compiler(DefaultOptions()); + // No control edge linking NoOp with source/sink. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: NoOp")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 098072d33cd4eb7f7dec0ec4196b43eca0220d4a..fd39a58ce64acad12768a031c3c9d03c26c01b71 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -66,8 +66,8 @@ XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function* - shape_representation_fn) + const std::function( + const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), @@ -92,7 +92,7 @@ void XlaContext::AddRetval(int retval_index, DataType type, } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal) { + const xla::LiteralSlice& literal) { VLOG(1) << "Adding retval index " << retval_index << " with non-data-dependent tensor to XLA computation"; if (retvals_.size() <= retval_index) { @@ -119,8 +119,8 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::RepresentationShape(const TensorShape& shape, - DataType type) const { +xla::StatusOr XlaContext::RepresentationShape( + const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } @@ -131,9 +131,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Max(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Max(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -145,9 +147,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Min(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Min(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -159,9 +163,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Add(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Add(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -173,9 +179,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Mul(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Mul(x, y); return b.Build().ConsumeValueOrDie(); }); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 341bf6ff1f37fa7cd81f41c02a941214067b1bd1..38d8cd653cbbe5b01325d6b478589d88909bac56 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -47,8 +48,8 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function* - shape_representation_fn); + const std::function( + const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -83,7 +84,7 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal); + const xla::LiteralSlice& literal); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` @@ -101,8 +102,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - TensorShape RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -160,7 +161,7 @@ class XlaContext : public ResourceBase { // should be represented in XLA. Parameters/return values will be shaped // according to this function, and reshaped back to/from their declared shapes // for computations. Must be non-null. - const std::function* + const std::function(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f1594193af09c7193f03b4685d3a7d4510d654dd..edbc5e95a8c22dd35dd7c384afdfaf80553eceaf 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,9 +19,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" @@ -32,103 +36,71 @@ namespace tensorflow { namespace { -Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - bool is_min, xla::XlaOp* argminmax) { - xla::XlaOp init_value; - const xla::XlaComputation* reducer; - if (is_min) { - init_value = XlaHelpers::MaxValue(builder, input_type); - reducer = ctx->GetOrCreateMin(input_type); - } else { - init_value = XlaHelpers::MinValue(builder, input_type); - reducer = ctx->GetOrCreateMax(input_type); - } - - xla::PrimitiveType xla_output_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - - xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(input_shape.dims() - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = builder->ConvertElementType( - builder->Eq(input, input_max, broadcast_dims), xla_output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; - xla::XlaOp shift_amount = - XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::XlaOp full_mask = builder->ShiftRightArithmetic( - builder->ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - xla::XlaOp iota; - - const int64 axis_size = input_shape.dim_size(axis); - TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); - xla::XlaOp product = - builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - xla::XlaOp output = - builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), - *ctx->GetOrCreateMax(output_type), - /*dimensions_to_reduce=*/{axis}); - *argminmax = output; - return Status::OK(); +xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, + bool is_min) { + xla::XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + xla::XlaOp init_value; + xla::XlaComputation reducer; + if (is_min) { + init_value = xla::MaxValue(builder, input_shape.element_type()); + reducer = + xla::CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = xla::MinValue(builder, input_shape.element_type()); + reducer = + xla::CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + xla::XlaOp partial_mask = xla::ConvertElementType( + xla::Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + xla::XlaOp shift_amount = + xla::ConstantR0WithType(builder, output_type, bits_in_type); + xla::XlaOp full_mask = xla::ShiftRightArithmetic( + xla::ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); + xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); + xla::XlaOp product = + xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return xla::Reduce(product, xla::MinValue(builder, output_type), + xla::CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); } } // namespace -xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MinValue(type)); -} - -xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MaxValue(type)); -} - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::Literal::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::One(type)); -} - -xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { - switch (data_type) { - case DT_HALF: - return b->ConstantR0( - static_cast(Eigen::NumTraits::epsilon())); - case DT_BFLOAT16: - return b->ConstantR0(bfloat16::epsilon()); - case DT_FLOAT: - return b->ConstantR0(std::numeric_limits::epsilon()); - case DT_DOUBLE: - return b->ConstantR0(std::numeric_limits::epsilon()); - default: - LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " - << DataTypeString(data_type); - } + return xla::ConstantLiteral(b, xla::Literal::One(type)); } xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, @@ -176,44 +148,14 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, xla::XlaOp* argmax) { - return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, - axis, /*is_min=*/false, argmax); -} - -Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, xla::XlaOp* argmin) { - return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, - axis, /*is_min=*/true, argmin); +xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); } -Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota) { - TensorShape linspace_shape({size}); - Tensor linspace; - switch (dtype) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(dtype)); - } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); - *iota = builder->ConstantLiteral(linspace_literal); - return Status::OK(); +xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); } Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, @@ -245,25 +187,28 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = builder->Eq( - indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); + xla::XlaOp one_hot_bool = xla::Eq( + indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. - *one_hot = builder->Select( - one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()), - builder->Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select(one_hot_bool, + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } @@ -275,7 +220,7 @@ xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); - return builder->ConvertElementType(operand, convert_to); + return xla::ConvertElementType(operand, convert_to); } } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c3fdc5252e74363fe289eeabb2cb0d68298ee291..d6ca4ab9346593892917e8375b07a8790dc26e79 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -28,14 +28,6 @@ namespace tensorflow { // Helper methods for building XLA computations. class XlaHelpers { public: - // Returns a handle representing the minimum value of a scalar - // element of data_type. - static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the maximum value of a scalar - // element of data_type. - static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the zero value of a scalar // element of data_type. static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); @@ -44,10 +36,6 @@ class XlaHelpers { // element of data_type. static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); - // Returns the machine epsilon for floating-point type `data_type`, i.e., - // the difference between 1.0 and the next representable value. - static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. @@ -65,25 +53,15 @@ class XlaHelpers { gtl::ArraySlice shape, xla::Literal* output); - // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmax); - - // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmin); - - // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota); + // Returns the argmax of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); + + // Returns the argmin of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6..359cb4c4670227e592ed4b8339825e7f95b16899 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -19,7 +19,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -38,8 +42,7 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().builder() != nullptr || - expression->resource() != nullptr); + CHECK(expression->handle().valid() || expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle(); return expression; } @@ -48,7 +51,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().builder(), nullptr); + CHECK(!expression->handle().valid()); return const_cast(expression); } @@ -67,6 +70,20 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } +DataType XlaOpKernelContext::input_type(int index) const { + return context_->input(index).dtype(); +} + +xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(input_type(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + Status XlaOpKernelContext::ConstantInput(int index, xla::Literal* constant_literal) { return ConstantInputReshaped( @@ -87,6 +104,25 @@ Status XlaOpKernelContext::ConstantInputReshaped( } const XlaExpression* expression = CastExpressionFromTensor(tensor); + auto copy_tensor_to_literal = [](const Tensor& tensor, + xla::Literal* literal) { + xla::Shape literal_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); + + *literal = xla::Literal(literal_shape); + + // memcpy over the payload ... + // TODO(phawkins): handle string types. + size_t total_bytes = tensor.TotalBytes(); + if (total_bytes > 0) { + void* dst_ptr = literal->untyped_data(); + const void* src_ptr = DMAHelper::base(&tensor); + memcpy(dst_ptr, src_ptr, total_bytes); + } + return Status::OK(); + }; + // If the tensor has a known constant value, there is no need to invoke XLA. if (expression->has_constant_value()) { Tensor temp(tensor.dtype()); @@ -95,19 +131,21 @@ Status XlaOpKernelContext::ConstantInputReshaped( // with the enclosing Tensor. return errors::Internal("Incompatible shapes in ConstantInputReshaped."); } - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } // Make sure we treat zero-element tensors as constant. if (new_shape.num_elements() == 0) { Tensor temp(tensor.dtype(), new_shape); - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. - handle = builder()->Reshape(handle, new_shape.dim_sizes()); + handle = xla::Reshape(handle, new_shape.dim_sizes()); } // The XLA layout is specified minor to major, and TensorFlow's minor @@ -162,7 +200,8 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Converts an int32 or int64 scalar literal to an int64. -static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { +static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, + int64* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -177,7 +216,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { } // Converts an float32 or float64 scalar literal to a float64. -static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) { +static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, + double* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -204,7 +244,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { } // Converts an int32 or int64 1D literal to an int64 vector. -static Status LiteralToInt64Vector(const xla::Literal& literal, +static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 1) { return errors::InvalidArgument("value is not 1D"); @@ -314,13 +354,13 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(variable->shape(), variable->type()); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + xla_context.RepresentationShape(variable->shape(), variable->type())); if (representation_shape == variable->shape()) { *value = variable->value(); } else { - *value = - builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); } return Status::OK(); } @@ -368,10 +408,11 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { const TensorShape& shape = constant.shape(); - xla::Literal literal; - OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); - xla::XlaOp handle = builder()->ConstantLiteral(literal); - CHECK_NE(handle.builder(), nullptr); + xla::BorrowingLiteral literal; + OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); + + xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); + CHECK(handle.valid()); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -416,7 +457,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, xla::XlaOp handle) { - TF_RET_CHECK(handle.builder() != nullptr); + TF_RET_CHECK(handle.valid()); const XlaExpression* expression = CastExpressionFromTensor(context_->input(input_index)); @@ -435,10 +476,10 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(shape, type); + TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + xla_context.RepresentationShape(shape, type)); if (shape != representation_shape) { - handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + handle = xla::Reshape(handle, representation_shape.dim_sizes()); } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 667dc262ca03ca716ffbf015a78fc14c7a8b7c1a..2bde2c983d0cca05558e86a36698d6f0e097705a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -67,7 +68,12 @@ class XlaOpKernelContext { int num_inputs() const { return context_->num_inputs(); } // Returns the type of input 'index'. - DataType input_type(int index) { return context_->input(index).dtype(); } + DataType input_type(int index) const; + + // Returns the type of input 'index' as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); // Returns the shape of input 'index'. TensorShape InputShape(int index); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e309cb1e34db7f8430c2494c03aed41652b7a167..46785bc1f0a1279bfd67a55844fe238d9797382b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); @@ -71,16 +71,18 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } - if (!x.has_device_whitelist || !y.has_device_whitelist) { - LOG(WARNING) << "Registrations of " << x.name - << " do not both have device whitelists."; + if (!x.has_device_whitelist && !y.has_device_whitelist) { + LOG(WARNING) << "Duplicate registrations of " << x.name + << "with no device whitelists."; return false; } - for (const auto& device : x.device_whitelist) { - if (y.device_whitelist.count(device) != 0) { - LOG(WARNING) << "Multiple registrations of " << x.name << " on device " - << device; - return false; + if (x.has_device_whitelist && y.has_device_whitelist) { + for (const auto& device : x.device_whitelist) { + if (y.device_whitelist.count(device) != 0) { + LOG(WARNING) << "Multiple registrations of " << x.name << " on device " + << device; + return false; + } } } if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { @@ -157,97 +159,143 @@ void XlaOpRegistry::RegisterCompilationKernels() { registry.jit_kernels_registered_ = true; OpRegistryInterface* op_registry = OpRegistry::Global(); - for (const auto& op : registry.ops_) { - const string& op_name = op.first; - const std::unique_ptr& op_registration = op.second; - const OpDef* op_def; - Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); - if (!lookup_status.ok()) { - LOG(ERROR) << lookup_status.error_message(); - XLA_LOG_LINES( - ERROR, "Ops registered: \n" + - dynamic_cast(op_registry)->DebugString(true)); + // Order of op registration: + // The goal is to allow the co-existence of backend-specific kernels and + // generic kernels. To achieve this, we enforce the following order of + // registrations for one op: + // 1. Process op registration with device whitelists: + // this pass registers backend-specific kernels for this op. + // 2. Process op registration without device whitelists: + // this pass registers the kernels for all the other supported backends. + for (auto& ops : registry.ops_) { + const string& op_name = ops.first; + std::vector>& op_registrations = ops.second; + // Partition the op registration so that the ones with device whitelists + // precede the one without device whitelist. + std::partition(op_registrations.begin(), op_registrations.end(), + [](const std::unique_ptr& op_reg) { + return op_reg->has_device_whitelist; + }); + + // Collect a set of backend registered by ops with device whitelists. + // The op registration without whitelists will register a generic kernel + // for all other backends not in this set. + std::unordered_set whitelisted_backend; + for (auto& op_registration : op_registrations) { + if (op_registration->has_device_whitelist) { + whitelisted_backend.insert(op_registration->device_whitelist.begin(), + op_registration->device_whitelist.end()); + } } - TF_CHECK_OK(lookup_status); - std::unordered_set type_attrs; - for (const OpDef::AttrDef& attr_def : op_def->attr()) { - if (attr_def.type() == "type" || attr_def.type() == "list(type)") { - type_attrs.insert(attr_def.name()); + for (auto& op_registration : op_registrations) { + const OpDef* op_def; + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, + "Ops registered: \n" + + dynamic_cast(op_registry)->DebugString(true)); } - } + TF_CHECK_OK(lookup_status); - // Checks there are no type constraints referring to unknown attributes. - for (const auto& constraint : op_registration->type_constraints) { - if (type_attrs.find(constraint.first) == type_attrs.end()) { - LOG(FATAL) << "Unknown type attribute " << constraint.first - << " in XLA op registration for " << op_name; + std::unordered_set type_attrs; + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.type() == "type" || attr_def.type() == "list(type)") { + type_attrs.insert(attr_def.name()); + } } - } - for (auto& backend : registry.backends_) { - // If the operator has a device whitelist, only register on whitelisted - // devices. - if (op_registration->has_device_whitelist && - op_registration->device_whitelist.find(backend.first) == - op_registration->device_whitelist.end()) { - continue; + // Checks there are no type constraints referring to unknown attributes. + for (const auto& constraint : op_registration->type_constraints) { + if (type_attrs.find(constraint.first) == type_attrs.end()) { + LOG(FATAL) << "Unknown type attribute " << constraint.first + << " in XLA op registration for " << op_name; + } } - std::unique_ptr kdef(new KernelDef); - kdef->set_op(op_registration->name); - kdef->set_device_type(backend.first); - - // Constrain each type attribute to the intersection of: - // a) the types supported by the backend, and - // b) the types allowed by the OpDef, and - // c) the type constraints. - for (const string& type_attr : type_attrs) { - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name(type_attr); - auto* allowed_values = - attr_constraint->mutable_allowed_values()->mutable_list(); - - const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); - const auto* op_def_allowed_types = - op_def_attr.has_allowed_values() - ? &op_def_attr.allowed_values().list().type() - : nullptr; - auto constraint_it = op_registration->type_constraints.find(type_attr); - const std::set* type_constraints = - constraint_it != op_registration->type_constraints.end() - ? &constraint_it->second - : nullptr; - for (DataType dtype : backend.second.supported_types) { - // Filter out types that aren't allowed by the OpDef. - if (op_def_allowed_types != nullptr && - std::find(op_def_allowed_types->begin(), - op_def_allowed_types->end(), - dtype) == op_def_allowed_types->end()) { - continue; + for (auto& backend : registry.backends_) { + // If the operator has a device whitelist, only register on whitelisted + // devices. + if (op_registration->has_device_whitelist && + op_registration->device_whitelist.find(backend.first) == + op_registration->device_whitelist.end()) { + continue; + } + + // If the operator does NOT has a device whitelist, skip all devices + // that has already been registered. + if (!op_registration->has_device_whitelist && + whitelisted_backend.find(backend.first) != + whitelisted_backend.end()) { + continue; + } + + std::unique_ptr kdef(new KernelDef); + kdef->set_op(op_registration->name); + kdef->set_device_type(backend.first); + + // Constrain each type attribute to the intersection of: + // a) the types supported by the backend, and + // b) the types allowed by the OpDef, and + // c) the type constraints. + bool unsatisfiable_type_constraint = false; + for (const string& type_attr : type_attrs) { + KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); + attr_constraint->set_name(type_attr); + auto* allowed_values = + attr_constraint->mutable_allowed_values()->mutable_list(); + + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = + op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; + for (DataType dtype : backend.second.supported_types) { + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; + } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } - // Filter out types based on the type constraints. - if (type_constraints != nullptr && - type_constraints->find(dtype) == type_constraints->end()) { - continue; + if (op_registration->allow_resource_types) { + allowed_values->add_type(DT_RESOURCE); + } + // Don't build KernelDefs that have unsatisfiable type constraints. + if (allowed_values->type().empty()) { + unsatisfiable_type_constraint = true; + break; } - // Passed all the filters, this type is allowed. - allowed_values->add_type(dtype); } - if (op_registration->allow_resource_types) { - allowed_values->add_type(DT_RESOURCE); + if (unsatisfiable_type_constraint) continue; + + if (backend.second.op_filter != nullptr && + !backend.second.op_filter(kdef.get())) { + continue; } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op_name; + registry.kernel_registrars_.emplace_back( + new kernel_factory::OpKernelRegistrar( + new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); + backend.second.kernel_defs.push_back(std::move(kdef)); } - if (backend.second.op_filter != nullptr && - !backend.second.op_filter(kdef.get())) { - continue; - } - VLOG(2) << "XLA op registration: device: " << backend.first - << " op: " << op_name; - registry.kernel_registrars_.emplace_back( - new kernel_factory::OpKernelRegistrar( - new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); - backend.second.kernel_defs.push_back(std::move(kdef)); } } } @@ -265,12 +313,12 @@ std::vector XlaOpRegistry::DeviceKernels( << "Unknown backend " << compilation_device_name; for (const std::unique_ptr& k : it->second.kernel_defs) { auto op_iter = registry.ops_.find(k->op()); - CHECK(op_iter != registry.ops_.end()); + CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty()); // The test in IsCompatible ensures that if there are multiple matching // registrations for this op name, they all have the same value of // compilation_only, so only the first match needs to be tested. if (include_compilation_only_kernels || - !op_iter->second->compilation_only) { + !op_iter->second.front()->compilation_only) { kernels.push_back(k.get()); } } @@ -282,10 +330,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); - if (it == registry.ops_.end()) { + if (it == registry.ops_.end() || it->second.empty()) { return nullptr; } - return &it->second->compile_time_constant_inputs; + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // compile_time_constant_inputs, so only the first match is returned. + return &it->second.front()->compile_time_constant_inputs; } std::vector XlaOpRegistry::BackendNames() { @@ -378,16 +429,15 @@ XlaOpRegistrar::XlaOpRegistrar( std::unique_ptr registration) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); mutex_lock lock(registry.mutex_); - auto existing_ops = registry.ops_.equal_range(registration->name); - for (auto existing = existing_ops.first; existing != existing_ops.second; - ++existing) { - if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) { + auto& existing_ops = registry.ops_[registration->name]; + for (auto& existing : existing_ops) { + if (!XlaOpRegistry::IsCompatible(*existing, *registration)) { LOG(FATAL) << "XLA op registration " << registration->name << " is incompatible with existing registration of the same name."; } } - registry.ops_.emplace(registration->name, std::move(registration)); + existing_ops.emplace_back(std::move(registration)); } XlaBackendRegistrar::XlaBackendRegistrar( diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3..2d4593ea4999ad6d8cd0f0e2eec9c6d69c3020b8 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -203,7 +203,7 @@ class XlaOpRegistry { // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. - std::unordered_multimap> ops_ + std::unordered_map>> ops_ GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b3b15b1af7636fddd4c29477cbfe6f9761f2c47 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// This test is to verify the correctness of XLA op registration with specific +// backend overrides. + +// A dummy backend-specific OpKernel for CPU. +class DummyCPUOp : public XlaOpKernel { + public: + explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +// A dummy generic OpKernel for all backends. +class DummyGenericOp : public XlaOpKernel { + public: + explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +REGISTER_OP("DummyDuplicateOp") + .Attr("T: {float, int32}") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +// Register the DummyCPUOp kernel for CPU with type INT32. +REGISTER_XLA_OP(Name("DummyDuplicateOp") + .Device(DEVICE_CPU_XLA_JIT) + .TypeConstraint("T", DT_INT32), + DummyCPUOp); +// Register the DummyGeneric kernel for all registered device (except CPU since +// it is already registered), with type FLOAT. +REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT), + DummyGenericOp); + +// Test the correctness of registered kernels. The kernel registered for CPU +// should have type INT32 while all other kernels should have type FLOAT. +TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + if (kernels.op() == "DummyDuplicateOp") { + EXPECT_EQ(kernels.constraint_size(), 1); + EXPECT_EQ(kernels.constraint(0).name(), "T"); + if (kernels.device_type() == "XLA_CPU_JIT") { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_INT32); + } else { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_FLOAT); + } + } + } +} + +// A dummy generic OpKernel for all backends. +class DummyInfeasibleTypeConstraintOp : public XlaOpKernel { + public: + explicit DummyInfeasibleTypeConstraintOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + LOG(FATAL) << "unreachable"; + } +}; + +REGISTER_OP("DummyInfeasibleTypeConstraintOp") + .Attr("T: {float, string}") + .Input("input: T") + .Output("output: T") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); +REGISTER_XLA_OP( + Name("DummyInfeasibleTypeConstraintOp").TypeConstraint("T", DT_STRING), + DummyInfeasibleTypeConstraintOp); + +TEST(XlaOpRegistryTest, OpWithInfeasibleTypeConstraintIsNotRegistered) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + // The operator should not be registered. + EXPECT_NE(kernels.op(), "DummyInfeasibleTypeConstraintOp"); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 540c65c597f20d5bb26494e56c09ff2187cfb0db..baea8149658ec0849ebb570931ca68518ec5284e 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -89,16 +90,16 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } switch (kind_) { case kVariable: { - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - shape_.dim_sizes()); + value_ = + xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes()); break; } case kTensorArray: { TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()); + value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); break; } case kStack: { @@ -106,9 +107,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); value_ = - builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()), - builder->ConstantR0(0)}); + xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + xla::ConstantR0(builder, 0)}); break; } @@ -130,8 +131,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::XlaOp gradient_value = builder->Broadcast( - XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); + xla::XlaOp gradient_value = + xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), @@ -152,7 +153,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); } - *pack = builder->Tuple(elems); + *pack = xla::Tuple(builder, elems); } return Status::OK(); } @@ -168,7 +169,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!initialized()) { initial_value_ = v; } @@ -178,7 +179,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!gradient->initialized()) { gradient->initial_value_ = v; } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 9ce36d1aa7622334b2acfbe9aa85d7419c4772ed..4de18a77887496d30e3b1407ecd9042e619653af 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -75,7 +75,7 @@ class XlaResource { const xla::XlaOp& initial_value() const { return initial_value_; } // A variable is initialized if it has a value. - bool initialized() const { return value_.builder() != nullptr; } + bool initialized() const { return value_.valid(); } // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index fb1991e9ec29359dc51151add3426fd45af910fb..03e542855ba0e3ae81e0b754eb319cadbd5079ba 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -53,7 +53,6 @@ xla_proto_library( deps = [ ":xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", ], ) @@ -143,30 +142,15 @@ cc_library( cc_library( name = "statusor", - srcs = ["statusor.cc"], hdrs = [ "statusor.h", - "statusor_internals.h", ], visibility = ["//visibility:public"], deps = [ ":status", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "statusor_test", - size = "small", - srcs = ["statusor_test.cc"], - deps = [ - ":statusor", - ":test", - ":types", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/stream_executor", ], ) @@ -176,6 +160,7 @@ cc_library( hdrs = [ "iterator_util.h", "map_util.h", + "overflow_util.h", "ptr_util.h", "util.h", ], @@ -251,7 +236,7 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", @@ -310,7 +295,6 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) @@ -583,6 +567,7 @@ tf_cc_test( ":shape_util", ":test", ":xla_data_proto", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 989cd61d9fc2f18f88780f337bd13a3b6dca5918..8f08d3b2e04670ad6590aca1db0fd9d25faed83f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -76,7 +75,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -87,6 +86,7 @@ cc_library( hdrs = ["executable_build_options.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", @@ -99,7 +99,6 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", @@ -111,6 +110,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", @@ -126,7 +126,6 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -162,22 +161,6 @@ cc_library( ], ) -cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 0a79b3cf279e2585b1b070ec875d95d755044563..3d596a6e65430b6e9692aabd65fc8aa84b7b873d 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -64,7 +64,7 @@ StatusOr> Client::Transfer( } StatusOr> Client::TransferToServer( - const Literal& literal, const DeviceHandle* device_handle) { + const LiteralSlice& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; *request.mutable_literal() = literal.ToProto(); if (device_handle) { @@ -91,7 +91,7 @@ StatusOr> Client::TransferToServer( return MakeUnique(stub_, response.data()); } -Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, +Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; *request.mutable_literal() = literal.ToProto(); @@ -161,22 +161,6 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr data, - Execute(computation, arguments, execution_options, execution_profile)); - - const Shape* shape_with_output_layout = nullptr; - if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); - } - return Transfer(*data, shape_with_output_layout); -} - StatusOr> Client::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -221,65 +205,11 @@ StatusOr> Client::ComputeConstant( return Literal::CreateFromProto(response.literal()); } -StatusOr Client::LoadSnapshot(const SessionModule& module) { - LoadComputationSnapshotRequest request; - *request.mutable_module() = module; - LoadComputationSnapshotResponse response; - - Status s = stub_->LoadComputationSnapshot(&request, &response); - if (!s.ok()) { - return s; - } - - VLOG(1) << "load snapshot response: " << response.ShortDebugString(); - return Computation(stub_, response.computation()); -} - StatusOr Client::LoadSnapshot(const HloSnapshot& module) { TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); - VLOG(1) << execution_stats; - } - } - - return MakeUnique(stub_, response.output()); -} - StatusOr> Client::Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -320,41 +250,6 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } -StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; - - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - Status s = stub_->ExecuteParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { - outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); -} - StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteGraphParallelRequest request; @@ -449,24 +344,6 @@ StatusOr>> Client::DeconstructTuple( return std::move(handles); } -StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); - *request.mutable_debug_options() = debug_options; - ComputationStatsResponse response; - - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - CHECK(response.has_stats()); - return response.stats(); -} - StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { @@ -488,23 +365,6 @@ StatusOr Client::GetComputationStats( return response.stats(); } -StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); -} - StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); @@ -527,28 +387,6 @@ StatusOr Client::GetShape(const GlobalData& data) { return response.shape(); } -StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN( - auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); - int64 total_flops = - computation_stats.flop_count() + computation_stats.transcendental_count(); - if (profile.compute_time_ns() > 0) { - int64 nanoseconds = profile.compute_time_ns(); - int64 cycle_count = profile.compute_cycle_count(); - double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( - "[Execution Statistics] flop count: ", computation_stats.flop_count(), - ", transcendental count: ", computation_stats.transcendental_count(), - ", compute execution time: ", nanoseconds, " nsec", - ", compute cycles: ", cycle_count, ", performance: ", gflops, - "gflop/s"); - } - return string("[Execution Statistics] not available."); -} - StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a63ff4c56d1dd78c7abfa2bf163b5fbd54d82b2b..68f0d0ac78c859fde7a6a007cd250b047a7bfcda 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,11 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,21 +51,6 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. - StatusOr> Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and returns the global - // data that was produced from the execution. - // * If execution_options is not nullptr, these options are passed to the - // service to affect how it compiles our computation. (The pointer does not - // need to live beyond this call.) - // * If execution_profile is not nullptr then the pointed-to ExecutionProfile - // will be filled with profile data from the execution. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -78,34 +62,6 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; - - // Executes a list ComputationInstances and returns global data produced from - // each computation. - StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); - - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct XlaComputationInstance { const XlaComputation& computation; std::vector arguments; @@ -125,7 +81,6 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); @@ -152,14 +107,14 @@ class Client { // device (and its replicas if replication is enabled). Otherwise, data is // transferred to the default device (and its replicas). StatusOr> TransferToServer( - const Literal& literal, const DeviceHandle* device_handle = nullptr); + const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr); // Transfer the given literal to the Infeed interface of the device. // // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, + Status TransferToInfeed(const LiteralSlice& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); // Transfers from the Outfeed of the device. @@ -177,17 +132,6 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and transfers the result - // to the client as a literal. Parameters are defined the same as for - // Execute() and Transfer(). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -209,8 +153,6 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; @@ -223,12 +165,6 @@ class Client { const GlobalData& data); // Retrieves the statistics of the given computation. - StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; - - // Retrieves the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; @@ -239,13 +175,6 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). - StatusOr> GetComputationShape( - const Computation& computation); - - // As above, but returns the shape of the provided computation (parameter - // types/names and return type). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> GetComputationShape( const XlaComputation& computation); @@ -253,9 +182,6 @@ class Client { // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - StatusOr LoadSnapshot(const SessionModule& module); - - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -263,8 +189,6 @@ class Client { private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, - const ExecutionProfile& profile); StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 96e38bca01087991943aff40ed1cb3e21f9e6cba..5c9abad4c3126be5e45e96c770c0679fe8606788 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -21,28 +21,11 @@ limitations under the License. namespace xla { -StatusOr>> -CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return compiler_service_->CompileAheadOfTime(service_instances, options); -} - StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector service_instances; service_instances.reserve(computations.size()); for (const AotXlaComputationInstance& instance : computations) { @@ -54,7 +37,8 @@ CompileOnlyClient::CompileAheadOfTime( service_instance.argument_layouts = instance.argument_layouts; service_instance.result_layout = instance.result_layout; } - return compiler_service_->CompileAheadOfTime(service_instances, options); + return compiler_service_->CompileAheadOfTime(service_instances, options, + metadata); } int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index c8725b8517484acdaf093bc3b34adb00f69155b1..332c96503637344d56e363e19db4880c37ca9684 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -38,26 +37,7 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); - // A description of an xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. @@ -66,15 +46,15 @@ class CompileOnlyClient : public Client { const Shape* result_layout; }; - // Compiles a list of xla computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. + // Compiles a list of xla computations for ahead-of-time execution. + // This is intended for use in static compilation. The |options| + // parameter describes the target for which the compiler should emit + // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); + const AotCompilationOptions& options, + std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. static int64 PointerSizeForTriple(tensorflow::StringPiece triple); diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bda0f0c4cb969939883efebcf3a6d6be381..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index 9a1bcde76387297cb7f374b25baad1d5ec284859..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -// -// TODO(b/74197823): Deprecated. Use XlaComputation instead. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 6e3c5cb484b8f1ef053fa287a4d462aeb886e530..7dee41f6a05025ec196b78e54015e8e71777031f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } +ExecutableBuildOptions& +ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { + return dump_unoptimized_hlo_proto_to_; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( tensorflow::StringPiece dirpath) { dump_per_pass_hlo_proto_to_ = dirpath.ToString(); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 11f10983606fe02b1edb11a260edde8e5f9a726f..9dc9be4423564fb967b247c2d1df31099cb80237 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -64,6 +65,13 @@ class ExecutableBuildOptions { tensorflow::StringPiece dirpath); const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() + const; + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( @@ -76,6 +84,13 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); tensorflow::gtl::optional hlo_profile() const; + void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + disabled_hlo_passes_.push_back(std::string(pass_name)); + } + const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + return disabled_hlo_passes_; + } + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; @@ -87,8 +102,10 @@ class ExecutableBuildOptions { bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; + std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d49d959a6c8112d3701857a70cecb24701c7b6d9..a6b9b4725324adf26a136d490cf28a89c92571c0 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -13,11 +13,18 @@ filegroup( ]), ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], deps = [ + ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -28,6 +35,88 @@ cc_library( ], ) +cc_library( + name = "constants", + srcs = ["constants.cc"], + hdrs = ["constants.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "constants_test", + srcs = ["constants_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "math", + srcs = ["math.cc"], + hdrs = ["math.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "math_test", + srcs = ["math_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "numeric", + srcs = ["numeric.cc"], + hdrs = ["numeric.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "numeric_test", + srcs = ["numeric_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index a1d34796ccfd86f2025eff0ecb51338eb6a9b1da..978fc40f3492cd7d9d7831c370b287bf45e6d3e0 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -42,8 +43,8 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); + auto lhs = Parameter(b.get(), 0, scalar, "lhs"); + auto rhs = Parameter(b.get(), 1, scalar, "rhs"); generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } @@ -55,7 +56,7 @@ XlaComputation CreateScalarAddComputation(PrimitiveType type, return CreateScalarComputation( "add", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Add(lhs, rhs); + return Add(lhs, rhs); }); } @@ -64,17 +65,15 @@ XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, return CreateScalarComputation( "mul", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Mul(lhs, rhs); + return Mul(lhs, rhs); }); } XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Ge(lhs, rhs); - }); + return CreateScalarComputation("ge", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, + const XlaOp& rhs) { return Ge(lhs, rhs); }); } XlaComputation CreateScalarMaxComputation(PrimitiveType type, @@ -82,7 +81,7 @@ XlaComputation CreateScalarMaxComputation(PrimitiveType type, return CreateScalarComputation( "max", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Max(lhs, rhs); + return Max(lhs, rhs); }); } @@ -91,7 +90,7 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, return CreateScalarComputation( "min", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Min(lhs, rhs); + return Min(lhs, rhs); }); } @@ -99,26 +98,27 @@ XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { return CreateScalarComputation( "and", PRED, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->And(lhs, rhs); + return And(lhs, rhs); }); } XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Or(lhs, rhs); - }); + return CreateScalarComputation("or", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, + const XlaOp& rhs) { return Or(lhs, rhs); }); } -StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { - auto f = builder->ConstantR0(false); - XlaComputation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); +XlaOp Any(XlaOp predicates) { + XlaBuilder* builder = predicates.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto f = ConstantR0(builder, false); + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return Reduce(predicates, f, logical_or, all_dimensions); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 64b6b7d63353165e45bf12d35126a7eeef9e56e4..d0b916e8c8f742406caad0571d6e99224ed81404 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -53,7 +53,7 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); +XlaOp Any(XlaOp predicates); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc new file mode 100644 index 0000000000000000000000000000000000000000..1686389a234659a433f1508bd3e0458793541e47 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/constants.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp Zero(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, Literal::Zero(type)); +} + +XlaOp Zeros(XlaBuilder* builder, const Shape& shape) { + return Broadcast(Zero(builder, shape.element_type()), + AsInt64Slice(shape.dimensions())); +} + +XlaOp ZerosLike(XlaOp prototype) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return Zeros(builder, shape); + }); +} + +XlaOp One(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, Literal::One(type)); +} + +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0( + builder, + static_cast(Eigen::NumTraits::epsilon())); + case BF16: + return ConstantR0(builder, bfloat16::epsilon()); + case F32: + return ConstantR0(builder, std::numeric_limits::epsilon()); + case F64: + return ConstantR0(builder, + std::numeric_limits::epsilon()); + default: + return builder->ReportError(InvalidArgument( + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + } +} + +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, Literal::MinValue(type)); +} + +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + Eigen::NumTraits::lowest()); + case BF16: + return ConstantR0(builder, bfloat16::lowest()); + case F32: + return ConstantR0(builder, -std::numeric_limits::max()); + case F64: + return ConstantR0(builder, -std::numeric_limits::max()); + default: + return MinValue(builder, type); + } +} + +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, Literal::MaxValue(type)); +} + +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + Eigen::NumTraits::highest()); + case BF16: + return ConstantR0(builder, bfloat16::highest()); + case F32: + return ConstantR0(builder, std::numeric_limits::max()); + case F64: + return ConstantR0(builder, std::numeric_limits::max()); + default: + return MaxValue(builder, type); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..b47f5243f008ecb2045456e4505d1a571fbed745 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is +// determined at C++ run-time, rather than C++ compile-time. +// If 'value' is floating point but 'type' is not, or if 'value' is complex but +// 'type' is not, an error will be returned. This is to catch accidental +// truncation; in such cases, use an explicit cast. +template +XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { + if (std::is_floating_point::value && + !(primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type))) { + return builder->ReportError(InvalidArgument( + "Invalid cast from floating point type to %s in ConstantR0WithType.", + PrimitiveType_Name(type).c_str())); + } + if (std::is_same::value && + !primitive_util::IsComplexType(type)) { + return builder->ReportError(InvalidArgument( + "Invalid cast from complex type to %s in ConstantR0WithType.", + PrimitiveType_Name(type).c_str())); + } + switch (type) { + case F16: + return ConstantR0(builder, static_cast(value)); + case BF16: + return ConstantR0(builder, static_cast(value)); + case F32: + return ConstantR0(builder, static_cast(value)); + case F64: + return ConstantR0(builder, static_cast(value)); + case C64: + return ConstantR0(builder, static_cast(value)); + case U8: + return ConstantR0(builder, static_cast(value)); + case U32: + return ConstantR0(builder, static_cast(value)); + case U64: + return ConstantR0(builder, static_cast(value)); + case S8: + return ConstantR0(builder, static_cast(value)); + case S32: + return ConstantR0(builder, static_cast(value)); + case S64: + return ConstantR0(builder, static_cast(value)); + default: + return builder->ReportError( + InvalidArgument("Invalid type for ConstantR0WithType (%s).", + PrimitiveType_Name(type).c_str())); + } +} + +// Returns a scalar containing 'value' cast to the same run-time type as +// 'prototype'. +// If 'value' is floating point but 'prototype' is not, or if 'value' is complex +// 'prototype' is not, an error will be returned. +template +XlaOp ScalarLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return ConstantR0WithType(builder, shape.element_type(), value); + }); +} + +// Returns a scalar with value '0' of 'type'. +XlaOp Zero(XlaBuilder* builder, PrimitiveType type); + +// Returns a zero-filled tensor with shape `shape`. +XlaOp Zeros(XlaBuilder* builder, const Shape& shape); + +// Returns a zero-filled tensor with the same shape as `prototype`. +XlaOp ZerosLike(XlaOp prototype); + +// Returns a scalar with value '1' of 'type'. +XlaOp One(XlaBuilder* builder, PrimitiveType type); + +// Returns the machine epsilon for floating-point type `type`, i.e., +// the difference between 1.0 and the next representable value. +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite or infinite value for 'type'. +// Returns '-inf' for floating-point types. +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite value for 'type'. For a floating +// point type, this is equal to -MaxFiniteValue(). +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite or infinite value for 'type'. +// Returns 'inf' for floating-point types. +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite value for 'type'. +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1e3439862344c01af15ec0571155ca46a579e54 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ConstantsTest = ClientLibraryTestBase; + +using ::testing::HasSubstr; + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::S32, 4); + ComputeAndCompareR0(&builder, 4, {}); +} + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::S32, 4.5); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast")); +} + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::F32, -7); + ComputeAndCompareR0(&builder, -7, {}); + ConstantR0WithType(&builder, xla::F32, 0.5); + ComputeAndCompareR0(&builder, 0.5, {}); +} + +XLA_TEST_F(ConstantsTest, ScalarLikeS32) { + XlaBuilder builder(TestName()); + ScalarLike(ConstantR0(&builder, 42), -3); + ComputeAndCompareR0(&builder, -3, {}); +} + +XLA_TEST_F(ConstantsTest, ScalarLikeF32) { + XlaBuilder builder(TestName()); + ScalarLike(ConstantR0(&builder, 42.75), -3.2); + ComputeAndCompareR0(&builder, -3.2, {}); +} + +XLA_TEST_F(ConstantsTest, ZeroS32) { + XlaBuilder builder(TestName()); + Zero(&builder, S32); + ComputeAndCompareR0(&builder, 0, {}); +} + +XLA_TEST_F(ConstantsTest, ZeroF32) { + XlaBuilder builder(TestName()); + Zero(&builder, F32); + ComputeAndCompareR0(&builder, 0.0, {}); +} + +XLA_TEST_F(ConstantsTest, ZerosS32) { + XlaBuilder builder(TestName()); + Zeros(&builder, ShapeUtil::MakeShape(S32, {2, 2})); + ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); +} + +XLA_TEST_F(ConstantsTest, ZerosLikeF32) { + XlaBuilder builder(TestName()); + ZerosLike(ConstantR1(&builder, {1., 2., 3.})); + ComputeAndCompareR1(&builder, {0., 0., 0.}, {}); +} + +XLA_TEST_F(ConstantsTest, OneS32) { + XlaBuilder builder(TestName()); + One(&builder, S32); + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ConstantsTest, OneF32) { + XlaBuilder builder(TestName()); + One(&builder, F32); + ComputeAndCompareR0(&builder, 1., {}); +} + +XLA_TEST_F(ConstantsTest, EpsilonF32) { + XlaBuilder builder(TestName()); + Epsilon(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::epsilon(), + {}); +} + +XLA_TEST_F(ConstantsTest, MinFiniteValueS32) { + XlaBuilder builder(TestName()); + MinFiniteValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::min(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxFiniteValueS32) { + XlaBuilder builder(TestName()); + MaxFiniteValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinFiniteValueF32) { + XlaBuilder builder(TestName()); + MinFiniteValue(&builder, F32); + ComputeAndCompareR0(&builder, -std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxFiniteValueF32) { + XlaBuilder builder(TestName()); + MaxFiniteValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinValueS32) { + XlaBuilder builder(TestName()); + MinValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::min(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxValueS32) { + XlaBuilder builder(TestName()); + MaxValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinValueF32) { + XlaBuilder builder(TestName()); + MinValue(&builder, F32); + ComputeAndCompareR0(&builder, -std::numeric_limits::infinity(), + {}); +} + +XLA_TEST_F(ConstantsTest, MaxValueF32) { + XlaBuilder builder(TestName()); + MaxValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::infinity(), + {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc new file mode 100644 index 0000000000000000000000000000000000000000..558755904007431cc0902d95a49627ea07f59127 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } + +XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } + +XlaOp Square(XlaOp operand) { return Pow(operand, ScalarLike(operand, 2.0)); } + +XlaOp Reciprocal(XlaOp operand) { + return Pow(operand, ScalarLike(operand, -1.0)); +} + +namespace { + +// Polynomials for computing erf/erfc. Originally from cephes. +// Note we use float for compatibility across devices, at the cost of some +// precision for 64 bit computations. +// +// Coefficients are in descending order. +std::array kErfcPCoefficient = { + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; +std::array kErfcQCoefficient = { + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; +std::array kErfcRCoefficient = { + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; +std::array kErfcSCoefficient = { + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; +std::array kErfTCoefficient = { + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; +std::array kErfUCoefficient = { + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; +} // namespace + +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients) { + XlaOp poly = ScalarLike(x, 0.0); + for (float c : coefficients) { + poly = poly * x + ScalarLike(x, c); + } + return poly; +} + +// Compute an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x) { + XlaOp abs_x = Abs(x); + XlaOp z = Exp(-x * x); + + XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); + XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); + XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); + XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); + + XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); + + return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +} + +// Compute a polynomial approximation of the error function. +XlaOp Erf(XlaOp x) { + XlaOp z = x * x; + XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); + XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); + return x * pt / pu; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +XlaOp ErfInv(XlaOp x) { + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, + Broadcast(ScalarLike(x, w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h new file mode 100644 index 0000000000000000000000000000000000000000..e7c8b50273067a979158f79aa80abc6058901040 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +namespace xla { + +// Computes the square root of 'operand'. +XlaOp Sqrt(XlaOp operand); + +// Computes the reciprocal of the square root of 'operand'. +XlaOp Rsqrt(XlaOp operand); + +// Computes the square of 'operand'. +XlaOp Square(XlaOp operand); + +// Computes the reciprocal of 'operand'. +XlaOp Reciprocal(XlaOp operand); + +// Evaluates a polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients); + +// Computes an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x); + +// Computes an approximation of the error function. +XlaOp Erf(XlaOp x); + +// Computes an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1df4e6ea42a2211c285075a3ed9159a9d603ccf5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MathTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(MathTest, SqrtF32) { + XlaBuilder builder(TestName()); + Literal zero_literal = Literal::Zero(PrimitiveType::F32); + + std::unique_ptr zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + +XLA_TEST_F(MathTest, SquareTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Square(x); + + std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, + 5.29, 25., 0.81, 5.76, 2.56}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, ReciprocalTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reciprocal(x); + + std::vector expected = { + 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, + 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtZeroes) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {0.0, -0.0}); + Sqrt(x); + + ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); + Sqrt(x); + + std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc new file mode 100644 index 0000000000000000000000000000000000000000..cbe9e7fdd1330164f1f9c4520c2bb81e38f4ceb9 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +#include +#include + +namespace xla { + +namespace { + +template +XlaOp MakeIota(XlaBuilder* builder, int64 size) { + std::vector values(size); + for (int64 i = 0; i < size; ++i) { + values[i] = static_cast(i); + } + return xla::ConstantR1(builder, values); +} + +} // namespace + +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + switch (type) { + case S8: + return MakeIota(builder, size); + case S16: + return MakeIota(builder, size); + case S32: + return MakeIota(builder, size); + case S64: + return MakeIota(builder, size); + case U8: + return MakeIota(builder, size); + case U16: + return MakeIota(builder, size); + case U32: + return MakeIota(builder, size); + case U64: + return MakeIota(builder, size); + case BF16: + return MakeIota(builder, size); + case F16: + return MakeIota(builder, size); + case F32: + return MakeIota(builder, size); + case F64: + return MakeIota(builder, size); + case C64: + return MakeIota(builder, size); + default: + return builder->ReportError( + InvalidArgument("Unimplemented type for Iota: %s.", + PrimitiveType_Name(type).c_str())); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h new file mode 100644 index 0000000000000000000000000000000000000000..2a409ae31147a4a88367422ce31c9fbcb22fdbca --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc8a73e9d793ef8f65c321759e03b0de75edd500 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using NumericTest = ClientLibraryTestBase; + +XLA_TEST_F(NumericTest, Iota) { + XlaBuilder builder(TestName()); + Iota(&builder, S32, 10); + + ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 9cd87f74735ff50df8a3382723c7d045ff6c9e52..731ad13b8d0e5d65acc316e72be9fe7d35e826a4 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -48,15 +48,15 @@ int64 DataSizeOfShape(const Shape& shape) { // Creates a XlaOp for an op what generates fake data with the given shape. XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { - return builder->Broadcast( - builder->ConstantLiteral(Literal::One(shape.element_type())), + return Broadcast( + ConstantLiteral(builder, Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } - return builder->Tuple(parts); + return Tuple(builder, parts); } std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, @@ -92,21 +92,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, return MakeFakeDataViaDeviceOrDie(shape, client); } -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); - - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; -} - std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { CHECK(computation.proto().has_program_shape()) diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 9e06141b1f13d24cd033b72e31ee3a0442fe6a37..dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -32,12 +32,6 @@ namespace xla { std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client); -// Returns vector of GlobalData handles of fake data (created using -// MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); - // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 9d44d3ad7d52b957aa3c76a588e3cbc07bb49f8b..5f9710914bd0ceff55f5b0a2db05e553ce8bd637 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& host_computation_layout = - executable_->module_config().host_entry_computation_layout(); - const ComputationLayout& device_computation_layout = - executable_->module_config().device_entry_computation_layout(); + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != host_computation_layout.parameter_count()) { + if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - host_computation_layout.parameter_count(), arguments.size()); - } - if (arguments.size() != device_computation_layout.parameter_count()) { - return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", - device_computation_layout.parameter_count(), arguments.size()); + computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, @@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( - host_computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } - if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->on_device_shape())) { - return InvalidParameterArgument( - executable_.get(), i, - "Argument does not match device shape or layout of computation " - "parameter " - "%d: want %s, got %s", - i, - ShapeUtil::HumanString( - device_computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); - } } if (run_options.stream() != nullptr) { @@ -185,7 +164,7 @@ StatusOr LocalExecutable::Run( run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); - if (executable_->dumping()) { + if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); } return executable_->ExecuteOnStreamWrapper( @@ -195,45 +174,44 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { - executable_->session_module()->set_execution_platform( + executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module())); - TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); return std::move(result); } Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module) { - session_module->clear_arguments(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*argument)); - *session_module->add_arguments() = literal->ToProto(); + *hlo_snapshot->add_arguments() = literal->ToProto(); } return Status::OK(); } Status LocalExecutable::RecordResult(const ShapedBuffer* result, - SessionModule* session_module) { - session_module->clear_result(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); - *session_module->mutable_result() = literal->ToProto(); + *hlo_snapshot->mutable_result() = literal->ToProto(); return Status::OK(); } StatusOr> LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, + backend_->BorrowStream(shaped_buffer.device_ordinal())); + return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } @@ -261,25 +239,6 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -StatusOr> LocalClient::Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options) { - ExecutableBuildOptions updated_options = options; - if (options.device_ordinal() == -1) { - updated_options.set_device_ordinal(default_device_ordinal()); - VLOG(3) << "Set device ordinal to default value of: " - << updated_options.device_ordinal(); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); -} - StatusOr> LocalClient::Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -307,22 +266,26 @@ StatusOr LocalClient::LiteralToShapedBuffer( TF_ASSIGN_OR_RETURN(auto scoped_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), allocator, device_ordinal)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend().stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, literal, scoped_buffer)); + stream.get(), literal, scoped_buffer)); return std::move(scoped_buffer); } StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend().stream_executor(shaped_buffer.device_ordinal())); - return backend().transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( + shaped_buffer.device_ordinal())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } +StatusOr LocalClient::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + return local_service_->GlobalDataToShapedBuffer(data, replica_number); +} + Status LocalClient::TransferToInfeedLocal(const Literal& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 31950377f4c70c0cf63e8a726d9366d0c0ee4fb4..4d9e0d7cd9d6ddebead1e12b23e94b529038039b 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,13 +19,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -59,12 +59,18 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. + // + // The given ExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + // + // The given ServiceExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); @@ -73,11 +79,10 @@ class LocalExecutable { // proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module); + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. - Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -108,17 +113,11 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. - StatusOr> Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. + // The given ExecutableBuildOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -137,6 +136,11 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + // Transfer the given literal to the infeed queue of the given device. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 0d6e207971ec64515ec5e6da292910920edd101a..ee00a9eada8dd906c26e07a4affccdaf544f1693 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -1,7 +1,5 @@ # Description: # The new XLA client libraries. -# -# This is NOT YET ready to use. licenses(["notice"]) # Apache 2.0 @@ -37,11 +35,11 @@ cc_library( ], ) -# TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], deps = [ ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", @@ -53,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index ae506317c2e4862d77cb4f0628e919871ad1aeb2..12efcb4b4f787da9a2fd694b4ee09dd490a68a52 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -47,6 +48,7 @@ int64 GetUniqueId() { // computation. bool CanBeRoot(HloOpcode opcode) { switch (opcode) { + case HloOpcode::kAfterAll: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kOutfeed: @@ -59,6 +61,36 @@ bool CanBeRoot(HloOpcode opcode) { } // namespace +XlaOp operator-(const XlaOp& x) { return Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } +XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } +XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } +XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } +XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } + +XlaOp operator~(const XlaOp& x) { return Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } +XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } +XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); } + +XlaOp operator>>(const XlaOp& x, const XlaOp& y) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Argument to >> operator does not have an integral type (%s).", + ShapeUtil::HumanString(shape).c_str()); + } + if (ShapeUtil::ElementIsSigned(shape)) { + return ShiftRightArithmetic(x, y); + } else { + return ShiftRightLogical(x, y); + } + }); +} + StatusOr XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); @@ -81,7 +113,7 @@ XlaBuilder::XlaBuilder(const string& computation_name) XlaBuilder::~XlaBuilder() {} -void XlaBuilder::NoteError(const Status& error) { +XlaOp XlaBuilder::ReportError(const Status& error) { CHECK(!error.ok()); if (die_immediately_on_error_) { LOG(FATAL) << "error building computation: " << error; @@ -91,19 +123,22 @@ void XlaBuilder::NoteError(const Status& error) { first_error_ = error; first_error_backtrace_.CreateCurrent(/*skip_count=*/1); } + return XlaOp(this); } -XlaOp XlaBuilder::NoteErrorOrReturn( - const std::function()>& op_creator) { +XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr& op) { if (!first_error_.ok()) { - return {}; + return XlaOp(this); } - auto op = op_creator(); if (!op.ok()) { - NoteError(op.status()); - return {}; + return ReportError(op.status()); } - return op.ConsumeValueOrDie(); + return op.ValueOrDie(); +} + +XlaOp XlaBuilder::ReportErrorOrReturn( + const std::function()>& op_creator) { + return ReportErrorOrReturn(op_creator()); } StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { @@ -207,7 +242,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); if (!build_status.ok()) { - parent_builder_->NoteError( + parent_builder_->ReportError( AddStatus(build_status.status(), tensorflow::strings::StrCat("error from: ", name_))); return {}; @@ -315,7 +350,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -327,7 +362,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { XlaOp XlaBuilder::BinaryOp( HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -383,7 +418,7 @@ XlaOp XlaBuilder::BinaryOp( XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -430,7 +465,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); *instr.mutable_literal() = literal.ToProto(); @@ -440,7 +475,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice operands) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); @@ -461,7 +496,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { return InvalidArgument("parameter %lld already registered", @@ -476,7 +511,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, @@ -498,6 +533,14 @@ XlaOp XlaBuilder::Broadcast( }); } +XlaOp XlaBuilder::BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions) { + return ReportErrorOrReturn([&]() -> StatusOr { + return InDimBroadcast(shape, operand, broadcast_dimensions); + }); +} + StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); @@ -510,7 +553,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -530,7 +573,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); std::vector starts(ShapeUtil::Rank(shape), 0); std::vector limits(shape.dimensions().begin(), @@ -545,7 +588,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -566,7 +609,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -584,7 +627,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, int64 dimension) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -603,7 +646,7 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -624,7 +667,7 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( @@ -638,7 +681,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -648,7 +691,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Collapse(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. @@ -690,7 +733,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); @@ -700,11 +743,19 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { - return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); + TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); + TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == + ShapeUtil::IsTuple(false_shape)); + HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect + : HloOpcode::kSelect; + return TernaryOp(opcode, pred, on_true, on_false); + }); } XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); @@ -718,7 +769,7 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { } XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); if (!ShapeUtil::IsTuple(tuple_shape)) { @@ -767,7 +818,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); DotDimensionNumbers dimension_numbers; @@ -780,7 +831,7 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -859,7 +910,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -905,7 +956,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -992,7 +1043,7 @@ StatusOr XlaBuilder::MakeWindow( XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1009,23 +1060,69 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, } XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } - *instr.mutable_shape() = shape; + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; instr.set_infeed_config(config); - return AddInstruction(std::move(instr), HloOpcode::kInfeed); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + // The sharding is set by the client according to the data tuple shape. + // However, the shape of the infeed instruction is a tuple containing the + // data and a token. For tuple sharding type, the sharding must be changed + // to accommodate the token. + XlaOp infeed; + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) { + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + OpSharding infeed_instruction_sharding = *sharding(); + // Arbitrarily assign the token to device 0. + *infeed_instruction_sharding.add_tuple_shardings() = + sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, + infeed_instruction_sharding); + TF_ASSIGN_OR_RETURN(infeed, + AddInstruction(std::move(instr), HloOpcode::kInfeed)); + } else { + TF_ASSIGN_OR_RETURN(infeed, + AddInstruction(std::move(instr), HloOpcode::kInfeed)); + } + + // The infeed instruction produces a tuple of the infed data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto infeed_data; + *infeed_data.mutable_shape() = shape; + infeed_data.set_tuple_index(0); + return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, + {infeed}); }); } void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1042,14 +1139,33 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, instr.set_outfeed_config(outfeed_config); - return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + TF_RETURN_IF_ERROR( + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + .status()); + + // The outfeed instruction produces a token. However, existing users expect + // a nil shape (empty tuple). This should only be relevant if the outfeed is + // the root of a computation. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto tuple_instr; + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + + // The dummy tuple should have no sharding. + { + XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); + TF_ASSIGN_OR_RETURN( + XlaOp empty_tuple, + AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {})); + return empty_tuple; + } }); } XlaOp XlaBuilder::CustomCall(const string& call_target_name, tensorflow::gtl::ArraySlice operands, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (tensorflow::str_util::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1066,7 +1182,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape; instr.set_channel_name(channel_name); @@ -1120,11 +1236,9 @@ XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } -// TODO(b/65209188): Create a dedicated lowering for Xor. XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); + return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Not(const XlaOp& operand) { @@ -1223,7 +1337,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1238,7 +1352,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, XlaOp XlaBuilder::Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1251,13 +1365,25 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSort, operand); -} - -XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); +XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + operand_shape_ptrs.push_back(&keys_shape); + Shape values_shape; + if (values.has_value()) { + TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); + operand_shape_ptrs.push_back(&values_shape); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + return values.has_value() + ? AddInstruction(std::move(instr), HloOpcode::kSort, + {keys, *values}) + : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + }); } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, @@ -1267,7 +1393,7 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1279,7 +1405,7 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1290,16 +1416,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnaryOp(HloOpcode::kNegate, operand); } @@ -1313,13 +1429,12 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, const XlaComputation& computation, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); } HloInstructionProto instr; - std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1331,16 +1446,32 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, dimensions)); + const Shape& output_shape = instr.shape(); + const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); + std::vector new_operands(operands.begin(), operands.end()); + for (XlaOp& new_operand : new_operands) { + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); + const int64 rank = ShapeUtil::Rank(shape); + if (rank != output_rank) { + TF_ASSIGN_OR_RETURN(new_operand, + InDimBroadcast(output_shape, new_operand, {})); + TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand)); + } + if (!ShapeUtil::SameDimensions(output_shape, shape)) { + TF_ASSIGN_OR_RETURN(new_operand, + AddBroadcastSequence(output_shape, new_operand)); + } + } - return AddInstruction(std::move(instr), HloOpcode::kMap, operands); + return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands); }); } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Check the number of parameters per RNG distribution. @@ -1378,7 +1509,7 @@ XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Infer shape. @@ -1400,7 +1531,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); @@ -1425,7 +1556,7 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); @@ -1457,13 +1588,14 @@ XlaOp XlaBuilder::Reduce( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( operand_shape, init_shape, dimensions_to_reduce, @@ -1482,7 +1614,7 @@ XlaOp XlaBuilder::Reduce( XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); @@ -1495,7 +1627,7 @@ XlaOp XlaBuilder::ReduceWindow( const XlaComputation& computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1518,7 +1650,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1542,7 +1674,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1565,7 +1697,7 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1590,7 +1722,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1611,14 +1743,40 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, }); } -XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { - return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, replica_group_ids, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (channel_id.has_value()) { + return Unimplemented("channel_id is not supported in AllReduce"); + } + + HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + for (int64 replica_group_id : replica_group_ids) { + instr.add_replica_group_ids(replica_group_id); + } + + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, {operand}); @@ -1631,7 +1789,7 @@ XlaOp XlaBuilder::SelectAndScatter( tensorflow::gtl::ArraySlice window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, @@ -1648,7 +1806,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1676,7 +1834,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -1690,20 +1848,29 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, } void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { - NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - // Send instruction produces a tuple of {aliased operand, U32 context}. + ReportErrorOrReturn([&]() -> StatusOr { + // Send HLO takes two operands: a data operand and a token. Generate the + // token to pass into the send. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN( - XlaOp send, - AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -1711,21 +1878,42 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { } XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - // Recv instruction produces a tuple of {receive buffer, U32 context}. - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv, - AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + return ReportErrorOrReturn([&]() -> StatusOr { + // Recv HLO takes a single token operand. Generate the token to pass into + // the Recv and RecvDone instructions. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; - *recv_done_instr.mutable_shape() = shape; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); - return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, - {recv}); + TF_ASSIGN_OR_RETURN(XlaOp recv_done, + AddInstruction(std::move(recv_done_instr), + HloOpcode::kRecvDone, {recv})); + + // The RecvDone instruction produces a tuple of the data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto recv_data; + *recv_data.mutable_shape() = shape; + recv_data.set_tuple_index(0); + return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, + {recv_done}); }); } @@ -1966,9 +2154,497 @@ StatusOr XlaBuilder::LookUpInstruction( return &instructions_[op.handle()]; } -XlaOp XlaBuilder::UnimplementedOp() { - NoteError(Unimplemented("Op not implemented")); - return {}; +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name) { + return builder->Parameter(parameter_number, shape, name); +} + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { + return builder->ConstantLiteral(literal); +} + +XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + return operand.builder()->Broadcast(operand, broadcast_sizes); +} + +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions) { + return operand.builder()->BroadcastInDim(operand, shape, + broadcast_dimensions); +} + +XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return operand.builder()->Pad(operand, padding_value, padding_config); +} + +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return operand.builder()->Reshape(operand, dimensions, new_sizes); +} + +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return operand.builder()->Reshape(operand, new_sizes); +} + +XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return operand.builder()->Collapse(operand, dimensions); +} + +XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return operand.builder()->Slice(operand, start_indices, limit_indices, + strides); +} + +XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { + return operand.builder()->SliceInDim(operand, start_index, limit_index, + stride, dimno); +} + +XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} + +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + +XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return builder->ConcatInDim(operands, dimension); +} + +void Trace(const string& tag, const XlaOp& operand) { + return operand.builder()->Trace(tag, operand); +} + +XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { + return pred.builder()->Select(pred, on_true, on_false); +} + +XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { + return builder->Tuple(elements); +} + +XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { + return tuple_data.builder()->GetTupleElement(tuple_data, index); +} + +XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); +} + +XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); +} + +XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); +} + +XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); +} + +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); +} + +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); +} + +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) { + return lhs.builder()->Dot(lhs, rhs); +} + +XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers); +} + +XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding); +} + +XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, + padding); +} + +XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, + padding, dimension_numbers); +} + +XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, + dimension_numbers); +} + +XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers); +} + +XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) { + return operand.builder()->Fft(operand, fft_type, fft_length); +} + +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { + return builder->Infeed(shape, config); +} + +void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config); +} + +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + return builder->Call(computation, operands); +} + +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return builder->CustomCall(call_target_name, operands, shape); +} + +XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape) { + return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape); +} + +XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return real.builder()->Complex(real, imag, broadcast_dimensions); +} + +XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } + +XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); +} + +XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); +} + +XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); +} + +XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); +} + +XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); +} + +XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); +} + +XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); +} + +XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->And(lhs, rhs, broadcast_dimensions); +} + +XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); +} + +XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); +} + +XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } + +XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); +} + +XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); +} + +XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); +} + +XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return operand.builder()->Reduce(operand, init_value, computation, + dimensions_to_reduce); +} + +XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return operand.builder()->ReduceAll(operand, init_value, computation); +} + +XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return operand.builder()->ReduceWindow(operand, init_value, computation, + window_dimensions, window_strides, + padding); +} + +XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return operand.builder()->ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + padding); +} + +XlaOp CrossReplicaSum(const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { + return operand.builder()->CrossReplicaSum(operand, replica_group_ids); +} + +XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return operand.builder()->CrossReplicaSum(operand, computation, + replica_group_ids, channel_id); +} + +XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { + return operand.builder()->SelectAndScatter(operand, select, window_dimensions, + window_strides, padding, source, + init_value, scatter); +} + +XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return operand.builder()->SelectAndScatterWithGeneralPadding( + operand, select, window_dimensions, window_strides, padding, source, + init_value, scatter); +} + +XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } + +XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return y.builder()->Atan2(y, x, broadcast_dimensions); +} + +XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); } + +XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); } + +XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); } + +XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); } + +XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); } + +XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); } + +XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); } + +XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); } + +XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); } + +XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); } + +XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); } + +XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); } + +XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } + +XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } + +XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); +} + +XlaOp IsFinite(const XlaOp& operand) { + return operand.builder()->IsFinite(operand); +} + +XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { + return operand.builder()->ConvertElementType(operand, new_element_type); +} + +XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { + return operand.builder()->BitcastConvertType(operand, new_element_type); +} + +XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } + +XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return operand.builder()->Transpose(operand, permutation); +} + +XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return operand.builder()->Rev(operand, dimensions); +} + +XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values) { + return keys.builder()->Sort(keys, std::move(values)); +} + +XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { + return min.builder()->Clamp(min, operand, max); +} + +XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return builder->Map(operands, computation, dimensions, static_operands); +} + +XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) { + return mu.builder()->RngNormal(mu, sigma, shape); +} + +XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) { + return a.builder()->RngUniform(a, b, shape); +} + +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init) { + return init.builder()->While(condition, body, init); +} + +XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return predicate.builder()->Conditional(predicate, true_operand, + true_computation, false_operand, + false_computation); +} + +XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return operand.builder()->ReducePrecision(operand, exponent_bits, + mantissa_bits); +} + +XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return input.builder()->Gather(input, gather_indices, dimension_numbers, + window_bounds); +} + +void Send(const XlaOp& operand, const ChannelHandle& handle) { + return operand.builder()->Send(operand, handle); +} + +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle) { + return builder->Recv(shape, handle); +} + +XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon, + feature_index); +} + +XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormInference( + operand, scale, offset, mean, variance, epsilon, feature_index); +} + +XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var, + grad_output, epsilon, feature_index); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index d802e43bc63670dc817f105a25098baf6ceafcb9..274aba8a31072db1e821b1834178a85288d64521 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -13,15 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(b/74197823): Replace computation_builder.h with this file. -// -// This is NOT YET ready to use. - #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #include #include +#include #include #include "tensorflow/compiler/xla/client/padding.h" @@ -48,21 +45,25 @@ class XlaBuilder; // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the // instruction as an operand. -// -// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: - XlaOp() : handle_(0), builder_(nullptr) {} - ~XlaOp() {} + XlaOp() : handle_(-1), builder_(nullptr) { + static_assert(std::is_trivially_destructible::value, + "XlaOp should be trivially destructible"); + } + ~XlaOp() = default; - const XlaBuilder* builder() const { return builder_; } + XlaBuilder* builder() const { return builder_; } - bool operator==(const XlaOp& rhs) const { - return handle_ == rhs.handle_ && builder_ == rhs.builder_; - } + // Returns true if the XlaOp represents valid, non-erroneous value. + bool valid() const { return handle_ >= 0; } + + // Returns true if the XlaOp was created by the XlaOp() constructor and + // not returned by a builder. + bool IsUninitialized() const { return builder_ == nullptr; } - bool operator!=(const XlaOp& rhs) const { - return handle_ != rhs.handle_ || builder_ != rhs.builder_; + bool IsIdenticalTo(const XlaOp& rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; } friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { @@ -71,6 +72,7 @@ class XlaOp { } private: + explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} XlaOp(int64 handle, XlaBuilder* builder) : handle_(handle), builder_(builder) {} @@ -78,15 +80,41 @@ class XlaOp { friend class XlaBuilder; + // < 0 means "invalid handle". int64 handle_; - XlaBuilder* builder_; // Not owned. + + // Not owned. Non-null for any handle returned by XlaBuilder, even if the + // handle is invalid. + XlaBuilder* builder_; }; +// Arithmetic operator overloads for the XlaOp type. +XlaOp operator-(const XlaOp& x); +XlaOp operator+(const XlaOp& x, const XlaOp& y); +XlaOp operator-(const XlaOp& x, const XlaOp& y); +XlaOp operator*(const XlaOp& x, const XlaOp& y); +XlaOp operator/(const XlaOp& x, const XlaOp& y); +XlaOp operator%(const XlaOp& x, const XlaOp& y); + +// Bitwise operator overloads for the XlaOp type. +XlaOp operator~(const XlaOp& x); +XlaOp operator&(const XlaOp& x, const XlaOp& y); +XlaOp operator|(const XlaOp& x, const XlaOp& y); +XlaOp operator^(const XlaOp& x, const XlaOp& y); +XlaOp operator<<(const XlaOp& x, const XlaOp& y); +// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs +// a right logical shift. +XlaOp operator>>(const XlaOp& x, const XlaOp& y); + +// We don't overload the relational operators (==, !=, <, <=, >, >=) because the +// semantics might be surprising since their result types are usually 'bool'. +// Further programmers may expect == to be a structural equality. +// We also choose not to overload any of the mutating operators (e.g., +=, -=) +// because the semantics might be misleading — XLA computations are immutable. + // A convenient interface for building up computations. // // Thread-compatible. -// -// TODO(b/74197823): Replace xla::ComputationBuilder with this one. class XlaBuilder { public: // computation_name: name to use for the built computation. @@ -130,6 +158,93 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. + StatusOr Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr GetProgramShape() const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr& op); + + // A helper function that runs a function that returns a StatusOr and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function()>& op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; + + private: // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -202,6 +317,27 @@ class XlaBuilder { XlaOp Broadcast(const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + // Performs in-dimension-style broadcast. + // + // Operand specifies the input to be broadcast. "shape" is expected output + // shape. "broadcast_dimensions" are the dimensions to be broadcasting into. + // Dimension numbers in broadcast_dimensions map to individual dimensions + // of the operand, and specify what dimension of the output shape they + // should be broadcast. + // e.g. + // Say operand = [1, 2], i.e., a 1D tensor with 2 elements. + // and dimension of shape is [2,2]. + // Specifying {1} as brodcast_dimension will generate output + // [1 , 2] + // [1 , 2] + // On the other hand, specifying {0} as broadcast_dimension + // will generate output + // [1 , 1] + // [2 , 2] + XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config // specifies the padding amount for each dimension. @@ -350,26 +486,6 @@ class XlaBuilder { XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers); - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static Status Validate(const ConvolutionDimensionNumbers& dnum); - // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -536,9 +652,35 @@ class XlaBuilder { tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding); - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - XlaOp CrossReplicaSum(const XlaOp& operand); + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Enqueues an operation that scatters the `source` array to the selected // indices of each window. @@ -609,16 +751,6 @@ class XlaBuilder { // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - XlaOp SqrtF32(const XlaOp& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); @@ -641,14 +773,6 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); @@ -663,7 +787,18 @@ class XlaBuilder { tensorflow::gtl::ArraySlice dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. - XlaOp Sort(const XlaOp& operand); + // If only keys are provided: + // * The keys must be a rank-1 tensor (i.e. an array). + // * The result is a sorted array of keys. + // + // If both keys and values are provided: + // * The keys and the values must be rank-1 tensors with the same dimensions. + // The element types of the tensors may be different. + // * The result is a tuple that consists of a sorted array of keys as the + // first element, and an array with their corresponding values as the second + // element. + XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values = + tensorflow::gtl::nullopt); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -710,14 +845,6 @@ class XlaBuilder { // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - StatusOr IsConstant(const XlaOp& operand) const; - // Normalizes operand across spatial and batch dimensions for each feature. // // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` @@ -756,47 +883,6 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr CreateSubBuilder(const string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - // Returns the shape of the given op. - StatusOr GetShape(const XlaOp& op) const; - - // Returns the (inferred) result for the current computation's shape. - StatusOr GetProgramShape() const; - - private: StatusOr AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice operands = {}); @@ -804,17 +890,6 @@ class XlaBuilder { void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - XlaOp NoteErrorOrReturn(const std::function()>& op_creator); - - // Helper method that creates an empty op and notes error. - XlaOp UnimplementedOp(); - StatusOr LookUpInstruction(const XlaOp& op) const; // Internal helper method that does the building for an arbitrary unary op. @@ -910,8 +985,962 @@ class XlaBuilder { bool die_immediately_on_error_ = false; XlaBuilder* parent_builder_{nullptr}; + + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, + const Shape& shape, const string& name); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + template + friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); + template + friend XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); + friend XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values); + template + friend XlaOp ConstantR2( + XlaBuilder* builder, + std::initializer_list> values); + template + friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); + template + friend XlaOp ConstantFromArray(XlaBuilder* builder, + const Array& values); + template + friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); + template + friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); + template + friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); + template + friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); + template + friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); + template + friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + + template + friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + + friend XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + friend XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + + friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + friend XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); + + friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + friend void Trace(const string& tag, const XlaOp& operand); + + friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false); + friend XlaOp Tuple(XlaBuilder* builder, + tensorflow::gtl::ArraySlice elements); + friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + friend XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config); + friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + friend XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Conj(const XlaOp& operand); + friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Not(const XlaOp& operand); + friend XlaOp ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id); + friend XlaOp SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp Abs(const XlaOp& operand); + friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Exp(const XlaOp& operand); + friend XlaOp Expm1(const XlaOp& operand); + friend XlaOp Floor(const XlaOp& operand); + friend XlaOp Ceil(const XlaOp& operand); + friend XlaOp Round(const XlaOp& operand); + friend XlaOp Log(const XlaOp& operand); + friend XlaOp Log1p(const XlaOp& operand); + friend XlaOp Sign(const XlaOp& operand); + friend XlaOp Clz(const XlaOp& operand); + friend XlaOp Cos(const XlaOp& operand); + friend XlaOp Sin(const XlaOp& operand); + friend XlaOp Tanh(const XlaOp& operand); + friend XlaOp Real(const XlaOp& operand); + friend XlaOp Imag(const XlaOp& operand); + friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp IsFinite(const XlaOp& operand); + friend XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp Neg(const XlaOp& operand); + friend XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + friend XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values); + friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Map(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape); + friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init); + friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); }; +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; +}; + +// Free functions for building XlaOps. The intention is that these will +// become the public API for building XlaOps rather than calling methods on +// XlaBuilder directly. + +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name); + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); + +// Enqueues a constant onto the computation. Methods are templated on the +// native host type (NativeT) which corresponds to a specific XLA +// PrimitiveType as given in the following table: +// +// Native Type PrimitiveType +// ----------------------------- +// bool PRED +// int32 S32 +// int64 S64 +// uint32 U32 +// uint64 U64 +// float F32 +// double F64 +// +// Note: not all primitive types defined in xla_data.proto have a +// corresponding native type yet. +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value); +template +XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values); +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + +// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the +// computation. The vector has size 'length' and every element has the value +// 'value'. +template +XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + +// Adds dimensions to an array by duplicating the data in the array. +// +// The new dimensions are inserted on the left, i.e. if +// broadcast_sizes has values {a0, ..., aN} and the operand shape +// has dimensions {b0, ..., bM} then the shape of the output has +// dimensions {a0, ..., aN, b0, ..., bM}. +// +// The new dimensions index into copies of the operand, i.e. +// +// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] +XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + +// Performs in-dimension-style broadcast. +// +// Operand specifies the input to be broadcast. "shape" is expected output +// shape. "broadcast_dimensions" are the dimensions to be broadcasting into. +// Dimension numbers in broadcast_dimensions map to individual dimensions +// of the operand, and specify what dimension of the output shape they +// should be broadcast. +// e.g. +// Say operand = [1, 2], i.e., a 1D tensor with 2 elements. +// and dimension of shape is [2,2]. +// Specifying {1} as brodcast_dimension will generate output +// [1 , 2] +// [1 , 2] +// On the other hand, specifying {0} as broadcast_dimension +// will generate output +// [1 , 1] +// [2 , 2] +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + +// Enqueues a pad operation onto the computation that pads the given value on +// the edges as well as between the elements of the input. padding_config +// specifies the padding amount for each dimension. +XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + +// Enqueues an operation onto the computation that flattens the operand based +// on the dimension order (major/slowest-varying to minor/fastest-varying) +// given, followed by reshaping it into the shape with the given dimension +// sizes (also major to minor). Conceptually, this is a limited form of +// "shape casting". +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + +// Enqueues an operation onto the computation that collapses the operand, from +// first to last dimension (C order), then reshapes it to the given dimension +// sizes. Conceptually, this is a limited form of "shape casting". +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + +// Wrapper for Reshape. +// Enqueues an operation to collapse the provided dimensions; e.g. an +// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to +// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must +// be a consecutive, in-order subsequence of the operand dimensions. +// +// Note that collapsing a single dimension does nothing: +// +// {256} collapsing {0} => {256} +// {1} collapsing {0} => {1} +// +// Collapsing multiple dimensions produces a single result dimension: +// +// {256, 2} collapsing {0,1} => {512} +// {256, 2, 3} collapsing {0,1} => {512, 3} +// +// This could potentially cause data to be moved -- it provides a more +// structured form of reshaping than an arbitrary Reshape operation. +XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + +// Enqueues a slice operation onto the computation that slices the operand +// from the start indices to the limit indices; e.g. +// +// x +// [ 0 1 2 3 ] +// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] +// [ 8 9 a b ] +// +// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D +// range notation. +// The strides parameter determines the stride over the slice +XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + +// Enqueues a slice operation in a given dimension, taking all other +// dimensions as they are; e.g. if dimno is 1 from start_index 2 to +// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand +// for: +// +// array[:, 2:4:1, :] +XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + +// Enqueues a slice operation onto the computation that slices the 'operand' +// from dynamic start indices which are passed in 'start_indices'. +// The size of the slice in each dimension is passed in 'slice_sizes', +// which specify the end point of exclusive slice intervals in each +// dimension [start, start + size). +// The shape of 'start_indices' must be rank == 1, with dimension size +// equal to the rank of the 'operand'. +// Slice index calculations are computed modulo input dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + +// Enqueues a dynamic update slice operation onto the computation, which +// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. +// The shape of 'update' determines the shape of the slice of 'operand' +// which is updated. +// The indices specified in 'start_indices' specify the offset of the slice +// of 'operand' which is updated. +// +// update = {10, 11} // calculated at runtime. +// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] +// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] +// [7 8 9] [7 8 9 ] +// +// The shape of 'start_indices' must be rank == 1, with dimension size +// equal to the rank of the 'operand'. +// Slice index calculations are computed modulo update dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + +// Enqueues a concatenate instruction onto the computation. 'operands' must +// have >= 1 entry. +XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, int64 dimension); + +// Enqueue a tracing operation onto the computation; the computation will emit +// a logging message with the operand. +void Trace(const string& tag, const XlaOp& operand); + +// Enqueues a conditional-move-like select operation onto the computation; +// predicated on pred, selects between on_true and on_false. +XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + +// Enqueues a tuple-creation instruction onto the computation. +XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); + +// Enqueues a tuple-element-get instruction onto the computation. +XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + +// Enqueues an equal-to comparison instruction onto the computation. +XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a not-equal comparison instruction onto the computation. +XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a greater-or-equal comparison instruction onto the computation. +XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a greater-than comparison instruction onto the computation. +XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a less-than comparison instruction onto the computation. +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a less-or-equal comparison instruction onto the computation. +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a dot instruction onto the computation. +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + +// Enqueues a general dot instruction onto the computation. +XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, which uses the +// default convolution dimension numbers. +XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration in the format returned by MakePadding(). +XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided dimension numbers configuration. +XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration as well as the dimension numbers. +XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration, dilation factors and dimension numbers. +XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues an FFT instruction onto the computation, of the given type and +// with the given FFT length. +XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + +// Enqueues an infeed instruction onto the computation, which writes data of +// the given shape to the infeed buffer of the device. +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config = ""); + +// Enqueues an outfeed instruction onto the computation. This instruction +// generates outgoing data transfers for the given data. +// +// shape_with_layout communicates the laid out shape that we want to outfeed +// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error +// will occur. +void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + +// Enqueues a call instruction onto the computation. +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + +// Enqueues a custom call instruction onto the computation. +// During code generation, a call instruction is emitted which targets a +// symbol with the name |call_target_name|. The |operands| are passed to the +// call instruction. |shape| is the resultant shape. +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + +// Enqueues a pseudo-op to represent host-side computation data-dependencies. +// During code generation, host send and receive operations will be generated +// to transfer |operands| to the host and a single result of |shape| back to +// the device. Host send/recv operations are emitted using |channel_name|. +// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO +// instruction scheduling. +XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + +// The following methods enqueue element-wise binary arithmetic operations +// onto the computation. The shapes of the operands have to match unless one +// of the operands is a scalar, or an explicit broadcast dimension is given +// (see g3doc for more details). + +// Enqueues a complex compose instruction onto the computation. +XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a complex conjugate instruction onto the computation. +XlaOp Conj(const XlaOp& operand); + +// Enqueues an add instruction onto the computation. +XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a subtract instruction onto the computation. +XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a multiply instruction onto the computation. +XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a divide instruction onto the computation. +XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a remainder instruction onto the computation. +XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a max instruction onto the computation. +XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a min instruction onto the computation. +XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Element-wise logical operators +XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Not(const XlaOp& operand); + +XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); +XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Reduces an array among the provided dimensions, given "computation" as a +// reduction operator. +XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + +// Convenience wrapper around the above that reduces all the dimensions in the +// operand shape. +XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + +// Enqueues a windowed reduce instruction onto the computation. +XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + +// As ReduceWindow(), but the padding is given in the format +// returned by MakePadding(). +XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + +// Returns the sum of the operand value within each subgroup of replicas. All +// replicas supply one input to the sum and all replicas receive the resulting +// sum for each subgroup. +XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + +// Enqueues an operation that do an AllReduce of the operand cross cores. Here +// AllReduce means doing a reduction on the input operand cross cores and then +// broadcasting the reduction result to those cores. The reduction function is +// defined by `computation`, which should be a commutative computation on +// scalars, e.g., add, min, or max. The way that AllReduce is applied is +// configured by: +// +// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all +// replicas belong to one group. Allreduce will be applied within subgroups. +// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, +// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// +// - `channel_id`: for Allreduce nodes from different models, if they have the +// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be +// applied cross models. +// +// TODO(b/79737069): Rename this to AllReduce when it's ready to use. +XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& + channel_id = tensorflow::gtl::nullopt); + +// Enqueues an operation that scatters the `source` array to the selected +// indices of each window. +XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); + +// As SelectAndScatter(), but the padding is given in the format +// returned by MakePadding(). +XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + +// Enqueues an abs instruction onto the computation. +XlaOp Abs(const XlaOp& operand); + +// Enqueues a atan2 instruction onto the computation. +XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues an exp instruction onto the computation. +XlaOp Exp(const XlaOp& operand); + +// Enqueues an expm1 instruction onto the computation. +XlaOp Expm1(const XlaOp& operand); + +// Enqueues a floor instruction onto the computation. +XlaOp Floor(const XlaOp& operand); + +// Enqueues a ceil instruction onto the computation. +XlaOp Ceil(const XlaOp& operand); + +// Enqueues a round instruction onto the computation, rounding to nearest even +// with half-way cases rounding away from zero. +XlaOp Round(const XlaOp& operand); + +// Enqueues an log instruction (natural logarithm) onto the computation. +XlaOp Log(const XlaOp& operand); + +// Enqueues an log1p instruction (log(x+1)) onto the computation. +XlaOp Log1p(const XlaOp& operand); + +// Enqueues a sign instruction onto the computation. +XlaOp Sign(const XlaOp& operand); + +// Enqueues a count leading zeros instruction onto the computation. +XlaOp Clz(const XlaOp& operand); + +// Enqueues a cosine instruction onto the computation. +XlaOp Cos(const XlaOp& operand); + +// Enqueues a sine instruction onto the computation. +XlaOp Sin(const XlaOp& operand); + +// Enqueues a tanh instruction onto the computation. +XlaOp Tanh(const XlaOp& operand); + +// Enqueues a real-part instruction onto the computation. +XlaOp Real(const XlaOp& operand); + +// Enqueues an imaginary-part instruction onto the computation. +XlaOp Imag(const XlaOp& operand); + +// Enqueues a lhs^rhs computation onto the computation. +XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues an operator that tests if the operand's values are finite, i.e., +// not Inf or NaN. Defined only for floating-point types. Returns an array of +// booleans with the same shape where entries are true iff the corresponding +// entry was NaN. +XlaOp IsFinite(const XlaOp& operand); + +// Enqueues a convert instruction onto the computation that changes the +// element type of the operand array to primitive_type. +XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); + +// Enqueues a no-op instruction onto the computation that changes +// the element type of the operand array to primitive_type. The +// bit-widths of the source and destination element types must be +// identical. +XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); + +// Enqueues a negate instruction onto the computation. +XlaOp Neg(const XlaOp& operand); + +// Enqueues a transpose instruction onto the computation. +XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + +// Enqueues a reverse instruction onto the computation. The order of the +// elements in the given dimensions is reversed (i.e., the element at index i +// is moved to index dimension_size - 1 - i). +XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); + +// * The result is a sorted array of keys. +// +// If both keys and values are provided: +// * The keys and the values must be rank-1 tensors with the same dimensions. +// The element types of the tensors may be different. +// * The result is a tuple that consists of a sorted array of keys as the +// first element, and an array with their corresponding values as the second +// element. +XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional values = tensorflow::gtl::nullopt); + +// Enqueues a clamp instruction onto the computation. +XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + +// Enqueues a map instruction onto the computation. +XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + +// Enqueues a N(mu, sigma) random number generation instruction onto the +// computation. +XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + +// Enqueues a U(a, b) random number generation instruction onto the +// computation. Returns values in the semi-open interval [a, b). +XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + +// Enqueues a while node onto the computation. +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + +// Enqueues a conditional node onto the computation. +XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + +// Enqueues a ReducePrecision node onto the computation. +XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + +// Enqueues a Gather node onto the computation. +XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + +// Enqueues a Send node onto the computation, to send the given operand to +// a Recv instruction that shares the same channel handle. +void Send(const XlaOp& operand, const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation. The data comes from a Send +// instruction that shares the same channel handle and its shape must +// be the same as the given shape. +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` +// is the normalized result and batch_mean and batch_var are the mean and +// variance, respectively, across batch for the operand. +XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// `BatchNormInference` is equivalent to calling `BatchNormTraining` without +// computing `mean` and `variance` for each batch inside the operation. It +// uses the input `mean` and `variance` instead as estimated values. The +// purpose of this op is to reduce latency in inference, hence the name +// `BatchNormInference`. +// +// The output has the same shape as `operand`, and contains the normalized +// values for each batch. +XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + +// Calculates the gradients of a batch norm op. +// +// The inputs `batch_mean` and `batch_var` represent the mean and variance +// across the batch. +// +// Returns a tuple of three elements: +// - grad_operand: Gradient with respect to input `operand` +// - grad_offset: Gradient with respect to input `offset` +// - grad_scale: Gradient with respect to input `scale` +XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + +// Implementation details below this point. + template XlaOp XlaBuilder::ConstantR0(NativeT value) { return ConstantLiteral(*Literal::CreateR0(value)); @@ -987,36 +2016,93 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { return ConstantFromArray(values); } -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -// -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -class XlaScopedShardingAssignment { - public: - XlaScopedShardingAssignment(xla::XlaBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } +// Free function template implementations. - XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; - XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = - delete; +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { + return ConstantLiteral(builder, *Literal::CreateR0(value)); +} - ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } +template +XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(builder, *Literal::CreateR1(values)); +} - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } +template +XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(builder, literal); +} - xla::XlaBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; -}; +inline XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values) { + return ConstantLiteral(builder, *Literal::CreateR1(values)); +} + +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values) { + return ConstantLiteral(builder, *Literal::CreateR2(values)); +} + +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout) { + return ConstantLiteral( + builder, *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { + return ConstantLiteral(builder, *Literal::CreateFromArray(values)); +} + +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout) { + return ConstantLiteral( + builder, *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values) { + return ConstantLiteral(builder, + *Literal::CreateR2FromArray2D(values)); +} + +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout) { + return ConstantLiteral( + builder, + *Literal::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values) { + return ConstantFromArray(builder, values); +} + +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout) { + return ConstantFromArrayWithLayout(builder, values, layout); +} + +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values) { + return ConstantFromArray(builder, values); +} } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index 2df3ea3af0d4fcfb9bc803feebd96f09042ab1f3..3b8beb2c7840e23752b5f47bbc5f55d89751884d 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -53,16 +53,86 @@ class XlaBuilderTest : public ::testing::Test { TEST_F(XlaBuilderTest, OnePlusTwo) { XlaBuilder b(TestName()); - b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); + Add(ConstantR0(&b, 1.0), ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); } +TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { + auto test_unary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant())); + test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant())); +} + +TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { + auto test_binary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + + test_binary_operator([](XlaOp x, XlaOp y) { return x + y; }, + op::Add(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x - y; }, + op::Subtract(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x * y; }, + op::Multiply(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x / y; }, + op::Divide(op::Constant(), op::Constant())); + + test_binary_operator([](XlaOp x, XlaOp y) { return x & y; }, + op::And(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x | y; }, + op::Or(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; }, + op::Xor(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x << y; }, + op::ShiftLeft(op::Constant(), op::Constant())); + test_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightArithmetic(op::Constant(), op::Constant())); + + auto test_unsigned_binary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unsigned_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightLogical(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { + XlaBuilder b(TestName()); + ConstantR0(&b, 1) >> ConstantR0(&b, 2); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Argument to >> operator does not have an integral type")); +} + TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); - b.Add(x, b.ConstantR0(1.0)); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + Add(x, ConstantR0(&b, 1.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); @@ -72,9 +142,9 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); - auto x = b.Parameter(0, x_shape, "x"); - auto y = b.Parameter(1, y_shape, "y"); - auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + auto x = Parameter(&b, 0, x_shape, "x"); + auto y = Parameter(&b, 1, y_shape, "y"); + auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1}); TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); @@ -86,8 +156,8 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { TEST_F(XlaBuilderTest, XPlusX) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); - b.Add(x, x); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + Add(x, x); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); @@ -95,9 +165,9 @@ TEST_F(XlaBuilderTest, XPlusX) { TEST_F(XlaBuilderTest, ShapeInferenceError) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); - b.Add(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + Add(x, y); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); @@ -105,12 +175,12 @@ TEST_F(XlaBuilderTest, ShapeInferenceError) { TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { XlaBuilder b_call("add"); - b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x"); XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); - b.Add(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y"); + Add(x, y); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -119,16 +189,16 @@ TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { TEST_F(XlaBuilderTest, Call) { XlaBuilder b_call("the_only_to_apply"); - auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); - auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); - b_call.Add(p0, p1); + auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); + Add(p0, p1); TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto one = b.ConstantR0(1); - auto two = b.ConstantR0(2); - b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = ConstantR0(&b, 1); + auto two = ConstantR0(&b, 2); + Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), @@ -137,9 +207,9 @@ TEST_F(XlaBuilderTest, Call) { TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); - b.Add(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + Add(x, y); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); // Expected: @@ -158,9 +228,9 @@ TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); - b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + Add(x, y, /*broadcast_dimensions=*/{0, 1}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); // The binary operation has in-dim broadcast and degenerate broadcast, should @@ -183,9 +253,10 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { XlaBuilder b1("b1"); - auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); XlaBuilder builder("main"); - builder.Add(p0, p0); + auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p"); + Add(p, p0); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -196,8 +267,8 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); - b.Reshape(x, /*new_sizes=*/{6, 35}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + Reshape(x, /*new_sizes=*/{6, 35}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Parameter())); @@ -205,8 +276,8 @@ TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { TEST_F(XlaBuilderTest, ReshapeHasTranspose) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); - b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); @@ -214,25 +285,38 @@ TEST_F(XlaBuilderTest, ReshapeHasTranspose) { TEST_F(XlaBuilderTest, Transpose) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); - b.Transpose(x, /*permutation=*/{1, 0}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + Transpose(x, /*permutation=*/{1, 0}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Transpose(op::Parameter())); } -// TODO(b/65209188): Create a dedicated lowering for Xor. -TEST_F(XlaBuilderTest, Xor) { +TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); - b.Xor(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + Add(b.ReportError(InvalidArgument("a test error")), x); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { + XlaBuilder b(TestName()); + StatusOr op(ConstantR0(&b, 1.0)); + Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - LOG(ERROR) << module->ToString(); - EXPECT_THAT(root, - op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), - op::And(op::Parameter(0), op::Not(op::Parameter(1))))); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { + XlaBuilder b(TestName()); + StatusOr op(InvalidArgument("a test error")); + Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); } } // namespace diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index b70b57e9ffec40188f246f5e884146012c02f4a2..0ffba208b1f8683fe1d26107cbfd096b856267f1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -25,8 +25,6 @@ limitations under the License. namespace xla { // The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: XlaComputation() : unique_id_(-1) {} diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a26b20c861846501c911253d89619591c37322b3 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD @@ -0,0 +1,18 @@ +# Description: +# Python API for shardings in XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "xla_sharding", + srcs = ["xla_sharding.py"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/python_api:types", + "//tensorflow/compiler/xla/python_api:xla_shape", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..abd10b164eaef8e75ed304483861baf250c5b954 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -0,0 +1,204 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Experimental support for defining XLA shardings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import xla_shape +from tensorflow.core.framework import attr_value_pb2 + + +class Sharding(object): + """A class to support adding sharding attributes to Ops. + + Use the factory constructors and then call apply_to_tensor: + Sharding.replicate().apply_to_tensor(tensor) + """ + + def __init__(self, proto=None): + """Do not use this constructor; use the factory functions below.""" + self._proto = proto + + @classmethod + def replicate(cls): + """Returns a replicated sharding attribute. + + This causes an op to be computed in its entirety independently on all + cores in the XLA device. + """ + return Sharding( + proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) + + @classmethod + def assign_device(cls, core): + """Returns an AssignDevice sharding attribute. + + This causes an op to be computed in its entirety only on one core in + the XLA device. + Args: + core: The core to assign this Op to. + """ + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.MAXIMAL, + tile_assignment_dimensions=[1], + tile_assignment_devices=[core])) + + @classmethod + def tile(cls, tile_shape, tile_assignment): + """Returns a Tiled sharding attribute. + + This causes an op to be partially computed on multiple cores in the + XLA device. + + Args: + tile_shape: A xla_shape.Shape describing the tile shape that each core + will compute. + The tile shape does not need to be divisible by the tile assignment. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + + Raises: + TypeError: tile_assignment was not of np.array type or tile_shape was + not of xla_shape.Shape type. + + TODO(jmolloy): This concept is nefarious and is not + something we really want to expose to users (especially as the + contract for tile_assignment is very strict). + """ + if not isinstance(tile_assignment, np.ndarray): + raise TypeError('Tile assignment must be of type np.ndarray') + if not isinstance(tile_shape, xla_shape.Shape): + raise TypeError('Tile shape must be of type xla_shape.Shape') + dims = list(tile_assignment.shape) + flattened_devices = tile_assignment.reshape(-1, order='C') + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape.message, + tile_assignment_dimensions=dims, + tile_assignment_devices=list(flattened_devices))) + + @classmethod + def split(cls, tensor, split_dimension, num_devices): + """Returns a Sharding that splits a tensor across a dimension. + + This creates a Tiled attribute, similar to tile(), but easier to use for the + common case of tiling a tensor N ways in one dimension. + + Args: + tensor: A tf.Tensor to split. + split_dimension: The dimension number to split. + num_devices: The number of cores to split `tensor` over. + + Raises: + ValueError: The tensor to split was smaller in the split dimension than + the number of devices to split over. + """ + tensor.shape.assert_is_fully_defined() + shape = tensor.shape.as_list() + if shape[split_dimension] < num_devices: + raise ValueError('Split dimension was smaller than the required number ' + 'of splits: shape=%r, dimension=%r, num_devices=%r', + shape, split_dimension, num_devices) + + tile_shape = shape + tile_shape[split_dimension] = int( + math.ceil(tile_shape[split_dimension] / num_devices)) + tile_shape_proto = xla_data_pb2.Shape( + element_type=xla_data_pb2.F32, dimensions=tile_shape) + + tile_assignment_dims = [1] * len(shape) + tile_assignment_dims[split_dimension] = num_devices + + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape_proto, + tile_assignment_dimensions=tile_assignment_dims, + tile_assignment_devices=range(num_devices))) + + def apply_to_tensor(self, tensor): + """Applies this Sharding attribute to `tensor`.""" + if len(tensor.op.outputs) > 1: + proto = self._get_or_create_tuple_proto(tensor.op) + # We can't mutate an element of old_proto.tuple_shardings, so create + # a new proto. + tuple_shardings = list(proto.tuple_shardings) + tuple_shardings[tensor.value_index] = self._proto + proto = xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) + else: + proto = self._proto + + attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) + # TODO(jmolloy): This need to be seriously revisited before declaring this + # API available for public use. + # pylint: disable=protected-access + tensor.op._set_attr('_XlaSharding', attr_value) + + @property + def proto(self): + """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" + return self._proto + + def _get_or_create_tuple_proto(self, op): + try: + attr = op.get_attr('_XlaSharding') + proto = xla_data_pb2.OpSharding() + proto.ParseFromString(attr) + return proto + except ValueError: + return self._create_tuple_proto(op) + + def _create_tuple_proto(self, op): + shardings = [ + xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) + for _ in op.outputs + ] + return xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) + + +# Helpers for the above factory functions that allow easy application of +# shardings, for example: +# tensor = xla_sharding.replicate(tensor) + + +def replicate(tensor): + Sharding.replicate().apply_to_tensor(tensor) + return tensor + + +def assign_device(tensor, device): + Sharding.assign_device(device).apply_to_tensor(tensor) + return tensor + + +def tile(tensor, tile_shape, tile_assignment): + Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) + return tensor + + +def split(tensor, split_dimension, num_devices): + Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) + return tensor diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index a76fdcda250168cbed2acd01bdd9ddc3b4c93b92..15eeb2ea13607d43c995197f8f0e3c58abd4d94a 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor) { + Layout layout; + layout.set_format(DENSE); + for (int i = major_to_minor.size() - 1; i >= 0; i--) { + layout.add_minor_to_major(major_to_minor[i]); + } + return layout; +} + /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { Layout layout; layout.set_format(SPARSE); @@ -88,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + // Opaque and token types have empty layouts. + return Layout(); + } + // A Layout proto corresponds to a single array, not a tuple. - DCHECK(!ShapeUtil::IsTuple(shape)); + CHECK(ShapeUtil::IsArray(shape)); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -116,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsOpaque(*shape)) { - shape->clear_layout(); - } else { + } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); + } else { + // Opaque, token types etc. have no layout. + shape->clear_layout(); } } @@ -150,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } return Status::OK(); - } else if (ShapeUtil::IsOpaque(shape)) { - if (shape.has_layout()) { - return InvalidArgument("opaque should not have a layout field"); - } - return Status::OK(); - } else { - // Array shape. + } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); + } else { + // Token, opaque, etc. shape. + if (shape.has_layout()) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } } @@ -171,7 +189,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (ShapeUtil::IsOpaque(shape)) { + if (!ShapeUtil::IsArray(shape)) { + if (layout.minor_to_major_size() != 0 || + layout.padded_dimensions_size() != 0) { + return InvalidArgument( + "shape of primitive type %s should not have a non-trivial layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } return Status::OK(); } @@ -224,6 +248,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } + if (layout.format() == SPARSE) { + if (!layout.padded_dimensions().empty()) { + return InvalidArgument("Sparse layout has padded dimensions"); + } + } + return Status::OK(); } @@ -263,7 +293,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || shape.layout().padded_dimensions_size() == 0) { return false; } @@ -313,7 +343,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); - } else if (ShapeUtil::IsOpaque(shape)) { + } else if (!ShapeUtil::IsArray(shape)) { + // Opaque, token types etc. ignore layout. return true; } return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; @@ -422,12 +453,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { - return false; - } if (ShapeUtil::IsTuple(lhs)) { - if (ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -436,9 +464,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } } return true; - } else { + } else if (ShapeUtil::IsArray(lhs)) { return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } else { + // Layouts of non-array and non-tuple shapes is ignored. + return true; } } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index d3d6a2cc94012f7113fd1cb1b17e9c9d5323d9bf..739bbe73675c7fb855627006028eafdf703d6540 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Similar to MakeLayout, but take indices in reverse order. + static Layout MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4fd1d818e3e3b417eee9f6b14bb598bfb9480c6e..e4c825450dcd45a8fbeaacbb2ad145f94307176f 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { "elements, but shape is rank")); } +TEST_F(LayoutUtilTest, CopyTokenLayout) { + Shape src = ShapeUtil::MakeTokenShape(); + Shape dst = ShapeUtil::MakeTokenShape(); + + // Layouts are trivially the same for token types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyOpaqueLayout) { + Shape src = ShapeUtil::MakeOpaqueShape(); + Shape dst = ShapeUtil::MakeOpaqueShape(); + + // Layouts are trivially the same for opaque types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, ClearLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), @@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) { EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); } +TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) { + // Opaque and token types trivially have layouts. + for (Shape shape : + {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) { + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + LayoutUtil::ClearLayout(&shape); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + } +} + TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 3696fdbe12e311af3b286ef0dfe91377983b72dd..2125ab7c61ab5e30fe51e16994e0da4883d509c4 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -606,8 +606,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } // namespace Status EqualShapes(const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return InvalidArgument("tupleness-mismatch! want: %s got %s", + if (expected.element_type() != actual.element_type()) { + return InvalidArgument("element type mismatch, want: %s got %s", ShapeUtil::HumanString(expected).c_str(), ShapeUtil::HumanString(actual).c_str()); } @@ -626,7 +626,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else { + } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected).c_str(), @@ -652,6 +652,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } } } + // Non-array, non-tuple shapes are trivially equivalent. return Status::OK(); } @@ -705,6 +706,9 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } break; } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); default: LOG(FATAL) << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " @@ -716,9 +720,11 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } return AppendStatus(result, - tensorflow::strings::Printf("expected: %s\nactual: %s", - expected.ToString().c_str(), - actual.ToString().c_str())); + tensorflow::strings::Printf( + "\nat index: %s\nexpected: %s\nactual: %s", + Literal::MultiIndexAsString(multi_index).c_str(), + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str())); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 1022372df20d5447dc2735ee0ad7733558d0b9bb..eeabf835ac348a5ba55699631188b0e329c98c43 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -148,8 +148,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else { - CHECK(ShapeUtil::IsArray(shape)); + } else if (ShapeUtil::IsArray(shape)) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -165,6 +164,10 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->set_buffer(new char[piece->size_bytes()]); } } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); } } @@ -264,8 +267,8 @@ Status Literal::CopySliceFromInternal( StridedCopy(data(), linear_index(shape(), dest_base), 0, src_literal.data(), linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(shape()) && - !ShapeUtil::HasZeroElements(src_literal.shape())) { + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { // Perform copy if neither src nor dest has dimensions with zero element, // otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); @@ -327,6 +330,10 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } +/* static */ std::unique_ptr Literal::CreateToken() { + return MakeUnique(ShapeUtil::MakeTokenShape()); +} + std::vector Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector elements; @@ -379,7 +386,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, tensorflow::gtl::ArraySlice src, const Shape& dest_shape, const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::HasZeroElements(dest_shape)) { + if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } std::vector index(ShapeUtil::Rank(dest_shape)); @@ -807,6 +814,47 @@ std::unique_ptr LiteralBase::Relayout( return result; } +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { @@ -939,13 +987,30 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(permuted_shape); - DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { @@ -963,51 +1028,17 @@ std::unique_ptr LiteralBase::Slice( const auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); - - auto result_literal = MakeUnique(result_shape); - - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - float value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); case C64: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - complex64 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case S32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - int32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case U32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - uint32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -1153,7 +1184,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { return; } @@ -1344,6 +1375,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, return; } + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + if (LayoutUtil::IsSparseArray(subshape)) { pieces->push_back(shape_to_string(subshape)); pieces->push_back("{"); @@ -1532,7 +1568,7 @@ string LiteralBase::ToString(bool print_layout) const { void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( @@ -1938,7 +1974,7 @@ bool LiteralBase::IsAllFirst() const { // Empty shapes are not all the first element since there is no first // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { return false; } auto piece_is_all = [&]() { @@ -2106,6 +2142,7 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { } break; case TUPLE: + case TOKEN: // Nothing to do but assign the shape which is done above. return; default: @@ -2258,6 +2295,9 @@ StatusOr> Literal::CreateFromProto( } return Status::OK(); } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } CHECK(ShapeUtil::IsArray(piece->subshape())); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); @@ -2317,28 +2357,27 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_NE(src_buf_ptr, nullptr); - CHECK(LayoutUtil::HasLayout(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(&shape_); + root_piece_.set_subshape(shape_.get()); } BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsNestedTuple(shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); - root_piece_.set_subshape(&shape_); - BuildPieceSubtree(shape_, &root_piece_); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_.tuple_shapes(i); + const auto& src_shape = shape_->tuple_shapes(i); CHECK(ShapeUtil::IsArray(src_shape)); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index ad5c7c8995f2276ee0f180dd1c782a44f46fb30a..37ca8ea9f1d158b6bce8d5688288351f55c3b3c8 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -277,6 +277,12 @@ class LiteralBase { StatusOr> Reshape( tensorflow::gtl::ArraySlice dimensions) const; + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers // in the original literal, and it specifies the order of the new dimensions @@ -307,6 +313,11 @@ class LiteralBase { // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. static std::unique_ptr CreateFromShape(const Shape& shape); protected: @@ -531,6 +542,12 @@ class LiteralBase { friend class Literal; friend class LiteralSlice; friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; }; // Class representing literal values in XLA. @@ -900,6 +917,9 @@ class Literal : public LiteralBase { return MakeTupleOwned(std::move(v)); } + // Create a constant token literal. Token types have no value. + static std::unique_ptr CreateToken(); + // Returns a vector containing the tuple elements of this Literal as separate // Literals. This Literal must be tuple-shaped and can be a nested tuple. The // elements are moved into the new Literals; no data is copied. Upon return @@ -1082,8 +1102,10 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. - const Shape shape_; + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template @@ -1437,7 +1459,7 @@ void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector indices(ShapeUtil::Rank(shape()), 0); @@ -1650,7 +1672,7 @@ template const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); TF_RETURN_IF_ERROR(literal.get()->Populate( [&](tensorflow::gtl::ArraySlice indexes) { return generator(indexes); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 5b85474ad11c3d8f7c3971af0f7269a25ed68a96..493d807591dd3c425293e4ee796bca3036a3088c 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -334,6 +334,22 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { EXPECT_EQ(nil, nil); } +TEST_F(LiteralUtilTest, TokenEquality) { + auto token0 = Literal::CreateToken(); + auto token1 = Literal::CreateToken(); + auto scalar = Literal::CreateR0(1.0); + + EXPECT_EQ(*token0, *token1); + EXPECT_NE(*token0, *scalar); + + EXPECT_EQ(*Literal::MakeTuple({token0.get()}), + *Literal::MakeTuple({token0.get()})); + EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}), + *Literal::MakeTuple({token1.get(), scalar.get()})); + EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}), + *Literal::MakeTuple({scalar.get(), token1.get()})); +} + TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. auto colmajor = @@ -1066,7 +1082,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1108,7 +1124,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1431,7 +1447,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { std::vector int64_values = {1, 2, 3}; const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1443,7 +1459,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { EXPECT_EQ(literal.Get({2}), 3); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { std::vector one_two_three = {1, 2, 3}; const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1810,5 +1826,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { + std::unique_ptr literal = Literal::CreateR0(9); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/overflow_util.h b/tensorflow/compiler/xla/overflow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8657d3a4bfa992b9ca0619f24923fd4542eed894 --- /dev/null +++ b/tensorflow/compiler/xla/overflow_util.h @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Multiply two nonnegative int64's, returning negative for overflow +inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) { + // Multiply in uint64 rather than int64 since signed overflow is undefined. + // Negative values will wrap around to large unsigned values in the casts + // (see section 4.7 [conv.integral] of the C++14 standard). + const uint64 ux = x; + const uint64 uy = y; + const uint64 uxy = ux * uy; + + // Check if we overflow uint64, using a cheap check if both inputs are small + if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) { + // Ensure nonnegativity. Note that negative numbers will appear "large" + // to the unsigned comparisons above. + CHECK(x >= 0 && y >= 0); + + // Otherwise, detect overflow using a division + if (ux != 0 && uxy / ux != uy) return -1; + } + + // Cast back to signed. Any negative value will signal an error. + return static_cast(uxy); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 143c9a2366be5786b7ef2148580caeb97d67d2d8..b16147e3be71771269d8b7a18528bef3a8c72d99 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) { } } +bool IsArrayType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index b26a10ade63a5dad3bf8f9f3a2a33c3c5e67bdb2..889e9a1ceca675689406d255d348c82c398563aa 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type); bool IsIntegralType(PrimitiveType type); +// Returns true if values of the given primitive type are held in array shapes. +bool IsArrayType(PrimitiveType primitive_type); + // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 932cce943f7c046a85984e6e5ed6b59dae371473..22cc4e2436e5d3a7ed77a2b9f5515878661ef294 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -12,6 +12,7 @@ py_library( deps = [ ":pywrap_xla", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -51,6 +52,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index df262c97bfcd91a5c2921a36ecb8f8a6172cffe6..be55d50b234442ec569c85e4f5224ad1c179bca8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,13 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/default/thread_annotations.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace xla { - namespace swig { // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of @@ -97,6 +98,36 @@ const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; } +ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } + +LocalShapedBufferTuple::LocalShapedBufferTuple( + std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + DCHECK(element != nullptr); + } +} + +LocalShapedBufferTuple::~LocalShapedBufferTuple() { + for (LocalShapedBuffer* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr LocalShapedBufferTuple::Release(int i) { + LocalShapedBuffer* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int LocalShapedBufferTuple::size() const { return elements_.size(); } + static StatusOr ToBuffer(LocalClient* client, int device_ordinal, const Literal& arg) { @@ -145,73 +176,73 @@ StatusOr> CompiledLocalComputation::Execute( GetReplicaCount()); for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, - &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - // Transfer arguments in - std::vector scoped_buffers; - scoped_buffers.reserve(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Literal& argument = arguments[i]; - const tensorflow::gtl::optional& shape_with_layout = - shapes_with_layout[i]; - - StatusOr pushed; - if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); - } else { - pushed = ToBuffer(client, device_ordinal, argument); - } - if (!pushed.ok()) { - results[replica] = pushed.status(); - return; - } - - scoped_buffers.push_back(std::move(pushed).ValueOrDie()); - } - - // Execute - std::vector argument_buffers; - argument_buffers.reserve(scoped_buffers.size()); - for (auto& buffer : scoped_buffers) { - argument_buffers.push_back(&buffer); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - if (!result_buffer_status.ok()) { - results[replica] = result_buffer_status.status(); - return; - } - - // Transfer result out - results[replica] = client->ShapedBufferToLiteral( - std::move(result_buffer_status).ValueOrDie()); - }); + pool.Schedule( + [this, client, replica, &arguments, &shapes_with_layout, &results] { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " + << device_ordinal; + + // Transfer arguments in + std::vector scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + const Literal& argument = arguments[i]; + const tensorflow::gtl::optional& shape_with_layout = + shapes_with_layout[i]; + + StatusOr pushed; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, *relaid); + } else { + pushed = ToBuffer(client, device_ordinal, argument); + } + if (!pushed.ok()) { + results[replica] = pushed.status(); + return; + } + + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); + } + + // Execute + std::vector argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(&buffer); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + if (!result_buffer_status.ok()) { + results[replica] = result_buffer_status.status(); + return; + } + + // Transfer result out + results[replica] = client->ShapedBufferToLiteral( + std::move(result_buffer_status).ValueOrDie()); + }); } } @@ -276,6 +307,15 @@ const XlaComputation& LocalComputation::computation() const { return computation_; } +string LocalComputation::GetSerializedProto() const { + string result; + if (!computation_.proto().SerializeToString(&result)) { + LOG(ERROR) << "Failed to serialize the HloModuleProto."; + return ""; + } + return result; +} + StatusOr LocalComputation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation_.GetProgramShape()); @@ -303,14 +343,11 @@ StatusOr LocalComputationBuilder::Build() { LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return builder_.Parameter(parameter_number, shape, name); + return xla::Parameter(&builder_, parameter_number, shape, name); } -std::unique_ptr LocalComputationBuilder::GetShape( - const LocalOp& operand) { - auto result = MakeUnique(); - *result = builder_.GetShape(operand.op()).ValueOrDie(); - return result; +StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { + return builder_.GetShape(operand.op()); } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -319,72 +356,70 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { } LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { - return builder_.Infeed(shape); + return xla::Infeed(&builder_, shape); } void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand.op(), shape, outfeed_config); + xla::Outfeed(operand.op(), shape, outfeed_config); } LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { - return builder_.ConstantLiteral(literal); + return xla::ConstantLiteral(&builder_, literal); } LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand.op(), broadcast_sizes); + return xla::Broadcast(operand.op(), broadcast_sizes); } LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config) { - return builder_.Pad(operand.op(), padding_value.op(), padding_config); + return xla::Pad(operand.op(), padding_value.op(), padding_config); } LocalOp LocalComputationBuilder::Reshape( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand.op(), dimensions, new_sizes); + return xla::Reshape(operand.op(), dimensions, new_sizes); } LocalOp LocalComputationBuilder::Collapse( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand.op(), dimensions); + return xla::Collapse(operand.op(), dimensions); } LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return builder_.CrossReplicaSum(operand.op()); + return xla::CrossReplicaSum(operand.op()); } LocalOp LocalComputationBuilder::Slice( const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand.op(), start_indices, limit_indices, strides); + return xla::Slice(operand.op(), start_indices, limit_indices, strides); } LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, - dimno); + return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); + return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } LocalOp LocalComputationBuilder::DynamicUpdateSlice( const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices) { - return builder_.DynamicUpdateSlice(operand.op(), update.op(), - start_indices.op()); + return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } LocalOp LocalComputationBuilder::ConcatInDim( @@ -394,7 +429,7 @@ LocalOp LocalComputationBuilder::ConcatInDim( for (const auto& op : operands) { xla_ops.push_back(op.op()); } - return builder_.ConcatInDim(xla_ops, dimension); + return xla::ConcatInDim(&builder_, xla_ops, dimension); } LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( @@ -404,7 +439,7 @@ LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const LocalOp& source, const LocalOp& init_value, const LocalComputation& scatter) { - return builder_.SelectAndScatterWithGeneralPadding( + return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } @@ -417,22 +452,22 @@ LocalOp LocalComputationBuilder::Tuple( xla_ops.push_back(op.op()); } - return builder_.Tuple(xla_ops); + return xla::Tuple(&builder_, xla_ops); } LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data.op(), index); + return xla::GetTupleElement(tuple_data.op(), index); } LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { - return builder_.Dot(lhs.op(), rhs.op()); + return xla::Dot(lhs.op(), rhs.op()); } LocalOp LocalComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); + return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } LocalOp LocalComputationBuilder::ConvGeneralDilated( @@ -442,14 +477,13 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, - padding, lhs_dilation, rhs_dilation, - dimension_numbers); + return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers); } LocalOp LocalComputationBuilder::ConvertElementType( const LocalOp& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand.op(), new_element_type); + return xla::ConvertElementType(operand.op(), new_element_type); } LocalOp LocalComputationBuilder::Call( @@ -460,46 +494,39 @@ LocalOp LocalComputationBuilder::Call( for (const auto& op : operands) { xla_ops.push_back(op.op()); } - return builder_.Call(local_computation.computation(), xla_ops); + return xla::Call(&builder_, local_computation.computation(), xla_ops); } LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand.op(), permutation); + return xla::Transpose(operand.op(), permutation); } LocalOp LocalComputationBuilder::Rev( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand.op(), dimensions); + return xla::Rev(operand.op(), dimensions); } LocalOp LocalComputationBuilder::Map( tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + tensorflow::gtl::ArraySlice dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { xla_ops.push_back(op.op()); } - std::vector static_xla_ops; - static_xla_ops.reserve(static_operands.size()); - for (const auto& op : static_operands) { - static_xla_ops.push_back(op.op()); - } - - return builder_.Map(xla_ops, local_computation.computation(), dimensions, - static_xla_ops); + return xla::Map(&builder_, xla_ops, local_computation.computation(), + dimensions); } LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand.op(), init_value.op(), - local_computation.computation(), dimensions_to_reduce); + return xla::Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( @@ -508,7 +535,7 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - return builder_.ReduceWindowWithGeneralPadding( + return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), window_dimensions, window_strides, padding); } @@ -516,27 +543,27 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape) { - return builder_.RngNormal(mu.op(), sigma.op(), shape); + return xla::RngNormal(mu.op(), sigma.op(), shape); } LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape) { - return builder_.RngUniform(a.op(), b.op(), shape); + return xla::RngUniform(a.op(), b.op(), shape); } LocalOp LocalComputationBuilder::While(const LocalComputation& condition, const LocalComputation& body, const LocalOp& init) { - return builder_.While(condition.computation(), body.computation(), init.op()); + return xla::While(condition.computation(), body.computation(), init.op()); } LocalOp LocalComputationBuilder::Conditional( const LocalOp& predicate, const LocalOp& true_operand, const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional( - predicate.op(), true_operand.op(), true_computation.computation(), - false_operand.op(), false_computation.computation()); + return xla::Conditional(predicate.op(), true_operand.op(), + true_computation.computation(), false_operand.op(), + false_computation.computation()); } StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { @@ -552,7 +579,7 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ - return builder_.method_name args; \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -586,22 +613,25 @@ _FORWARD_BINOP(Max) _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) +_FORWARD_BINOP(Xor) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) +_FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) +_FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) -_FORWARD_UNOP(SqrtF32) -_FORWARD_UNOP(SquareF32) +_FORWARD_UNOP(Sqrt) +_FORWARD_UNOP(Square) _FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) -_FORWARD_UNOP(ReciprocalF32) +_FORWARD_UNOP(Reciprocal) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) @@ -622,6 +652,54 @@ void DeleteLocalComputation(LocalComputation* computation) { delete computation; } -} // namespace swig +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer) { + if (!ShapeUtil::IsTuple( + local_shaped_buffer->shaped_buffer()->on_device_shape())) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString( + local_shaped_buffer->shaped_buffer()->on_device_shape()) + .c_str()); + } + + DeviceMemoryAllocator* allocator = + local_shaped_buffer->shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); + + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); + + ShapeTree& shape_tree = tuple_buffer.buffers(); + const Shape& tuple_shape = tuple_buffer.on_device_shape(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); + + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); + + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator))); + } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); +} +} // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a06b85b4ea28c4f386598901138930eaaed12079..690ff277e884c6f1540b12e7002248571d07fe71 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { - namespace swig { // Initializes the number of replicas that XLA will be initialized with (when @@ -69,10 +68,42 @@ class LocalShapedBuffer { StatusOr > ToLiteral() const; + // Transfers ownership of the encapsulated ShapedBuffer to the caller, + // analogous to std::unique_ptr::release(). + ShapedBuffer Release(); + private: ScopedShapedBuffer shaped_buffer_; }; +// Result of a tuple destructuring operation on a LocalShapedBuffer -- this +// appears to be a simpler mechanism for the time being than an alternative like +// using SWIG to transform std::vectors into Python lists of SWIG objects +// directly. +class LocalShapedBufferTuple { + public: + // Note: any LocalShapedBuffer elements that are not Release()'d will be + // deallocated in the destructor. + explicit LocalShapedBufferTuple(std::vector elements); + + ~LocalShapedBufferTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements +// in LocalShapedBufferTuple form. +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer); + // Wraps a LocalExecutable produced by compiling a // LocalComputation. The Execute method forwards to that of the // underlying LocalExecutable, and additionally handles tranferring @@ -112,6 +143,11 @@ class LocalComputation { const XlaComputation& computation() const; + // Returns the HloModuleProto contained in the XlaComputation in the + // serialized binary format. Logs an internal error and returns an empty + // string on failure. + string GetSerializedProto() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -151,7 +187,7 @@ class LocalComputationBuilder { LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - std::unique_ptr GetShape(const LocalOp& operand); + StatusOr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); @@ -234,8 +270,7 @@ class LocalComputationBuilder { LocalOp Map(tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + tensorflow::gtl::ArraySlice dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, @@ -297,22 +332,25 @@ class LocalComputationBuilder { _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) + _FORWARD_BINOP(Xor) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) + _FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) - _FORWARD_UNOP(SqrtF32) - _FORWARD_UNOP(SquareF32) + _FORWARD_UNOP(Sqrt) + _FORWARD_UNOP(Square) _FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) - _FORWARD_UNOP(ReciprocalF32) + _FORWARD_UNOP(Reciprocal) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) @@ -331,7 +369,6 @@ void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); void DeleteLocalComputation(LocalComputation* computation); } // namespace swig - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 04c56bbba95fbf3248df6c49700ff563c8b253c0..c44e69e6153239b39f9f8a40539a75ddffdef25d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -200,6 +200,20 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBufferTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); @@ -851,6 +865,11 @@ tensorflow::ImportNumpy(); })) { return nullptr; } + if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { + build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); })) { @@ -900,12 +919,16 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBufferTuple; +%unignore xla::swig::LocalShapedBufferTuple::Release; +%unignore xla::swig::LocalShapedBufferTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalComputation::GetSerializedProto; %unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; @@ -965,24 +988,28 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Min; %unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Xor; %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; %unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::SqrtF32; -%unignore xla::swig::LocalComputationBuilder::SquareF32; +%unignore xla::swig::LocalComputationBuilder::Sqrt; +%unignore xla::swig::LocalComputationBuilder::Square; %unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Reciprocal; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DeleteCompiledLocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 1d5b75d1bee2dcee3e448d0bcb72103b539efac6..27aee634bac613a87c919a357e085ec71c7deeb1 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,6 +28,7 @@ import numpy as np from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api +from tensorflow.compiler.xla.service import hlo_pb2 # Most functions are snake_case for consistency with other modules, whereas @@ -88,18 +89,20 @@ _UNARY_OPS = [ 'Not', 'Abs', 'Exp', + 'Expm1', 'Floor', 'Round', 'Ceil', 'Log', + 'Log1p', 'Sign', 'Cos', 'Sin', 'Tanh', - 'SqrtF32', - 'SquareF32', + 'Sqrt', + 'Square', 'IsFinite', - 'ReciprocalF32', + 'Reciprocal', 'Neg', 'Sort', ] @@ -120,6 +123,7 @@ _BINARY_OPS = [ 'Min', 'And', 'Or', + 'Xor', 'Pow', ] @@ -183,6 +187,14 @@ class LocalBuffer(object): self._delete(self.c_local_shaped_buffer) self.c_local_shaped_buffer = None + def destructure(self): + assert self.c_local_shaped_buffer is not None + result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + size = result.size() + destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size)) + return destructured + def is_deleted(self): return self.c_local_shaped_buffer is None @@ -246,9 +258,12 @@ class Shape(object): self._dimensions == other._dimensions and self._minor_to_major == other._minor_to_major) + def __ne__(self, other): + return not self == other + def __repr__(self): return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' - '_is_tuple={!r}), _minor_to_major={!r}').format( + '_is_tuple={!r}, _minor_to_major={!r})').format( self._dtype, self._dimensions, self._is_tuple, self._minor_to_major) @@ -352,6 +367,7 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None self.dump_optimized_hlo_proto_to = None + self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False @@ -410,6 +426,17 @@ class LocalComputation(object): assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + def GetProto(self): + """Get the HloModuleProto proto object in this local computation. + + Returns: + An HloModuleProto proto object that has the whole-graph information. + """ + + serialized = self.c_local_computation.GetSerializedProto() + proto = hlo_pb2.HloModuleProto.FromString(serialized) + return proto + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. @@ -882,20 +909,19 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.c_local_computation, operands) - def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. Args: operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. - static_operands: auxiliary arguments passed to the applied computation. Returns: A LocalOp representing the added Map op. """ return self._client.Map(operands, computation_to_apply.c_local_computation, - dimensions, static_operands) + dimensions) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. @@ -1100,6 +1126,61 @@ class ComputationBuilder(object): dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers + def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a + triple (lhs_spec, rhs_spec, out_spec) where each element is a string of + length N+2 identifying by position (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. + + Returns: a LocalOp representing the ConvGenralDilated operation. + """ + if not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c073c02040e4d260cf760ea2b25f70d60ddd41a1..0564ddcb85ee3952f82649687e79a864999baf2c 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -157,6 +157,13 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayBool([True, True, False, False]))) self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + def testBooleanXor(self): + c = self._NewComputation() + c.Xor( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + def testSum2DF32(self): c = self._NewComputation() c.Add( @@ -164,6 +171,16 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testGetProto(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + built = c.Build() + proto = built.GetProto() # HloModuleProto + self.assertTrue(len(proto.computations) == 1) + self.assertTrue(len(proto.computations[0].instructions) == 3) + def testSum2DF64(self): c = self._NewComputation() c.Add( @@ -355,6 +372,55 @@ class LocalBufferTest(LocalComputationTest): with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + def testDestructureTupleEmpty(self): + t = () + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 0) + + def testDestructureTupleOneArrayElement(self): + t = (np.array([1, 2, 3, 4], dtype=np.int32),) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 1) + array = pieces[0] + got = array.to_py() + want = NumpyArrayS32([1, 2, 3, 4]) + np.testing.assert_equal(want, got) + + def testDestructureTupleTwoArrayElementDifferentType(self): + t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32)) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + got = array0.to_py() + want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) + np.testing.assert_equal(want, got) + got = array1.to_py() + want = NumpyArrayS32([2, 3, 4, 5]) + np.testing.assert_equal(want, got) + + def testDestructureTupleNested(self): + t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + tuple0, array1 = pieces + got = array1.to_py() + want = NumpyArrayS32([5]) + np.testing.assert_equal(want, got) + got = tuple0.to_py() + self.assertEqual(type(got), tuple) + self.assertEqual(len(got), 2) + np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) + np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + class SingleOpTest(LocalComputationTest): """Tests for single ops. @@ -509,6 +575,46 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=result) + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = ("NHWC", "OIHW", "CWNH") + c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), + c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) @@ -521,6 +627,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Expm1(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -533,6 +645,12 @@ class SingleOpTest(LocalComputationTest): c.Log(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.log(arr)) + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log1p(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -1057,14 +1175,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._CreateBinaryDivF64Computation(), [0]) self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) - def DISABLED_testMapWithStaticOperands(self): - c = self._NewComputation() - factor = c.ConstantF32Scalar(3.0) - c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateMulF32ByParamComputation(), [0], - static_operands=[factor]) - self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) - def testSelectAndScatterF32(self): c = self._NewComputation() c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8999cda5ef852d1246bea45a3312575ec1ac0721 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -0,0 +1,36 @@ +# Description: +# Python API for XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "types", + srcs = ["types.py"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", + ], +) + +py_library( + name = "xla_shape", + srcs = ["xla_shape.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_library( + name = "xla_literal", + srcs = ["xla_literal.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + ":xla_shape", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py new file mode 100644 index 0000000000000000000000000000000000000000..b60f8dce92ace1b2c682374a2605b3a477936bbc --- /dev/null +++ b/tensorflow/compiler/xla/python_api/types.py @@ -0,0 +1,124 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Utilities for XLA-specific Python types.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 + +# Records corresponsence between a XLA primitive type and Python/Numpy types. +# +# primitive_type: value of type xla_data_pb2.PrimitiveType +# numpy_dtype: corresponsing Numpy "dtype" (like np.float32) +# literal_field_name: name of the field in the LiteralProto message elements +# of this type go into. +# literal_field_type: type of the field named 'literal_field_name'. +# +# TODO(eliben): figure out how to avoid knowing the extra Python type and the +# astype cast when writing into Literals. +TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ + 'primitive_type', 'numpy_dtype', 'literal_field_name', 'literal_field_type' +]) + +# Maps from XLA primitive types to TypeConversionRecord. +MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.F16: + TypeConversionRecord( + primitive_type=xla_data_pb2.F16, + numpy_dtype=np.float16, + literal_field_name='f16s', + literal_field_type=float), + xla_data_pb2.F32: + TypeConversionRecord( + primitive_type=xla_data_pb2.F32, + numpy_dtype=np.float32, + literal_field_name='f32s', + literal_field_type=float), + xla_data_pb2.F64: + TypeConversionRecord( + primitive_type=xla_data_pb2.F64, + numpy_dtype=np.float64, + literal_field_name='f64s', + literal_field_type=float), + xla_data_pb2.S8: + TypeConversionRecord( + primitive_type=xla_data_pb2.S8, + numpy_dtype=np.int8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.S16: + TypeConversionRecord( + primitive_type=xla_data_pb2.S16, + numpy_dtype=np.int16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.S32: + TypeConversionRecord( + primitive_type=xla_data_pb2.S32, + numpy_dtype=np.int32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.S64: + TypeConversionRecord( + primitive_type=xla_data_pb2.S64, + numpy_dtype=np.int64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.U8: + TypeConversionRecord( + primitive_type=xla_data_pb2.U8, + numpy_dtype=np.uint8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.U16: + TypeConversionRecord( + primitive_type=xla_data_pb2.U16, + numpy_dtype=np.uint16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.U32: + TypeConversionRecord( + primitive_type=xla_data_pb2.U32, + numpy_dtype=np.uint32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.U64: + TypeConversionRecord( + primitive_type=xla_data_pb2.U64, + numpy_dtype=np.uint64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.PRED: + TypeConversionRecord( + primitive_type=xla_data_pb2.PRED, + numpy_dtype=np.bool, + literal_field_name='preds', + literal_field_type=bool) +} + +# Maps from Numpy dtypes to TypeConversionRecord. +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +MAP_DTYPE_TO_RECORD = { + str(np.dtype(record.numpy_dtype)): record + for record in MAP_XLA_TYPE_TO_RECORD.values() +} diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py new file mode 100644 index 0000000000000000000000000000000000000000..b040098c294ffaae92b72f678947f99289239314 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA LiteralProto utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types +from tensorflow.compiler.xla.python_api import xla_shape + + +def ConvertLiteralToNumpyArray(literal): + """Converts a XLA literal to a Numpy array.""" + element_type = literal.shape.element_type + if element_type == xla_data_pb2.TUPLE: + return tuple( + ConvertLiteralToNumpyArray(subliteral) + for subliteral in literal.tuple_literals) + + type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] + if not literal.shape.dimensions: + return np.array( + getattr(literal, type_record.literal_field_name)[0], + type_record.numpy_dtype) + else: + # Infer the proper Numpy order from the LiteralProto's layout. The repeated + # field representing the array's content in the Literal is linearized. + # Reading is done in two steps: + # + # 1. Read the array as 1D from the LiteralProto repeated field. + # 2. Reshape the array to its proper shape, using the right order depending + # on the LiteralProto's layout. + layout_order = literal.shape.layout.minor_to_major + numpy_shape = tuple(literal.shape.dimensions) + if layout_order == range(len(literal.shape.dimensions)): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F') + elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') + else: + raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) + ndarray = np.array( + getattr(literal, type_record.literal_field_name), + copy=False, + dtype=type_record.numpy_dtype) + return numpy_reshaper(ndarray) + + +def _ConvertNumpyArrayToLiteral(ndarray): + """Converts a Numpy array to a XLA literal.""" + type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)] + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message) + + if ndarray.ndim == 0: + getattr(literal, type_record.literal_field_name).append( + np.asscalar(ndarray.astype(type_record.literal_field_type))) + else: + # Ndarrays with boolean dtypes need special type conversion with protobufs + if ndarray.dtype in {np.bool_, np.dtype('bool')}: + for element in np.nditer(ndarray): + getattr(literal, type_record.literal_field_name).append( + type_record.literal_field_type(element)) + else: + ndarray_flat = ndarray.ravel(order='A') + getattr(literal, type_record.literal_field_name).extend(ndarray_flat) + return literal + + +def ConvertNumpyArrayToLiteral(value): + """Converts a Numpy array or a nested tuple thereof to an XLA literal.""" + if isinstance(value, tuple): + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message) + for component in value: + component_literal = literal.tuple_literals.add() + component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component)) + return literal + else: + return _ConvertNumpyArrayToLiteral(value) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6af28958035bbb03e7e1dbb0d0c7bb2c2f25b96d --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA Shape utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types + + +class Shape(object): + """Wraps a xla_data_pb2.Shape message with a convenient Python type. + + Provides direct access to the underlying xla_data_pb2.Shape message in the + message attribute, along with accessor wrappers to the message's fields. + Avoid direct access to .message unless interacting directly with protobuf APIs + like CopyFrom. In other words, prefer hauling the shape around in a Shape, and + only access .message when strictly required by the protobuf API. + """ + + def __init__(self, element_type, dimensions, layout=None): + """Creates a new XLA Shape. + + Args: + element_type: element type from xla_data_pb2. + dimensions: sequence of dimensions sizes (integers), or sequence + of Shapes in the case of a tuple, i.e. when element_type is + TUPLE. + layout: optional minor_to_major sequence for layout. If not given, the + default major-to-minor layout is used. + + Raises: + ValueError: if element_type is TUPLE but dimensions are not Shape objects. + """ + self.message = xla_data_pb2.Shape() + self.message.element_type = element_type + if element_type == xla_data_pb2.TUPLE: + if not all(isinstance(subshape, Shape) for subshape in dimensions): + raise ValueError( + 'XLA tuple requires sequence of Shape objects as dimensions') + self._tuple_shapes = tuple(dimensions) + for component_shape in self._tuple_shapes: + component_message = self.message.tuple_shapes.add() + component_message.CopyFrom(component_shape.message) + else: + self.message.dimensions.extend(dimensions) + if layout is None: + layout = list(reversed(range(len(dimensions)))) + self.message.layout.format = xla_data_pb2.DENSE + self.message.layout.minor_to_major.extend(layout) + + def element_type(self): + return self.message.element_type + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions. Try tuple_shapes()?') + return self.message.dimensions + + def tuple_shapes(self): + """If this is a tuple, returns its sequence of constituent Shape objects. + + Returns: + Tuple sub-shapes. + + Raises: + ValueError: if this is not a tuple. + """ + if not self.is_tuple(): + raise ValueError('tuple_shapes() called on a non-tuple shape') + return self._tuple_shapes + + def layout(self): + return self.message.layout + + @staticmethod + def from_pyval(pyval): + return CreateShapeFromNumpy(pyval) + + +def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name + """Create a Shape from a given Numpy array. + + Args: + ndarray: Numpy array. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)].primitive_type + dimensions = ndarray.shape + + # Set the shape's layout based on the ordering of ndarray. + # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). + if np.isfortran(ndarray): + # Column-major layout. This corresponds to a "dimension order is + # minor-to-major" layout in XLA. + layout = range(ndarray.ndim) + else: + # Row-major layout. This corresponds to a "dimension order is + # major-to-minor" layout int XLA. + layout = list(reversed(xrange(ndarray.ndim))) + + return Shape(element_type, dimensions, layout) + + +def CreateShapeFromNumpy(value): # pylint: disable=invalid-name + """Create a Shape from a Numpy array or a nested tuple structure thereof. + + Args: + value: Numpy array or (possibly nested) tuple structure that bottoms out in + Numpy arrays. + + Returns: + A Shape object. + """ + if isinstance(value, tuple): + return Shape( + xla_data_pb2.TUPLE, + [CreateShapeFromNumpy(component) for component in value]) + else: + return _CreateShapeFromNumpy(value) + + +def CreateShapeFromDtypeAndTuple(dtype, shape_tuple): # pylint: disable=invalid-name + """Create a shape from a Numpy dtype and a sequence of nonnegative integers. + + Args: + dtype: a numpy dtype, e.g. np.dtype('int32'). + shape_tuple: a sequence of nonnegative integers. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(dtype)].primitive_type + return Shape(element_type, shape_tuple) diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 28d6a8c3fe85fa4179bf2f41c82ad4eb93a045fe..8fa6961d197dce519cf151283b8bc0836a4615c0 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -265,9 +265,9 @@ class ReferenceUtil { const Array3D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 3); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; for (int i = 0; i < 3; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -299,9 +299,9 @@ class ReferenceUtil { const Array4D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 4); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; for (int i = 0; i < 4; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -330,13 +330,14 @@ class ReferenceUtil { return result; } - // Slices with modulo-wrapping. + // Slices with index clamping template - static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, - int64 start, int64 size) { + static std::vector ClampSlice1D( + const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { - result.push_back(input[(start + i) % input.size()]); + result.push_back(input[(start + i)]); } return result; } @@ -552,12 +553,11 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 3); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()}; + int64 pad_low[3]; + int64 pad_high[3]; + int64 pad_interior[3]; + int64 output_bounds[3]; for (int64 i = 0; i < 3; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); @@ -573,7 +573,7 @@ class ReferenceUtil { Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; + int indices[] = {0, 0, 0}; for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { @@ -611,12 +611,12 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 4); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + int64 pad_low[4]; + int64 pad_high[4]; + int64 pad_interior[4]; + int64 output_bounds[4]; for (int64 i = 0; i < 4; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0d56a9a477b15964ad45e798865aa8d2c7385073..0b1cec1925d4424db086f8a3f62c91ede090189c 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -39,10 +39,10 @@ tf_cc_binary( srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@grpc//:grpc++_unsecure", ], ) @@ -54,6 +54,7 @@ tf_cc_test( ], deps = [ ":grpc_stub", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -61,7 +62,6 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@grpc//:grpc++_unsecure", ], ) @@ -71,9 +71,9 @@ cc_library( hdrs = ["grpc_service.h"], deps = [ ":xla_service_proto", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 313f11a9a957155eb277dc02ba5d2565c87e0235..f8414468bd9e0a9faf0072c47d94d12ab11b908d 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "grpc++/create_channel.h" -#include "grpc++/security/credentials.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -85,13 +85,13 @@ TEST_F(GRPCClientTestBase, ItsAlive) { TEST_F(GRPCClientTestBase, AxpyTenValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto y = builder.ConstantR1( - {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); - auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = ConstantR1( + &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = Mul(alpha, x); + Add(ax, y); std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 5f4dc6bd08f18b50e60b173432d3d305759bccea..4e1435fa30a24c320ddbedb84d37b369a3158a54 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -32,19 +32,6 @@ namespace xla { return tensorflow::ToGrpcStatus(s); } -::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Computation(arg, result); }); -} - -::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, - const OpRequest* arg, OpResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Op(arg, result); }); -} - ::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) { @@ -60,21 +47,6 @@ namespace xla { }); } -::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - return DelegateRPC([this, arg, results]() { - return service_->SetReturnValue(arg, results); - }); -} - -::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Execute(arg, result); }); -} - ::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, const ExecuteGraphRequest* arg, ExecuteResponse* result) { @@ -82,13 +54,6 @@ namespace xla { [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); } -::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); -} - ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { @@ -136,20 +101,6 @@ namespace xla { [this, arg, result]() { return service_->ResetDevice(arg, result); }); } -::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->IsConstant(arg, result); }); -} - -::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ComputeConstant(arg, result); }); -} - ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) { @@ -157,43 +108,4 @@ namespace xla { [this, arg, result]() { return service_->GetShape(arg, result); }); } -::grpc::Status GRPCService::GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationShape(arg, result); - }); -} - -::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->GetLocalShape(arg, result); }); -} - -::grpc::Status GRPCService::GetComputationStats( - ::grpc::ServerContext* context, const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationStats(arg, result); - }); -} - -::grpc::Status GRPCService::SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->SnapshotComputation(arg, result); - }); -} - -::grpc::Status GRPCService::LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->LoadComputationSnapshot(arg, result); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 50f02796f2d45baf894841782cd96d8d51a5ba00..ca1b09b648013ad45d806040c5ddcf11d9e5604e 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ #define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ -#include "grpc++/server_context.h" +#include "grpcpp/server_context.h" #include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" #include "tensorflow/compiler/xla/service/service.h" @@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service { static StatusOr> NewService( se::Platform* platform = nullptr); - ::grpc::Status Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) override; - - ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, - OpResponse* result) override; - ::grpc::Status Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) override; @@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - ::grpc::Status Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, const ExecuteGraphRequest* arg, ExecuteResponse* result) override; - ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; @@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service { const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - ::grpc::Status IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) override; - - ::grpc::Status ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - ::grpc::Status GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) override; - ::grpc::Status GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ::grpc::Status GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - ::grpc::Status GetComputationStats(::grpc::ServerContext* context, - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - ::grpc::Status SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - ::grpc::Status LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - private: std::unique_ptr<::xla::Service> service_; diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index e29908ccec80db76e3b5b856e57382c56430c379..c68c857c304138ff4318e243f66547c6acce1005 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -15,9 +15,9 @@ limitations under the License. // Basic server binary that exposes a xla::Service through a GRPC interface // on a configurable port. -#include "grpc++/security/server_credentials.h" -#include "grpc++/server.h" -#include "grpc++/server_builder.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 620ac6cec4f76d938e57e87849066df59514938a..7b8ab158e1396d7087a407be180ab44d2e16e121 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,21 +62,6 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->LoadComputationSnapshot(context, *request, response); - }); -} - -Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Execute(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -84,13 +69,6 @@ Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, }); } -Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, - ExecuteParallelResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteParallel(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { @@ -99,13 +77,6 @@ Status GRPCStub::ExecuteGraphParallel( }); } -Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteAsync(context, *request, response); - }); -} - Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -120,13 +91,6 @@ Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, }); } -Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, - ComputationStatsResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationStats(context, *request, response); - }); -} - Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { @@ -135,13 +99,6 @@ Status GRPCStub::GetComputationGraphStats( }); } -Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationShape(context, *request, response); - }); -} - Status GRPCStub::GetShape(const GetShapeRequest* request, GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -163,48 +120,6 @@ Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, }); } -// Methods used by ComputationBuilder. -Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Computation(context, *request, response); - }); -} - -Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->CreateOp(context, *request, response); - }); -} - -Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetLocalShape(context, *request, response); - }); -} - -Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, - SetReturnValueResponse* responses) { - return MakeRPC([this, request, responses](::grpc::ClientContext* context) { - return grpc_stub_->SetReturnValue(context, *request, responses); - }); -} - -Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->IsConstant(context, *request, response); - }); -} - -Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, - ComputeConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ComputeConstant(context, *request, response); - }); -} - Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { @@ -213,14 +128,6 @@ Status GRPCStub::ComputeConstantGraph( }); } -// Methods used by Computation. -Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->SnapshotComputation(context, *request, response); - }); -} - // Methods used by GlobalData. Status GRPCStub::Unregister(const UnregisterRequest* request, UnregisterResponse* response) { diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 5906d45769b5749b0c590dbc0e1972077dc3e7ba..8dfcb761387d608abbb1f62974f49b976a7ff7ff 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,39 +43,21 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) override; - - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) override; - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) override; - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; @@ -85,30 +67,9 @@ class GRPCStub : public ServiceInterface { Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Methods used by ComputationBuilder. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - - Status Op(const OpRequest* arg, OpResponse* result) override; - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; - // Methods used by Computation. - Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; - // Methods used by GlobalData. Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index c47164ee1b7657ae378a053f553442bee751753e..551ae895e05586daec0ffcd425f4950f76bdd50d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -75,19 +75,7 @@ service XlaService { rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { } - // Requests the program shape of the referenced computation. - rpc GetComputationShape(GetComputationShapeRequest) - returns (GetComputationShapeResponse) { - } - - // Requests the statistics of the given computation. - rpc GetComputationStats(ComputationStatsRequest) - returns (ComputationStatsResponse) { - } - // Requests the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc GetComputationGraphStats(ComputationGraphStatsRequest) returns (ComputationStatsResponse) { } @@ -121,25 +109,12 @@ service XlaService { rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { } - // Tests if an expression is a compile-time constant. - rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { - } - - // Computes the value of a constant expression. - rpc ComputeConstant(ComputeConstantRequest) - returns (ComputeConstantResponse) { - } - // Computes the value of a constant expression. The request contains the // computation graph for the constant expression. rpc ComputeConstantGraph(ComputeConstantGraphRequest) returns (ComputeConstantResponse) { } - // Retrieves the inferred shape for a value within a computation. - rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) { - } - // Requests one or more device handles from the target. The returned device // handles can be used to specify the device on which to execute computations // or transfer data. @@ -153,32 +128,6 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Requests that the referenced computation be specialized for the provided - // arguments for subsequent execution. This permits things such as value - // specialization. - rpc Specialize(SpecializeRequest) returns (SpecializeResponse) { - } - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { - } - - // Computation creates a new computation with the given name. - // A unique ComputationHandle is returned. - rpc Computation(ComputationRequest) returns (ComputationResponse) { - } - - // Adds a new op to a computation. - rpc CreateOp(OpRequest) returns (OpResponse) { - } - - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - rpc Execute(ExecuteRequest) returns (ExecuteResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. @@ -188,38 +137,13 @@ service XlaService { // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and // execution timing. - rpc ExecuteParallel(ExecuteParallelRequest) - returns (ExecuteParallelResponse) { - } - - // Invokes the provided list of computations in parallel with the provided - // global data for each computation. Returns a list of global data output and - // execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) returns (ExecuteParallelResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { - } - // Waits until the given execution (aysnchronously launched) is complete, and // returns the global data output. rpc WaitForExecution(WaitForExecutionRequest) returns (WaitForExecutionResponse) { } - - // Serializes a computation to proto form, so it can be loaded via - // LoadComputationSnapshot. - rpc SnapshotComputation(SnapshotComputationRequest) - returns (SnapshotComputationResponse) { - } - - // Loads a computation from a captured snapshot. - rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) - returns (LoadComputationSnapshotResponse) { - } } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 394447fb7fbc11243afbf1bcfa5af461e40f725a..fe99f700d23dbab799ba011b705c59d6ef7a2e52 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -16,19 +16,23 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( - name = "session_proto", - srcs = ["session.proto"], + name = "hlo_proto", + srcs = ["hlo.proto"], visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) -xla_proto_library( - name = "hlo_proto", +tf_proto_library_py( + name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], visibility = ["//visibility:public"], - deps = ["//tensorflow/compiler/xla:xla_data_proto"], + deps = ["//tensorflow/compiler/xla:xla_data_proto_py"], ) xla_proto_library( @@ -266,6 +270,7 @@ cc_library( "dfs_hlo_visitor.cc", "hlo_computation.cc", "hlo_instruction.cc", + "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", "hlo_sharding.cc", @@ -273,18 +278,21 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "hlo_clone_context.h", "hlo_computation.h", + "hlo_domain_metadata.h", "hlo_instruction.h", + "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", "hlo_sharding.h", ], deps = [ + ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", ":hlo_reachability", ":name_uniquer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -297,6 +305,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -336,8 +345,8 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -375,6 +384,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -385,20 +395,9 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", - ], -) - -cc_library( - name = "versioned_computation_handle", - srcs = ["versioned_computation_handle.cc"], - hdrs = ["versioned_computation_handle.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", ], ) @@ -407,12 +406,14 @@ tf_cc_test( srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -429,6 +430,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -533,45 +535,6 @@ tf_cc_test( ], ) -cc_library( - name = "user_computation", - srcs = ["user_computation.cc"], - hdrs = ["user_computation.h"], - deps = [ - ":hlo", - ":session_proto", - ":shape_inference", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "user_computation_test", - srcs = ["user_computation_test.cc"], - deps = [ - ":hlo_matchers", - ":user_computation", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -617,10 +580,8 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", - ":compilation_cache", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":execution_tracker", @@ -631,11 +592,8 @@ cc_library( ":hlo_module_config", ":hlo_proto_util", ":platform_util", - ":session_proto", ":source_map_util", ":transfer_manager", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -662,7 +620,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":hlo", @@ -671,8 +628,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -696,7 +651,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":platform_util", ":service", "//tensorflow/compiler/xla:status_macros", @@ -761,6 +715,23 @@ cc_library( ], ) +tf_cc_test( + name = "shaped_buffer_test", + srcs = ["shaped_buffer_test.cc"], + deps = [ + ":cpu_plugin", + ":device_memory_allocator", + ":platform_util", + ":shaped_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:ptr_util", + "//tensorflow/core:test", + ], +) + cc_library( name = "executable", srcs = ["executable.cc"], @@ -776,9 +747,7 @@ cc_library( ":hlo_graph_dumper", ":hlo_proto", ":pool", - ":session_proto", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -874,34 +843,12 @@ cc_library( ], ) -cc_library( - name = "computation_tracker", - srcs = ["computation_tracker.cc"], - hdrs = ["computation_tracker.h"], - deps = [ - ":hlo", - ":hlo_module_config", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1007,7 +954,6 @@ tf_cc_test( ":buffer_assignment", ":buffer_value", ":call_graph", - ":computation_tracker", ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", @@ -1021,9 +967,9 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1059,9 +1005,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1156,15 +1102,16 @@ tf_cc_test( srcs = ["hlo_scheduling_test.cc"], deps = [ ":buffer_value", + ":heap_simulator", ":hlo", ":hlo_ordering", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1197,9 +1144,22 @@ tf_cc_test( deps = [ ":hlo_matchers", ":instruction_fusion", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", ], ) @@ -1371,9 +1331,9 @@ tf_cc_test( deps = [ ":gather_expander", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1679,14 +1639,11 @@ tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ - ":computation_tracker", ":cpu_plugin", ":hlo", ":hlo_cost_analysis", ":local_service", ":service", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", @@ -1725,9 +1682,9 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1878,6 +1835,44 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_liveness_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -1957,6 +1952,7 @@ cc_library( hdrs = ["tuple_points_to_analysis.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", "//tensorflow/compiler/xla:shape_tree", @@ -1989,20 +1985,6 @@ tf_cc_test( ], ) -cc_library( - name = "compilation_cache", - srcs = ["compilation_cache.cc"], - hdrs = ["compilation_cache.h"], - deps = [ - ":executable", - ":hlo_module_config", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "layout_assignment", srcs = [ @@ -2014,10 +1996,12 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2087,12 +2071,31 @@ cc_library( ], ) +cc_library( + name = "hlo_module_dce", + srcs = ["hlo_module_dce.cc"], + hdrs = ["hlo_module_dce.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_liveness_analysis", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", @@ -2124,6 +2127,7 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", + ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", @@ -2131,6 +2135,7 @@ cc_library( ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2144,6 +2149,7 @@ tf_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ + ":flatten_call_graph", ":hlo", ":hlo_matchers", ":hlo_ordering", @@ -2153,6 +2159,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -2176,6 +2183,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_dce_test", + srcs = ["hlo_module_dce_test.cc"], + deps = [ + ":hlo", + ":hlo_module_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], @@ -2192,9 +2220,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2242,6 +2270,7 @@ cc_library( hdrs = ["hlo_cse.h"], deps = [ ":hlo", + ":hlo_domain_map", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -2264,10 +2293,10 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2306,6 +2335,79 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_domain_map", + srcs = ["hlo_domain_map.cc"], + hdrs = ["hlo_domain_map.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_sharding_metadata", + srcs = ["hlo_sharding_metadata.cc"], + hdrs = [ + "hlo_sharding_metadata.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_domain_isolator", + srcs = ["hlo_domain_isolator.cc"], + hdrs = ["hlo_domain_isolator.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_library( + name = "hlo_domain_remover", + srcs = ["hlo_domain_remover.cc"], + hdrs = ["hlo_domain_remover.h"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_domain_test", + srcs = ["hlo_domain_test.cc"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_remover", + ":hlo_parser", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], @@ -2387,10 +2489,10 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2448,7 +2550,6 @@ cc_library( name = "hlo_tfgraph_builder", srcs = ["hlo_tfgraph_builder.cc"], hdrs = ["hlo_tfgraph_builder.h"], - visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", @@ -2479,6 +2580,7 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", @@ -2517,6 +2619,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -2535,10 +2638,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2675,7 +2778,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -2711,8 +2814,8 @@ tf_cc_test( ":tuple_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2735,9 +2838,10 @@ tf_cc_test( deps = [ ":while_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2763,6 +2867,7 @@ tf_cc_test( ":hlo_matchers", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], @@ -2789,8 +2894,8 @@ tf_cc_test( ":hlo_matchers", ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -2821,3 +2926,97 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + ], +) + +tf_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":hlo_matchers", + ":indexed_array_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo", + ":hlo_lexer", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_casting_utils", + hdrs = ["hlo_casting_utils.h"], + deps = ["//tensorflow/core:lib"], +) + +tf_cc_test( + name = "hlo_casting_utils_test", + srcs = ["hlo_casting_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f732ed8f398c4699bd5247dc7fa1e9677340dcae..1ddeb27e4041df22bd3d0ec200bcddbd09937e01 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -50,20 +50,15 @@ namespace { namespace m = match; -// Returns whether operand is a literal with the given value. -bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAll(value); -} - bool IsAll(const HloInstruction* op, int8 value) { - if (IsLiteralWithValue(op, value)) { - return true; - } - if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { - return true; + switch (op->opcode()) { + case HloOpcode::kBroadcast: + return IsAll(op->operand(0), value); + case HloOpcode::kConstant: + return op->literal().IsAll(value); + default: + return false; } - return false; } // Returns whether the given transpose produces a result which is bit-wise @@ -75,21 +70,22 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape produces a result which is bit-wise +// Returns true if the given reshape/copy produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. // // This function is conservative -- even if this function returns false, the // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeIsBitcast( - const HloInstruction* reshape, +bool ReshapeOrCopyIsBitcast( + const HloInstruction* instr, const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { - CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); + CHECK(HloOpcode::kReshape == instr->opcode() || + HloOpcode::kCopy == instr->opcode()); - const HloInstruction* operand = reshape->operand(0); + const HloInstruction* operand = instr->operand(0); // Can't insert bitcasts if the compiler used a memory layout which isn't // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && - valid_bitcast_callback(operand->shape(), reshape->shape()); + return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && + valid_bitcast_callback(operand->shape(), instr->shape()); } // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -157,8 +153,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub) override; - Status HandleMaximum(HloInstruction* maximum) override; - Status HandleMinimum(HloInstruction* minimum) override; + Status HandleMap(HloInstruction* map) override; // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -198,8 +193,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Helper method to perform and add reduction in a single dimension. HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { - HloInstruction* zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction* zero = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::Zero(hlo->shape().element_type()).CloneToUnique())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -231,10 +227,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); - // A Reshape or Broadcast that feeds an element-wise operation with a unique - // non-scalar operand can sink to after the operation. - StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast); + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -431,7 +427,15 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); + if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) { + return Status::OK(); + } + + if (is_layout_sensitive_ && + ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + ReplaceWithBitcast(copy); + } + return Status::OK(); } @@ -447,7 +451,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( // Filter out and remove empty operands. std::vector nonempty_operands; for (HloInstruction* operand : operands) { - if (!ShapeUtil::HasZeroElements(operand->shape())) { + if (!ShapeUtil::IsZeroElementArray(operand->shape())) { nonempty_operands.push_back(operand); } } @@ -526,6 +530,10 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, BuildTupleConstant(computation_, constant->literal())); } + if (constant->shape().element_type() == TOKEN) { + return Status::OK(); + } + // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { @@ -561,6 +569,14 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { return Status::OK(); } +namespace { +template +Status InvertConstant(const HloInstruction& constant, Literal* result) { + return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return T{1.0} / constant.literal().Get(indices); + }); +} +} // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { Shape* shape; @@ -622,14 +638,31 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - HloInstruction* one = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(a->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = computation_->AddInstruction( - HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kMultiply, a, inverse)); + Literal new_literal(b->shape()); + switch (b->shape().element_type()) { + case F16: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case F32: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case BF16: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case F64: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case C64: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + default: + return Status::OK(); + } + auto inverse = computation_->AddInstruction( + HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); + return ReplaceInstruction(divide, new_divide); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) @@ -649,18 +682,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, b, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a, b_times_c)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); + return ReplaceInstruction(divide, new_divide); } // A / (B / C) => (A*C) / B if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, a, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a_times_c, b)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); + return ReplaceInstruction(divide, new_divide); } return Status::OK(); @@ -1056,9 +1089,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(dot->shape()) || + ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); return ReplaceWithNewInstruction( @@ -1219,9 +1252,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, switch (instruction->opcode()) { case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSort: case HloOpcode::kTranspose: return true; + case HloOpcode::kSort: + return (!ShapeUtil::IsTuple(instruction->shape())); default: return false; } @@ -1303,7 +1337,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); changed_ |= sink_succeeded; if (sink_succeeded) { return Status::OK(); @@ -1390,7 +1424,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { - if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); @@ -1555,15 +1589,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor:: - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast) { +StatusOr +AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast) { + TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); bool changed = false; - if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + if (ShapeUtil::IsScalar(broadcast->shape())) { return false; } - HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); - for (HloInstruction* user : reshape_or_broadcast->users()) { + HloInstruction* operand = broadcast->mutable_operand(0); + for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; } @@ -1581,55 +1616,50 @@ StatusOr AlgebraicSimplifierVisitor:: continue; } - int64 reshape_or_broadcast_operand_index = -1; // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_count = 0; - for (int64 i = 0; i < user->operand_count(); ++i) { - if (ShapeUtil::IsScalar(user->operand(i)->shape())) { - ++scalar_count; - } else { - reshape_or_broadcast_operand_index = i; + int64 scalar_broadcast_count = 0; + int64 broadcast_use_count = 0; + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; + } else if (broadcast == user_operand) { + ++broadcast_use_count; } } - if (scalar_count != user->operand_count() - 1) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } - VLOG(4) << "Sinking reshape or broadcast after user:"; - VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); + std::vector new_operands; + new_operands.reserve(user->operand_count()); + + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()), + user_operand->mutable_operand(0), {}))); + } else { + CHECK_EQ(broadcast, user_operand); + new_operands.push_back(operand); + } + } + VLOG(4) << "Sinking broadcast after user:"; + VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), - reshape_or_broadcast); - auto new_user_operands = user->operands(); - new_user_operands[reshape_or_broadcast_operand_index] = operand; - auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(operand->shape().dimensions()), - LayoutUtil::MinorToMajor(operand->shape())), - new_user_operands)); + HloInstruction* new_user = + computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()), + new_operands)); VLOG(4) << " new user: " << new_user->ToString(); - HloInstruction* new_reshape_or_broadcast = nullptr; - if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user)); - } else { - TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user, reshape_or_broadcast->dimensions())); - } - VLOG(4) << " new reshape/broadcast: " - << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); + HloInstruction* new_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + user->shape(), new_user, broadcast->dimensions())); + VLOG(4) << " new broadcast: " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -1640,7 +1670,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Reshape directly to empty constant if the shape contains zero-element // dimension. - if (ShapeUtil::HasZeroElements(reshape->shape())) { + if (ShapeUtil::IsZeroElementArray(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( Literal::CreateFromShape(reshape->shape())); @@ -1672,19 +1702,9 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } - // A Reshape that feeds a unary element-wise operation can sink the - // reshape after the unary element-wise operation. - TF_ASSIGN_OR_RETURN( - bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - changed_ |= sink_succeeded; - if (sink_succeeded) { - return Status::OK(); - } - // Make this a bitcast if possible. if (is_layout_sensitive_ && - ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { + ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -1751,7 +1771,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( // If any dimension of update is 0, elide the DynamicUpdateSlice. This // optimization becomes invalid should we later prefer to warn about out of // bound indices. - if (ShapeUtil::HasZeroElements(update->shape())) { + if (ShapeUtil::IsZeroElementArray(update->shape())) { return ReplaceInstruction(dynamic_update_slice, dynamic_update_slice->mutable_operand(0)); } @@ -1763,8 +1783,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { auto init_value = reduce->mutable_operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - if (ShapeUtil::HasZeroElements(arg->shape()) || - ShapeUtil::HasZeroElements(reduce->shape())) { + if (ShapeUtil::IsZeroElementArray(arg->shape()) || + ShapeUtil::IsZeroElementArray(reduce->shape())) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); @@ -1786,6 +1806,46 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); + } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1830,21 +1890,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape()) || - ShapeUtil::HasZeroElements(arg->shape())) { - auto reshape = computation_->AddInstruction( - HloInstruction::CreateReshape(reduce->shape(), arg)); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateMap(reduce->shape(), - {init_value, reshape}, function)); - } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { - if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) { return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcast(reduce_window->shape(), @@ -1858,7 +1909,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateMap(reduce_window->shape(), - {operand, reduce_window->mutable_operand(1)}, + {reduce_window->mutable_operand(1), operand}, function)); } @@ -2040,16 +2091,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { return ReplaceWithNewInstruction( convolution, HloInstruction::CreateBroadcast( convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(convolution->shape().element_type(), {}), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::Zero(convolution->shape().element_type()) + .CloneToUnique())), {})); } const auto& window = convolution->window(); @@ -2188,66 +2238,37 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { - // Match the following tree: - // min_operand operand - // \ / - // max_operand min - // \ / - // max - // where max_operand and min_operand are scalar constants. - { - HloInstruction* min; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMinimum, maximum, - /*matching_operand=*/&min, - /*other_operand=*/&max_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, min, - /*matching_operand=*/&min_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, - max_operand)) { +Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { + auto* map_computation = map->to_apply(); + auto* map_root = map_computation->root_instruction(); + if (map_root->opcode() == HloOpcode::kParameter) { + ReplaceInstructionIfSameShape( + map, map->mutable_operand(map_root->parameter_number())); + return Status::OK(); + } + if (map_root->opcode() == HloOpcode::kConstant) { + if (!ShapeUtil::IsScalar(map_root->shape())) { return Status::OK(); } + auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); + if (ShapeUtil::IsScalar(map->shape())) { + return ReplaceWithNewInstruction(map, std::move(clone)); + } + return ReplaceWithNewInstruction( + map, + HloInstruction::CreateBroadcast( + map->shape(), computation_->AddInstruction(std::move(clone)), {})); } - - return Status::OK(); -} - -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { - // Match the following tree: - // max_operand operand - // \ / - // min_operand max - // \ / - // min - // where max_operand and min_operand are scalar constants. - { - HloInstruction* max; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMaximum, minimum, - /*matching_operand=*/&max, - /*other_operand=*/&min_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, max, - /*matching_operand=*/&max_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, - max_operand)) { + std::vector new_operands; + for (auto* root_operand : map_root->operands()) { + if (root_operand->opcode() != HloOpcode::kParameter) { return Status::OK(); } + new_operands.push_back( + map->mutable_operand(root_operand->parameter_number())); } - - return Status::OK(); + auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); + return ReplaceWithNewInstruction(map, std::move(clone)); } StatusOr AlgebraicSimplifier::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 4e082877c776c35bab499c805fef7632765a3ee1..b733f6f59eb028b2dff921722c462441251772fe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Reduce(Reduce(A)) -> Reduce(A) +TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r4f32, "param")); + std::vector dims0({0}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7}); + HloInstruction* reduce0 = builder.AddInstruction( + HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation)); + std::vector dims1({1, 2}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, + dims1, add_computation)); + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -143,6 +181,42 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.AddInstruction(HloInstruction::CreateMap( + r2f32, + {param0, builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {}))}, + add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMap); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -296,17 +370,16 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r2f32, "param1")); HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, r2f32, "param2")); HloInstruction* param3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, r0f32, "param3")); + HloInstruction::CreateParameter(3, r2f32, "param3")); HloInstruction* div0 = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1)); HloInstruction* div1 = builder.AddInstruction( @@ -327,8 +400,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); - EXPECT_TRUE( - ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32)); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -388,7 +459,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -396,7 +466,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* param2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); builder.AddInstruction( @@ -413,14 +483,9 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); - - const HloInstruction* negate = - computation->root_instruction()->operand(1)->operand(1); - const Shape& negate_shape = negate->shape(); - EXPECT_EQ(0, negate_shape.dimensions_size()); } -// A / Const => A * (1 / Const) +// A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); @@ -439,20 +504,19 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Divide(op::Constant(), constant))); + op::Multiply(param0, op::Constant())); } // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, @@ -469,15 +533,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { - Shape r0c64 = ShapeUtil::MakeShape(C64, {}); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1c64, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction::CreateParameter(1, r1c64, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction::CreateParameter(2, r1c64, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, @@ -1088,6 +1151,33 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { EXPECT_THAT(computation->root_instruction(), param0); } +TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param")); + *param->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3}); + HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); + *copy->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({1, 2, 0, 3}); + auto computation = module().AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + + AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + // Verify that the copy is not replaced. + EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + + AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, + bitcasting_callback()); + ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + // Verify that the copy is replaced. + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); +} + // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); @@ -1318,59 +1408,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); - HloInstruction* movable_reshape = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), - HloOpcode::kMaximum, movable_reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Maximum(param, zero))); -} - -// Regression test for a bug in the reshape sinking transformation, where -// moving a reshape to a scalar led to a crash. -TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); - HloInstruction* reshape = builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); -} - // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { @@ -1707,7 +1744,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1752,7 +1789,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1774,7 +1811,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1797,7 +1834,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1925,7 +1962,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2030,160 +2068,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { EXPECT_EQ("NO_CHANGE", build_and_simplify()); } -// Test that max(min(A, x), y) is transformed to clamp(y, A, x) -TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, param0, min_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Minimum(param0, min_value), max_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar -// values. -TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for -// broadcasted scalar values. -TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to -// clamp(non-constant1, A, non-constant2) -TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); -} - -// Test that min(f(max(A, constant1)), constant2) is not transformed to -// clamp(constant1, A, constant2) -TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* fmax = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); - builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, fmax, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); -} - // Test that slice(broadcast(/*scalar value*/)) simplifies to a single // broadcast. TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { @@ -2193,10 +2077,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2212,10 +2094,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2230,10 +2112,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2252,7 +2132,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2261,7 +2141,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2342,7 +2223,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2437,7 +2319,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 96e02b82b97ff2fd682638f4c6297cbc2019c481..ec13fadbc75e2315d1d6ef72e24a0faca0c7de40 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,21 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} HloComputation* GetOrCreateScalarAddComputation( PrimitiveType primitive_type) { - HloComputation** scalar_add_computation = - &scalar_add_computations_[primitive_type]; - if (*scalar_add_computation) { - return *scalar_add_computation; - } - HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(primitive_type, {}); auto scalar_lhs = b.AddInstruction( @@ -93,26 +85,39 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - *scalar_add_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_add_computation; + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormExpander is - // traversing. - HloComputation* computation_; - - bool rewrite_training_op_; - bool rewrite_inference_op_; - bool rewrite_grad_op_; - bool use_fusion_; - - // Whether rewrite has occurred. - bool changed_ = false; + std::unique_ptr Rsqrt( + HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, + operand, exponent); + } - // Cached computations for adding two scalars. - tensorflow::gtl::FlatMap - scalar_add_computations_; + std::unique_ptr Mean( + int64 element_count, HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* elem_count_recip = + add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(1.0 / element_count))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, + operand, elem_count_recip); + } // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -136,6 +141,16 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { changed_ = true; return Status::OK(); } + // Current HloComputation instance the BatchNormExpander is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_inference_op_; + bool rewrite_grad_op_; + + // Whether rewrite has occurred. + bool changed_ = false; }; } // namespace @@ -143,13 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -167,6 +181,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); // Expand batch norm training into smaller HLO ops. @@ -176,12 +194,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -193,8 +206,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = - add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = add(HloInstruction::CreateBroadcast( + operand_shape, + add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -213,8 +227,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( GetOrCreateScalarAddComputation(ptype); // X^2. - auto operand_squared = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = + add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); // Sum[X]. auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, dimensions_without_feature, @@ -225,71 +239,48 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); - // Fuse two parallel reduces together to improve performance. - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); - - sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - squared_sum = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // E[X]. - auto mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + auto mean = add(Mean(elements_per_feature_int64, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); // E^2[X]. - auto mean_square = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, mean, mean)); + auto mean_square = + add_binary(feature_shape, HloOpcode::kMultiply, mean, mean); // Var[X]. - auto var = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + auto var = + add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square); auto var_broadcasted = add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted); auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); @@ -331,8 +322,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( + operand_shape, + computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(epsilon_literal))), + {})); std::vector dimensions_without_feature; @@ -349,6 +343,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); auto scale_broadcasted = add( @@ -364,30 +362,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( @@ -435,6 +426,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); @@ -450,26 +445,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = + auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon_activation = add( + HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {})); + auto epsilon_feature = + add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {})); std::vector dimensions_without_feature; @@ -489,26 +478,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, - epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = + add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation), + add)); + + auto rsqrt_var_add_epsilon = add(Rsqrt( + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature), + add)); // X - E[X]. - auto activation_minus_mean = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); + auto activation_minus_mean = add_binary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). auto grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + activation_minus_mean); HloComputation* add_reduce_computation = GetOrCreateScalarAddComputation(ptype); @@ -524,25 +510,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple( - {sum_grad_output_times_activiation_minus_mean, grad_beta})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, - HloInstruction::FusionKind::kInput); - - sum_grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - grad_beta = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); + auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, + sum_grad_output_times_activiation_minus_mean, + rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, @@ -554,39 +525,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( {feature_index})); // I4 = (X - E[X]) * I3 - auto i4 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); + auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3, + activation_minus_mean); // I5 = I4 / (Var[X] + epsilon) - auto i5 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, i4, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)))); + auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4, + add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation)); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = + add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, - elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add( + Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); - auto i1 = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto elements_per_feature_literal = + Literal::CreateR0(elements_per_feature_int64); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + add(HloInstruction::CreateBroadcast( + activation_shape, elements_per_feature, {}))); // I6 = I1 - I2 - I5 - auto i6 = add(HloInstruction::CreateBinary( + auto i6 = add_binary( activation_shape, HloOpcode::kSubtract, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - i1, i2)), - i5)); + add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, i6)); + auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6); auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { @@ -615,8 +587,8 @@ StatusOr BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_fusion_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 4ad987085da91684bb7891070afeefd19be4138f..7ae202c583516443a6263403fb5460d1adbabd97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, bool use_fusion = true) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 08d0152e3cfcfcb7ae1e85f72c2f7dc856f5e8b3..1b8b2d204503576c3fcb02f6d5b37f2db45e1768 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape()) || - !bfloat16_support_->SupportsMixedPrecisions(*crs)) { - return DefaultAction(crs); - } - // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + // Then do per-tuple-element handling on the output. std::vector> per_tuple_element_gtes( crs->operand_count()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 28e71c2054f59ba4d5d096bf7d898161877bb42f..f7b4c1405dbc8719d8fba5476e6e41d2921ea877 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum, /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* tuple = builder.AddInstruction( HloInstruction::CreateTuple({gte_a, convert_gte_b})); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(FoldConversions(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 1afaefd9df9c5771fb9e134ae9050f3abb00ea4a..830f26422bdc2b3bd789e7d5926bcebac815d34a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ed0746980f87ac2bea79c308644dc63769f9e309..ff6d5027efba813042af65a0e50e172cc0a99ff8 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -85,9 +85,9 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { - for (const auto& index : root_changes_it->second) { + for (const auto& entry : root_changes_it->second) { for (const HloValue* value : - dataflow_->GetValueSet(root, index).values()) { + dataflow_->GetValueSet(root, entry.second).values()) { changed_root_buffers.insert(value); } } @@ -204,6 +204,12 @@ void BFloat16Propagation::DetermineWhileComputationsPrecision( bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { + // If the subshape isn't floating point then none of the users will be BF16. + const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index); + if (subshape.element_type() != BF16 && subshape.element_type() != F32) { + return false; + } + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { @@ -257,23 +263,34 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // If the op propagates precision and it outputs a BF16, then it's OK to // supply BF16 also as the input. In the backward pass, the users shapes // should have already been processed. - PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; - if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && - ShapeUtil::IsTuple(use.instruction->shape()))) { - ShapeIndex use_output_index{use.operand_number}; - for (int64 i : use.operand_index) { - use_output_index.push_back(i); - } - user_output_type = - OutputTypeAfterChange(use.instruction, use_output_index); - } else { - user_output_type = OutputTypeAfterChange(use.instruction, {}); - } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( - *use.instruction, use.operand_number) && - user_output_type == BF16) { - continue; + *use.instruction, use.operand_number)) { + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) { + ShapeIndex use_output_index; + for (int64 i = 1; i < use.operand_index.size(); ++i) { + use_output_index.push_back(use.operand_index[i]); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else { + if (OutputTypeAfterChange(use.instruction, use.operand_index) == + BF16) { + continue; + } + } } return false; } @@ -368,6 +385,7 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output( if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && hlo->opcode() != HloOpcode::kTuple && hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->opcode() != HloOpcode::kDomain && hlo->shape().element_type() != BF16) { for (int64 i = 0; i < hlo->operand_count(); ++i) { if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, @@ -559,7 +577,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); tensorflow::gtl::FlatSet resolved; for (auto comp_it = computations_topological_order.rbegin(); @@ -631,7 +649,7 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { subshape, converted_outputs.element(parent_index), output_index.back())); } - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { continue; } if (!ShapeUtil::Compatible( @@ -742,7 +760,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); // The first step is a forward pass (parameters to root), where we determine // the potential candidate instructions to use bfloat16 in the outputs that @@ -784,9 +802,8 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // Apply the changes in changes_to_bf16_. for (auto& change : changes_to_bf16_) { - auto shape = change.first->mutable_shape(); - for (const auto& index : change.second) { - auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + for (const auto& entry : change.second) { + auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); subshape->set_element_type(BF16); changed_ = true; @@ -815,8 +832,8 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { PrimitiveType BFloat16Propagation::OutputTypeAfterChange( HloInstruction* hlo, const ShapeIndex& index) const { - PrimitiveType type_on_hlo = - ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index); + const PrimitiveType type_on_hlo = subshape->element_type(); if (type_on_hlo != F32) { return type_on_hlo; } @@ -824,7 +841,7 @@ PrimitiveType BFloat16Propagation::OutputTypeAfterChange( if (it == changes_to_bf16_.end()) { return type_on_hlo; } - return ContainsKey(it->second, index) ? BF16 : F32; + return ContainsKey(it->second, subshape) ? BF16 : F32; } PrimitiveType BFloat16Propagation::ValueTypeAfterChange( @@ -838,14 +855,16 @@ void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { if (target_type == BF16) { auto& entry = changes_to_bf16_[hlo]; - entry.insert(index); + entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index), + index); } else { CHECK_EQ(target_type, F32); auto it = changes_to_bf16_.find(hlo); if (it == changes_to_bf16_.end()) { return; } - it->second.erase(index); + it->second.erase( + ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index de0355ddfca127753f90d1899b424a8e77c9b291..02b8cad089dd8465b7af5c1014e37b77ded6949d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -194,17 +194,11 @@ class BFloat16Propagation : public HloPassInterface { // are subject to further adjustment, then finally applied to the HLOs. This // avoids setting changed_ to true but all changes are reverted during // adjustment. - struct IndexHasher { - int64 operator()(const ShapeIndex& index) const { - int64 hash = 0; - for (int64 i : index) { - hash = tensorflow::Hash64Combine(hash, std::hash()(i)); - } - return hash; - } - }; + // + // For each HloInstruction, changes_to_bf16_ stores the affected buffers in + // the output as a map from in-place pointers to subshapes to shape indices. tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5e1499ee6b6ef397f95f7ed29e808d530777bd07..2124b302cccaca7f87dc4f3274233509d6a6161f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -150,11 +150,11 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(0)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)), + dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(1)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)), + dot->operand(1)->literal())); } // Tests that BF16 can be propagated through nested tuples. @@ -434,7 +434,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); HloInstruction* gte0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, sel, 0)); HloInstruction* gte1 = builder.AddInstruction( @@ -742,4 +742,89 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { EXPECT_EQ(add1->shape().element_type(), BF16); } +TEST_F(BFloat16PropagationTest, TupleDomain) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* a_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1})); + HloInstruction* b_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1})); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans})); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), root); + + // test BF16 propagated through domain + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(), + BF16); + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(), + BF16); + + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_TRUE(OutputsBF16(a_gte)); + EXPECT_TRUE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); +} + +// Tests that bf16 is not propagated through a domain in case its input cannot +// be propagated. In the case below the input of the domain is the parameter +// tuple which cannot be propagated, so the domain instruction is not propagated +// either. +TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* a_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); + HloInstruction* b_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_FALSE(OutputsBF16(a_gte)); + EXPECT_FALSE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(domain)); + EXPECT_FALSE(OutputsBF16(param)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 07b4b14b5ec1bdbc01345091105df69368b0b2fb..23645346e6f491beb5171cc839c013ce5f83d789 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -25,6 +25,7 @@ bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -43,6 +44,7 @@ bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -81,6 +83,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -92,11 +95,15 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; + case HloOpcode::kBitcast: + return hlo.shape().element_type() == + hlo.operand(0)->shape().element_type(); case HloOpcode::kDynamicSlice: return operand_index == 0; case HloOpcode::kDynamicUpdateSlice: return operand_index == 0 || operand_index == 1; case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; default: break; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index c0b8bf903923a327fb1378eafb51a7d493d5e62d..afe4b2e1425f9e84320ffd5f08beceaac8168c22 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -632,7 +633,7 @@ Status BufferAssignment::ComputeSummaryStats() { if (module_sequence.size() == module_->computation_count()) { TF_ASSIGN_OR_RETURN( const int64 min_size, - MinimumMemoryForSequence(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index a4fb0eefaca094898ed9acad8062484d1a36afe7..6958ee722a8189b8089ba2d8f53aca8174f6a593 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -33,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -82,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { class BufferAssignmentTest : public HloTestBase { protected: - BufferAssignmentTest() : computation_tracker_() {} + BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -252,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Computation tracker for nested computations. - ComputationTracker computation_tracker_; - // Shapes for use in the examples. Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); @@ -375,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -422,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // share anything. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -481,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // have the color 0, which allows the mul and add to share buffers. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -551,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -605,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); module->AddEntryComputation(builder.Build()); @@ -658,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); auto exp2 = builder.AddInstruction( @@ -822,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); auto tanh = builder.AddInstruction( @@ -1369,8 +1365,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter( 2, ShapeUtil::MakeShape(PRED, {}), "param1")); - auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); + auto select = builder.AddInstruction( + HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, + pred_param, tuple_param0, tuple_param1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1500,11 +1497,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -1540,7 +1537,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { // be {%rev, %neg, %concat}. This occurs right at the concat itself. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto log = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); auto rev = builder.AddInstruction( @@ -1677,7 +1674,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, @@ -1797,7 +1794,7 @@ ENTRY %test_module { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. @@ -1878,11 +1875,15 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); - auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = + builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, "")); + auto infeed_data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r0s32, infeed, 0)); auto cond0 = module->AddEmbeddedComputation(build_cond()); auto body0 = module->AddEmbeddedComputation(build_body()); auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); + HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data)); auto cond1 = module->AddEmbeddedComputation(build_cond()); auto body1 = module->AddEmbeddedComputation(build_body()); @@ -1913,8 +1914,8 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = {infeed, while0, while1, zero, - add, while2, tuple}; + sequence[module->entry_computation()] = { + token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; TF_ASSERT_OK_AND_ASSIGN( auto assignment, BufferAssigner::Run( @@ -2107,7 +2108,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index f623aef67a4f98b447a9a15634a78deb60cfe6f1..7833ebe73ba5d2412101eede1b584ce86df084e8 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); + HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); auto send = builder.AddInstruction( - HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); + HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a8053d15e124319c5c898f0034b9aaa95a007a89..a23427f00ccd88bb0fe1d973a667f80ca54b14cd 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 738d00881dd057fc13c115006c15e8f5b6d14a1d..924348c870b9ca3d86af560a0c8359af7220427e 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -148,14 +148,16 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({})); outfeeder.AddInstruction( - HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/"")); + HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); HloComputation::Builder outer(TestName() + ".outer"); outer.AddInstruction(HloInstruction::CreateCall( - ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation)); + outfeed_computation->root_instruction()->shape(), /*operands=*/{}, + outfeed_computation)); module->AddEntryComputation(outer.Build()); diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index c7763f2ca3e68490cd0cd9b4ba4d7bd180134080..fac0afd672ff3ed083aacf778dd9c4f90a2ee870 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,9 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc deleted file mode 100644 index b16907da9e9c909d2639f83895db27d724a84a7b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/compilation_cache.h" - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -std::shared_ptr CompilationCache::Insert( - std::unique_ptr executable, - const HloModuleConfig& module_config) { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = - BuildKey(executable->entry_computation_handle(), module_config); - VLOG(2) << "inserting cache key: " << key; - if (cache_.count(key) == 0) { - cache_.emplace(key, std::move(executable)); - } else { - // Executable already exists in the cache. This can happen if two Execute - // calls for a new computation are received simultaneously by the - // service. In this case, we discard the Executable given as a parameter and - // return what is in the cache. This is necessary because the service relies - // on the cache to keep ownership of the Executable. We only want to store - // one Executable for a given computation version and we can't discard the - // executable which is in the cache because it may be in use. - executable.reset(); - } - return cache_.at(key); -} - -std::shared_ptr CompilationCache::LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = BuildKey(versioned_handle, module_config); - VLOG(2) << "looking up cache key: " << key; - if (cache_.count(key) == 0) { - VLOG(2) << "cache key not found: " << key; - return nullptr; - } else { - std::shared_ptr result = cache_.at(key); - VLOG(2) << "hit executable with module config: " - << result->module_config().compilation_cache_key(); - return result; - } -} - -CompilationCache::CacheKey CompilationCache::BuildKey( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - // The computation shape is represented entirely by its ProgramShape member, - // so just serialize the proto as part of the key. - return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", - versioned_handle.version, "::", - module_config.compilation_cache_key()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h deleted file mode 100644 index 09989726ae6629aa65cb1dd84c16408a75019fa5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { - -// A cache which stores Executables indexed by computation handle and version. -class CompilationCache { - public: - CompilationCache() {} - - // Insert the given Executable into the cache. Return a bare Executable - // pointer for the caller to use. Note: the returned pointer will *not* be the - // same as the given unique pointer if the computation already exists in the - // cache. See comments in the .cc implementation for details of this case. - // - // module_config is provided by the caller, instead of being taken from the - // executable, so that we can insert keys into the compilation cache that are - // devoid of layout (where XLA gets to choose what layout to compile). - // - // A shared_ptr is returned so the caller can keep the Executable from being - // destructed in the event that the Executable is evicted from the - // computation cache (and the cache's shared_ptr to the Executable is - // destructed). - std::shared_ptr Insert(std::unique_ptr executable, - const HloModuleConfig& module_config); - - // Lookup the Executable for the specified versioned computation in the cache. - // Return a shared_ptr to the Executable if it exists in the cache. Return - // nullptr otherwise. - std::shared_ptr LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - - protected: - mutable tensorflow::mutex mutex_; - - // Map from versioned handle with program layout to Executable built - // for that computation version and program layout. - using CacheKey = string; - - CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - std::map> cache_ GUARDED_BY(mutex_); - - private: - TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d39fd7307ae1b5bd0c431f98c413011ca081050b..7426672a7a2a9102bd5ea98bd51092982e1e09b4 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +63,8 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_program_shape()); @@ -101,59 +101,8 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); -} - -StatusOr>> -CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto state if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); - hlo_modules.push_back(std::move(hlo_module)); - } - - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, + metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 7f2ce0e8974c01b09664235d7b9d19555b2705a3..1ac950bdd66bd034dfdafa8598ec506221e99c2f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -38,24 +38,7 @@ class CompileOnlyService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); - // A description of a xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { HloModuleProto computation; std::vector argument_layouts; @@ -65,31 +48,21 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8b01a6c4b5004d03e6e7d23b99b923fdcdeaff99..6b3b9820f09803c8a04504e6c35c22de51abf04b 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,6 +28,34 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return {}; +} + +std::unique_ptr +Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return nullptr; +} + +// Define a default version where metadata is not used. +StatusOr>> +Compiler::CompileAheadOfTime( + std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { + if (metadata != nullptr) { + return Unimplemented( + "Populating AotCompilationMetadata is not implemented on this " + "compiler."); + } + return CompileAheadOfTime(std::move(modules), options); +} + /* static */ std::map* Compiler::GetPlatformCompilerFactories() { static auto* r = new std::map; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index a4b59d1ba9b24e3f886a7feb51181ae8f990951f..99abb9bae32b35652e84cddc7c38dbd97ecb5006 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -24,9 +24,11 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -34,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -91,6 +94,19 @@ class AotCompilationOptions { DebugOptions debug_options_; }; +// Abstract superclass describing metadata produced during ahead-of-time +// compilation. +class AotCompilationMetadata { + public: + AotCompilationMetadata(const AotCompilationMetadata&) = delete; + AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; + + virtual ~AotCompilationMetadata() = default; + + protected: + AotCompilationMetadata() = default; +}; + // Abstract compiler interface that is subclassed for compilation on a // particular platform. // @@ -153,12 +169,39 @@ class Compiler { std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; + // Returns the backend configurations that the backend will consider for the + // given HLO. Returns no configurations if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + + // Returns the backend configuration that the backend chooses by default for + // the given HLO. Returns no configuration if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::unique_ptr + ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) = 0; + // Similar to CompileAheadOfTime above but AotCompilationMetadata + // has an argument that can be populated during compilation. + virtual StatusOr>> + CompileAheadOfTime(std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + ///// // The Compiler class also serves as a point to register compiler objects // for the various platforms. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index d2d4f14fcec35f5b51a2670a646154ce8bb9bfc1..cb61f3da39fb8eef69fd81066d87a1da91a62935 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -23,12 +23,15 @@ limitations under the License. namespace xla { -ComputationLayout::ComputationLayout(const ProgramShape& program_shape) +ComputationLayout::ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts) : result_layout_(program_shape.result()) { for (auto& shape : program_shape.parameters()) { parameter_layouts_.emplace_back(shape); } - SetToDefaultLayout(); + if (ignore_layouts) { + SetToDefaultLayout(); + } } void ComputationLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 80e102411c7885669947d89f378b1ec61e3e4e96..6975f387b4864bf28ea0ad23d7d4602b5b346e08 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -32,10 +32,20 @@ namespace xla { // mutable layouts. class ComputationLayout { public: + // Creates a new ComputationLayout with the given result layout. + explicit ComputationLayout(ShapeLayout result_layout) + : result_layout_(std::move(result_layout)) {} + // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the - // ProgramShape are ignored. - explicit ComputationLayout(const ProgramShape& program_shape); + // ProgramShape are ignored if ignore_layouts is true. + explicit ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts = true); + + // Adds a new parameter layout to the computation layout. + void add_parameter_layout(ShapeLayout shape_layout) { + parameter_layouts_.push_back(std::move(shape_layout)); + } // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc deleted file mode 100644 index 70e25eebdb068db893e24aec0f72d09090ac7027..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/computation_tracker.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" - -using ::tensorflow::strings::Appendf; - -namespace xla { - -ComputationTracker::ComputationTracker() : next_computation_(1) {} - -ComputationHandle ComputationTracker::NewComputation( - const string& computation_name) { - tensorflow::mutex_lock lock(computation_mutex_); - ComputationHandle computation_handle; - int64 handle_value = next_computation_++; - computation_handle.set_handle(handle_value); - opaque_to_computation_[handle_value] = - MakeUnique(computation_name, computation_handle); - return computation_handle; -} - -StatusOr ComputationTracker::LoadSessionModule( - const SessionModule& session_module) { - tensorflow::mutex_lock lock(computation_mutex_); - - // For each embedded computation, create a new computation based on its - // serialized data, and place the mapping from the old computation handle to - // the new computation handle. - - // Build a mapping from old embedded computation handles to new computation - // handles. We build the ID mapping first since the embedded computations are - // in no particular order and may refer to each other. - std::map old_to_new; - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { - return InvalidArgument("Duplicate embedded computation handle %lld", - old_handle); - } - } - - // Create a new computation from each serialized embedded computation. - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - const ComputationHandle& new_handle = old_to_new[old_handle]; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - computation, new_handle, old_to_new)); - } - - // Finally, place the entry computation in the tracker with all of the - // remappings populated from the above. - const int64 old_handle = session_module.entry().computation_handle().handle(); - TF_ASSIGN_OR_RETURN( - old_to_new[old_handle], - LoadSessionComputation(session_module.entry(), &old_to_new)); - return old_to_new[old_handle]; -} - -StatusOr> -ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); - const VersionedComputationHandle entry_versioned_handle = - user_computation->GetVersionedHandle(); - std::set visited; - std::list post_order; - { - tensorflow::mutex_lock lock(computation_mutex_); - ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); - } - auto session_module = MakeUnique(); - *session_module->mutable_entry() = - Resolve(entry_versioned_handle.handle) - .ValueOrDie() - ->CloneSessionComputation(entry_versioned_handle.version); - for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { - *session_module->add_embedded_computations() = - Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); - } - return std::move(session_module); -} - -StatusOr ComputationTracker::Resolve( - const ComputationHandle& computation) const { - tensorflow::mutex_lock lock(computation_mutex_); - return ResolveInternal(computation); -} - -ComputationHandle ComputationTracker::AllocateHandle() { - int64 handle_value = next_computation_++; - ComputationHandle result; - result.set_handle(handle_value); - return result; -} - -StatusOr ComputationTracker::LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) { - TF_RET_CHECK(old_to_new != nullptr); - const ComputationHandle new_handle = AllocateHandle(); - (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - session_computation, new_handle, *old_to_new)); - return new_handle; -} - -StatusOr ComputationTracker::ResolveInternal( - const ComputationHandle& computation) const { - auto it = opaque_to_computation_.find(computation.handle()); - if (it == opaque_to_computation_.end()) { - return NotFound("computation handle not found: %lld", computation.handle()); - } - UserComputation* user_computation = it->second.get(); - return user_computation; -} - -void ComputationTracker::ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const { - if (visited->count(versioned_handle) > 0) { - CHECK_EQ(1, visited->count(versioned_handle)); - return; - } - - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - std::vector embedded_handles = - computation->GetEmbeddedComputations(versioned_handle.version); - - for (const auto& embedded_handle : embedded_handles) { - ComputeComputationPostOrder(embedded_handle, visited, post_order); - } - - visited->insert(versioned_handle); - post_order->push_back(versioned_handle); -} - -StatusOr> ComputationTracker::BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(computation_mutex_); - - VLOG(1) << "BuildHloModule(" << entry_handle - << ", include_unreachable_instructions=" - << include_unreachable_instructions << ")"; - XLA_VLOG_LINES(1, ToStringInternal()); - - TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, - ResolveInternal(entry_handle.handle)); - - // Build a topological sort of the entry and any embedded computations as a - // list. The root of the computation will be the last element in the list. - std::set visited; - std::list post_order; - ComputeComputationPostOrder(entry_handle, &visited, &post_order); - - // Map from ComputationHandle value and computation version to HloComputation. - std::map hlo_computations; - - // The resolver lambda resolves VersionedHandles to embedded - // HloComputation*. This is required by UserComputation::BuildHloComputation - // when lowering calling operations (map, reduce etc). - auto resolver = [&hlo_computations]( - const VersionedComputationHandle& versioned_handle) -> HloComputation* { - CHECK_GT(hlo_computations.count(versioned_handle), 0); - return hlo_computations.at(versioned_handle); - }; - - // Print the post-order list for this entry computation. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Visiting UserComputations in post order:"; - for (const VersionedComputationHandle& versioned_handle : post_order) { - VLOG(2) << " " << versioned_handle; - } - } - - string module_name = - tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle, config); - for (auto versioned_handle : post_order) { - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - computation->BuildHloComputation(versioned_handle.version, resolver, - config.debug_options(), - include_unreachable_instructions)); - - // Add the newly created computation to VersionedHandle-to-HloComputation - // map. - DCHECK_EQ(0, hlo_computations.count(versioned_handle)); - hlo_computations[versioned_handle] = hlo_computation.get(); - - if (computation == entry_computation) { - module->AddEntryComputation(std::move(hlo_computation)); - } else { - module->AddEmbeddedComputation(std::move(hlo_computation)); - } - } - - return std::move(module); -} - -string ComputationTracker::ToString() const { - tensorflow::mutex_lock lock(computation_mutex_); - return ToStringInternal(); -} - -string ComputationTracker::ToStringInternal() const { - string out; - Appendf(&out, "ComputationTracker(%p):\n", this); - for (const auto& handle_computation : opaque_to_computation_) { - int64 handle = handle_computation.first; - const std::unique_ptr& computation = - handle_computation.second; - Appendf(&out, " %4lld : %s \"%s\"\n", handle, - computation->GetVersionedHandle().ToString().c_str(), - computation->name().c_str()); - } - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h deleted file mode 100644 index d42d66adefe7faa2751da4cd80b392a38917ce70..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Tracks computations for the XLA service; computations can be registered -// with a UserComputation instance and can be resolved from a handle for later -// use. -// -// This class is also capable of serializing/deserializing computations that it -// tracks (and to serialize properly you need to serialize all referred-to -// computations as well). -class ComputationTracker { - public: - ComputationTracker(); - - // Creates a new UserComputation object and returns the corresponding - // ComputationHandle for it. - // - // Precondition: user_computation is not already present in the map. - ComputationHandle NewComputation(const string& computation_name); - - // Restores session data for a computation that has been serialized, and - // allocates a new computation handle for it. - StatusOr LoadSessionModule( - const SessionModule& session_module); - - // Snapshots a computation (referenced by the provided handle) at its latest - // version, returning a module where it is the entry, and any referred-to - // computations are entrained as "embedded" (non-entry) computations. - StatusOr> SnapshotComputation( - const ComputationHandle& computation); - - // Resolves a ComputationHandle to a UserComputation that is present in the - // map. - StatusOr Resolve( - const ComputationHandle& computation) const; - - // Builds an HLO module using the specified computation as the entry. The - // module will include the entry computation as well as all computations which - // are called directly or indirectly from the entry computation via operations - // like "map". config is the HLO module configuration to use for the - // constructed module. - // If include_unreachable_instructions is true, then instructions - // which are not reachable from the root are lowered into HloInstructions - // including unreachable parameters. This ensures the entry HloComputation has - // the same program shape (ProgramShape) as the entry UserComputation. - StatusOr> BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions = true) const; - - string ToString() const; - - private: - // Bumps the next_computation_ number and returns the allocated number wrapped - // in a ComputationHandle. - ComputationHandle AllocateHandle() - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Loads a session computation into a UserComputation, registers it, and - // returns the computation handle of the registered computation. If old_to_new - // is provided, it is used for remapping references to computations present in - // session_computation. - // - // old_to_new will be updated with the mapping from session_computation's old - // handle to the returned handle value, and may not be null. - StatusOr LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Internal implementation of Resolve method which requires, but does not - // acquire the mutex. - StatusOr ResolveInternal( - const ComputationHandle& computation) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Builds a post order sort of a computation ("entry") and all of its embedded - // computations including all transitively embedded computations. An embedded - // computation (the callee) will always appear in the sort before the - // computation which calls the embedded computation (the caller). Necessarily, - // the entry computation is the last element in the sort. visited and - // post_order should be empty when calling. post_order contains the post order - // sort when the function return. - void ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Guards the computation mapping. Marked mutable so that the Resolve method - // can remain const; Resolve does't really modify the tracker in any way, but - // it has to lock the mutex for safety. - mutable tensorflow::mutex computation_mutex_; - - // The next sequence number to assign to a computation, guarded by the same - // mutex as the mapping as they'll be mutated at the same time. - int64 next_computation_ GUARDED_BY(computation_mutex_); - - // Mapping from ComputationHandle value to the corresponding registered - // UserComputation object. - std::map> opaque_to_computation_ - GUARDED_BY(computation_mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 868348547d9f5cbdc7576c7fc0697d72c3a3e557..68f6ffc6b7012b7674b8a046df71c7aed7a386fa 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -119,10 +119,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = + true_computation->AddInstruction(HloInstruction::CreateAfterAll({})); auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( true_computation->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))), - /*channel_id=*/0)); + token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -133,8 +135,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = + true_computation->AddInstruction(HloInstruction::CreateAfterAll({})); auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -144,8 +148,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); - false_computation->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = + false_computation->AddInstruction(HloInstruction::CreateAfterAll({})); + false_computation->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 33d8338809d4e8c7c4774f062c3dda5494543ca6..52e66b3e77097dfdb462ed4a953581b9d316064b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -472,6 +472,10 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { + if (ShapeUtil::IsToken(value_a->shape())) { + // Token values have no representation and cannot interfere. + continue; + } for (const HloValue* value_b : buffer.values()) { if (value_a != value_b) { DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, @@ -613,7 +617,10 @@ class CopyRemover { VLOG(2) << copy->name() << " is not removable"; return false; } - + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } const CopyNodes& copy_node = copy_map_.at(copy); ValueNode* src = copy_node.src; ValueNode* dest = copy_node.dest; @@ -947,28 +954,6 @@ class CopyRemover { BufferValueTracker buffer_value_tracker_; }; -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); - - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); - } - } - } - return Status::OK(); -} - // Add copies to address special constraints on the roots of computations not // related to live range interference: // @@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { HloInstruction* instruction = pair.first; const ShapeTree& indices_to_copy = pair.second; + ShapeTree copies_added(indices_to_copy.shape()); std::vector users = instruction->users(); TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); + instruction, &indices_to_copy, &copies_added)); for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } + // Special case copies are not eligible for later copy elision passes. + indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { + if (has_copy) { + HloInstruction* copy = *copies_added.mutable_element(index); + if (copy != nullptr) { + copy->SetCopyElisionAllowed(false); + } + } + }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1097,6 +1092,29 @@ void MaybeDumpModule(const string& message, const HloModule& module) { } // namespace +Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, + HloModule* module) { + MaybeDumpModule("after adding copies to resolve interference", *module); + + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + std::unique_ptr call_graph = CallGraph::Build(module); + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + instruction->CopyElisionAllowed()) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } + } + MaybeDumpModule("after removing unnecessary copies", *module); + + return Status::OK(); +} + StatusOr CopyInsertion::Run(HloModule* module) { // Copy insertion is performed in three steps: // @@ -1130,16 +1148,13 @@ StatusOr CopyInsertion::Run(HloModule* module) { "Call graph must be flattened before copy insertion."); } - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + int64 num_existing_copies = 0; + if (VLOG_IS_ON(1)) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + ++num_existing_copies; + } } } } @@ -1158,13 +1173,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - MaybeDumpModule("after adding copies to resolve interference", *module); - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); - - MaybeDumpModule("after removing unnecessary copies", *module); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1185,7 +1195,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { } } } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 65e3d31e347e2cb249a072e7d06ca10c55401748..c5573f76f31681ae9988039e9000636876478113 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -48,10 +47,25 @@ class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } + // fusion_can_share_buffer: backend specific function that decides whether a + // fusion can share buffer with its operand. + // + // TODO(b/80315712): Find a better way to tell whether a fusion can share + // buffer. + CopyInsertion(const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr) + : fusion_can_share_buffer_(fusion_can_share_buffer) {} + // Run the pass on the given module. Returns whether the module was changed // (copies were inserted). StatusOr Run(HloModule* module) override; + // Try to remove as many copies from the module as possible without + // introducing live range interference. Only copy instructions that are + // eligible for copy elision are considered for removal. + Status RemoveUnnecessaryCopies(const HloOrdering& ordering, + HloModule* module); + // The CPU and GPU backend need additional copies added due to deficiencies in // buffer assignment. Specifically, copies are needed for constants live-out // of computations, and for values which are live-in and live-out of the same @@ -62,8 +76,14 @@ class CopyInsertion : public HloPassInterface { // // TODO(b/62548313): Remove this when buffer assignment is module-scoped. static StatusOr AddCopiesForBufferAssignment(HloModule* module); + + private: + // Backend specific function that decides whether a fusion can share buffer + // with its operand. + HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; }; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 153f062d015e49db11c4c9ae0a2a61e76c020f02..105d117caccc21e6673261d44a59be30c28b9039 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) { } TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { - // Verify that an kCopy instructions which exist in the pass before + // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}))); + auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); + Layout reversed_layout = + LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); + Shape copy_shape = constant->shape(); + *copy_shape.mutable_layout() = reversed_layout; + HloInstruction* copy_1 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); - HloInstruction* add_copy = builder.AddInstruction( - HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + builder.AddInstruction( + HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); module->AddEntryComputation(builder.Build()); @@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); + EXPECT_EQ(module->entry_computation()->root_instruction(), add); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -206,7 +211,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction* pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); @@ -377,7 +382,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction* pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); HloInstruction* gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); @@ -686,7 +691,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( - nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); + nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, data_init, &builder); @@ -1595,6 +1600,45 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { EXPECT_THAT(condition->root_instruction(), op::Constant()); } +TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { + string module_string = R"( +HloModule TokensShouldNotBeCopied + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokensShouldNotBeCopied () -> s32[] { + %one = s32[] constant(1) + %negative_one = s32[] negate(%one) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloRunner::CreateModuleFromString( + module_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should be no copies added because tokens should not be copied. + EXPECT_EQ(CountCopies(*module), 0); +} + std::unique_ptr MakeTrivialCondition(const Shape& shape) { auto builder = HloComputation::Builder("trivial_condition"); builder.AddInstruction( @@ -1636,8 +1680,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1677,8 +1720,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1750,8 +1792,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { std::vector tuple_params(num_tuple_inputs); for (int i = 0; i < num_iters; ++i) { auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), - config); + HloModule module("BM_ManyElementTuple", config); for (int j = 0; j < num_tuple_inputs; ++j) { tuple_params[j] = builder.AddInstruction( HloInstruction::CreateParameter(j, element_shape, "")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d718322ba0da5d74fe4527f0db9829f07db2c433..3479240610a197aeed0c0a07099239e1161b1352 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -53,29 +53,6 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -cc_library( - name = "external_constant_pool", - srcs = ["external_constant_pool.cc"], - hdrs = ["external_constant_pool.h"], - deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "external_constant_pool_test", - srcs = ["external_constant_pool_test.cc"], - deps = [ - ":external_constant_pool", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], @@ -126,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", @@ -150,7 +128,14 @@ cc_library( "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep - ], + ] + select({ + "//tensorflow:linux_ppc64le": [ + "@llvm//:powerpc_disassembler", + "@llvm//:powerpc_code_gen", + ], + "//conditions:default": [ + ], + }), alwayslink = True, # Contains compiler registration ) @@ -167,7 +152,6 @@ cc_library( ":cpu_runtime", ":custom_call_target_registry", ":disassembler", - ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_fp16", ":runtime_conv2d", @@ -177,6 +161,7 @@ cc_library( ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -247,7 +232,6 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", - ":external_constant_pool", ":ir_emission_utils", ":ir_function", ":parallel_loop_emitter", @@ -264,6 +248,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", @@ -515,7 +500,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -577,6 +561,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], @@ -632,10 +632,10 @@ tf_cc_test( deps = [ ":cpu_instruction_fusion", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -689,9 +689,9 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -881,6 +881,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", ], @@ -941,7 +942,7 @@ tf_cc_test( ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index beeb826747d10306e07179f16c0ecb09ca629d5d..55962ba70d213939ccb49cad3bdd75395cc4eaa5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" @@ -263,12 +264,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_strength_reduction=*/false); + pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -283,6 +284,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass(); pass.AddPass(); } + pipeline.AddPass(); pipeline.AddPass( [&target_machine_features]( const HloInstruction& dot, @@ -302,14 +304,19 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->device_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); - pipeline.AddPass(/*is_layout_sensitive=*/true); + { + auto& pass = pipeline.AddPass>( + "after layout assignement"); + pass.AddPass>( + /*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return true; }, + /*enable_dot_strength_reduction=*/false); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + } pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = @@ -547,8 +554,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -577,7 +584,7 @@ StatusOr> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, jit->external_constant_pool()); + &target_machine_features); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -727,7 +734,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -764,8 +771,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, - /*external_constant_pool=*/nullptr); + &target_machine_features); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index d12fa6bb9ad2054bdc052c9d7b3729cc28e11f6d..8727c72b6e42517b1859e98ecadb41bbceed761c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace cpu { @@ -40,7 +40,7 @@ ENTRY DotOperation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* dot = module->entry_computation()->root_instruction(); @@ -71,7 +71,7 @@ ENTRY ConvOperation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* conv = module->entry_computation()->root_instruction(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index cf43b74c699ca8cbbef11a0abbaf4d69476f5d77..1093559892ddb9c238fd9c1f7e3d419ec7022776 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -206,8 +206,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/host_result_shape(), - /*on_device_shape=*/host_result_shape(), run_options->allocator(), + /*on_host_shape=*/result_shape(), + /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); // Move OwningDeviceMemory values which contain the array(s) of the result diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 46fe060817b0264d90574b45a94cf1f6e5964593..750310c633286aa8f964c9ae5dcf847f2dc0557c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -172,7 +172,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -202,7 +202,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -233,7 +233,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -501,8 +501,8 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { HloInstruction* exp = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get()))); module->AddEntryComputation(builder.Build()); @@ -525,8 +525,8 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get()))); module->AddEntryComputation(builder.Build()); @@ -775,7 +775,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) { string hlo_string = tensorflow::strings::StrCat( "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); RunFusionAndCheckOpcodesWereFused( module.get(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 53536a277cd513627a4cff20936110d68bb31c8a..3c4fe68b830d9602f009b318d4e51e9a04a27e09 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -29,7 +29,7 @@ namespace cpu { class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, const TargetMachineFeatures* target_machine_features) : LayoutAssignment(entry_computation_layout), target_machine_features_(*target_machine_features) {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index f6c93d36f72d681b2e7b7e5748ac9e6294fc8cb1..429fc7b78608da0e9cd794ac294851b326f5be24 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -54,7 +54,7 @@ class CpuLayoutAssignmentTest : public HloTestBase { [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout, + cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } @@ -321,7 +321,7 @@ static StatusOr RunDotOutputFusion( [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(computation_layout, + cpu::CpuLayoutAssignment layout_assignment(&computation_layout, &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index f9c51f243c47b8069500eca3c9c2929b17f04e62..3ed7876715f64191f6e652d2b5cb1673df9a1b94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -16,12 +16,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; +const char* const kXlaEnableExperimentalLlvmIrGemm = + "xla_enable_experimental_llvm_ir_gemm"; +const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -54,6 +58,49 @@ tensorflow::gtl::optional LlvmIrGemvTilingFactor( return tensorflow::gtl::nullopt; } +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; +} + +static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, + tensorflow::StringPiece suffix) { + CHECK_GE(str.size(), suffix.size()); + CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); + return str.substr(0, str.size() - suffix.size()); +} + +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrGemmTileSize); + if (it == extra_options_map.end()) { + return tensorflow::gtl::nullopt; + } + + std::vector tile_components = + tensorflow::str_util::Split(it->second, ':'); + CHECK_EQ(tile_components.size(), 3); + + int64 tile_size_m; + int64 tile_size_k; + int64 tile_size_n_in_vector_width; + + CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); + CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + + tensorflow::StringPiece tile_size_n_in_vector_width_str = + RemoveSuffix(tile_components[2], "*vectwidth"); + + CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); + + return std::tuple(tile_size_m, tile_size_k, + tile_size_n_in_vector_width); +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index be62ff3cc1af23408ca8a00f1372e7a998f160c6..429b9e16cbdd6f623919533582481f1640118081 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,8 +26,11 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 215405f6802cf1956ebec011da2fcd11b95c0c64..54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 1dce6efa5cd65e67ae73a2e2affe2d2d3c537508..aa0e96712302e806a389c6ad05a2c1b6634ef901 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -52,6 +52,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index d97802ee45d6add3c466577d7624d9ca74e2f380..b877b295814a7e13569a1837ed3e1787f2fc3f56 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -160,9 +160,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - Status s = - TransferBufferToDevice(executor, /*size=*/size, - /*source=*/source, queued_buffer->device_memory()); + Status s = executor->SynchronousMemcpyH2D( + /*host_src=*/source, /*size=*/size, queued_buffer->device_memory()); if (!s.ok()) { queued_buffer->Done(s); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 5cdfc110affb85c19a5059800c1b1fc5ff614efe..58228180ca55ede50c8579bbd73cfdfffc07e208 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -42,17 +42,17 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { namespace { -// Loads a tile of values from a 2D tensor. -class TileLoader { +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { public: - // Constructs a TileLoader that will load a tile consisting of + // Constructs a MemoryTile that can operate on tiles consisting of // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { + : vsl_(vsl), ir_builder_(ir_builder) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { llvm::Value* total_offset = ir_builder->CreateMul( @@ -62,9 +62,10 @@ class TileLoader { } } - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. std::vector LoadTile(llvm::Value* minor_dim_offset) const { std::vector result; result.reserve(pointers_.size()); @@ -74,11 +75,104 @@ class TileLoader { return result; } + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(tensorflow::gtl::ArraySlice tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, + ir_builder_->getInt64(j)))); + } + } + return result; + } + private: VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* ir_builder_; std::vector pointers_; }; +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", + tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the // layout of the vector does not matter). This implementation uses a tiling // scheme to improve performance. @@ -140,38 +234,46 @@ class TileLoader { // TODO(sanjoy): We should investigate if using gather loads and scatter stores // can be used here have the same inner loop for both column-major and row-major // matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter { +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, - int64 tile_rows, int64 tile_cols, - int64 m, int64 k, llvm::Value* lhs, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { - CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), + ir_builder_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: void EmitOuterLoopBody(llvm::Value* column, int64 column_count, bool is_first_column); - TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m_, + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); } @@ -188,18 +290,14 @@ class ColumnMajorMatrixVectorProductEmitter { return result; } - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column); void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -211,26 +309,26 @@ class ColumnMajorMatrixVectorProductEmitter { void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( llvm::Value* column, int64 column_count, bool is_first_column) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, /*column_count=*/column_count); std::vector rhs_tile = LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, /*columns=*/column_count, is_first_column); EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); } void ColumnMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k_ % tile_cols_; - int64 column_limit = k_ - column_remainder; + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols_, is_first_column); - }); + ksl_.ForReturnVoid("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, @@ -239,29 +337,30 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column) { - int64 row_limit = m_ - (m_ % tile_rows_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows_, [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.ForReturnVoid( + "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m_ - (m_ % tile_rows_); - if (row_start == m_) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { return; } @@ -274,25 +373,25 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( is_first_scalar_col, ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( + ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -365,51 +464,55 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer // epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter { +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, - llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* addend, llvm::Value* result, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { - CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: - TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k_, + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); } void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators); void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -421,7 +524,7 @@ class RowMajorMatrixVectorProductEmitter { void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, int64 row_count) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, /*row_count=*/row_count); std::vector vector_accumulators; std::vector scalar_accumulators; @@ -429,7 +532,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); } - EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, &vector_accumulators); EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); @@ -466,12 +569,13 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, void RowMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m_ % tile_rows_; - int64 row_limit = m_ - row_remainder; + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + ksl_.ForReturnVoid( + "dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -479,48 +583,395 @@ void RowMajorMatrixVectorProductEmitter::Emit() { } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, int64 rows, + MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators) { - int64 column_limit = k_ - (k_ % tile_cols_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols_, [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set(vsl_.Add( + old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators) { - int64 column_start = k_ - (k_ % tile_cols_); - if (column_start == k_) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { return; } for (int r = 0; r < rows; r++) { llvm::Value* total_offset = ir_builder_->CreateMul( ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k_)); + ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); } } +// This class implements a tiled matrix multiplication algorithm, intended for +// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, +// Kazushige, and Robert Van De Geijn. "High-performance implementation of the +// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): +// 4). +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class MatrixMatrixBlockPanelEmitter { + public: + // Describe the dimensions of the GEBP kernel. These will usually not be the + // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP + // kernels with smaller dimensions. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { + return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); + } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the GEBP emitter. The LLVM IR emitted by + // the emitter, modulo the LLVM values holding the input and output buffers, + // must be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), + "_", max_vectorization_width(), "_", min_vectorization_width(), "_", + tile_size_m(), "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + ir_builder_(ir_builder), + ksl_(ir_builder_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; +}; + +void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, + ir_builder_, "gebp"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = + ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.ForReturnVoid( + "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile( + vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.ForReturnVoid( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, + result_memory_tile.LoadTile(n_i)); + ksl_.ForReturnVoid( + "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, + dims().n(), k_i, tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = + rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + } // namespace DotOpEmitter::DotOpEmitter(const HloInstruction& dot, @@ -558,6 +1009,89 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, return dot_emitter.Emit(); } +bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( + const DotOpEmitter::MatMultDims& mat_mult_dims) { + if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + return false; + } + + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + switch (primitive_type) { + default: + return false; + + case F32: + case F64: + case S32: + case S64: + break; + } + + if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && + mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { + return false; + } + + llvm::Value* lhs = lhs_array_.GetBasePointer(); + llvm::Value* rhs = rhs_array_.GetBasePointer(); + llvm::Value* target = target_array_.GetBasePointer(); + int64 m = mat_mult_dims.m; + int64 k = mat_mult_dims.k; + int64 n = mat_mult_dims.n; + + if (mat_mult_dims.lhs_column_major) { + std::swap(lhs, rhs); + std::swap(m, n); + } + + int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + ir_builder_->CreateMemSet( + target, ir_builder_->getInt8(0), size_bytes, + target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + + int64 max_target_vector_width = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; + std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = + GetGemmTileSize(); + + MatrixMatrixBlockPanelEmitter::Config config( + /*scalar_type=*/primitive_type, + MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + << config.GetCacheKey(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs, rhs, target, + [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { + MatrixMatrixBlockPanelEmitter gebp_emitter( + config, /*lhs=*/lhs, /*rhs=*/rhs, + /*result=*/target, ir_builder_); + gebp_emitter.Emit(); + }); + + return true; +} + bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { return false; @@ -610,7 +1144,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; + return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -643,47 +1177,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = vector_register_element_size; - int64 tile_cols = tiling_factor; - - string kernel_name = tensorflow::strings::StrCat( - "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = tiling_factor; - int64 tile_cols = vector_register_element_size; - - string kernel_name = tensorflow::strings::StrCat( - "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { RowMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } @@ -775,8 +1301,11 @@ Status DotOpEmitter::Emit() { // from messing up the vectorization. std::unique_ptr reduction_loop = loop_nest.AddLoop( 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", - /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && - rhs_reduction_along_minor_dimension); + /*unroll_mode=*/ + (lhs_reduction_along_minor_dimension && + rhs_reduction_along_minor_dimension) + ? xla::llvm_ir::UnrollMode::kNoUnroll + : xla::llvm_ir::UnrollMode::kDefaultUnroll); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. @@ -851,7 +1380,7 @@ Status DotOpEmitter::Emit() { // the rhs and lhs indexes with the reduction dimensions removed. The terms // from the rhs index are the lower dimensions in the index so we add them // first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(lhs_index.GetType()); for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); @@ -875,10 +1404,13 @@ Status DotOpEmitter::Emit() { Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; + // Use the same index_type for all tensor accesses in the same kernel. + llvm::Type* index_type = ir_builder_->getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); llvm::Value* lhs_value = - lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + lhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); llvm::Value* rhs_value = - rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + rhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { #define REAL(x) ir_builder_->CreateExtractValue(x, {0}) #define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) @@ -896,7 +1428,8 @@ Status DotOpEmitter::EmitScalarDot() { } else { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } - target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); + target_array_.EmitWriteArrayElement(/*index=*/element_index, result, + ir_builder_); return Status::OK(); } @@ -909,8 +1442,7 @@ Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = - hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1019,7 +1551,9 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, - /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1}; + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1, + /*target_column_major=*/ + LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1097,8 +1631,8 @@ bool PotentiallyImplementedAsEigenDot( const Shape& lhs_shape = hlo.operand(0)->shape(); const Shape& rhs_shape = hlo.operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { + if (ShapeUtil::IsZeroElementArray(lhs_shape) || + ShapeUtil::IsZeroElementArray(rhs_shape)) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index a75b8ffcbfce111167e573ecee49b04acd34f86c..ed2a18976a0f1a88e7bb4632d3a63167d5c146ad 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -123,6 +123,9 @@ class DotOpEmitter { // True if the RHS contraction dimension is not 0. bool rhs_non_canonical; + + // True if the result matrix is column major. + bool target_column_major; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -130,6 +133,8 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; + bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. int64 GetGemvTilingFactor() const { @@ -138,6 +143,28 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + + // Returns true if we should use an experimental implementation of GEMM + // (general matrix matrix multiplication) if possible. + bool EnableExperimentalLlvmIrGemm() const { + return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); + } + + // Returns true if we should call into multi-threaded Eigen routines. + bool ShouldUseMultiThreadedEigen() { + return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + } + const HloInstruction& dot_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc deleted file mode 100644 index c56286559158758ca6db5ae097729286bde346f0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" - -namespace xla { -namespace cpu { -void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, - int64 alignment) { - CHECK(!ShapeUtil::IsTuple(literal.shape())); - CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); - CHECK(entries_.find(name) == entries_.end()); - - const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); - void* raw_pointer = tensorflow::port::AlignedMalloc( - literal_size, std::max(alignment, sizeof(void*))); - CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size - << " bytes with alignment of " << alignment; - - std::memcpy(raw_pointer, literal.untyped_data(), literal_size); - entries_.emplace(std::move(name), static_cast(raw_pointer)); -} - -const uint8* ExternalConstantPool::Find(const string& name) { - auto it = entries_.find(name); - return it == entries_.end() ? nullptr : it->second.get(); -} -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h deleted file mode 100644 index 0677f5f0b58005079890052a426e5f48c5d09ed1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ - -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/platform/mem.h" - -namespace xla { -namespace cpu { -// An ExternalConstantPool maintains a set of constants kept external to -// generated LLVM IR. These constants are accessed from the IR via globals with -// extern linkage. This current incarnation of ExternalConstantPool only -// supports the JIT CPU backend; the AOT backend is not supported. -// -// Implementation-wise, this is a simple wrapper around a map of strings to byte -// buffers. This simply implementation works in a JIT scenario. This class -// will have to become smarter if we decide to support external constant pools -// on AOT compiles in the future. -class ExternalConstantPool { - public: - // Inserts a buffer with the contents of `literal` into the constant pool with - // the name `name`. It is an error to try to insert two constants with the - // same `name` into the same constant pool. The buffer for literal is aligned - // to `aligment` bytes, and `alignment` must be a power of 2. - // - // The constant pool copies out the contents of `literal` into a buffer it - // owns -- it does not keep pointers to `literal`, or to memory owned by - // `literal`. - void Insert(string name, const LiteralSlice& literal, int64 alignment); - - // Find the constant with name `name` in this constant pool. If there isn't - // such constant, return nullptr. - const uint8* Find(const string& name); - - private: - // We need to `AlignedFree` pointers allocated into `entries_` since we - // allocate them with `AlignedMalloc`. - struct FreeDeleter { - void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); } - }; - - tensorflow::gtl::FlatMap> - entries_; -}; -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc deleted file mode 100644 index 9290a4e5dfc03ddb86e9d82f1f0f4f9a8ceebb88..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace cpu { -namespace { -class ExternalConstantPoolTest : public ::testing::Test {}; - -template -T GetFromBuffer(const uint8* buffer, int64 index) { - T result; - std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); - return result; -} - -TEST(ExternalConstantPoolTest, Basic) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer(constant, 0), 1); - EXPECT_EQ(GetFromBuffer(constant, 1), 2); - EXPECT_EQ(GetFromBuffer(constant, 2), 3); - EXPECT_EQ(GetFromBuffer(constant, 3), 4); - - EXPECT_EQ(constant_pool.Find("name-1"), nullptr); -} - -TEST(ExternalConstantPoolTest, RowMinorLayout) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2WithLayout( - {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer(constant, 0), 1); - EXPECT_EQ(GetFromBuffer(constant, 1), 3); - EXPECT_EQ(GetFromBuffer(constant, 2), 2); - EXPECT_EQ(GetFromBuffer(constant, 3), 4); -} - -TEST(ExternalConstantPoolTest, Alignment) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - - for (int i = 0; i < 8; i++) { - int64 alignment = 1 << i; - string name = tensorflow::strings::StrCat("name-", i); - - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert(name, *literal, alignment); - - const uint8* constant = constant_pool.Find(name); - ASSERT_NE(constant, nullptr); - EXPECT_EQ(reinterpret_cast(constant) % alignment, 0); - } -} - -} // namespace -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b560b7531c0d24e6f670e61a15dce295d9fa2a49..1a8bedfe6afb4f096ddd4703c312b84d521a7ba5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution( return false; } - if (ShapeUtil::HasZeroElements(input_shape) || - ShapeUtil::HasZeroElements(kernel_shape)) { + if (ShapeUtil::IsZeroElementArray(input_shape) || + ShapeUtil::IsZeroElementArray(kernel_shape)) { return false; } // Make sure input and kernel has the same data type. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index abb2471e6ae6b2f2949ab2e91235e5047ae404f8..530ebce854fedf4e4db12139d5b56087b1176a6c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -35,7 +35,7 @@ ENTRY Conv { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* entry_computation = module->entry_computation(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 23fcb9cc712e71dc582819dc9b13b3af54bfcef5..6b9a1d8c01aee46e271bc5a950e1a4bb45b7b822 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -48,6 +48,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -83,8 +85,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features, - ExternalConstantPool* external_constant_pool) + const TargetMachineFeatures* target_machine_features) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -94,8 +95,7 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features), - external_constant_pool_(external_constant_pool) { + target_machine_features_(*target_machine_features) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -160,47 +160,25 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::GlobalVariable* result; - - // We avoid creating large constants in the LLVM IR since LLVM is not - // efficient for large constant arrays. We still emit "small enough" constant - // arrays into the Ir, in the off chance the LLVM optimizer can do something - // interesting with it. - const int kMaxInternalConstantSizeInBytes = 128; - if (external_constant_pool_ && - ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { - string global_name = tensorflow::strings::StrCat( - "constant_global_", external_global_constant_counter_++); - result = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/IrShapeType(literal.shape()), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/nullptr, - /*Name=*/AsStringRef(global_name)); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); - external_constant_pool_->Insert(global_name, literal, - MinimumAlignmentForShape(literal.shape())); - } else { - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - result = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); - } - return result; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + return llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; + llvm::Constant* global_for_const; auto it = emitted_literals_.find(&literal); if (it != emitted_literals_.end()) { @@ -221,10 +199,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else { - // Use the elemental emitter for non-tuple shapes. + } else if (ShapeUtil::IsArray(copy->shape())) { + // Use the elemental emitter for array shapes. return DefaultAction(copy); } + return Unimplemented( + "unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type()).c_str()); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -298,45 +279,60 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); - auto on_true = select->operand(1); - auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); - - if (ShapeUtil::IsTuple(select->shape())) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); - llvm_ir::EmitTupleSelect( - GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &ir_builder_, module_); - return Status::OK(); - } - return DefaultAction(select); } -Status IrEmitter::HandleInfeed(HloInstruction* infeed) { - VLOG(2) << "HandleInfeed: " << infeed->ToString(); +Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { + auto pred = tuple_select->operand(0); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + TF_RET_CHECK(pred->shape().element_type() == PRED); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); + TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); + llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), + GetEmittedValueFor(on_true), + GetEmittedValueFor(on_false), &ir_builder_, module_); + return Status::OK(); +} - const Shape& shape = infeed->shape(); +Status IrEmitter::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast(instruction); + VLOG(2) << "HandleInfeed: " << infeed->ToString(); - // The infeed operation produces data (dequeued from the infeed queue) at this - // address, which has been provided by buffer assignment. + // The infeed operation produces a two-element tuple containing data and a + // token value. HloInfeedInstruction::infeed_shape gives us the data shape. + const Shape& data_shape = infeed->infeed_shape(); + DCHECK(ShapeUtil::Equal(data_shape, + ShapeUtil::GetTupleElementShape(infeed->shape(), 0))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); - llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); - if (ShapeUtil::IsTuple(shape)) { - TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); + // Write the tuple index table. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(infeed, {0})); + llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, + assignment_.GetUniqueSlice(infeed, {1})); + llvm::Value* token_address = EmitTempBufferPointer( + token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, + &ir_builder_, module_); + + if (ShapeUtil::IsTuple(data_shape)) { + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to // their corresponding target locations. We then construct the // tuple outer buffer containing pointers to the internal // elements. std::vector tuple_element_addresses; - for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - assignment_.GetUniqueSlice(infeed, {i})); + assignment_.GetUniqueSlice(infeed, {0, i})); const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); + ShapeUtil::GetTupleElementShape(data_shape, i); // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root @@ -351,11 +347,11 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, - module_); + llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), + tuple_element_addresses, &ir_builder_, module_); } else { - TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, - GetEmittedValueFor(infeed))); + TF_RETURN_IF_ERROR( + EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); } return Status::OK(); @@ -555,7 +551,8 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index input_index(index.size()); + llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), + index.size()); llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( @@ -686,7 +683,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm_ir::IrArray::Index operand_index(ir_builder_.getInt64Ty(), + source_index.size()); llvm::Value* in_bounds_condition = ir_builder_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( @@ -760,7 +758,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // value and the current output value. SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index selected_index; + llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( selected_index_address, {ir_builder_.getInt32(i)}); @@ -1102,7 +1100,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We are not in the padding, so carry out the computation. int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(num_dims); + llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims); for (int i = 0; i < num_spatial_dims; ++i) { input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } @@ -1110,7 +1108,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { input_index[dnums.input_batch_dimension()] = batch; llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); - llvm_ir::IrArray::Index kernel_index(num_dims); + llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(), + num_dims); for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() @@ -1164,7 +1163,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); @@ -1186,16 +1191,45 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { - if (hlo_module_config_.replica_count() == 1) { - // When there is a single replica, a cross replica sum is the identity - // function, and the buffer assignment expects a copy (we could eliminate - // these at the HLO level as an optimization). - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on CPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on CPU."); + } + + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + + // CRS with one operand and one replica is simply the identity function. + if (crs->operand_count() == 1) { return EmitMemcpy(*crs->operand(0), *crs); } - // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented("CrossReplicaSum is not implemented on CPU."); + // CRS with multiple operands and one replica produces a (one-deep) tuple. + std::vector operand_ptrs; + for (int64 i = 0; i < crs->operand_count(); ++i) { + llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i)); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(crs, {i})); + + const Shape& operand_shape = crs->operand(i)->shape(); + CHECK(ShapeUtil::IsArray(operand_shape)) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + + // TODO(b/63762267): Be more aggressive about specifying alignment. + ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, + ShapeUtil::ByteSizeOf(operand_shape)); + } + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + return Status::OK(); } // Fills up the free variables in 'index_with_free_var' with values from @@ -1386,6 +1420,10 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; + case HloOpcode::kXor: + return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { return ir_builder->CreateXor(lhs, rhs); }; + case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, @@ -1642,7 +1680,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); - llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); + llvm_ir::IrArray::Index array_index(ir_builder_.getInt64Ty(), + reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -1833,7 +1872,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); - if (ShapeUtil::HasZeroElements(slice->shape())) { + if (ShapeUtil::IsZeroElementArray(slice->shape())) { return Status::OK(); } @@ -2026,7 +2065,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); - llvm_ir::IrArray::Index output_index; + llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = ir_builder_.CreateMul( operand_index[i], @@ -2488,6 +2527,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } +Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { + TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); + // No code to generate, but we need to emit an address for book-keeping. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2769,7 +2815,10 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = compute_function_->result_arg(); - if (!ShapeUtil::IsNil(target_shape)) { + if ((ShapeUtil::IsArray(target_shape) && + !ShapeUtil::IsZeroElementArray(target_shape)) || + (ShapeUtil::IsTuple(target_shape) && + !ShapeUtil::IsEmptyTuple(target_shape))) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index f49cfc1dc378bb80da3ddf995363acfa2081067b..3089f6451e7dc4b2752c6ae65b3f5f8ecc3d7405 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -30,7 +30,6 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -67,17 +66,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. - // external_constant_pool: if non-null, points to an ExternalConstantPool - // instance into which the Ir emitter can spill - // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine, - ExternalConstantPool* external_constant_pool); + const TargetMachineFeatures* target_machine); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -122,6 +117,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; @@ -150,6 +146,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -527,7 +524,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal); + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); const HloModuleConfig& hlo_module_config_; @@ -535,9 +533,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const TargetMachineFeatures& target_machine_features_; - int64 external_global_constant_counter_ = 0; - ExternalConstantPool* external_constant_pool_; - struct LiteralPtrHashFunctor { size_t operator()(const Literal* literal) const { return literal->Hash(); } }; @@ -548,7 +543,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } }; - tensorflow::gtl::FlatMap emitted_literals_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 54af40506dab48b3c2a3a44eb0b5f5fb213a32ec..59ae5acd8b7cea049f09eaf4cc98b41339973c77 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,13 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); + CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); const int64 num_dims = shape_.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); + llvm_ir::IrArray::Index array_index(index_type, num_dims); // Add loops from outer-most to inner-most dimensions. for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 755715634aa70a822b21d25dcae20a8fe053477a..25e182a26d6f21c7eba550020cf17403aa92abf7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 63d0f7b95c7e45913c707471dbe2dc62e05251d6..4fa5984b0466b178a587e97cbced97deac749f74 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -38,7 +38,7 @@ class SimpleCostModel : public ParallelCostModel { const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -63,7 +63,7 @@ class DefaultCostModel : public ParallelCostModel { int64 max_parallelism; // Calculate flops-to-bytes-ratio for 'instruction'. const int64 bytes_accessed = - std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction)); const float flops_to_bytes_ratio = cost_analysis_->flop_count(*instruction) / static_cast(bytes_accessed); @@ -93,7 +93,7 @@ class DefaultCostModel : public ParallelCostModel { } // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fc2efbaf9a22b02cd729da2f367d53bc15506836..36c9f743859ae2da6c4fb3fd753bd7862fe2d3ab 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -110,8 +110,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = u32[12345678,2]{1,0} infeed() - ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 + ROOT outfeed0 = token[] outfeed(infeed0.data) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e02475babad7160d0f43bb23de0b50e..0bf693edd0b985a4e62c16414646cc6a17db26ee 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; + // Unsupported FFT type + abort(); } } @@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + // Unsupported FFT rank + abort(); } } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index 92da5f71c23d5e1450b39ea8b7bb8345f6fabb3b..f8c8dd5e93d53db8d87be0208b5cf4daac3464f1 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..2613ddb12704aea7d0884c6c8c062dc028383639 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000000000000000000000000000000000..dcd133d012cf074a4cd2f550585881388bea6156 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 167aa4adda995a259190a932a76a34ca5883444c..7e792a82b8bf28121c054332bc619d736858c729 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -49,9 +49,9 @@ int main(int argc, char** argv) { // Build computation. xla::XlaBuilder builder(""); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p1, p0, {0}); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 42fe955f1917e0268dc739e44fbd0a7afb39185c..d12c5396148d32adb178b955a34e050cc56784da 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -115,7 +115,7 @@ ShapePartitionIterator::ShapePartitionIterator( for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { const int64 dim_size = shape_.dimensions(dimensions_[i]); dimension_partition_sizes_[i] = - std::max(1LL, dim_size / dimension_partition_counts_[i]); + std::max(int64{1}, dim_size / dimension_partition_counts_[i]); } // Calculate the partition strides for each dimension. diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 62c97e5641da7fd1c88457ef95f5ca8be4e52eb9..be772cfb7e564cebc5725854dbf5678e5c507556 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -99,6 +100,7 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( + execution_session_, [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); }, @@ -125,13 +127,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - if (const uint8* from_constant_pool = - external_constant_pool_.Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; @@ -201,6 +196,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 1851a3ee0bb97b4860605d7211a6ae70ac88686b..d74b63fcf45bd70cd18ee41f1e9714ba6a222abd 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -91,10 +90,6 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } - ExternalConstantPool* external_constant_pool() { - return &external_constant_pool_; - } - // Creates an llvm::TargetMachine suitable for JITting code that will run on // the current machine. static std::unique_ptr InferTargetMachineForJIT( @@ -112,7 +107,6 @@ class SimpleOrcJIT { std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; - ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 67f776e7b5883f425b41c05342b74bebe223e17f..66ae5ef0f66e90982102d73e474f5d0582f5415c 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -152,9 +152,9 @@ tf_cc_test( srcs = ["cpu_literal_caching_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -166,9 +166,9 @@ tf_cc_test( srcs = ["cpu_outfeed_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index ed8f375bd6186e4805fe9ded5be9ae7c9f4d5c84..1d4bf483aedef5a15ef51cf216030b76255d4ec8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -56,7 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest { TEST_F(CpuExternalConstantsTest, Basic) { TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( -CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16 )"); } @@ -64,8 +65,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8 -CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [64 x i8] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 23e7a3de4d8188a3add259582e11030539e154c1..783b2820e922612973632c555fc8ae01418f1754 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -96,8 +96,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + vshape, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); @@ -114,9 +117,9 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction->fused_expression_root()->opcode()); - // There should be 7 fused instructions: 2 parameters and the fused + // There should be 8 fused instructions: 2 parameters and the fused // operations. - EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. auto result = ExecuteAndTransfer(std::move(module), {}); @@ -170,8 +173,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + cshape, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); @@ -188,9 +194,9 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction1->fused_expression_root()->opcode()); - // There should be 5 fused instructions in the root fusion instruction: 2 + // There should be 6 fused instructions in the root fusion instruction: 2 // parameters, multiply, floor, and exp. - EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + EXPECT_EQ(6, fusion_instruction1->fused_instruction_count()) << fusion_instruction1->fused_instructions_computation()->ToString(); auto fusion_instruction2 = reduce->operand(0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index dd63b998e9b6d04981ec6f7300c883c9b23b154f..ea7e479d66fbda1bfd388fd77b25db2db56f0d65 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -47,7 +47,7 @@ class InfeedTest : public ClientLibraryTestBase { // don't use ResetDevice since it is not implemented on CPU. ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); - builder.Infeed(literal.shape()); + Infeed(&builder, literal.shape()); if (ShapeUtil::IsTuple(literal.shape())) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); @@ -125,8 +125,8 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(40.0f), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 40.0f), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add the reduced value of the Infeed @@ -134,17 +134,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto infeed = builder.Infeed(infeed_shape); - auto addend = - builder.Reduce(infeed, builder.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder), {0}); - builder.Add(prev, addend); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto infeed = Infeed(&builder, infeed_shape); + auto addend = Reduce(infeed, ConstantR0(&builder, 0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + Add(prev, addend); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - auto init = builder.ConstantR0(0.0f); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0.0f); + While(condition, body, init); // Build and asynchronously launch the computation. auto computation = builder.Build().ConsumeValueOrDie(); @@ -207,8 +206,8 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.GetTupleElement(prev, 1); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + GetTupleElement(prev, 1); condition = builder.Build().ConsumeValueOrDie(); } @@ -221,27 +220,27 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { const auto build_body = [this, &result_shape](const Shape& infeed_shape) { XlaComputation body; XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto infeed = builder.Infeed(infeed_shape); - auto addend = builder.Reduce( - builder.GetTupleElement(infeed, 0), builder.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder), {0}); - auto result = builder.Add(builder.GetTupleElement(prev, 0), addend); - builder.Tuple({result, builder.GetTupleElement(infeed, 1)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto infeed = Infeed(&builder, infeed_shape); + auto addend = + Reduce(GetTupleElement(infeed, 0), ConstantR0(&builder, 0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + auto result = Add(GetTupleElement(prev, 0), addend); + Tuple(&builder, {result, GetTupleElement(infeed, 1)}); return builder.Build().ConsumeValueOrDie(); }; // Create the first while loop with infeed1_shape. - auto init = builder.Tuple( - {builder.ConstantR0(0.0f), builder.ConstantR0(true)}); - auto while1 = builder.While(condition, build_body(infeed1_shape), init); - auto result1 = builder.Tuple( - {builder.GetTupleElement(while1, 0), builder.ConstantR0(true)}); + auto init = Tuple(&builder, {ConstantR0(&builder, 0.0f), + ConstantR0(&builder, true)}); + auto while1 = While(condition, build_body(infeed1_shape), init); + auto result1 = Tuple( + &builder, {GetTupleElement(while1, 0), ConstantR0(&builder, true)}); // Create the second while loop with infeed2_shape. Note that the result from // the first while loop is used as the initial value. - auto while2 = builder.While(condition, build_body(infeed2_shape), result1); - builder.GetTupleElement(while2, 0); + auto while2 = While(condition, build_body(infeed2_shape), result1); + GetTupleElement(while2, 0); // Build the computation. auto computation = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index d6e0425c5542be89835571f0103b1829f63cc2c2..90b99c828e2fcfd77579026a39d3a6711599feee 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -49,18 +50,18 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = () outfeed(f32[2,3,2] const_a) - ROOT out1 = () outfeed(f32[2,3,2] const_b) + out0 = token[] outfeed(f32[2,3,2] const_a) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b) } )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] -CHECK-NOT: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [48 x i8] +CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -78,34 +79,35 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { HloModule RepeatedConstants while_body { - arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) } while_cond { - arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) - ROOT unknown = pred[] infeed() + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) - const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) } )"; string filecheck_pattern = R"( -CHECK: private constant [2 x float] -CHECK: private constant [2 x [1 x float]] -CHECK-NOT: private constant [2 x float] -CHECK-NOT: private constant [2 x [1 x float]] +CHECK: private constant [4 x i8] +CHECK: private constant [8 x i8] +CHECK-NOT: private constant [4 x i8] +CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 879372eb13884cdb7edd8cfb3e8b4bac4e314951..dac416e1c78c2f60d458480c5062f48b77d4878d 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -32,16 +32,17 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - ROOT out = () outfeed(f32[2,3,2] const_a) + outfeed = token[] outfeed(f32[2,3,2] const_a) + ROOT root = () tuple() } )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index cd1165e23812861ba9951546b7dd744529232196..c444d151858d3a152a01b99657ffae89ebc6b487 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const { void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } + +TileVariable::TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value) { + for (llvm::Value* initial_vector_value : initial_value) { + storage_.emplace_back(vector_support, initial_vector_value); + } +} + +std::vector TileVariable::Get() const { + std::vector result; + c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); + return result; +} + +void TileVariable::Set(tensorflow::gtl::ArraySlice value) { + CHECK_EQ(value.size(), storage_.size()); + for (int64 i = 0, e = value.size(); i < e; i++) { + storage_[i].Set(value[i]); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 6479bf76aab581ae3ec2923d98dab53720cab203..49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -143,6 +144,12 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements, int64 scale) { + return ComputeOffsetPointer( + base_pointer, + ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { return ComputeOffsetPointer(base_pointer, @@ -311,6 +318,21 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; + +// This wraps a set of alloca-backed stack variables that can, as a whole, store +// a tile. A "tile" is a sequence of vectors that is typically used as a 2D +// grid of scalar values (e.g. for tiled GEMMs). +class TileVariable { + public: + TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value); + + std::vector Get() const; + void Set(tensorflow::gtl::ArraySlice value); + + private: + std::vector storage_; +}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index b9d7ec9c2e17e560580fcea060bf552c42fe3b3c..52aa53dcee59379107e7da4e3afccec226ac5a6e 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -76,6 +76,7 @@ class DfsHloVisitorBase { virtual Status HandleClamp(HloInstructionPtr hlo) = 0; virtual Status HandleSelect(HloInstructionPtr hlo) = 0; + virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0; virtual Status HandleMaximum(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -183,6 +184,9 @@ class DfsHloVisitorBase { virtual Status HandleOr(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandleXor(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } virtual Status HandleShiftLeft(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -197,6 +201,10 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + virtual Status HandleDomain(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; @@ -239,6 +247,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleAfterAll(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 240faebe62f5cee4f61b3c36b5e8f653cfd6db8e..ecd97a87968edaa447ed2df801e95468e3dba0e4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -79,6 +79,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } + Status HandleTupleSelect(HloInstructionPtr tuple_select) override { + return DefaultAction(tuple_select); + } Status HandleDot(HloInstructionPtr dot) override { return DefaultAction(dot); } @@ -188,6 +191,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleAfterAll(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 0a400e982ad50e35fc9c13f383574e1ce869877c..21c6f7d358bef171a54ebd97e7f4d2638ee179a8 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -456,17 +456,15 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { - // (x == x) && abs(x) != inf + // abs(x) o!= inf, this works because the comparison returns false if + // either operand is NaN. auto type = operand_value->getType(); - auto equal_self = - ir_builder_->CreateFCmpOEQ(operand_value, operand_value); auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); @@ -1166,6 +1164,8 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kXor: + return ir_builder_->CreateXor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1222,25 +1222,32 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const Shape& operand_shape = hlo.operand(operand_no)->shape(); // If the operand is scalar, the source index is always {}. if (ShapeUtil::IsScalar(operand_shape)) { - return llvm_ir::IrArray::Index(); + return llvm_ir::IrArray::Index(target_index.GetType()); } // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { + // + // `IrArray::Index` may contain a physical linear which we can propagate to + // our operand only if our layouts match. "only if" is a bit strong since + // e.g. we can still forward the linear index if the operand shape is + // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases + // are probably not worth handling here for now. + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) && + LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) { return target_index; } // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); - llvm_ir::IrArray::Index source_index; + llvm_ir::IrArray::Index source_index(target_index.GetType()); for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { CHECK_EQ(1, operand_shape.dimensions(i)); - source_index.push_back(ir_builder_->getInt64(0)); + source_index.push_back(target_index.GetConstantWithIndexType(0)); } } return source_index; @@ -1542,27 +1549,46 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); + // Use the same index type for all tensor accesses in the same kernel. + llvm::Type* index_type = index.GetType(); + llvm_ir::IrArray::Index slice_start_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); + + // Clamp the start index so that the sliced portion fits in the operand: + // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = + ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* operand_dim_size = + index_typed_const(input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = + index_typed_const(hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(operand_dim_size, output_dim_size), + EmitIntegralMax(index_typed_const(0), start_index_value, + /*is_signed=*/true), + /*is_signed=*/true); + start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; } - llvm_ir::IrArray::Index input_index(rank); + llvm_ir::IrArray::Index input_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); + // input_index = start_index + offset_index + input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1582,17 +1608,18 @@ StatusOr ElementalIrEmitter::EmitElementalGather( const llvm_ir::ElementGenerator& indices_generator = operand_to_generator.at(hlo->operand(1)); + llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to // generate. This index "unsafe" as in the components in here may be // out of bounds. - IrArray::Index unsafe_operand_index; + IrArray::Index unsafe_operand_index(index_type); // First copy in the window indices to unsafe_operand_index. for (int64 i = 0, e = operand_shape.dimensions_size(), unsafe_operand_index_dim = 0; i < e; i++) { if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + unsafe_operand_index.push_back(index.GetConstantWithIndexType(0)); } else { unsafe_operand_index.push_back( index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); @@ -1600,7 +1627,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( } // This is the index of the index vector in the gather_indices tensor. - IrArray::Index gather_index_index; + IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { @@ -1616,8 +1643,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, int64 dim) { - llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( - index_component, ir_builder_->getInt64Ty()); + llvm::Value* gather_dim_component_extended = + ir_builder_->CreateSExtOrTrunc(index_component, index_type); unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = ir_builder_->CreateAdd( unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], @@ -1633,18 +1660,18 @@ StatusOr ElementalIrEmitter::EmitElementalGather( indices_shape.dimensions(dim_numbers.index_vector_dim()); for (int64 i = 0; i < index_vector_size; i++) { gather_index_index[dim_numbers.index_vector_dim()] = - ir_builder_->getInt64(i); + index.GetConstantWithIndexType(i); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_unsafe_operand_index(gather_dim_component, i); } } - IrArray::Index safe_operand_index; + IrArray::Index safe_operand_index(index_type); for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { safe_operand_index.push_back(ir_builder_->CreateURem( unsafe_operand_index[i], - ir_builder_->getInt64(operand_shape.dimensions(i)))); + index.GetConstantWithIndexType(operand_shape.dimensions(i)))); } return operand_generator(safe_operand_index); @@ -1659,106 +1686,54 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - + llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); + llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = ir_builder_->getTrue(); for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + llvm::Type* index_type = index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = + ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* input_dim_size = + index_typed_const(input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = + index_typed_const(update_hlo->shape().dimensions(i)); + + start_index_value = + EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(index_typed_const(0), start_index_value, + /*is_signed=*/true), + /*is_signed=*/true); + start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); - - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); + slice_start_index[i] = start_index_value; slice_limit_index[i] = ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection = ir_builder_->CreateAnd( slice_intersection, ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, + "slice_intersection"); + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps, - "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd( - slice_intersection, slice_intersection_or, "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; + "slice_intersection"); } // Emit: @@ -1773,14 +1748,9 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, ir_builder_); // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); + llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); + update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); @@ -1846,7 +1816,8 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.false_block, ir_builder_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); ir_builder_->CreateStore(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, ir_builder_); @@ -1873,10 +1844,15 @@ StatusOr ElementalIrEmitter::EmitElementalDot( int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), - ir_builder_); + llvm::Type* index_type = dot_result_index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + + std::unique_ptr inner_loop = + llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0), + index_typed_const(contracted_dim_size), + index_typed_const(1), ir_builder_); SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); PrimitiveType primitive_type = hlo->shape().element_type(); @@ -1895,7 +1871,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( // Given an output index [a,b,c,d,e] in the result, we compute: // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - IrArray::Index lhs_index, rhs_index; + IrArray::Index lhs_index(index_type), rhs_index(index_type); for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); @@ -1994,6 +1970,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kShiftLeft: diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index b43dc0c65d9b6e7c05e06010ba2ff2eb27392295..8980d4303353a132ada2b3c685b4f2856c33c6a1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -33,7 +33,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); } }; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 8119478ce934da06969024905e5e054e0b509b03..fd75847d0c0e737957401b8efc420d504a3c0706 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -82,7 +82,18 @@ StatusOr Executable::ExecuteOnStreamWrapper( StatusOr return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); - TF_RETURN_IF_ERROR(return_value.status()); + if (!return_value.status().ok()) { + if (profile != nullptr) { + // Ensure the ThenStartTimer call has completed before we destroy timer. + // We already have a failure status to return, so just log this if it + // fails. + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + LOG(ERROR) << "Failed to BlockHostUntilDone: " << status; + } + } + return return_value.status(); + } if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; @@ -116,6 +127,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( if (profile->compute_time_ns() == 0) { profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); } + + const int64 executable_size_in_bytes = SizeInBytes(); + if (executable_size_in_bytes != 0) { + profile->set_executable_size_in_bytes(executable_size_in_bytes); + } } if (profile_ptr != nullptr) { @@ -129,19 +145,7 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } -Status Executable::DumpSessionModule() { - TF_RET_CHECK(dumping()); - const string& directory_path = - module_config().debug_options().xla_dump_executions_to(); - VersionedComputationHandle versioned_handle = entry_computation_handle(); - // This filename does not include the version number because the computation - // is only ever executed at one version. - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), - session_module_->entry().name().c_str(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, - *session_module_); -} +int64 Executable::SizeInBytes() { return -1; } Status Executable::DumpHloSnapshot() { TF_RET_CHECK(dumping_snapshot()); @@ -156,26 +160,6 @@ Status Executable::DumpHloSnapshot() { return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const SessionModule& session_module) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(session_module, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); -} - /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const HloSnapshot& hlo_session) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 4f0466c544738fa1ec4602ee5104daee8d969c83..98eaeee30a693211ae564a5ef3c373f0364bef11 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -27,9 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -90,8 +88,7 @@ class Executable { // called explicitly for other (async, for example) variants after the stream // has completed. virtual Status PopulateExecutionProfile( - HloExecutionProfile* hlo_execution_profile, - se::StreamExecutor* executor) { + HloExecutionProfile* hlo_execution_profile, se::Stream* stream) { return Status::OK(); } @@ -132,25 +129,15 @@ class Executable { const HloModuleConfig& module_config() const { return hlo_module_->config(); } - // Returns the versioned computation handle of the computation computed by - // this executable. - const VersionedComputationHandle& entry_computation_handle() const { - return hlo_module_->entry_computation_handle(); - } - // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& host_result_shape() const { - return hlo_module_->config().host_entry_computation_layout().result_shape(); + const Shape& result_shape() const { + return hlo_module_->config().entry_computation_layout().result_shape(); } - // TODO(b/74197823): Delete the session module dumping helpers. - void set_session_module(std::unique_ptr session_module) { - session_module_ = std::move(session_module); - } - bool dumping() const { return session_module_ != nullptr; } - SessionModule* session_module() const { return session_module_.get(); } - Status DumpSessionModule(); + // Returns the size of the executable in bytes. Returns -1 by default if the + // method is not overridden to support this kind of query. + virtual int64 SizeInBytes(); // Dumping helpers. void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { @@ -160,10 +147,6 @@ class Executable { HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } Status DumpHloSnapshot(); - // Dump session_module to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const SessionModule& session_module); - // Dump hlo snapshot to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, const HloSnapshot& hlo_session); @@ -179,9 +162,6 @@ class Executable { // around. const std::unique_ptr hlo_module_; - // SessionModule this was compiled from. Null if not dumping executions. - std::unique_ptr session_module_; - // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md similarity index 100% rename from tensorflow/compiler/xla/tools/parser/README.md rename to tensorflow/compiler/xla/service/g3doc/hlo_parser.md diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 2d3e4b1fcdf6675955714cab262a8b2ca8ff4297..7cd2c9c136acac46e8e6c548c9e58b9bc8e6e0d2 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -300,7 +300,7 @@ static StatusOr PermuteGatherAndWindowDims( StatusOr GatherExpander::ExpandGather( HloInstruction* gather_instr) { - CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); @@ -369,7 +369,7 @@ StatusOr GatherExpander::Run(HloModule* module) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::HasZeroElements(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()); }; std::vector gather_instrs; diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 1c72ca066502eb549bf8638cdf0b7827b06f92d7..020ffcd106862cb2641a9f3bceb70acdd969a458 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -36,7 +36,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); Status status = GatherExpander{}.Run(module.get()).status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); @@ -63,7 +63,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); ASSERT_TRUE(changed); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 5ee67ccb4ae147683c7b41941670c6fc413a0d09..85e28a0dfe38415974e435106a2d0b75863f2df5 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -43,7 +43,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -52,12 +52,24 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - element_pointers.data(), region); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, GetByteSizeRequirement(shape), element_pointers.data(), region)); + // Ensure the buffer is transferred before we destroy element_pointers. + return stream->BlockHostUntilDone(); +} + +void GenericTransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) { + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + return done(status); + } + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); } StatusOr> -GenericTransferManager::TransferLiteralFromDevice( +GenericTransferManager::TransferLiteralFromDeviceInternal( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; @@ -74,9 +86,8 @@ GenericTransferManager::TransferLiteralFromDevice( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (!ShapeUtil::IsTuple(subshape)) { - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, + if (ShapeUtil::IsArray(subshape)) { + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ @@ -88,8 +99,8 @@ GenericTransferManager::TransferLiteralFromDevice( return std::move(literal); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, +Status GenericTransferManager::TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -103,9 +114,10 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK( ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + TF_RET_CHECK(stream->parent()->device_ordinal() == + device_buffer.device_ordinal()); - TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); + TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), @@ -121,16 +133,21 @@ Status GenericTransferManager::TransferLiteralToDevice( if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { source = subliteral.untyped_data(); + return TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); } else { // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); source = relayed_out_literal->untyped_data(); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory)); + return stream->BlockHostUntilDone(); } - return TransferBufferToDevice( - executor, - /*size=*/GetByteSizeRequirement(device_subshape), source, - &device_memory); } return Status::OK(); }); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3da9570ef7eebcdf618439f628fb4d5589993e4f..d216fe7d29e8f2e84ab4f558ee5caec32d07a70a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -41,12 +41,13 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; + void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) override; - Status TransferLiteralToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; @@ -64,11 +65,14 @@ class GenericTransferManager : public TransferManager { const void* source) override; Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: + StatusOr> TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4012f87f2bf69d1ab056da5d6c750441c7404980..d90b0fb57d7acd24576e9e8e41316b19b6c44979 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,8 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -23,6 +25,11 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +xla_proto_library( + name = "backend_configs", + srcs = ["backend_configs.proto"], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -133,6 +140,7 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -156,6 +164,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -228,6 +237,20 @@ cc_library( ], ) +cc_library( + name = "hlo_execution_profiler", + srcs = ["hlo_execution_profiler.cc"], + hdrs = ["hlo_execution_profiler.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:pool", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "gpu_executable", srcs = [ @@ -266,8 +289,10 @@ cc_library( "while_thunk.h", ], deps = [ + ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", + ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", ":partition_assignment", @@ -322,6 +347,7 @@ cc_library( srcs = ["cudnn_convolution_algorithm_picker.cc"], hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", @@ -338,6 +364,7 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -401,10 +428,42 @@ tf_cc_test( srcs = ["instruction_fusion_test.cc"], deps = [ ":instruction_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:multi_output_fusion", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + deps = [ + ":multi_output_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -446,9 +505,9 @@ tf_cc_test( ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -508,6 +567,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", + ":multi_output_fusion", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -525,7 +585,6 @@ cc_library( "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -542,6 +601,8 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -553,7 +614,6 @@ cc_library( "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", - "@llvm//:support", ], alwayslink = True, # Contains compiler registration ) @@ -587,14 +647,18 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ + ":gpu_options", ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -691,6 +755,28 @@ cc_library( ], ) +cc_library( + name = "gpu_options", + srcs = ["gpu_options.cc"], + hdrs = ["gpu_options.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + tf_cc_test( name = "gpu_hlo_support_checker_test", srcs = ["gpu_hlo_support_checker_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto new file mode 100644 index 0000000000000000000000000000000000000000..640c6392b8b820c708b853c2a3cea4d4116e85a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package xla.gpu; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstrucitons and later +// uses during e.g. codegen. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + // Opaque algorithm number of cudnn algorithm chosen for this conv. + int64 algorithm = 1; + + // Whether we may use tensor cores when running this conv. Even if this is + // true, cudnn may choose not to use tensor cores, e.g. because the GPU or + // selected algorithm doesn't support it. + bool tensor_ops_enabled = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 77a48965e031349b045a956fd3f28c58607328e5..5e4fe1dd398dedd999e18d7ef6dfb5a4fd3bf4cb 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -43,7 +44,9 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable, } Status ConditionalThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); // Copy the predicate value from device. bool predicate; se::DeviceMemoryBase predicate_address = @@ -59,10 +62,15 @@ Status ConditionalThunk::ExecuteOnStream( // Execute the true or the false computation depending on the value of the // predicate. if (predicate) { - TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream)); + profiler->StartHloComputation(); + TF_RETURN_IF_ERROR( + true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->true_computation()); } else { + profiler->StartHloComputation(); TF_RETURN_IF_ERROR( - false_thunk_.ExecuteOnStream(buffer_allocations, stream)); + false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->false_computation()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index ee03865d174469285a9e98b8a30fea90d997df37..aef24342c9fe182eb54b1c2beff840a76e7b8115 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -50,7 +51,8 @@ class ConditionalThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice predicate_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index f0881124128c9b043392ffc4fa3aee2cd5b754c7..7833a4077e6c6ee4960665f37fb01a35530fd302 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -55,7 +56,8 @@ ConvolutionThunk::ConvolutionThunk( tensor_ops_enabled_(tensor_ops_enabled) {} Status ConvolutionThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase input_data = buffer_allocations.GetDeviceAddress(input_buffer_); se::DeviceMemoryBase filter_data = @@ -68,6 +70,7 @@ Status ConvolutionThunk::ExecuteOnStream( se::dnn::AlgorithmConfig algorithm_config( se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 6d845025b1aef2b0a5f147401b6db0598ba94d6d..d76ca6698dcf462c3c4961ce6a9784822af3a81f 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,8 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: class ScratchAllocator; diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index ee38c0318a878c7bcdc02afdcd146bfb4498d9a2..92e03f94c11f68082f0a8caa64f82e8533557194 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,9 +31,11 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( mem_size_(mem_size) {} Status HostToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); return Status::OK(); } @@ -47,11 +50,13 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( mem_size_(mem_size) {} Status DeviceToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemcpy(&destination_data, source_data, mem_size_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 8b128386f61636de9ac41e856a2b00c578e05735..91564b520acae1839e0a466cf580db00bdf57e46 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const void* source_address_; @@ -63,7 +65,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index db6924c742e4a949a3e939b6d6659e92c2d1e312..c77e3c81c9d38af7857ad1389d20221514bf38f1 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -126,12 +126,17 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { HloInstruction* variance_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-2))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + inverse_stddev->shape(), + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-2))), + {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( variance_plus_epsilon->shape(), HloOpcode::kSubtract, - variance_plus_epsilon, epsilon)); + variance_plus_epsilon, + computation_->AddInstruction(HloInstruction::CreateBroadcast( + variance_plus_epsilon->shape(), epsilon, {})))); // Repackage the results. std::unique_ptr new_tuple = HloInstruction::CreateTuple({ @@ -175,12 +180,17 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { HloInstruction* var_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( batch_norm->operand(3)->shape(), HloOpcode::kAdd, - batch_norm->mutable_operand(3), epsilon)); + batch_norm->mutable_operand(3), + computation_->AddInstruction(HloInstruction::CreateBroadcast( + batch_norm->operand(3)->shape(), epsilon, {})))); HloInstruction* inverse_stddev = computation_->AddInstruction(HloInstruction::CreateBinary( var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-.5))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + var_plus_epsilon->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(-.5))), + {})))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 68099fd63847ef9993f9bc7ac0e28b2939631b35..7b172812c36bb141787ef3a9285d6f7ce13e343b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -99,13 +100,15 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( } Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(hlo_instruction()->shape(), feature_index_); se::DeviceMemory output(buffer_allocations.GetDeviceAddress(output_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationForward( se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), @@ -123,6 +126,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( /*is_training=*/false, // /*var_to_inv_var=*/nullptr, // /*inv_var_to_var=*/nullptr); + if (!stream->ok()) { return InternalError("BatchNormalizationForward call failed."); } @@ -158,7 +162,8 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( } Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; // The BatchNormTraining HLO outputs a tuple of three elements: output data, @@ -175,6 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(output_inv_stddev_)); se::DeviceMemory null_device_ptr(nullptr); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationForward( se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), @@ -240,7 +246,8 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( } Status CudnnBatchNormBackwardThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; @@ -257,6 +264,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( se::DeviceMemory output_grad_offset( buffer_allocations.GetDeviceAddress(output_grad_offset_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationBackward( se::DeviceMemory( buffer_allocations.GetDeviceAddress(grad_output_)), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index 874f85a863092ee05ae5df1f92d732318c5a0554..d2143b3952984722d136757255aa0aa60e9cab7e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -60,7 +61,8 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { const CudnnBatchNormForwardInferenceThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; @@ -90,7 +92,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const CudnnBatchNormForwardTrainingThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; @@ -123,7 +126,8 @@ class CudnnBatchNormBackwardThunk : public Thunk { delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 6a46bdb9b438f81dc564b9033f5d302f90b6a997..3dc98c4c93ea2b9b68dd3ee27794a39847f8756c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -316,21 +317,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( Shape new_call_shape = ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - HloInstruction* algorithm_hlo = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); - HloInstruction* tensor_ops_enabled_hlo = - computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(tensor_ops_enabled))); + + CudnnConvBackendConfig backend_config; + backend_config.set_algorithm(algorithm); + backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, - tensor_ops_enabled_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1)}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( instr->convolution_dimension_numbers()); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely // (conv_result, u8[0]). diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index e0c73aa73acb7f3313eb54fb07390cb76590433e..f9dccd287d955502858f6c24ccd4de80256fc148 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { } // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || - ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 10b4c3de89989c52cfea5273c3d5b0beef76abd2..0645fbb3ad39f1f1649caf45a6068b5a196c30b9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -113,8 +115,17 @@ Status RunCudnnConvolution( // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) + input_descriptor.set_layout(input_dl) .set_feature_map_count( input_shape.dimensions(dnums.input_feature_dimension())) .set_count(input_shape.dimensions(dnums.input_batch_dimension())); @@ -126,7 +137,7 @@ Status RunCudnnConvolution( } FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( @@ -149,7 +160,7 @@ Status RunCudnnConvolution( } BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) + output_descriptor.set_layout(output_dl) .set_feature_map_count( output_shape.dimensions(dnums.output_feature_dimension())) .set_count(output_shape.dimensions(dnums.output_batch_dimension())); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index e5e2a0478a0659986ddec8d6785827b14b9efb56..27d2c3e491bfc2108cbd168d1a5e1575c2eed11f 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -53,11 +53,17 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrAppend; +namespace { // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value); + if (operand->opcode() == HloOpcode::kConstant && + operand->literal().IsAllFloat(value)) { + return true; + } + return operand->opcode() == HloOpcode::kBroadcast && + IsFPLiteralWithValue(operand->operand(0), value); } +} // namespace GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, @@ -370,11 +376,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( "reduce_window_accum_ptr", ir_builder_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); ir_builder_->CreateStore(init_value, accum_ptr); } - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -385,14 +397,14 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); - IrArray::Index input_index(index.size()); + IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = ir_builder_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = ir_builder_->CreateNSWMul( - index[i], ir_builder_->getInt64(window.dimensions(i).stride())); + index[i], index_typed_const(window.dimensions(i).stride())); input_index[i] = ir_builder_->CreateNSWSub( ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), - ir_builder_->getInt64(window.dimensions(i).padding_low())); + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This @@ -403,7 +415,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( in_bounds, ir_builder_->CreateICmpULT( input_index[i], - ir_builder_->getInt64(operand->shape().dimensions(i)))); + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -429,11 +441,13 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( llvm::Value* accum_ptr = ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( hlo->shape().element_type(), module_)); + llvm::Type* index_type = output_index.GetType(); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index_type))); ir_builder()->CreateStore(init_value, accum_ptr); - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index e14ee6918bf148861ecccac99355fccf7ae93103..0cdddf8bcfd4e849b311bf810eda471d79dbf106 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -107,7 +108,8 @@ FftThunk::FftThunk(FftType fft_type, output_shape_(output_shape) {} Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -116,6 +118,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (fft_plan_ == nullptr) { const int64 fft_rank = fft_length_.size(); CHECK_LE(fft_rank, 3); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index b0a22564f3a09bb67a3c01723f6e37c604656d45..8c53be5077b0c5a88d303c729457139c6cb800f1 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -72,7 +73,8 @@ class FftThunk : public Thunk { // Does the FFT for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index b36539e0cb8d0a2f4758dd90acbdd8fc7181b8ca..4fdc55909a1afbac96aaa9bc931ed8ac6c0ae1df 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -37,11 +38,15 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (int64 i = 0; i < loop_limit_; ++i) { + profiler->StartHloComputation(); // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 41ddfe0ceb1d0516c1c64feca53212a925632209..c2d39071b292c6704e9b5857a68bd8b3f3b9a914 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -39,7 +40,8 @@ class ForThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 2217776c7d5a5f92c520d56222988f80401be9e4..b22bb1d39ba177ef42673c7a3755694b43c15d14 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { @@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {}; // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule MergeSharedFusionInstruction comp.3 { @@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 { // // Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule FlopsToBytesRatioThresholdExceeded comp.2 { @@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { // is merged into Fusion0 and Fusion1) would exceed the bytes transferred // threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdExeceeded comp.2 { @@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 { // Fusion2 is reduced for this test which makes the merge operation into its // operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdNotExeceeded comp.2 { @@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { // Check that we're willing to merge f1_computation into f2_computation, even // though f2 is an input fusion node. TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m f1_computation { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 79fca43d022816645b8a07b9e806fe9cc3745e7c..dbc7754e251eb8075ab97dd2f36bbc400530fcf5 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -252,7 +252,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, alpha_(alpha) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -352,6 +353,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, alpha_, stream); }; + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); bool launch_ok; if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { launch_ok = launch( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 7a4830d64e7caef5a1170cbdbf8ab373fdaf16e2..939c7f85e35b4fcb943a25aa6346d72798432920 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -48,7 +49,8 @@ class GemmThunk : public Thunk { // Does the gemm operation for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index df494a1aa961c3d3da0403015d38e29e67c19dde..decfc40dafafe875fa02bab6695f5c54e522f267 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -52,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -73,6 +73,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -160,11 +162,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); - - // Rewrite gather ops into smaller ones. - pass.AddPass(); + /*rewrite_grad_op=*/true); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -174,6 +172,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -200,18 +199,28 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout(), stream_exec); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); // Choose the fastest algorithm for each conv. // - // In theory doing this here is way too early: It needs to happen after - // layout assignment, because the layout of the inputs/outputs affects the - // speed of the conv. But currently we only allow only one input/output - // layout when calling cudnn, so there's no ambiguity. - // - // We pick the algorithm at this early stage so we can generate better HLO. - // After CudnnConvolutionRewriter, our convolutions are CustomCalls which - // return a tuple (conv_result, scratch_memory), and the each conv uses 0 - // bytes of scratch: + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: // // customcall = (f32[...], f32[0]) // return gte(customcall, 0) @@ -227,35 +236,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after layout - // assignment, fusion would already have run, and the gte(customcall, 0) - // would probably already be into a fusion node. We can't simplify across - // HloComputation boundaries, so in this case we wouldn't be able to - // simplify away the new_tuple bits. - // - // We'll need to revisit this if we ever allow multiple layouts for the - // inputs/outputs of a cudnn convolution. + // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass(); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassPipeline pipeline("layout_assignment"); - pipeline.AddPass( - hlo_module->device_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); pipeline.AddPass(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -266,6 +255,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); + fusion.AddPass(); + fusion.AddPass(/*is_layout_sensitive=*/true, + /*only_fusion_computations=*/true); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); @@ -282,6 +274,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } + + { + // Do an aggressive LICM pass over while loops. In particular, this hoists + // constants that were sunk by WhileLoopConstantSinking. Leaving them in + // the while loop may result in unnecessary copies. + HloPassPipeline pipeline("while-loop-licm"); + pipeline.AddPass(true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 9db85bc788bde46c890a46ce9b0902ddce3f5675..fbc1303085b579e898d2f503a341754109768567 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -52,61 +52,20 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants - // in IR. - for (HloInstruction* hlo : - module->entry_computation()->MakeInstructionPostOrder()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - const auto& values = dataflow->GetValueSet(operand).values(); - if (std::any_of(values.begin(), values.end(), [](const HloValue* value) { - return value->defining_instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy)); - changed = true; - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last two arguments to a CUDNN convolution are two HLO constants for - // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo)) { - // For all other library calls, materialize all the operands into memory. - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } - } - - // Init values of while and conditional nodes cannot be constants. Insert - // copies for any constants found at the operands of these nodes. + // in IR. Also, init values of while and conditional nodes cannot be + // constants. Insert copies for any constants found at the operands of these + // nodes. tensorflow::gtl::FlatSet inserted_copies; for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kConditional) { - continue; - } - for (auto operand : instruction->operands()) { + for (HloInstruction* hlo : computation->instructions()) { + // Inserts a copy of hlo->operand(n) if it's a constant. + auto copy_operand_if_constant = [&](int64 n) -> Status { + HloInstruction* operand = hlo->mutable_operand(n); // Skip the operands that have already been replaced with a copy in a // previous iteration (which is possible when a constant is used as an // operand in multiple places). if (ContainsKey(inserted_copies, operand)) { - continue; + return Status::OK(); } for (auto& pair : dataflow->GetInstructionValueSet(operand)) { const HloValueSet& value_set = pair.second; @@ -122,6 +81,47 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { } } } + return Status::OK(); + }; + + if (IsCustomCallToDnnBatchNorm(*hlo)) { + // The epsilon and feature_index operands to a CUDNN batchnorm op don't + // need to be materialized in memory -- in fact, they must be constants. + // These are the last two operands of all three batchnorm ops. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum || + hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kConditional) { + // For all other library calls, cross-replica-sum, while and conditional + // ops materialize all the operands into memory. (Cross-replica-sum + // gets its constant args materialized even if it's not implemented as a + // libcall to simplify the implementation. It's slower, but we can + // constant fold away constant args *anyway*, so we just need to make it + // work.) + for (int64 i = 0; i < hlo->operand_count(); ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } + } + } + + if (changed) { + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; + } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); + ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); + } } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 25d8f720ea4791a4c94efcad6909cd0c113fbe70..0cad2958c72797b4d70f00676928b2b21d7a3e8d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -41,77 +41,6 @@ namespace { using tensorflow::tracing::ScopedAnnotation; -// A helper class for profiling HLO in the course of GPU program execution. -// All of the profiling is guarded internally, to avoid the caller needing to -// have lots of conditionals sprinkled around. -class HloExecutionProfiler { - public: - // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation) - : do_profile_(do_profile), - profile_(profile), - stream_(stream), - sub_streams_(sub_streams), - computation_(computation) { - if (do_profile_) { - clock_rate_ghz_ = - stream->parent()->GetDeviceDescription().clock_rate_ghz(); - execution_timer_.reset(new se::Timer(stream->parent())); - per_op_timer_.reset(new se::Timer(stream->parent())); - stream->InitTimer(execution_timer_.get()) - .ThenStartTimer(execution_timer_.get()); - stream->InitTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, sets the total cycle count on the profile from the - // execution timer. - void FinishExecution() { - CHECK(!finished_execution_) << "Call FinishExecution only once!"; - finished_execution_ = true; - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->set_total_cycles_executed( - *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - // If profiling is enabled, starts the per-operation timer. - void StartOperation() { - if (do_profile_) { - stream_->ThenStartTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, stops the per-operation timer and records the time - // that the hlo_instruction took to execute in the profile. - void FinishOperation(const HloInstruction* hlo_instruction) { - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->SetCyclesTakenBy( - hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - private: - const bool do_profile_; - double clock_rate_ghz_; - HloExecutionProfile* profile_; - se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; - const HloComputation* computation_; - std::unique_ptr execution_timer_; - std::unique_ptr per_op_timer_; - bool finished_execution_ = false; -}; - } // namespace // Implementation note: HLO profiling is always enabled for GPU executables, @@ -207,18 +136,17 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); } - profiler.StartOperation(); VLOG(2) << "Executing the thunk for " << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR( + thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); } - profiler.FinishOperation(thunk->hlo_instruction()); } main_stream->ThenWaitFor(&sub_streams); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 89f1e625884568bf7370b3801d851ef4846c2a98..09ef62c87f8875a5803497e8eb628769f883202a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,31 +18,72 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -// cuDNN convolutions are called with specific layouts on the input, output, -// and filter: -// -// input: DataLayout::kBatchDepthYX -// output: DataLayout::kBatchDepthYX -// filter: FilterLayout::kOutputInputYX -// -// The order dimensions in the constant name is major-to-minor (eg, the -// most-major dimension of the input is batch, most-minor is X). The -// specific dimension numbers these named dimensions correspond to is -// determined by the ConvolutionDimensionNumbers argument. Y is spatial -// dimension 0, and X is spatial dimension 1. -// -// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. -static Status AddBackendConstraintsToDnnConvCustomCall( +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::FilterLayout; + +static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} + +// Returns (input, filter, output) layouts. +static std::tuple +HeuristicLayoutAssignment(const HloInstruction* instr, + stream_executor::StreamExecutor* stream_executor) { + // DataLayout and FilterLayout uses weird enum names. Translations: + // N <=> Batch or Output + // C <=> Depth or Input + // H <=> Y + // W <=> X + // + // Therefore kOutputInputYX and kBatchDepthYX mean NCHW. + + // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x + // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version + // changes, as well as the hardware updates. + if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && + IsVoltaOrLater(*stream_executor))) { + return std::make_tuple(DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); + // For BackwardInput that has stride, full NHWC layouts run significantly + // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + // + // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, + // NHWC). + if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); +} + +// Adds layout constraints on the cudnn custom-call instruction. The layout +// constraints are represented in terms of minor_to_major fields of both +// operands and the output shape. Depending on the underlying algorithm, one of +// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. +Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloInstruction* instr, LayoutConstraints* constraints) { CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); Shape input_shape; @@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall( << instr->custom_call_target(); } - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instr->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; - --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + { + DataLayout input; + FilterLayout filter; + DataLayout output; + if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); + } else { + input = DataLayout::kBatchDepthYX; + filter = FilterLayout::kOutputInputYX; + output = DataLayout::kBatchDepthYX; + } + + TF_ASSIGN_OR_RETURN( + std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), + *output_shape.mutable_layout()), + StreamExecutorConvLayoutsToXlaLayouts( + instr->convolution_dimension_numbers(), input, filter, output)); } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that @@ -132,7 +159,13 @@ static Status AddBackendConstraintsToDnnConvCustomCall( Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { + // Add convolution constraints in reverse postorder that the earliest + // convolution layout propagates first. This reduces the likelihood of fusion + // nodes with copies. + auto post_order = constraints->computation()->MakeInstructionPostOrder(); + for (auto iterator = post_order.rbegin(); iterator != post_order.rend(); + ++iterator) { + HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 51aae79c3d8d0000007f9d2926d245de838d3aca..ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { @@ -27,9 +28,10 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment( - const ComputationLayout& entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + se::StreamExecutor* stream_executor) + : LayoutAssignment(entry_computation_layout), + stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} protected: @@ -42,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment { LayoutConstraints* constraints) override; bool CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) override; + + private: + Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints); + + se::StreamExecutor* stream_executor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 7c801955943021def4ddc0accd9f318b7916ce93..e48165c1426ea04839c245bc20b851a0f1710246 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b4b4e20b633792de4251a4b0e89f4b579053ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { +namespace gpu { + +bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { + return !config.debug_options().xla_backend_extra_options().count( + "xla_gpu_experimental_conv_disable_layout_heuristic"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h new file mode 100644 index 0000000000000000000000000000000000000000..498d4a94955cb2c50e0b165f28ded44ac1c0bfff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +// Helper functions for querying options that are specific to the GPU backend. + +namespace xla { +namespace gpu { + +// Returns true if we should use heuristics to assign convolution layouts, as +// opposed to always assigning NCHW. +bool ConvUseLayoutHeuristic(const HloModuleConfig& config); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 7bb8df6581b49b1bf8c84a972f715e8dc119d8de..5343497c03c13a2589363da0fa33e18520220826 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -55,33 +55,28 @@ Status GpuTransferManager::TransferLiteralToInfeed( return TransferBufferToInfeed(executor, size, literal.untyped_data()); } - if (ShapeUtil::IsNestedTuple(shape)) { - return Unimplemented( - "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); - } - // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. std::vector buffers; - buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } }); - for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); - int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); - TF_ASSIGN_OR_RETURN( - gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, tuple_element_size, - literal.untyped_data({i}))); - buffers.push_back(buffer); - } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(literal_subshape)) { + int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); + TF_ASSIGN_OR_RETURN( + gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + literal.untyped_data(index))); + buffers.push_back(buffer); + } + return Status::OK(); + })); cleanup.release(); return EnqueueBuffersToInfeed(executor, buffers); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e96beb575300614a04c856adbb6d742b34d11df --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace gpu { +namespace { +void InitAndStartTimer(std::stack>* timers, + se::Stream* stream) { + timers->push(MakeUnique(stream->parent())); + stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); +} + +uint64 GetCyclesTaken( + std::stack>* timers, + const std::vector::SmartPtr>& sub_streams, + se::Stream* stream, double clock_rate_ghz) { + CHECK_GT(timers->size(), 0); + stream->ThenWaitFor(&sub_streams); + stream->ThenStopTimer(timers->top().get()); + stream->BlockHostUntilDone().IgnoreError(); + double nanoseconds = timers->top()->Nanoseconds(); + timers->pop(); + return static_cast(nanoseconds * clock_rate_ghz); +} +} // namespace + +HloExecutionProfiler::HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) + : do_profile_(do_profile), + profile_(profile), + stream_(stream), + sub_streams_(sub_streams), + computation_(computation) { + if (do_profile_) { + clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz(); + InitAndStartTimer(&timers_, stream); + } +} + +void HloExecutionProfiler::FinishExecution() { + CHECK(!finished_execution_) << "Call FinishExecution only once!"; + finished_execution_ = true; + if (do_profile_) { + profile_->set_total_cycles_executed( + *computation_, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartHloComputation() { + if (do_profile_) { + InitAndStartTimer(&timers_, stream_); + } +} + +void HloExecutionProfiler::FinishHloComputation( + const HloComputation* computation) { + if (do_profile_) { + profile_->set_total_cycles_executed( + *computation, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartHloInstruction() { + if (do_profile_) { + InitAndStartTimer(&timers_, stream_); + } +} + +void HloExecutionProfiler::FinishHloInstruction( + const HloInstruction* hlo_instruction) { + if (do_profile_) { + profile_->SetCyclesTakenBy( + hlo_instruction, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +std::unique_ptr +HloExecutionProfiler::MakeScopedInstructionProfiler( + const HloInstruction* hlo_instruction) { + return MakeUnique(this, hlo_instruction); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..e5c655edc65a0c58bfde6c7701c8874d39c0b5d7 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +class ScopedInstructionProfiler; + +// A helper class for profiling HLO in the course of GPU program execution. +// All of the profiling is guarded internally, to avoid the caller needing to +// have lots of conditionals sprinkled around. +class HloExecutionProfiler { + public: + // If profiling is enabled, start an execution timer running. + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation); + + // If profiling is enabled, sets the total cycle count on the profile from the + // execution timer. + void FinishExecution(); + + // If profiling is enabled, starts a timer for a (sub)computation. + void StartHloComputation(); + + // If profiling is enabled stops the timer for a (sub)computation and records + // the time that the computation took to execute in the profile. + void FinishHloComputation(const HloComputation* computation); + + // If profiling is enabled, starts a per-operation timer. + void StartHloInstruction(); + + // If profiling is enabled, stops the per-operation timer and records the time + // that the hlo_instruction took to execute in the profile. + void FinishHloInstruction(const HloInstruction* hlo_instruction); + + // Returns a ScopedInstructionProfiler and triggers a call to + // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes + // out of scope, it triggers a call to FinishHloInstruction(). + std::unique_ptr MakeScopedInstructionProfiler( + const HloInstruction* hlo_instruction); + + private: + const bool do_profile_; + double clock_rate_ghz_; + HloExecutionProfile* profile_; + se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; + const HloComputation* computation_; + std::stack> timers_; + bool finished_execution_ = false; +}; + +// This class can be used within the ExecuteOnStream() implementations of +// Thunks. It ensures that we always have a pair of matching +// StartHloInstruction() and FinishHloInstruction() calls to the profiler. +class ScopedInstructionProfiler { + public: + ScopedInstructionProfiler(HloExecutionProfiler* profiler, + const HloInstruction* hlo_instruction) + : profiler_(profiler), hlo_instruction_(hlo_instruction) { + if (hlo_instruction != nullptr) { + profiler->StartHloInstruction(); + } + } + ~ScopedInstructionProfiler() { + if (hlo_instruction_ != nullptr) { + profiler_->FinishHloInstruction(hlo_instruction_); + } + } + + private: + HloExecutionProfiler* profiler_; + const HloInstruction* hlo_instruction_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index f766f968826d960a8e86308f2395301aaa09f1ae..375709150e08996ea6a40f5e9e66a8f8d9287008 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -199,7 +199,7 @@ StatusOr> HloSchedule::Build( // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, - CreateMemoryMinimizingSequence( + ScheduleOneComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index e230d538cc2df826778e8d13eaaaf31ec81c57f0..45f0a1c645b2875cf90d2c11cfb66c3dd855d097 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -47,8 +47,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } HloVec RemoveHlo(const HloVec& input, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 061210352cf12e6802d066d311fd2cb481673f15..d420863b8569771b16a03591b6a0ddd0591f7e2e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -137,7 +137,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType( ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); @@ -158,7 +158,7 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index) { + ShapeIndexView shape_index) { VLOG(2) << "Binding " << hlo.ToString(); const Shape& hlo_shape = hlo.shape(); @@ -202,7 +202,7 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, << " of " << hlo.ToString(); llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); - alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index); // The GPU backend emits one kernel per top-level HLO, and LLVM views // execution of one kernel as the "whole program" executed on the GPU. diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 3d34311b4368d17cb074aaf33c71fc865e96387e..a86e6e78c693ac53bb2c70d88b999a4e1273ecad 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -51,7 +51,7 @@ class HloToIrBindings { // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index = {}); + ShapeIndexView shape_index = {}); // Unbinds all IR values that's defined in an LLVM function, e.g., function // arguments and stack variables. Global variables will be kept in bindings_. @@ -71,7 +71,7 @@ class HloToIrBindings { // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, - const ShapeIndex& shape_index = {}) const { + ShapeIndexView shape_index = {}) const { auto it = base_ptrs_.find(&hlo); CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); @@ -97,7 +97,7 @@ class HloToIrBindings { // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. llvm::Value* GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value); const BufferAssignment* buffer_assignment_; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 3ddc1c0789d746bf021256638342364aac63e0e3..ae310beefad0c81c17fd4140b441b3a19a002e2c 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -49,13 +49,25 @@ void InfeedManager::EnqueueBuffers(const std::vector& buffers) { } InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); + bool became_empty = false; + InfeedBuffer* current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + current_buffer = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + dequeued_buffer_.insert(current_buffer); + if (enqueued_buffer_.empty()) { + became_empty = true; + } + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } } - InfeedBuffer* current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); return current_buffer; } @@ -88,6 +100,10 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } +void InfeedManager::RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); +} + InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index d5f2216d460a45085536b15f9bf6e3bd3579f9c8..a3fc15cfe36a490f38daabca9ff36fbb1012aead 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -100,6 +101,10 @@ class InfeedManager { // returns null. se::Stream* GetStream(se::StreamExecutor* executor); + // Registers a callback that will be called when 'enqueued_buffer_' becomes + // empty. + void RegisterOnEmptyCallback(std::function callback); + private: // TODO(b/30467474): Revisit if this mutex becomes a point of // contention. @@ -122,6 +127,10 @@ class InfeedManager { // Executor that the host_to_device_stream belongs to. Not owned. se::StreamExecutor* host_to_device_executor_; + + // List of callbacks which will be called when 'enqueued_buffer_' becomes + // empty. + std::vector> on_empty_callbacks_; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index ea34d5b30c91e8b809e3e17a904e27e589fd6b5f..62915febb11d5defa0e44b688eacabb16a7621da 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -22,29 +23,31 @@ namespace xla { namespace gpu { InfeedThunk::InfeedThunk( - tensorflow::gtl::ArraySlice tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + const ShapeTree& infeed_slices, const HloInstruction* hlo_instruction) - : Thunk(Kind::kInfeed, hlo_instruction), - tuple_element_buffers_(tuple_element_buffers.begin(), - tuple_element_buffers.end()), - destination_buffer_(destination_buffer) {} + : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(2) << "Infeeding to GPU "; - se::DeviceMemoryBase destination_address = - buffer_allocations.GetDeviceAddress(destination_buffer_); - + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + // First copy the infeed data which is element 0 of the infeed instruction's + // two-tuple output (the other element is a token). + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); std::vector infeed_buffers; - if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { - CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0); + if (ShapeUtil::IsTuple(data_shape)) { + CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // Transfer the tuple elements first. std::vector tuple_element_addresses; - for (BufferAllocation::Slice tuple_element_buffer : - tuple_element_buffers_) { + for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) { + const BufferAllocation::Slice& tuple_element_buffer = + infeed_slices_.element({0, i}); se::DeviceMemoryBase tuple_element_address = buffer_allocations.GetDeviceAddress(tuple_element_buffer); @@ -56,15 +59,23 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, } // Transfer the tuple outer buffer. auto host_size = tuple_element_addresses.size() * sizeof(void*); - stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), + stream->ThenMemcpy(&data_address, tuple_element_addresses.data(), host_size); } else { InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), + stream->ThenMemcpy(&data_address, *(buffer->device_memory()), buffer->length()); } + // Construct top-level tuple of infeed containing the data and the token. Use + // a nullptr for the token, it should never be dereferenced. + std::vector infeed_addresses = {data_address.opaque(), nullptr}; + se::DeviceMemoryBase top_level_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); + stream->ThenMemcpy(&top_level_address, infeed_addresses.data(), + 2 * sizeof(void*)); + Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 93713cb12defd95bdd69cb0aa7ad7b4e37fc8fae..59487e245b78e66c45409fe712e86d3392e50580 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -32,23 +33,19 @@ namespace gpu { class InfeedThunk : public Thunk { public: // Constructs a InfeedThunk that copies data from the on-device - // infeed queue to the device buffer - // `destination_buffer`. `mem_size` is the size of the data in - // bytes. - InfeedThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + // infeed queue into the buffers in the given shape tree. + InfeedThunk(const ShapeTree& infeed_slices, const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: - const std::vector tuple_element_buffers_; - const BufferAllocation::Slice destination_buffer_; + const ShapeTree infeed_slices_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 5d5bef6b57b57fce4255a145634745b38dccacc7..64ed3d748febd8281a8e602194b31c937a4a682a 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -40,6 +40,7 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || @@ -77,15 +78,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (consumer->operand_count() == 2 && - (producer->opcode() == HloOpcode::kDot || - (producer->opcode() == HloOpcode::kFusion && - producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); HloInstruction* op1 = nullptr; HloInstruction* op2 = nullptr; - if (consumer->opcode() == HloOpcode::kFusion && + if (consumer->operand_count() == 1 && + consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && Match(consumer->fused_expression_root(), match::Op() @@ -103,10 +103,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, op2->opcode() != HloOpcode::kBroadcast) { return false; } - if (IsIEEEFloatingPointScalarConstant(alpha)) { + if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) { return true; } - } else if (consumer->opcode() == HloOpcode::kMultiply) { + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = consumer->operand(other_operand_index); // Fuse if 'alpha' is a broadcast of a scalar constant. if (alpha->opcode() == HloOpcode::kBroadcast && alpha->dimensions().empty() && @@ -173,10 +175,38 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Fuse scalar constants into loop fusion nodes, this reduces the number of + // parameters and makes matching scalar broadcasts easier. + if (ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion && + producer->opcode() == HloOpcode::kConstant) { + return true; + } + return IsFusile(*producer) && IsFusile(*consumer) && InstructionFusion::ShouldFuse(consumer, operand_index); } +bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + const HloInstruction* producer = consumer->operand(operand_index); + // The IR emitter has limited support for non-loop fusions with multi output + // at present. + // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Multi-output fusion requires instructions with compatible shapes. + if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { + return false; + } + // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for + // multi-output fusion. In particular, do not check whether an instruction is + // expensive to duplicate, since this doesn't matter here. + return GpuInstructionFusion::ShouldFuse(consumer, operand_index); +} + HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { if (IsReductionToVector(*consumer)) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index 9fb06b0a244186484b1c17edf13bd28a4305a1a6..f629d9ff2c7165b652369612c30979150f93bd24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -31,6 +31,9 @@ class GpuInstructionFusion : public InstructionFusion { bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 760e0e90f583d0e43975e23b731a40af75c7dc17..1963d9eef72d41fa0a275bea98f959671fa7e737 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -140,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { // Tests that broadcasts fused into a fusion with a reduce root. TEST_F(InstructionFusionTest, BroadcastIntoReduce) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module add { @@ -165,11 +168,11 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); EXPECT_THAT(root->fused_expression_root(), - op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); + op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } TEST_F(InstructionFusionTest, BitcastIntoAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -191,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) { } TEST_F(InstructionFusionTest, AddIntoBitcast) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -213,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) { } TEST_F(InstructionFusionTest, DontFuseGTE) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY DontFuseGTE { p0 = (f32[10], f32[10]) parameter(0) @@ -229,7 +232,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) { } TEST_F(InstructionFusionTest, DotOutputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { alpha = f32[] constant(3) @@ -252,13 +255,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { EXPECT_THAT( root->fused_expression_root(), op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = f32[] parameter(0) @@ -281,14 +284,15 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())) + << module->ToString(); } // Compute sum(100/p0), where p0 has type s32, twice. Check that the division // is *not* duplicated and fused into both reduces, because we say that integer // division is not cheap. TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = s32[] parameter(0) @@ -308,11 +312,12 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); } TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY NoOutputFusion { alpha = f32[] constant(3) @@ -334,7 +339,271 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); EXPECT_THAT(root->fused_expression_root(), op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); +} + +// Counts the HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +// Returns an HLO instruction from the given computation with the op code. +static StatusOr FindHloInstruction( + const HloComputation& computation, HloOpcode op) { + for (const auto* instruction : computation.instructions()) { + if (instruction->opcode() == op) { + return instruction; + } + } + return NotFound( + "Computation '%s' does not contain an instruction with op code '%s'.", + computation.name().c_str(), HloOpcodeString(op).c_str()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion) { + // sub --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT( + fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { + // tanh --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + tanh = f32[4,3]{1,0} tanh(p0) + add = f32[4,3]{1,0} add(tanh, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) + })") + .ValueOrDie(); + + // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion2) { + // sub --> add1 --\--------\ + // \----------> add2 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(sub, add1) + ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Add()), + op::Add(op::Subtract(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion3) { + // sub --> add1 ----\--------\ + // \ --> add2 --> add3 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + p3 = f32[4,3]{1,0} parameter(3) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(p2, sub) + add3 = f32[4,3]{1,0} add(add1, add2) + ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Add(), op::Add()), + op::Add(op::Parameter(), op::Subtract()))); +} + +TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { + // sub --> mul ---\ + // \--> call --> add --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + c = f32[] constant(42) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + sub = f32[4,3]{1,0} subtract(p0, p1) + mul = f32[4,3]{1,0} multiply(sub, c) + call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" + add = f32[4,3]{1,0} add(mul, call) + ROOT tuple = (f32[4,3]{1,0}) tuple(add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + // Visit instructions in post order to detect cycles. + // TODO(tjoerg): Add cycle detection to the HloVerifier. + class DummyVisitor : public DfsHloVisitorWithDefault { + public: + DummyVisitor() {} + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + } visitor; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + // Accept will return a FailedPrecondition when a cycle is detected. + EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); + } +} + +TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { + // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) + // \-------------------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[2,3]{1,0} parameter(2) + sub = f32[2,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` + // have incompatible shapes, expect that no multi-output fusion happens. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + fused_computation { + p1 = f32[10] parameter(0) + zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10] parameter(0) + mul = f32[10] multiply(p0, p0) + fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation + ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) + })") + .ValueOrDie(); + + // Multi-output fusion is not supported for non-loop fusions at present. Since + // `fused_computation` is a input fusion, expect no multi-output fusion to + // happen. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseScalarConstant) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY FuseScalarConstant { + p0 = f32[] parameter(0) + c0 = f32[] constant(1) + add1 = f32[] add(p0, c0) + b0 = f32[2]{0} broadcast(add1), dimensions={} + c1 = f32[2]{0} constant({1, 2}) + ROOT add2 = f32[2]{0} add(b0, c1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())), + op::Parameter())); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 22e715099526c20532bb298e84e50457d89f615e..388aa35d7dceeef92dbdb6c8a3bb7fb3796a0b61 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -56,8 +56,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape) && - !ShapeUtil::HasZeroElements(lhs_shape) && - !ShapeUtil::HasZeroElements(rhs_shape); + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); } bool DotImplementedAsGemm(const HloInstruction& dot) { @@ -162,19 +162,8 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn - // algorithm to use, and a boolean indicating whether to use tensor cores. - // - // It's up to a later pass to choose the algorithm and decide whether to use - // tensor cores, so to indicate that we haven't yet made a choice, we speicfy - // -1 and false for those args. - HloInstruction* negative_one = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); - HloInstruction* false_constant = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloInstruction* custom_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one, false_constant}, call_target)); + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 1e0db2821a2c212d0f212ae94ab69231bc6053ea..fe83d017f4cde36cac37400ed16faab225878ea7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -191,6 +191,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( HloOpcode root_opcode = computation.root_instruction()->opcode(); PrimitiveType element_type = computation.root_instruction()->shape().element_type(); + bool is_atomic_integral = element_type == S32 || element_type == U32 || + element_type == S64 || element_type == U64; llvm::Value* source = ir_builder_.CreateLoad(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. @@ -201,7 +203,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( {output_address->getType()}, &ir_builder_); return true; } - if (primitive_util::IsIntegralType(element_type)) { + if (is_atomic_integral) { // integral + integral ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, @@ -210,9 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } } - // NVPTX supports atomicMax and atomicMin on only integer types. - if (root_opcode == HloOpcode::kMaximum && - primitive_util::IsIntegralType(element_type)) { + // NVPTX supports atomicMax and atomicMin only on integer types. + if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) { // max(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max @@ -222,8 +223,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( return true; } - if (root_opcode == HloOpcode::kMinimum && - primitive_util::IsIntegralType(element_type)) { + if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) { // min(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min @@ -421,24 +421,27 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); - auto on_true = select->operand(1); - auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); - - if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), - GetIrArray(*pred, *select), - GetBasePointer(*on_true), - GetBasePointer(*on_false), &ir_builder_, module_); - return Status::OK(); - } - // We must not call the subclass `DefaultAction` method, lest its // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction` // assume no handler has already been called. return IrEmitter::DefaultAction(select); } +Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { + auto pred = tuple_select->operand(0); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + TF_RET_CHECK(pred->shape().element_type() == PRED); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); + TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), + GetIrArray(*pred, *tuple_select), + GetBasePointer(*on_true), GetBasePointer(*on_false), + &ir_builder_, module_); + return Status::OK(); +} + namespace { llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { return ir_builder->CreateExtractValue(x, {0}); @@ -475,12 +478,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_type = ir_builder_.getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { // If the operands are scalar, don't emit any loops. llvm::Value* lhs_value = - lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + lhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); llvm::Value* rhs_value = - rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + rhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); @@ -490,7 +496,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { } else { result = ir_builder_.CreateFMul(lhs_value, rhs_value); } - target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); + target_array.EmitWriteArrayElement(/*index=*/element_index, result, + &ir_builder_); return Status::OK(); } @@ -581,7 +588,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // address. The index into the target address is the concatenation of the rhs // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(index_type); for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); @@ -607,7 +614,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { - if (ShapeUtil::HasZeroElements(convolution->shape())) { + if (ShapeUtil::IsZeroElementArray(convolution->shape())) { // Emit no code for an empty output. return Status::OK(); } @@ -617,7 +624,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { } Status IrEmitter::HandleFft(HloInstruction* fft) { - if (ShapeUtil::HasZeroElements(fft->shape())) { + if (ShapeUtil::IsZeroElementArray(fft->shape())) { // Emit no code for an empty output. return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index b0accc08d479258d65a18202122e4c9e90ff78d0..d2dd335f10cc8346c5f941e5c8c6b5c403722fa3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -88,6 +88,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; @@ -120,10 +121,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } - // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice. - BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const { + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { return ir_emitter_context_->buffer_assignment() - .GetUniqueTopLevelSlice(&hlo) + .GetUniqueSlice(&hlo, index) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index bb47a4280541ce2806472aa9365bb0ef38c0c3b3..c9574c87a3be208915b3d6a32679553eb425d2f0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -120,9 +120,10 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { + const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); std::vector target_arrays; - for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; - ++i) { + target_arrays.reserve(num_elems); + for (int64 i = 0; i != num_elems; ++i) { target_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( @@ -130,6 +131,7 @@ Status IrEmitterNested::EmitTargetElementLoop( .EmitLoop()); std::vector tuple_operand_ptrs; + tuple_operand_ptrs.reserve(num_elems); for (const llvm_ir::IrArray& array : target_arrays) { tuple_operand_ptrs.push_back(array.GetBasePointer()); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 0d7ba4cf9a65840eb7ff7785b006c538740502f0..80208e1c98506a2d69125aa80d08218f4716101f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -58,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -79,6 +81,7 @@ namespace { using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::InlinedVector; using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; using tensorflow::strings::StrCat; @@ -267,7 +270,10 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Find the largest possible power of two to unroll by. // TODO(kramerb): Make this smarter. - int64 num_elements = ShapeUtil::ElementsIn(hlo->shape()); + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); for (int i = max_unroll_factor; i > 1; i /= 2) { if (num_elements % i == 0) { return i; @@ -277,6 +283,69 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Cannot unroll. return 1; } + +// Returns the llvm type for the indices used in the kernel that contains the +// hlo instruction. Such indices include the index for the parallel loop and +// the indices for the tensors accessed by the kernel. The return type is i32 +// iff the following conditions are met: +// . The launch_size of the kernel is within the range of i32. +// . The sizes of all the tensors accessed within the kernel are within the +// range of i32. +// Otherwise, the return type is i64. +llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, + llvm::IRBuilder<>* ir_builder) { + // Find the unnested hlo instructon for which the kernel is generated for. + const HloInstruction* unnested_hlo = hlo; + const HloComputation* computation = hlo->parent(); + if (computation->IsFusionComputation()) { + unnested_hlo = computation->FusionInstruction(); + } + + auto shape_in_range = [&](const Shape& s) { + bool in_range = true; + ShapeUtil::ForEachSubshape( + s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(sub_shape) && + !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); + + return in_range; + }; + + llvm::Type* i64_ty = ir_builder->getInt64Ty(); + // Check launch dimension + if (!IsInt32(launch_size)) { + return i64_ty; + } + + // Check the size of result tensors + if (!shape_in_range(unnested_hlo->shape())) { + return i64_ty; + } + + auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool { + return shape_in_range(operand->shape()); + }; + + // Check the size of input tensors + if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + return i64_ty; + } + + // Check the size of the internal result tensors + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + if (!c_all_of( + unnested_hlo->fused_instructions_computation()->instructions(), + hlo_shape_in_range)) { + return i64_ty; + } + } + + return ir_builder->getInt32Ty(); +} + } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { @@ -419,15 +488,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const HloInstruction* algorithm_inst = custom_call->operand(2); - CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); - int64 algorithm = algorithm_inst->literal().Get({}); - - const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); - CHECK(tensor_ops_enabled_inst->IsConstant()) - << tensor_ops_enabled_inst->ToString(); - bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); - + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -442,7 +504,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -455,7 +518,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -468,7 +532,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -496,12 +561,31 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kTuple: case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); + ArraySlice output_instructions = + root->opcode() == HloOpcode::kTuple + ? root->operands() + : ArraySlice(&root, 1); + + // For multi-output fusion emit an initializer for each tuple element. + // Otherwise it's sufficient to just initialize the single output. + HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion, output_instructions[i] == root + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + first_reduce = + first_reduce == nullptr ? output_instructions[i] : first_reduce; + } + } + CHECK(first_reduce != nullptr); thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); @@ -515,11 +599,54 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - Shape input_shape = root->operand(0)->shape(); - return EmitReductionToVector( - root, input_shape, fused_emitter.GetGenerator(root->operand(0)), - fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), - root->to_apply()); + // For multi-output fusion CHECK the constraints and feed all the + // reduces into a single loop code generator. Single-output reduce + // fusion is a special case of that. + InlinedVector input_gens; + InlinedVector init_value_gens; + std::vector> + extra_output_gens; + InlinedVector reducers; + InlinedVector reduce_output_shapes; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (root->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + CHECK(IsReductionToVector(*inst)) + << "Only reductions to vector are supported"; + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + CHECK(first_reduce->dimensions() == inst->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(inst->operand(1))); + reducers.push_back(inst->to_apply()); + reduce_output_shapes.push_back(std::move(output_shape_index)); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + return EmitReductionToVector(first_reduce, input_shape, input_gens, + init_value_gens, + first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -565,12 +692,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int unroll_factor = 1; - // TODO(kramerb): Unrolling multi-output loop fusions too. - if (!fusion->IsMultiOutputFusion()) { - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - unroll_factor = ComputeMaxUnrollFactor(fusion); - } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); @@ -908,10 +1031,33 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { + for (int i = 0; i != extra_output_gens.size(); ++i) { + const HloInstruction* output = reduce->parent()->FusionInstruction(); + llvm::Value* extra_output_address = + GetIrArray(*output, *output, extra_output_gens[i].second) + .EmitArrayElementAddress(index, &ir_builder_, + "extra_output_element_address"); + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + extra_output_gens[i].first(index)); + ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + } + return Status::OK(); +} + Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; int64 num_elems = ShapeUtil::ElementsIn(input_shape); @@ -923,6 +1069,20 @@ Status IrEmitterUnnested::EmitReductionToScalar( int64 num_tiles = RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + llvm::Type* index_ty = GetIndexTypeForKernel( + reduce, + launch_dimensions.block_count() * launch_dimensions.threads_per_block(), + &ir_builder_); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Check whether every thread will process a full tile's worth of elements // without reading outside the bounds of the input. If this is true, we can // skip some bounds checks in the final algorithm. @@ -963,62 +1123,74 @@ Status IrEmitterUnnested::EmitReductionToScalar( // auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; + x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_const(0), + index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* x = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)), tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), + ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)), "x_in_bounds", &ir_builder_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); } + llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); }; // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. llvm::Value* x_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); + index_typed_const(kTileSize), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), + ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)), ir_builder_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); @@ -1042,20 +1214,24 @@ Status IrEmitterUnnested::EmitReductionToScalar( : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1065,27 +1241,30 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = ir_builder_.CreateURem( - x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); + x_in_tiles, index_typed_const(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), - output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + /*linear=*/ir_builder_.getInt64(0), + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterates through all input tiles, one per thread. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1093,13 +1272,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Divide the input matrix into tiles of size Kx1. For example, when the // input matrix is 4x4 and K=2, the tiled matrix looks like // @@ -1109,12 +1293,27 @@ Status IrEmitterUnnested::EmitColumnReduction( // 4567 // Numbers indicate tile IDs. // // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. We - // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. - constexpr int64 kTileSize = 16; + // scalar is accumulated to the output vector using atomic operations. + // + // We choose 128 as the tile size based on empirical evidence. It's big enough + // to reduce the amount of atomic adds in the end, maximizing the memory + // bandwidth. + constexpr int64 kTileSize = 128; + // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_ty = ir_builder_.getInt64Ty(); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < height_in_tiles * width; @@ -1141,15 +1340,21 @@ Status IrEmitterUnnested::EmitColumnReduction( // } auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } // Emit an inner for-loop that partially reduces the elements in the given @@ -1157,24 +1362,27 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x = tile_index[1]; + y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); + x = ir_builder_.CreateZExtOrTrunc(x, index_ty); + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_const(0), + index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)), + ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)), tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a // y-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)), + ir_builder_.CreateICmpULT(y, index_typed_const(height)), "y_in_bounds", &ir_builder_); // Emit code that reads the input element and accumulates it to @@ -1207,22 +1415,27 @@ Status IrEmitterUnnested::EmitColumnReduction( .SourceIndexOfTranspose(normalized_input_shape, input_shape, transpose_dimension_mapping, &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); } - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); }; // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's // immediately beyond the tile. llvm::Value* y_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize))); + index_typed_const(kTileSize), + ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize))); llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)), + ir_builder_.CreateICmpULE(y_end, index_typed_const(height)), ir_builder_.getInt1(height % kTileSize == 0)); // The tile is entirely in bound if "height" is a multiple of kTileSize or // y_end <= height. @@ -1242,20 +1455,23 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterate through all input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1263,15 +1479,45 @@ Status IrEmitterUnnested::EmitColumnReduction( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); +} + +static std::pair ComputeTilingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); } Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // A naive algorithm is: - // 1. Divide the input tensor into tiles of size 1x1xK. + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. // 2. Partially reduces each tile to a scalar using one thread. // 3. Accumulates that scalar to the output vector using atomic operations. // @@ -1282,15 +1528,15 @@ Status IrEmitterUnnested::EmitRowReduction( // int y = linear_index / width_in_tiles % height; // int z = linear_index / (height * width_in_tiles); // float partial_result = 0; - // for (element_id_in_tile : range(kTileSize)) { - // int x = x_in_tiles * kTileSize + element_id_in_tile; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) - // partial_result = reducer(partial_result, input[z][y][z]); + // partial_result = reducer(partial_result, input[z][y][x]); // } // AtomicReducer(&output[y], partial_result); // } // - // Three optimizations are performed. + // Four optimizations are performed. // // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead @@ -1317,29 +1563,46 @@ Status IrEmitterUnnested::EmitRowReduction( // element_id_in_tile, which makes the code more friendly to optimizations // such as LICM. // + // 4. When the width is too small and x_tile_size is less than the target + // number of elements per thread and use a small factor of depth as + // z_tile_size to increase the number of elements calculated by each + // partial sum. This can reduce the needed number of dynamic shfl_down and + // atomic operations. + // // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < depth * height * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { // int x_in_tiles = linear_index % width_in_tiles; // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); + // int z_in_tiles = linear_index / (height * width_in_tiles); // int warp_id = x_in_tiles / warpSize; // int lane_id = x_in_tiles % warpSize; // float partial_result = 0; // int x = warp_id * kTileSize * warpSize + lane_id; - // if (width % (kTileSize * warpSize) == 0 || - // x + (kTileSize - 1) * warpSize < width) { - // // The entire tile is in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (width % (x_tile_size * warpSize) == 0 || + // x + (x_tile_size - 1) * warpSize < width) { + // // The entire x_tile is in bounds. + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // int tx = x; + // for (int element_id_in_x_tile = 0; + // element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, tx += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][tx]); + // } // } // } else { // // The tile is partially in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // int tx = x; + // for (int element_id_in_x_tile = 0; element_id_in_x_tile < + // x_tile_size; ++element_id_in_tile, tx += warpSize) { + // if (tx < width) + // partial_result = Reducer(partial_result, input[z][y][tx]); + // } // } // } // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) @@ -1350,132 +1613,184 @@ Status IrEmitterUnnested::EmitRowReduction( // AtomicReducer(&output[y], partial_result); // } // - // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. - constexpr int64 kTileSize = 8; + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeTilingSchemeForReduction(depth, width, kWarpSize); + // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + llvm::Type* index_ty = GetIndexTypeForKernel( + reduce, + launch_dimensions.block_count() * launch_dimensions.threads_per_block(), + &ir_builder_); - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { - // Emit the loop body that reduces one tile. + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* z = tile_index[0]; + llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - llvm::Value* warp_id = ir_builder_.CreateUDiv( - x_tile, ir_builder_.getInt64(kWarpSize), "warp_id"); - llvm::Value* lane_id = ir_builder_.CreateURem( - x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); - // The x-location of the last element in this tile. - // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); - llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize - 1), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); + x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty); - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm::Value* warp_id = + ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id"); + llvm::Value* lane_id = + ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id"); - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); - llvm::Value* x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - tile_element_loop->GetIndVarValue(), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - - // Points ir_builder_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); - } + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); + llvm::Value* last_x = ir_builder_.CreateNSWAdd( + lane_id, ir_builder_.CreateNSWMul( + index_typed_const(kWarpSize), + ir_builder_.CreateNSWAdd( + index_typed_const(x_tile_size - 1), + ir_builder_.CreateNSWMul( + warp_id, index_typed_const(x_tile_size))))); + + KernelSupportLibrary ksl( + &ir_builder_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = ir_builder_.CreateNSWAdd( + z_indvar, + ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile)); + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/index_typed_const(0), + /*end=*/index_typed_const(x_tile_loop_bound), + /*step=*/1, [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + index_typed_const(kWarpSize), + ir_builder_.CreateNSWAdd( + x_indvar, ir_builder_.CreateNSWMul( + warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, index_typed_const(width)), + "x_in_bounds", &ir_builder_); + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = + ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, + &ir_builder_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &ir_builder_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); + return Status::OK(); + }; - // Emit code that reads the input element and accumulates it to the - // partial reduction result. - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We - // need to convert that to an index to input_shape (the shape of the - // operand of "reduce"). This conversion is composed of a transposition - // from input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &ir_builder_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - } - return EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address); + return ksl.For("z_tile", + /*start=*/index_typed_const(0), + /*end=*/index_typed_const(z_tile_size), + /*step=*/1, emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), - ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ir_builder_.CreateICmpULT(last_x, index_typed_const(width))); + + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. @@ -1484,20 +1799,24 @@ Status IrEmitterUnnested::EmitRowReduction( : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1507,25 +1826,37 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + y, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + // We don't need to emit atomic operations if there is only one tile of + // results. 'depth' is the z dimension, 'width' is the x dimension. + if (z_tile_size >= depth && x_tile_size >= width) { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } else { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } + } + return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {depth, height, width_in_tiles}, - {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1533,7 +1864,7 @@ Status IrEmitterUnnested::EmitRowReduction( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } // Figures out whether `reduce` is a row or column reduction, and which @@ -1544,10 +1875,14 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for // a fused kReduce). @@ -1582,8 +1917,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. if (input_dims_to_keep.empty()) { - return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, - reducer); + return EmitReductionToScalar(reduce, input_shape, input_gens, + init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } else if (input_dims_to_keep.front() == LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width @@ -1600,8 +1936,9 @@ Status IrEmitterUnnested::EmitReductionToVector( height *= input_shape.dimensions(input_dim); } } - return EmitColumnReduction(height, width, reduce, input_shape, input_gen, - init_value_gen, reducer); + return EmitColumnReduction(height, width, reduce, input_shape, input_gens, + init_value_gens, reducers, reduce_output_shapes, + extra_output_gens); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1627,7 +1964,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gen, init_value_gen, reducer); + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } } @@ -1639,9 +1977,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that // initializes the output array to the initial value of the reduce. - if (IsReductionToVector(*reduce) && - // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits - 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + if (IsReductionToVector(*reduce)) { TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, BuildInitializerThunk(reduce)); std::vector> thunks; @@ -1651,16 +1987,15 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), - [&](const llvm_ir::IrArray::Index& index) { + reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*input, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - [&](const llvm_ir::IrArray::Index& index) { + }}, + {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - dimensions_to_reduce, reducer); + }}, + dimensions_to_reduce, {reducer}, {{}}, {}); } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); @@ -1727,6 +2062,14 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "Dilation for SelectAndScatter not implemented on GPU."); } + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + source->shape(), ir_emitter_context_->device_description()); + llvm::Type* index_type = GetIndexTypeForKernel( + select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + // kSelectAndScatter is implemented as two kernel launches: the first launch // initializes the output array to the given initial value, // and the second accumulates the "source" matrix to the @@ -1757,8 +2100,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "selected_value_address", &ir_builder_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), - "selected_index_address", &ir_builder_); + index_type, index_typed_const(rank), "selected_index_address", + &ir_builder_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); ir_builder_.CreateStore(ir_builder_.getInt1(false), @@ -1766,7 +2109,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), - &ir_builder_); + &ir_builder_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -1780,17 +2123,17 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm_ir::IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); + source_index[i], index_typed_const(window.dimensions(i).stride())); operand_index[i] = ir_builder_.CreateNSWSub( ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); + index_typed_const(window.dimensions(i).padding_low())); llvm::Value* index_condition = ir_builder_.CreateICmpULT( operand_index[i], - ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + index_typed_const(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = ir_builder_.CreateAnd(in_bounds_condition, index_condition); } @@ -1862,7 +2205,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index selected_index; + llvm_ir::IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( selected_index_address, {ir_builder_.getInt32(i)}); @@ -1880,8 +2223,6 @@ Status IrEmitterUnnested::HandleSelectAndScatter( source_value_address); }; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - source->shape(), ir_emitter_context_->device_description()); UpdateLaunchDimensions( launch_dimensions, // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk @@ -1892,7 +2233,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, source->shape(), launch_dimensions, &ir_builder_) - .EmitLoop(IrName(select_and_scatter)); + .EmitLoop(IrName(select_and_scatter), index_type); } Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { @@ -1928,6 +2269,61 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { + thunk_sequence_->push_back(BuildKernelThunk(tuple_select)); + return IrEmitter::HandleTupleSelect(tuple_select); +} + +Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on GPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on GPU."); + } + + // CRS with one operand and one replica is simply the identity function. + // Buffer assignment expects a copy, so that's what we do. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + if (crs->operand_count() == 1) { + CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + thunk_sequence_->push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); + return Status::OK(); + } + + // One-replica CRS with multiple operands produces a tuple of the inputs. + // Again, buffer assignment expects us to copy each. + std::vector> thunks; + std::vector tuple_element_buffers; + for (int64 i = 0; i < crs->operand_count(); ++i) { + tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(crs, {i}) + .ValueOrDie()); + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(i)), + /*destination_buffer=*/tuple_element_buffers.back(), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + } + + // Output a tuple of the buffers above. + thunks.push_back(MakeUnique(tuple_element_buffers, + GetAllocationSlice(*crs), crs)); + thunk_sequence_->push_back( + MakeUnique(std::move(thunks), crs)); + return Status::OK(); +} + +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -2051,11 +2447,6 @@ GetHloBufferSlices(const HloInstruction* hlo, return slices; } -Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { - // TODO(b/72710576): Gather is not implemented on GPUs - return Unimplemented("Gather is not implemented on GPUs."); -} - std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst, int unroll_factor) { const BufferAssignment& buffer_assn = @@ -2181,17 +2572,14 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); - std::vector tuple_element_buffers; - for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { - BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() - .GetUniqueSlice(inst, {i}) - .ConsumeValueOrDie(); - tuple_element_buffers.push_back(buffer); - } - - return MakeUnique( - tuple_element_buffers, - /*destination_buffer=*/GetAllocationSlice(*inst), inst); + ShapeTree slices(inst->shape()); + slices.ForEachMutableElement( + [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) { + *slice = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, index) + .ConsumeValueOrDie(); + }); + return MakeUnique(slices, inst); } namespace { @@ -2236,7 +2624,9 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (alpha->opcode() == HloOpcode::kBroadcast) { alpha = alpha->operand(0); } - alpha = inst->operand(alpha->parameter_number()); + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } // TODO(b/74185543): Remove the following if block once we support fusion // with a non-constant as well. Then we will just always use the constant // on the device. @@ -2279,21 +2669,30 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo) { + const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value = [&] { + const HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: return inst->operand(2); case HloOpcode::kReduce: return inst->operand(1); + case HloOpcode::kTuple: + CHECK(hlo->IsMultiOutputFusion()) + << ": " << hlo->ToString() << " is not a multi-output fusion."; + CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce) + << ": Found '" << inst->operand(index.back())->opcode() << "' in " + << inst->ToString() << " but expected 'reduce'."; + // For multi-output fusion look through the tuple. + return inst->operand(index.back())->operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; } }(); + const HloInstruction* init_value = init_value_operand; if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } @@ -2311,24 +2710,25 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by // repeating the literal 4 or 2 times, so long as the destination buffer is // an even multiple of 32 bits long. + const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index); if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { uint16 pattern16; if (num_bytes == 1) { uint8 b = literal_bytes.front(); pattern16 = uint16{b} | (uint16{b} << 8); } else { - pattern16 = literal_bytes.front(); + memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique(pattern32, - GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique( + pattern32, GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2338,20 +2738,31 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique(word, GetAllocationSlice(*hlo), - hlo)}; + return {MakeUnique( + word, GetAllocationSlice(*hlo, index), hlo)}; } } // Otherwise fall back to our slow initializer code. std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); - TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( - *hlo, - [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); - }, - kernel_thunk.get())); + LaunchDimensions launch_dimensions = + CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), + ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + // If the init_value was fused into this reduce we have to generate it first. + if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { + CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); + TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &ir_builder_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &ir_builder_) + .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2535,20 +2946,22 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_, unroll_factor) - .EmitLoop(IrName(&hlo)); + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), + &ir_builder_)); } - CHECK_EQ(unroll_factor, 1) - << "multi-output fusion does not support unrolling"; - // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(&hlo))); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, + &ir_builder_, unroll_factor) + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel( + &hlo, launch_dimensions.launch_bound(), &ir_builder_))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b41ab2162ab81f66e123a7055ca3ffc815c3ef88..e8dce1ca539a24f91d6c9e5f3425e085e2d30a5a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -67,7 +67,6 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; @@ -76,6 +75,9 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -99,6 +101,13 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& inst, tensorflow::gtl::ArraySlice args); + // Helper for writing extra outputs from inside a reduce kernel. + Status EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); + // EmitColumnReduction and EmitRowReduction emit code for column and row // reduction of a matrix and/or 3D tensor. Row and column reduction have // different memory access pattern, so for performance their implementations @@ -109,28 +118,43 @@ class IrEmitterUnnested : public IrEmitter { // `EmitReductionToVector`. Note that input shape might not be // [height x width], but can be bitcast to [height x weight] with "height" // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitColumnReduction( + int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be // [depth x height x width], but can be bitcast to [depth x height x weight] // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitRowReduction( + int64 depth, int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or @@ -140,13 +164,24 @@ class IrEmitterUnnested : public IrEmitter { // generate elements of the input and the initial value. Other parameters mean // the same as for `HandleReduce`. // + // Multiple reduces can be emitted in the same loop, assuming they have the + // same input and output shapes, and the same reduce dimensions. + // + // extra_output_gens can contain extra generators for intermediate outputs. + // These must have the same shape as the reduce input as they are computed + // when the reduce inputs are being read. + // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned @@ -165,7 +200,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo); + const HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index f56c1ce69f11ed79c8be76834269f29de93a9645..e76823ad103dfa5ba61a0d3ba81b2c028dfeb33e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -75,7 +76,8 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { } Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; @@ -100,6 +102,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" << buf.size() << "B)"; } + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 7def27e189b66747569344a3dbe5c0c446f903be..d751de50ad6671b3bf88cd4de49a8feb448e13ba 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -62,7 +63,8 @@ class KernelThunk : public Thunk { // Executes the kernel for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: // Buffers passed to the kernel as arguments. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 917c57682345d099a404721ebcee8028d076dc18..a4e4e85bf3d2c197cfc691b7fca0920aa6571729 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -272,7 +272,7 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(codegen_passes, pstream, + target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); codegen_passes.run(*module); } diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index d4100a898b5bb9eec382c34932c2db104c9e985b..9fd6cf7157ecd659e7eb1d2c5228eca931ff6a01 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -14,21 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace gpu { Status MemzeroThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemZero(&dest_data, dest_data.size()); return Status::OK(); } Status Memset32BitValueThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemset32(&dest_data, value_, dest_data.size()); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index 51c332d287d139335b356fc66411b5ffaa448b5a..d1fec0bd76b8a80f4a1e1c2e818f248997da7a75 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" @@ -36,7 +37,8 @@ class MemzeroThunk : public Thunk { : Thunk(Kind::kMemzero, hlo), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice dest_; @@ -52,7 +54,8 @@ class Memset32BitValueThunk : public Thunk { : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: uint32 value_; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea661b3c2cb2c945297ac2098cd1c4009b2e966d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -0,0 +1,263 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} + +bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) { + auto get_element_instr = + [&](const HloInstruction* instr) -> const HloInstruction* { + const HloInstruction* element_instr = instr; + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduce operand of the fusion root, + // because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (inst->opcode() == HloOpcode::kReduce) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } else { + element_instr = fused_expression_root; + } + } + return element_instr; + }; + + auto get_element_shape = [&](const HloInstruction* element_instr) { + // Special handling of kReduce instructions -- the fusion + // applies to the first operand. + if (element_instr->opcode() == HloOpcode::kReduce) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // The shapes in all tuple operands should agree, unless it is a reduce. + // In that case, the operand of the reduce needs to have the same shape + // as the other tuple operands, but also we need to compare the output + // shapes of the reduces. + // TODO(tjoerg): Allow differences in fp precision. + auto* element_instr_1 = get_element_instr(instr1); + auto* element_instr_2 = get_element_instr(instr2); + if (element_instr_1->opcode() == HloOpcode::kReduce && + element_instr_2->opcode() == HloOpcode::kReduce && + !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { + return false; + } + // The elementwise output shapes must be the same (including layout). + return ShapeUtil::Equal(get_element_shape(element_instr_1), + get_element_shape(element_instr_2)); +} + +namespace { +bool IsInputFusibleReduction(HloInstruction* instr) { + if (instr->IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr->fused_expression_root()->operands()) { + if (operand->opcode() == HloOpcode::kReduce) { + CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) + << " Reduce multi-output fusion " << instr->ToString() + << " must be an input fusion."; + return true; + } + } + return false; + } else if (instr->opcode() == HloOpcode::kFusion) { + // The loop emitter can handle to-vector reduce fusions. Such reduce + // fusions have the fusion kind kLoop rather than kInput. We do not fuse + // to-vector reduce fusions, because the resulting fusions may no longer be + // supported by loop emitter. + return IsReductionToVector(*instr->fused_expression_root()); + } else { + return IsReductionToVector(*instr); + } +} +} // namespace + +bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { + // We can fuse reduces and loop fusions. + return IsInputFusibleReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop); +} + +int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, + HloInstruction* instr2) { + tensorflow::gtl::FlatSet in_list; + for (auto instr : instr1->operands()) { + if (!IsProfitableOperand(instr)) { + continue; + } + in_list.insert(instr); + } + int64 profit = 0; + for (auto instr : instr2->operands()) { + if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + continue; + } + profit += ShapeUtil::ByteSizeOf(instr->shape()); + } + VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() + << ", the profit is =" << profit; + return profit; +} + +bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { + return false; + } + // If we're fusing fusions only do it if the fusion kind matches. Loop fusions + // merge into bigger loop fusions and input (reduce) fusions become fusions + // with multiple reduce outputs. We could fuse reduce and loop fusions + // together too (the result being an input fusion) if we find cases where this + // improves things. + CHECK(instr1->opcode() == HloOpcode::kFusion); + if (instr2->opcode() == HloOpcode::kFusion) { + return instr1->fusion_kind() == instr2->fusion_kind(); + } + return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; +} + +bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { + bool changed = false; + RecomputeReachability(); + + tensorflow::gtl::FlatSet to_fuse; + // Keep a list of the instructions to fuse after making all the fusion + // decisions. We first aggressively add instructions to potential_fusion_list, + // then filter out instructions that will be no longer fusable because of + // reachability change. This avoids recalculating reachability on a large set + // of instructions. + std::vector> + potential_fusion_list; + std::vector> fusion_list; + std::vector instrs_to_update_reachability; + + // For each reduce or reduce multi-output fusion, try to fuse it with loop + // fusions operands. + for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) { + if (consumer->user_count() == 0) { + continue; + } + if (!IsInputFusibleReduction(consumer)) { + continue; + } + + auto consumer_operands = consumer->operands(); + for (size_t i = 0; i < consumer_operands.size(); ++i) { + HloInstruction* producer = consumer_operands[i]; + if (!producer->IsFusable()) { + continue; + } + const bool is_loop_fusion = + producer->opcode() == HloOpcode::kFusion && + producer->fusion_kind() == HloInstruction::FusionKind::kLoop; + if (!is_loop_fusion) { + continue; + } + if (!ShapesCompatibleForFusion(producer, consumer)) { + continue; + } + // If we have already decided to fuse this producer, skip it. + if (ContainsKey(to_fuse, producer)) { + continue; + } + // Do not fuse a producer if the other operands of the fusion are + // reachable from the producer, this would create a cycle. + if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + break; + } + to_fuse.insert(producer); + potential_fusion_list.emplace_back(producer, consumer); + instrs_to_update_reachability.push_back(producer); + instrs_to_update_reachability.push_back(consumer); + break; + } + } + + // Filter out pairs that will be no longer fusable because of reachability + // change. + for (auto& fusion_pair : potential_fusion_list) { + HloInstruction* producer = fusion_pair.first; + HloInstruction* consumer = fusion_pair.second; + if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + UpdateReachability(producer, consumer, instrs_to_update_reachability); + fusion_list.push_back(fusion_pair); + } + } + + for (auto fusions_to_create : fusion_list) { + HloInstruction* producer = fusions_to_create.first; + HloInstruction* consumer = fusions_to_create.second; + if (consumer->opcode() != HloOpcode::kFusion) { + // Fusing with a reduce (fusion) always results in an input fusion. + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), HloInstruction::FusionKind::kInput, consumer)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer->name() << " into " << input_fusion->name(); + TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { + input_fusion->MergeFusionInstructionIntoMultiOutput(producer); + } else { + input_fusion->FuseInstructionIntoMultiOutput(producer); + } + } else { + VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " + << consumer->name(); + + if (producer->opcode() == HloOpcode::kFusion) { + consumer->MergeFusionInstructionIntoMultiOutput(producer); + } else { + consumer->FuseInstructionIntoMultiOutput(producer); + } + } + changed = true; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..67ca5d49eee8508e93284b134f8410eb3a89f9ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +namespace xla { +namespace gpu { + +// Multi-output fusion of sibling and producer-consumer instructions for the +// Jellyfish backend. +class GpuMultiOutputFusion : public MultiOutputFusion { + public: + GpuMultiOutputFusion(); + + protected: + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) override; + + // We currently only consider reduce and reduce fusion nodes as candidates. + bool IsFusible(HloInstruction* instr) override; + + // This function estimates the amount of memory reads saved by merging + // instr1 and instr2 into one multi-output fusion instruction. For a fusion + // instruction, all the operands need to be loaded from memory. If we merge + // instr1 and instr2, common operands will not be loaded twice. The profit is + // estimated as the size of the common operands b/w instr1 and instr2. + int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override; + + // Fuse loop fusions into reduce fusions. + bool DoProducerConsumerMultiOutputFusion() override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..979ea79243818c398b1b130254a41c95ced51830 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -0,0 +1,353 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + + scalar_add_computation { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0) + } + scalar_mul_computation { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + })"; + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { + // Fusion with reduce instruction root and a sibling reduce instruction + // sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] constant(1) + fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation + reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[6400]{0} parameter(1) + mul = f32[6400]{0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[6400]{0} parameter(1) + r1 = f32[64,100]{0,1} reshape(p1.2) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[6400]{0} parameter(1) + fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10]{0} parameter(0) + ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1.3 = f32[10,10]{1,0} parameter(1) + fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1 + p2 = f32[] parameter(2) + fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { + // Two sibling fusions with reduce instruction roots sharing the same input + // param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { + // Multi-output fusion with two reduce instructions root and a sibling reduce + // instruction sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { + const.1 = f32[] constant(1) + p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) + mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1) + reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2) + } + + ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + const = f32[] constant(1) + fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation + get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0 + get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1 + reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { + // Verify that if we already have a multi-output fusion that we prefer to pick + // a reduce op from its operands for checking shape compatibility. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation + ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1) + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10] parameter(0) + ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[10,10]{1,0} parameter(1) + p2 = f32[10]{0} parameter(2) + fusion.1 = (f32[10,10], f32[10]) fusion(p0, p1), kind=kInput, calls=fused_computation_1 + get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=0 + get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=1 + fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add + reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Add())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 7bda4e2fcd469bd430e5ef1846251c8504225383..c8f0d4185c63c5bafca6f30acab31cbe8e987277 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -370,26 +370,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - for (HloInstruction* instruction : - module->entry_computation()->MakeInstructionPostOrder()) { - if (IsCustomCallToDnnConvolution(*instruction)) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + for (HloInstruction* instruction : convs) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } return changed; } +StatusOr PadInsertion::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 5e1c68701daa02eba64f3e34933ce373a496c1b8..67e51509e4c717951c83c7e41943af1de762dee0 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -31,6 +31,7 @@ class PadInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index d8c07dc3119fb81a3ef22822acb11b7c4d5bbca5..cd833ec7bd858aabee84ac306d198e80eb112506 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -58,7 +58,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -71,14 +71,13 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // // %nctaid.x is currently specified as 2147483647. VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; + CHECK_NE(index_type, nullptr); std::vector array_indices; - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast(block_id)); - block_id = - ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id"); + block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" @@ -88,13 +87,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast(thread_id)); - thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), - "thread_id"); + thread_id = + ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = ir_builder_->CreateAdd( ir_builder_->CreateMul( block_id, - ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block()), + "", /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); @@ -110,21 +111,23 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Intrinsic::assume, {ir_builder_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block() * + launch_dimensions_.block_count()), "linear_index_in_range")}, {}, ir_builder_); if (unroll_factor_ > 1) { linear_index_base = ir_builder_->CreateMul( - linear_index_base, ir_builder_->getInt64(unroll_factor_), + linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_), "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } array_indices.emplace_back(linear_index_base, shape_, ir_builder_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = ir_builder_->CreateAdd( - linear_index_base, ir_builder_->getInt64(i), "linear_index", + linear_index_base, llvm::ConstantInt::get(index_type, i), + "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); array_indices.emplace_back(linear_index, shape_, ir_builder_); } @@ -132,7 +135,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), + llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))), llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); // Set exit_bb_ to the exit block of the if structure. diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 25318b3bed8bf4a2dfe3a4a974269d0405c3bfec..302e1bf1bc8e90f2eebd838f156a1552e86185ac 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index c125474edb1036090a926020f2b1e7fcf64c751a..02471129e004b4876ce20a62cade34060c65b478 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -47,6 +47,7 @@ class LaunchDimensions { int64 block_count() const { return block_count_; } int64 threads_per_block() const { return threads_per_block_; } + int64 launch_bound() const { return block_count() * threads_per_block(); } private: int64 block_count_; diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index b50f5b5a903e6ae3d28bccb7234a14babfa68a98..dfdba7d7d9a60458e1b1c90cf9f5017b44b7b801 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -15,12 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -SequentialThunk::SequentialThunk(std::vector>&& thunks, +SequentialThunk::SequentialThunk(std::vector> thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} @@ -33,9 +34,17 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, } Status SequentialThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + // TODO(b/71544591): We need to potentially measure the total time of the + // sequential thunk. This happens for a reduce op which consists of + // SequentialThunk with a thunk that initializes the output, and another thunk + // that does the actual reduce. Right now, in this case we would only measure + // the time of the last thunk, because both thunks would have the same + // HloInstruction. for (const auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR( + thunk->ExecuteOnStream(buffer_allocations, stream, profiler)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 3537110bb5c252054db4ce29171bd1a432e8cead..3c4de1d1a6c912ba31f56c29b10ca004d1e56da6 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -31,7 +32,7 @@ namespace gpu { // require multiple kernel launches or library calls. class SequentialThunk : public Thunk { public: - SequentialThunk(std::vector>&& thunks, + SequentialThunk(std::vector> thunks, const HloInstruction* hlo); SequentialThunk(const SequentialThunk&) = delete; SequentialThunk& operator=(const SequentialThunk&) = delete; @@ -41,7 +42,8 @@ class SequentialThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 696fa7e0194032b5c78bf11383c3280a62de07fa..6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -33,8 +33,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } // Pre-canned shapes. diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { +namespace gpu { + +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::DataLayoutString; +using stream_executor::dnn::FilterLayout; +using stream_executor::dnn::FilterLayoutString; + +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + DataLayout input, FilterLayout filter, + DataLayout output) { + std::vector input_layout; + switch (input) { + case DataLayout::kBatchDepthYX: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.push_back(dnums.input_feature_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + input_layout.push_back(dnums.input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid input layout: ", + DataLayoutString(input)); + } + + std::vector filter_layout; + switch (filter) { + case FilterLayout::kOutputInputYX: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + break; + case FilterLayout::kOutputYXInput: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid filter layout: ", + FilterLayoutString(filter)); + } + + std::vector output_layout; + switch (output) { + case DataLayout::kBatchDepthYX: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.push_back(dnums.output_feature_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + output_layout.push_back(dnums.output_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid output layout: ", + DataLayoutString(output)); + } + + return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(output_layout)); +} + +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output) { + Layout nchw_input, nchw_filter, nchw_output; + std::tie(nchw_input, nchw_filter, nchw_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX) + .ConsumeValueOrDie(); + + Layout nhwc_input, nhwc_filter, nhwc_output; + std::tie(nhwc_input, nhwc_filter, nhwc_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth) + .ConsumeValueOrDie(); + + DataLayout input_layout; + if (LayoutUtil::Equal(input, nchw_input)) { + input_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(input, nhwc_input)) { + input_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid input layout: ", + input.ShortDebugString()); + } + + FilterLayout filter_layout; + if (LayoutUtil::Equal(filter, nchw_filter)) { + filter_layout = FilterLayout::kOutputInputYX; + } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + filter_layout = FilterLayout::kOutputYXInput; + } else { + return tensorflow::errors::Internal("Invalid filter layout: ", + filter.ShortDebugString()); + } + + DataLayout output_layout; + if (LayoutUtil::Equal(output, nchw_output)) { + output_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(output, nhwc_output)) { + output_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid output layout: ", + output.ShortDebugString()); + } + + return std::make_tuple(input_layout, filter_layout, output_layout); +} +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h new file mode 100644 index 0000000000000000000000000000000000000000..39a6a38d001f502b2abb8de6efe2ce623b478c71 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +// Helper functions for interacting with StreamExecutor. + +namespace xla { +namespace gpu { + +// Returns (input, filter, output) XLA Layout protos given the StreamExecutor +// layouts. +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + stream_executor::dnn::DataLayout input, + stream_executor::dnn::FilterLayout filter, + stream_executor::dnn::DataLayout output); + +// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 931c0bffab850362dbd2df975657dd47d9cbd3ae..14d41033c2c7681e3262c0674be13b1f3aa83aef 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -94,11 +95,12 @@ class Thunk { // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's - // lifetime. Stream argument must be non-null. + // lifetime. 'stream' and 'profiler' must be non-null. // // Precondition: Initialize(stream->parent()) has been called. virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) = 0; + se::Stream* stream, + HloExecutionProfiler* profiler) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 97cb04c38fbf18e516857f5269c984696ca204c3..a10e40451c1db01ce73db7b56a3a0599769fa49b 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,13 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -31,6 +33,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, buffer_allocations.GetDeviceAddress(dest_buffer_)); auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream ->ThenMemcpy(&dest_buffer_address, tuple_element_buffer_addresses.data(), host_size) diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 951f809b51937c97a6e7de0345ec58a8b66a4242..2d5735d6c40ccd26f0e527f1a02403910db4c812 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -46,7 +47,8 @@ class TupleThunk : public Thunk { TupleThunk& operator=(const TupleThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 30b9640c4c75dae61e9a90da5fb10e9d4a90cd26..5e13f989c2ffb0396efc94a01783ee91725dbd44 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -43,14 +44,18 @@ Status WhileThunk::Initialize(const GpuExecutable& executable, } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); while (true) { // Invoke thunk sequence for while 'condition' computation. - TF_RETURN_IF_ERROR( - condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + profiler->StartHloComputation(); + TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_condition()); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; @@ -66,9 +71,14 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, break; } - // Invoke thunk sequence for while 'body' computation. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + // We measure the time of one execution of the while body computation. The + // while body may be executed more than once, the last measurement "wins". + profiler->StartHloComputation(); + // Invoke thunk sequence for while 'body' computation, and pass on + // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 22176685a92df9c95b10f755b209309843c0fa3a..9270f95ee67cf0bd3ab8082452a9d8703cb4304e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,8 @@ class WhileThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ad55728c45599c801aad7e12fac95ae9f0c4fc3b..7749201cbceece216a2db2569936949eb7de5125 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase { return InvalidArgument("Unexpected tuple index instruction : %s", inst->name().c_str()); } else if (tag == "loop_increment") { - // Parse the constant which represents the loop induction variable - // increment value. + // ParseHloString the constant which represents the loop induction + // variable increment value. TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); } else if (tag == "param0" && inst != computation_->parameter_instruction(0)) { diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 06a5e0351b63270b61b998ca2211f480f256f759..4005fc0d114a3ec7a38dfb5edecdaeb1e8497ade 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,6 +26,46 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function, + HeapSimulator::Options(), memory_by_computation)); + return result.heap_size; +} + /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, @@ -46,9 +86,11 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options, + const tensorflow::gtl::FlatMap* + memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -188,6 +230,9 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. + + // Make sure each buffer get reused at most once. + FlatSet reused_buffers; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -200,6 +245,9 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { + if (reused_buffers.count(operand_buffer) != 0) { + continue; + } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && points_to_analysis.CanShareOperandBufferWithUser( @@ -209,6 +257,7 @@ Status HeapSimulator::RunComputation( << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; + reused_buffers.insert(operand_buffer); break; } } @@ -219,6 +268,12 @@ Status HeapSimulator::RunComputation( Alloc(buffer, instruction); } } + // Account for the memory used by subcomputations when estimating the + // current heap size. + if (memory_by_computation_ != nullptr) { + algorithm_->AccountForSubcomputationMemory(instruction, + *memory_by_computation_); + } // If the whole module is sequential, we can save memory by running the // heap-simulation for sub-computations inline. E.g. the buffers for the @@ -286,12 +341,15 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence) + const SequentialHloOrdering::HloModuleSequence* module_sequence, + const tensorflow::gtl::FlatMap* + memory_by_computation) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence) { + module_sequence_(module_sequence), + memory_by_computation_(memory_by_computation) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -460,6 +518,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : instruction->called_computations()) { + auto it = memory_by_computation.find(c); + if (it != memory_by_computation.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + max_heap_size_ = + std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); +} + void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 8b2b43a37a5c41d334e5338c6a6fad160f03a51e..811a6042df9434ac3f4bed71b9c093433e25c1bb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -85,6 +85,23 @@ class HeapSimulator { const BufferValueFlatSet* buffers_to_assign; }; + // Returns the minimum memory required to compute an HLO module where all + // computations have been scheduled (represented by the given + // module_sequence), assuming no fragmentation. + static StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + + // Returns the minimum memory required to compute the given computation, + // assuming no fragmentation. + static StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -111,7 +128,9 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + const Options& options = Options(), + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -120,7 +139,9 @@ class HeapSimulator { HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence); + const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); Status RunComputation( @@ -144,7 +165,13 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; + // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // set by hlo scheduling. Then, in RunComputation, we check both in order to + // handle subcomputations. It would be good to unify the handling of + // subcomputations, but it's not clear how. const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const tensorflow::gtl::FlatMap* + memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of // buffer sharing. When ShareBuffer is called, instead of allocating new @@ -189,6 +216,11 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + virtual void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) {} + // Free de-allocates a previously allocated buffer. virtual void Free(const BufferValue* buffer, int64 size) = 0; @@ -207,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { ~NoFragmentationStatsHeap() override = default; void Alloc(const BufferValue* buffer, int64 size) override; + + void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 6271652412c2979ff926702f12722102344b0dfb..3849b565e3136924b2d2b1929353885f85b1a043 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -34,6 +34,65 @@ limitations under the License. namespace xla { namespace { +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + .ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; @@ -139,6 +198,11 @@ class HeapSimulatorTracker { .ConsumeValueOrDie(); } + int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) { + const BufferValue* buffer = BufferAt(instruction, index); + return result_.chunk_map.at(buffer).offset; + } + // Ensures the expected sequence of Alloc/Free/Finish calls was performed. void ExpectCallSequence(const CallSequence& expected) const { EXPECT_EQ(expected, actual_calls_); @@ -150,10 +214,9 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const BufferValue* a = BufferAt(instruction_a, index_a); - const BufferValue* b = BufferAt(instruction_b, index_b); - EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) - << *a << ", " << *b; + int64 offset_a = OffsetAt(instruction_a, index_a); + int64 offset_b = OffsetAt(instruction_b, index_b); + EXPECT_EQ(offset_a, offset_b); } private: @@ -252,6 +315,43 @@ TEST_F(HeapSimulatorTest, MultiplyAdd) { tracker.ExpectSharedBuffers(add, {}, mul, {}); } +TEST_F(HeapSimulatorTest, BufferReusedOnce) { + HeapSimulatorTracker tracker(TestName()); + auto builder = HloComputation::Builder(TestName()); + + HloComputation::Builder fusion_builder("fusion"); + { + HloComputation::Builder& builder = fusion_builder; + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32vec4_, "A")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); + + builder.AddInstruction(HloInstruction::CreateTuple({exp, neg})); + } + auto fusion_computation = + tracker.module()->AddEmbeddedComputation(fusion_builder.Build()); + auto a_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}), + HloInstruction::FusionKind::kLoop, {neg}, fusion_computation)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule({a_param, neg, fusion}); + + auto neg_buffer = tracker.OffsetAt(neg, {}); + int64 output_buffer_0 = tracker.OffsetAt(fusion, {0}); + int64 output_buffer_1 = tracker.OffsetAt(fusion, {1}); + // Only one buffer should be shared. + EXPECT_TRUE((neg_buffer == output_buffer_0) ^ + (neg_buffer == output_buffer_1)); +} + TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1f7c1cffd324ad2f4e4cdf11046c8459b8ceb6d5..d2417910606fdd13223076d33ff1bda1dd291d98 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -150,6 +150,11 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; + + // Cross Replica Sum fields. + repeated int64 replica_group_ids = 44; + int64 all_reduce_id = 45; + string cross_replica_sum_barrier = 46; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index a88283ed9a6459b4fa9310e160b59c77d51f1027..e8a4b034b4396860bd5873f43003844ce92dea6c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -452,15 +452,16 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( - HloModule* module) { + HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer) { VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); - TF_ASSIGN_OR_RETURN( - alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, - /*bitcast_defines_value=*/false)); + TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, + /*bitcast_defines_value=*/false, + fusion_can_share_buffer)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); @@ -493,6 +494,16 @@ StatusOr> HloAliasAnalysis::Run( bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { + CHECK(!buffer.values().empty()); + if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + // Tokens have no on-device representation and cannot interfere. + for (const HloValue* value : buffer.values()) { + // If one of the values is a token, all values must be a token. + DCHECK(ShapeUtil::IsToken(value->shape())); + } + continue; + } + // Check that the values in the buffer are totally ordered with respect to // 'ordering'. Begin by sorting the values with respect to 'ordering' with a // tie-break using value ID. The tie-break is necessary because we need a @@ -517,7 +528,6 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // a buffer and A interferes with C, then necessarily A also interferes // with B. So to check interference you only need to check interference // between A and B, and between B and C. - CHECK(!values.empty()); for (int i = 1; i < values.size(); ++i) { if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) { VLOG(1) << values[i - 1]->ToShortString() << " and " diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 67dfd4301b3a027a496911ecf6f06841dfd6423a..afb0c20f0cdf3eb92f72ab8bc368b4b8d723459e 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -39,7 +39,10 @@ class HloAliasAnalysis { public: // The callgraph of the given HloModule must be flattened // (xla::FlattenCallGraph) prior to running the analysis. - static StatusOr> Run(HloModule* module); + static StatusOr> Run( + HloModule* module, + const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr); string ToString() const; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 8f18d50f6e033fab1c01f42017b951c224c22799..a59bf1750c06c091187b211c8530be126cf5e524 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -654,8 +654,7 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { } TEST_F(HloAliasAnalysisTest, TupleSelect) { - // Test a kSelect of a tuple value. Non-top-level element flow through the - // instruction. + // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); @@ -677,13 +676,13 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({constant4})); const Shape tuple_shape = tuple1->shape(); auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1)); auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4)); auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); @@ -718,7 +717,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { } TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { - // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO: // // body((F32[], F32[]) %tuple_param): // %negate = Negate(%tuple_param{0}) @@ -769,7 +768,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({constant2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, select)); diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7f73bba036534a62a70a80431236cffa766c9b38 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +class HloInstruction; + +template +using EnableIfDerivedFromHlo = + typename std::enable_if::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template * = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template * = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast( + Cast(const_cast(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template * = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template * = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast( + CastOrNull(const_cast(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template * = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast(instruction); +} + +// Non-const overload of DynCast. +template * = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast( + DynCast(const_cast(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template * = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template * = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast( + DynCastOrNull(const_cast(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3364275409122254bf99b40a7d2fcbb2d7564cc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DummyInstruction : public HloInstruction { + public: + DummyInstruction() + : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {} +}; + +class AnotherDummyInstruction : public HloInstruction { + public: + AnotherDummyInstruction() + : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {} +}; + +TEST(HloCastingUtilsTest, CastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(Cast(null), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = CastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(DynCast(null), ""); +} + +TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = DynCastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h new file mode 100644 index 0000000000000000000000000000000000000000..658643b427a9625fac1166151a89cbd669f817d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +class HloInstruction; +class HloComputation; +class HloModule; + +// Data structure used to track the cloning of HloInstruction and HloComputation +// objects. +class HloCloneContext { + public: + // Creates a new HloCloneContext object to clone HloInstruction and + // HloComputation objects to be added to the module specified as argument. + // The suffix string will be appended to computation names. + explicit HloCloneContext(HloModule* module, const string& suffix = "") + : module_(module), suffix_(suffix) {} + + HloModule* module() const { return module_; } + + const string& suffix() const { return suffix_; } + + void MapInstruction(const HloInstruction* old_instruction, + HloInstruction* new_instruction) { + instructions_[old_instruction] = new_instruction; + } + + void MapComputation(const HloComputation* old_computation, + HloComputation* new_computation) { + computations_[old_computation] = new_computation; + } + + // Finds the new instruction mapped to its old copy, or return nullptr in case + // it is not found. + HloInstruction* FindInstruction(const HloInstruction* old_instruction) const { + return FindOrDefault(instructions_, old_instruction, nullptr); + } + + // Finds the new computation mapped to its old copy, or return nullptr in case + // it is not found. + HloComputation* FindComputation(const HloComputation* old_computation) const { + return FindOrDefault(computations_, old_computation, nullptr); + } + + // Retrieves the new instruction mapped to its old copy, or fail if not found. + HloInstruction* GetInstruction(const HloInstruction* old_instruction) const { + return FindOrDie(instructions_, old_instruction); + } + + // Retrieves the new computation mapped to its old copy, or fail if not found. + HloComputation* GetComputation(const HloComputation* old_computation) const { + return FindOrDie(computations_, old_computation); + } + + const tensorflow::gtl::FlatMap& + cloned_instructions() const { + return instructions_; + } + + const tensorflow::gtl::FlatMap& + cloned_computations() const { + return computations_; + } + + private: + HloModule* module_; + string suffix_; + tensorflow::gtl::FlatMap + instructions_; + tensorflow::gtl::FlatMap + computations_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 63c3dc4a5932f754a9ccdd70d03c999fe528a448..e36bef60a3c395af82cd93ef012de7eaf700ed4f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -64,7 +64,7 @@ HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { @@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +namespace { + +// Returns the new name for a fusion parameter when we change its number. +// +// Fusion parameters are named foo.param_1, bar.param_2, etc. We are +// renumbering the parameters, so replace the final number in the name with +// the updated value. +string RenameFusionParameter(const string& original_name, int64 new_param_no) { + const string param_underscore = ".param_"; + size_t index = original_name.rfind(param_underscore); + if (index == string::npos) { + return original_name; + } + string after_param = original_name.substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + return StrCat(original_name.substr(0, index + param_underscore.size()), + new_param_no); + } + return original_name; +} + +} // namespace + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->name(); - // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters, so replace the final number in the name with - // the updated value. - const string param_underscore = ".param_"; - size_t index = param_name.rfind(param_underscore); - if (index == string::npos) { - string after_param = name().substr(index + param_underscore.size()); - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { - param_name = - StrCat(param_name.substr(0, index), param_underscore, param_no); - } - } - + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); @@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } +Status HloComputation::RemoveUnusedParameters() { + CHECK(IsFusionComputation()); + int64 removed = 0; + for (int64 i = 0; i < param_instructions_.size(); ++i) { + HloInstruction* param_instruction = param_instructions_[i]; + if (param_instruction->user_count() == 0 && + param_instruction != root_instruction()) { + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + ++removed; + continue; + } + + if (removed > 0) { + const int64 param_no = i - removed; + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + } + } + param_instructions_.resize(param_instructions_.size() - removed); + return Status::OK(); +} + bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -234,7 +273,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); auto inst_it = instruction_iterators_.at(instruction); (*inst_it)->set_parent(nullptr); - instruction->DetachFromOperands(); instructions_.erase(inst_it); return Status::OK(); } @@ -246,9 +284,8 @@ void HloComputation::set_root_instruction( if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) - << new_root_instruction->shape().ShortDebugString() - << " is incompatible with " - << root_instruction_->shape().ShortDebugString(); + << new_root_instruction->shape() << " is incompatible with " + << root_instruction_->shape(); } bool root_found = false; for (auto& instruction : instructions_) { @@ -264,46 +301,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list post_order_; - tensorflow::gtl::FlatSet* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, - std::list* post_order) { + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -315,12 +317,53 @@ void ComputeComputationPostOrder( } } +enum State { kVisiting, kVisited }; + +void ComputeInstructionPostOrder( + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) { + std::vector dfs_stack; + dfs_stack.push_back(root); + while (!dfs_stack.empty()) { + const auto current = dfs_stack.back(); + auto it = visited->find(current); + if (it != visited->end()) { + if (it->second == kVisited) { + // Already visited. + dfs_stack.pop_back(); + continue; + } + // Visit this node. + CHECK_EQ(kVisiting, it->second); + dfs_stack.pop_back(); + post_order->push_back(current); + it->second = kVisited; + continue; + } + + visited->insert({current, kVisiting}); + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i]); + } + + for (HloInstruction* op : current->control_predecessors()) { + dfs_stack.emplace_back(op); + } + } +} + } // namespace -std::list HloComputation::MakeInstructionPostOrder() const { - std::list post_order; - std::list trace_instructions; - tensorflow::gtl::FlatSet added_instructions; +std::vector HloComputation::MakeInstructionPostOrder() const { + std::vector post_order; + post_order.reserve(instruction_count()); + std::vector trace_instructions; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -328,21 +371,20 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice(post_order.end(), - InstructionPostOrderer::GetOrder(instruction.get(), - &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list HloComputation::MakeEmbeddedComputationsList() +std::vector HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet visited; - std::list post_order; + std::vector post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -488,21 +530,7 @@ HloInstruction* HloComputation::CreateFusionInstruction( StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { - if (ShapeUtil::IsArray(instruction->shape())) { - if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { - // Use kCopy to copy array elements - HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - if (copies_added != nullptr) { - *copies_added->mutable_element(*index) = copy; - } - return copy; - } else { - // Array elements which are not to be copied are passed through - // transparently. - return instruction; - } - } else if (ShapeUtil::IsTuple(instruction->shape())) { + if (ShapeUtil::IsTuple(instruction->shape())) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -519,9 +547,27 @@ StatusOr HloComputation::DeepCopyHelper( index->pop_back(); } return AddInstruction(HloInstruction::CreateTuple(elements)); + } + if (ShapeUtil::IsToken(instruction->shape())) { + // Tokens have no on-device representation and cannot be copied. Pass + // through transparently. + return instruction; + } + + // Array shape. + TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { + // Use kCopy to copy array elements + HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + if (copies_added != nullptr) { + *copies_added->mutable_element(*index) = copy; + } + return copy; } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); + // Elements which are not to be copied are passed through + // transparently. + return instruction; } } @@ -609,7 +655,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { - const std::list all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique(all); std::vector inputs; @@ -752,22 +798,21 @@ Status HloComputation::Accept( } std::unique_ptr HloComputation::Clone( - const string& suffix, HloModule* module, - HloInstruction::CloneMap* clone_map) { + const string& suffix, HloCloneContext* context) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, clone_map, suffix); + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, HloInstruction::CloneMap* clone_map, - const string& suffix) { - HloInstruction::CloneMap local_clone_map; - if (clone_map == nullptr) { - clone_map = &local_clone_map; + HloCloneContext* context, const string& suffix) { + std::unique_ptr context_ptr; + if (context == nullptr) { + context_ptr = MakeUnique(parent(), suffix); + context = context_ptr.get(); } // Look up instr in the replacements map, and return either the replacement, @@ -792,18 +837,18 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; - std::unique_ptr new_instr = nullptr; + std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); CHECK_NE(replaced_operand, nullptr) - << "Replacements map specifies to leave out " << operand->ToString() - << ", but it is used by " << instr->ToString() << "."; - new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); + << "replacements map tried to eliminate a used instruction " + << operand->ToString() << ", used by " << instr->ToString(); + new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, - module, clone_map); + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, context); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -811,32 +856,23 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); + /*root_instruction=*/context->GetInstruction( + replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(*clone_map, instr); + HloInstruction* new_instr = context->GetInstruction(instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - CHECK_NE(replaced_successor, nullptr) - << "Replacements map specifies to leave out " << successor->ToString() - << ", but it is control-depended-on by " << instr->ToString() << "."; - - TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(*clone_map, replaced_successor))); - } - } - - // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before - // they're destroyed, otherwise they stick around in the operands' users lists - // and cause use-after-frees. - for (auto& kv : replacements) { - if (std::unique_ptr& new_instr = kv.second) { - new_instr->DetachFromOperands(); + // successor may not have been remapped, because it might have been + // removed by the replacements map. + if (replaced_successor != nullptr) { + TF_CHECK_OK(new_instr->AddControlDependencyTo( + context->GetInstruction(replaced_successor))); + } } } - + context->MapComputation(this, result.get()); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 8bc97df0365a32bdc89d4636ad4c7076ffb08296..c1c3e79ebc789eff0873515c5fffd11089b92043 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -112,6 +113,11 @@ class HloComputation { // instruction. Status RemoveParameter(int64 param_no); + // Remove unused parameters from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveUnusedParameters(); + // Add new parameter instruction to the computation. // This should be a new parameter. Instruction will be appended to parameters // and inserted to the instruction list. @@ -198,7 +204,7 @@ class HloComputation { // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. - std::list MakeInstructionPostOrder() const; + std::vector MakeInstructionPostOrder() const; // Computes and returns the reachability between HLO instructions in the // computation. The returned HloReachabilityMap is constructed such that @@ -220,7 +226,7 @@ class HloComputation { // transitively. The embedded computations are sorted such that if computation // A calls computation B (eg, via a map instruction) then A will appear after // B in the list. - std::list MakeEmbeddedComputationsList() const; + std::vector MakeEmbeddedComputationsList() const; // Creates a fusion instruction containing the given instructions. // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion @@ -300,17 +306,11 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // - // If the module pointer is not nullptr, then the cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the computation is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone( - const string& suffix = "clone", HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr); + // If the clone context is specified, it will be populated with the cloned + // object mappings, and its module() will be used to add new computations + // into. + std::unique_ptr Clone(const string& suffix = "clone", + HloCloneContext* context = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -320,9 +320,7 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 25469a54c48f4f5cab478aba929f1cc18de8b81f..a8f3f0e9c2dca8fb97ebc8f8c9dd80fcf7f4de4a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -371,6 +371,38 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { } } +TEST_F(HloComputationTest, DeepCopyToken) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); + + // No copy should be added. + EXPECT_THAT(copy, op::AfterAll()); +} + +TEST_F(HloComputationTest, DeepCopyTokenTuple) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); + + // Only the array (second tuple element) should be copied. The token is passed + // through transparently. + EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), + op::Copy(op::GetTupleElement(tuple)))); +} + TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); @@ -385,6 +417,9 @@ TEST_F(HloComputationTest, CycleDetection) { // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); + auto instructions = computation->MakeInstructionPostOrder(); + EXPECT_EQ(3, instructions.size()); + const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); ASSERT_FALSE(visit_status.ok()); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 35ecd4428d0dfde2de445ea34472d2c78148c6c9..436d103f230e078e62201bff377a5bab0e62f92b 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -51,14 +51,18 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce operation. + // Skip Constant, Parameter, Reduce, and AfterAll operation. // TODO(b/35975797): Enable Reduce operation once arbitrary computation // are supported by the evaluator. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one + // operand in which case constant folding will be impossible and this + // special case is not necessary. if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce) { + instruction->opcode() == HloOpcode::kReduce || + instruction->opcode() == HloOpcode::kAfterAll) { continue; } // Skip instructions with non-constant operands. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94c9c7eabcc99d4cf61f535925c068a9b55ed136..c49cf7f5db5ee9100718fbcd87dc5bdcc175ae5f 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -164,7 +164,11 @@ Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSelect(const HloInstruction*) { +Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) { + return HandleElementwiseOp(hlo); +} + +Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) { return Status::OK(); } @@ -172,15 +176,22 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicSlice( + const HloInstruction* dynamic_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicUpdateSlice( + const HloInstruction* dynamic_update_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -386,6 +397,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f2a23fd98d18b77d5fb25853901a789..0181138a6dc554438957e8545c66a98d32dd68d5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -54,7 +54,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConstant(const HloInstruction* constant) override; Status HandleGetTupleElement( const HloInstruction* get_tuple_element) override; - Status HandleSelect(const HloInstruction* select) override; + Status HandleSelect(const HloInstruction* hlo) override; + Status HandleTupleSelect(const HloInstruction* hlo) override; Status HandleCompare(const HloInstruction* compare) override; Status HandleClamp(const HloInstruction* clamp) override; Status HandleReducePrecision(const HloInstruction* hlo) override; @@ -97,6 +98,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 16fdda8a8b9ade09ea31cda1f4cf5e8ff2c0a081..9fc4c48226fa5307f5e030a612f3957756827e37 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -59,9 +59,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a unary user function: x => exp(x + 0.5) { XlaBuilder builder("add_and_exp"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto half = builder.ConstantR0(0.5); - builder.Exp(builder.Add(x, half)); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto half = ConstantR0(&builder, 0.5); + Exp(Add(x, half)); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); add_and_exp_ = computation_status.ConsumeValueOrDie(); @@ -70,9 +70,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); add_ = computation_status.ConsumeValueOrDie(); @@ -81,9 +81,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { XlaBuilder builder("sigmoid"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = builder.ConstantR0(1.0); - builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&builder, 1.0); + Div(one, Add(one, Exp(Neg(x)))); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); sigmoid_ = computation_status.ConsumeValueOrDie(); @@ -92,9 +92,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { XlaBuilder builder("max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Max(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); max_ = computation_status.ConsumeValueOrDie(); @@ -103,9 +103,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { XlaBuilder builder("gt"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Gt(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Gt(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); gt_ = computation_status.ConsumeValueOrDie(); @@ -137,9 +137,9 @@ class HloCostAnalysisTest : public ::testing::Test { TEST_F(HloCostAnalysisTest, MatrixMultiply) { XlaBuilder builder("matrix_multiply"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); - auto result = builder.Dot(lhs, rhs); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + Dot(lhs, rhs); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -159,8 +159,8 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); - auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); - auto result = builder.Map({input}, add_and_exp_, {0}); + auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); + Map(&builder, {input}, add_and_exp_, {0}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -176,17 +176,17 @@ TEST_F(HloCostAnalysisTest, Map) { TEST_F(HloCostAnalysisTest, Convolution) { XlaBuilder builder("convolution"); - auto input = builder.Parameter( - 0, + auto input = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, /*x_dim=*/20}), "input"); - auto kernel = builder.Parameter( - 1, + auto kernel = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); + Conv(input, kernel, {1, 1}, Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -206,9 +206,8 @@ TEST_F(HloCostAnalysisTest, Convolution) { TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = - builder.Reduce(input, builder.ConstantR0(0.0f), add_, {1}); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + Reduce(input, ConstantR0(&builder, 0.0f), add_, {1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -224,9 +223,9 @@ TEST_F(HloCostAnalysisTest, Reduce) { TEST_F(HloCostAnalysisTest, ReduceWindow) { XlaBuilder builder("reduce_window"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, - {4, 5}, {4, 5}, Padding::kValid); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + ReduceWindow(input, ConstantR0(&builder, 0), add_, {4, 5}, {4, 5}, + Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -241,12 +240,11 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { TEST_F(HloCostAnalysisTest, SelectAndScatter) { XlaBuilder builder("select_and_scatter"); auto operand = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); - auto result = - builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, - source, builder.ConstantR0(0), add_); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); + SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source, + ConstantR0(&builder, 0), add_); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -261,7 +259,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { TEST_F(HloCostAnalysisTest, Broadcast) { XlaBuilder b("broadcast"); - b.Broadcast(b.ConstantR0(42), {10, 7}); + Broadcast(ConstantR0(&b, 42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( @@ -273,13 +271,12 @@ TEST_F(HloCostAnalysisTest, Broadcast) { TEST_F(HloCostAnalysisTest, FullyConnectedForward) { XlaBuilder builder("fully_connected_forward"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); - auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); + auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias"); // sigmoid(input * weight + bias) - auto result = builder.Map( - {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); + Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -297,11 +294,11 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { XlaBuilder builder("conv_looking_matmul"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), - "input"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), - "weights"); - builder.Conv(lhs, rhs, {1, 1}, Padding::kSame); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "input"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "weights"); + Conv(lhs, rhs, {1, 1}, Padding::kSame); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( &conv_analysis)); @@ -311,10 +308,10 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { { XlaBuilder builder("matmul"); auto lhs = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); - builder.Dot(lhs, rhs); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); + Dot(lhs, rhs); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( &matmul_analysis)); @@ -419,9 +416,9 @@ TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { XlaBuilder builder("matmul"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); - auto tuple = builder.Tuple({x, y}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); + Tuple(&builder, {x, y}); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK( @@ -435,21 +432,21 @@ TEST_F(HloCostAnalysisTest, TupleCost) { TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); - auto input = builder.Parameter( - 0, + auto input = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, /*x_dim=*/20}), "input"); - auto kernel = builder.Parameter( - 1, + auto kernel = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.ConvGeneralDilated( - input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, - /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - XlaBuilder::CreateDefaultConvDimensionNumbers(2)); + ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1}, + /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -460,5 +457,51 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { EXPECT_EQ(analysis.flop_count(), 1472); } +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-slice"); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + DynamicSlice(x, ConstantR1(&builder, {1}), {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-update-slice"); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + DynamicUpdateSlice(x, ConstantR1(&builder, {1.0}), + ConstantR1(&builder, {1})); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index c17c26c5a435fe34dd1024d596004cf6b5fdce8c..a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -41,16 +42,16 @@ namespace { // Find and combine identical constants. Constants are identical if they have // the same type and value. -bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { - bool changed = false; - +StatusOr CombineConstants(HloComputation* computation, + bool is_layout_sensitive) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); // Map from ShortDebugString of the layoutless shape of the constant to the // set of constant instructions with that shape. Layoutless shape is used to // bin possible common constants together to reduce number of constant // comparisons. If we end up having too many constant comparisons, a more // precise binning might have to be used. std::multimap constants; - + int64 combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { HloInstruction* instruction = *inst_it; @@ -70,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal() == it->second->literal()) { + if (instruction->literal() == it->second->literal() && + domain_map->InSameDomain(it->second, instruction)) { match = it->second; break; } @@ -81,12 +83,13 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { // Match found, replace this instruction with the one in the multimap. TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); - changed = true; + ++combined; } } } - - return changed; + VLOG(4) << "Combined " << combined << " constants in " << computation->name() + << " computation"; + return combined > 0; } // An instruction is considered to be equivalent to another only if they @@ -123,24 +126,27 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } - changed |= CombineConstants(computation, is_layout_sensitive_); + TF_ASSIGN_OR_RETURN(bool combined, + CombineConstants(computation, is_layout_sensitive_)); + changed |= combined; // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. tensorflow::gtl::FlatSet - representatives(/*N=*/1024, &CseHash, cse_equal); - + representatives(/*N=*/computation->instruction_count() + 1, &CseHash, + cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { // If the instruction has zero operands (constants, parameters, etc.) skip // over it. if (instruction->operand_count() == 0) { continue; } - - // Skip instructions which have side effects. - if (instruction->HasSideEffect()) { + // Skip instructions which have side effects or are a domain (which must + // not be CSE-ed). + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kDomain) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9735764b692238d6a320bcff51e43b98dcadabda..16db374566c727f1f3efe2a6d419f1f3caf0aaf1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -32,10 +32,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" @@ -142,31 +142,46 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // Test that constants with the same value but different type are *not* // commoned. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + std::vector constants; + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + + const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); + for (int64 i = 0; i < constants.size(); ++i) { + constants[i] = builder.AddInstruction( + HloInstruction::CreateConvert(shape_r0, constants[i])); + } + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, constants[0], constants[1])); + for (int64 i = 2; i < constants.size(); ++i) { + root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, root, constants[i])); + } auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(7, computation->instruction_count()); + EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - EXPECT_EQ(6, computation->instruction_count()); + // CSE will remove both the second float(42.0f) and the corresponding + // convert/cast. + EXPECT_EQ(18, computation->instruction_count()); } TEST_F(HloCseTest, NonscalarConstants) { @@ -471,7 +486,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m add_computation { @@ -501,5 +516,25 @@ TEST_F(HloCseTest, CompareComputations) { EXPECT_EQ(root->operand(0), root->operand(1)); } +TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { + // Test that constants with the same value but in different domains (disjoint + // in this case) are not collapsed. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(2, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index b06e6c9f3e62f375a9e48f8ef81efe7121bbef94..de1a32d8bd9217baabda4ab4b02bf28baebad531 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -34,16 +34,86 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { + +// We have this pattern in dynamaic update slice fusion, which should be +// supported: +// +// Parameters: p0, p1 +// Fusion +// ds = DynamicSlice(p0, p1) +// ROOT DynamicUpdateslice(p0, ds, p1) +// +// In this case, we should be able to reuse p0 and output, although p0 has +// multiple uses. +bool MultiDynamicSliceUseShareSameIndices( + tensorflow::gtl::ArraySlice uses) { + if (uses.empty()) { + return false; + } + const HloInstruction* indices = nullptr; + for (HloUse use : uses) { + auto user = use.instruction; + if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + if (indices == nullptr) { + indices = user->operand(2); + } else if (indices != user->operand(2)) { + return false; + } + if (use.operand_number != 0) { + return false; + } + } else if (user->opcode() == HloOpcode::kDynamicSlice) { + if (indices == nullptr) { + indices = user->operand(1); + } else if (indices != user->operand(1)) { + return false; + } + } else { + return false; + } + } + return true; +} + +} // namespace using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value) +HloDataflowAnalysis::HloDataflowAnalysis( + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(&module)) {} + call_graph_(CallGraph::Build(&module)), + fusion_can_share_buffer_(fusion_can_share_buffer) {} + +bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + const HloInstruction* inst) { + tensorflow::gtl::FlatSet visited; + tensorflow::gtl::InlinedVector stack; + stack.push_back(inst); + while (!stack.empty()) { + const HloInstruction* current = stack.back(); + stack.pop_back(); + visited.insert(current); + for (const HloInstruction* user : current->users()) { + // Found a user that is non-elementwise on current instruction. + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index) && + user->opcode() != HloOpcode::kTuple) { + return false; + } + } + if (!visited.count(user)) { + stack.push_back(user); + } + } + } + return true; +} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -328,18 +398,17 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); bool changed = false; - // RecvDone forwards the operand value at {0} to the output. + // RecvDone forwards the operand value at {0} to element {0} of its output. for (auto& pair : GetInstructionValueSet(recv_done)) { ShapeIndex& index = pair.first; HloValueSet& value_set = pair.second; - ShapeIndex operand_index = {0}; - for (int64 i : index) { - operand_index.push_back(i); + if (index.empty() || index[0] != 0) { + continue; } const HloValueSet& operand_value_set = - GetValueSet(recv_done->operand(0), operand_index); + GetValueSet(recv_done->operand(0), index); if (value_set != operand_value_set) { value_set = operand_value_set; changed = true; @@ -363,7 +432,7 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateConditionalValueSet( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( @@ -396,6 +465,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { return changed; } +bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { + // Domain instructions just forward their operand. Given that domains can have + // a tuple operand, we iterate through its indexes, like for copies. + // Unlike copies though we also propagate the top-level value. + CHECK_EQ(domain->opcode(), HloOpcode::kDomain); + bool changed = false; + for (auto& pair : GetInstructionValueSet(domain)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -490,17 +577,17 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { } } -bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) { - CHECK_EQ(select->opcode(), HloOpcode::kSelect); - // A phi value is not defined at a kSelect instruction because kSelect does - // not create a new value. Rather it forwards a value from its operands. This - // contrasts with kWhile instruction (which does define a phi value) which has - // in-place update semantics. +bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) { + CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect); + // A phi value is not defined at a kTupleSelect instruction because + // kTupleSelect does not create a new value. Rather it forwards a value from + // its operands. This contrasts with kWhile instruction (which does define a + // phi value) which has in-place update semantics. bool changed = false; for (auto& pair : GetInstructionValueSet(select)) { const ShapeIndex& index = pair.first; if (index.empty()) { - // kSelect copies (not forwards) the top-level value. + // kTupleSelect copies (not forwards) the top-level value. continue; } HloValueSet& value_set = pair.second; @@ -538,7 +625,7 @@ bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) { bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet(xla_while->while_body()->root_instruction()), &GetInstructionValueSet(xla_while->operand(0))}; if (ssa_form_) { @@ -556,12 +643,14 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateBitcastValueSet(instruction); case HloOpcode::kSlice: return UpdateSliceValueSet(instruction); + case HloOpcode::kDomain: + return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: return UpdateGetTupleElementValueSet(instruction); - case HloOpcode::kSelect: - return UpdateSelectValueSet(instruction); + case HloOpcode::kTupleSelect: + return UpdateTupleSelectValueSet(instruction); case HloOpcode::kTuple: return UpdateTupleValueSet(instruction); case HloOpcode::kParameter: @@ -734,6 +823,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: + case HloOpcode::kDomain: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. break; @@ -759,21 +849,25 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { } break; case HloOpcode::kCopy: - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kTuple: // These instructions only define their top-level values. Any other // values flow from their operands. define_top_level_only(); break; case HloOpcode::kRecvDone: - // RecvDone aliases its input tuple element {0}, therefore does not - // define any values. + // RecvDone produces a two-element tuple. Element zero aliases its + // input tuple element {0}; element one is a token. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); break; case HloOpcode::kSend: - // Send produces a tuple of {aliased operand, U32 context}, therefore - // only defines the top-level tuple and the tuple element at {1}. + // Send produces a tuple of {aliased operand, U32 context, token}, + // therefore only defines the top-level tuple and the tuple elements + // at {1} and {2}. define_value_at(/*index=*/{}); define_value_at(/*index=*/{1}); + define_value_at(/*index=*/{2}); break; default: define_all_values(); @@ -787,12 +881,13 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) { VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique( - new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); + auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); dataflow_analysis->Propagate(); @@ -915,6 +1010,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; @@ -927,20 +1023,27 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); if (value.uses().size() != 1) { + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; + } return false; } const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } else { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -965,8 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return use.instruction == user->fused_expression_root() && use.operand_number == other_add_operand_index; + } else if (fusion_can_share_buffer_ != nullptr && + fusion_can_share_buffer_(user, operand)) { + return true; } } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, @@ -998,8 +1105,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } - // Check if 'user' is element-wise. - return user->IsElementwise(); + + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 9868746b6113881949e388cd2a4aa9f610b1fdb7..f4abc7a7c7dcfb223067fe946bec0c5ef32f206b 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -42,6 +42,20 @@ namespace xla { // Analysis which identifies all HLO values and their uses in an HLO module. class HloDataflowAnalysis { public: + // Different backends can have very different ways to do fusion, so we give + // backends the flexibility to decide whether an fusion instruction can share + // buffer with it's operands. If this is not specified, a default strategy + // will be used; if this is specified, it will be applied *in addition* to the + // default strategy. + // + // The first parameter of the function should be the fusion instruction, the + // second parameter should be an operand of the fusion instruction. + // + // TODO(b/80315712): Find a better way to tell whether a fusion can share + // buffer. + using FusionCanShareBufferFunction = std::function; + // Run dataflow analysis on the given module. Parameters: // // ssa_form : If true then new values are defined at the merge points of @@ -61,7 +75,10 @@ class HloDataflowAnalysis { // value of its operand. static StatusOr> Run( const HloModule& module, bool ssa_form = false, - bool bitcast_defines_value = false); + bool bitcast_defines_value = false, + const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); + + static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); // Returns true if 'instruction' defines an HLO value at the given shape index // of its output. @@ -136,8 +153,10 @@ class HloDataflowAnalysis { const ShapeIndex& user_index) const; protected: - HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value = false); + HloDataflowAnalysis( + const HloModule& module, bool ssa_form, + bool bitcast_defines_value = false, + const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); // Returns a new HloValue defined at the given instruction and shape index. HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, @@ -166,10 +185,11 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); bool UpdateRecvDoneValueSet(HloInstruction* recv_done); - bool UpdateSelectValueSet(HloInstruction* select); + bool UpdateTupleSelectValueSet(HloInstruction* select); bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); @@ -221,6 +241,10 @@ class HloDataflowAnalysis { // The Id to use for the next HloValue. HloValue::Id next_value_id_ = 0; + + // Backend specific function that decides whether a fusion can share buffer + // with its operand. + FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 5798326dcbf65c3c34748afb02afab1dc7af9147..f176473366ab74fa532ffb26ffc6adbb9731de67 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -860,8 +860,7 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { } TEST_P(HloDataflowAnalysisTest, TupleSelect) { - // Test a kSelect of a tuple value. Non-top-level element flow through the - // instruction. + // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); @@ -883,20 +882,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({constant4})); const Shape tuple_shape = tuple1->shape(); auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1)); auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4)); auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - // Top-level value is always defined by a kSelect. + // Top-level value is always defined by a kTupleSelect. EXPECT_TRUE(analysis.ValueIsDefinedAt(select11)); EXPECT_TRUE(analysis.ValueIsDefinedAt(select12)); EXPECT_TRUE(analysis.ValueIsDefinedAt(select34)); @@ -937,7 +936,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { } TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { - // Test kSelect of a nested tuple. + // Test kTupleSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); @@ -960,7 +959,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant4, inner_tuple2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); module_->AddEntryComputation(builder.Build()); @@ -983,7 +982,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { } TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { - // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO: // // body((F32[], F32[]) %tuple_param): // %add = Add(%tuple_param{0}, %tuple_param{1}) @@ -1043,7 +1042,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({constant2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0)); auto tuple = @@ -1158,44 +1157,50 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto send = builder.AddInstruction( - HloInstruction::CreateSend(param, /*channel_id=*/0)); + HloInstruction::CreateSend(param, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 4); + EXPECT_EQ(analysis.values().size(), 6); EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(param))); } TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { - // Test that a RecvDone forwards its operand tuple element at {0} to the - // output. + // Test that a RecvDone forwards its operand tuple element at {0} to element + // {0} of the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 3); + EXPECT_EQ(analysis.values().size(), 7); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); - EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); - EXPECT_THAT(HloValuesAt(recv_done), + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1})); + EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); EXPECT_TRUE( analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); @@ -1880,9 +1885,14 @@ class HloDataflowAnalysisTestBase : public HloTestBase { computation_ = module_->AddEntryComputation(std::move(computation)); } - void RunAnalysis() { + void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr) { CHECK_NOTNULL(module_.get()); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + dataflow_analysis_ = + HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false, + /*bitcast_defines_value=*/false, + fusion_can_share_buffer) + .ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { @@ -1974,6 +1984,114 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + NonElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0)); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, neg, {0, 1})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {reverse, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + MultiOutputFusionCanAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0)); + auto copy1 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + ElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {exp, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + auto index = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 0}))); + auto ds = builder.AddInstruction( + HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + + auto dus = builder.AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dus, ds, index}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2048,6 +2166,45 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + FusedDynamicUpdateSliceWithConvertCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + auto convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape_bf16, gte1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape_bf16, convert1, update, starts)); + + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {convert2, dynamic_update_slice, starts, update, convert1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -2136,6 +2293,33 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kMultiply, operand, operand)); + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, mul}, HloInstruction::FusionKind::kInput); + RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion, + const HloInstruction*) { + return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop; + }); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fcd723af146e2227b8661b1a4993f1338f7de389..7d35e251ca21951036336ff1a1eb4aabc87bc5ca 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -41,20 +41,13 @@ StatusOr HloDCE::Run(HloModule* module) { XLA_VLOG_LINES(2, module->ToString()); for (auto* computation : module->MakeComputationPostOrder()) { - std::unordered_set live_instructions; - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( - [&live_instructions](HloInstruction* instruction) { - live_instructions.insert(instruction); - return Status::OK(); - })); - // Remove any dead roots and their dead transitive operands. Collect them // into a separate list first to avoid problems with iterating through the // computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { - if (instruction->user_count() == 0 && - live_instructions.count(instruction) == 0 && + if (instruction != computation->root_instruction() && + instruction->user_count() == 0 && computation->IsRemovable(instruction) && !instruction->HasSideEffect()) { dead_roots.push_back(instruction); @@ -85,8 +78,7 @@ StatusOr HloDCE::Run(HloModule* module) { } // Remove dead computations. - std::list computations = module->MakeComputationPostOrder(); - for (auto* computation : computations) { + for (auto* computation : module->MakeComputationPostOrder()) { if (live_computations.count(computation) == 0) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 5a56607a665c4cbeb7b2572f182b88e890602968..f5524dc6fef3ae11e29011ad7927ee55e1701d76 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -75,19 +75,20 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); HloDCE dce; EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); } TEST_F(HloDceTest, DeadParameters) { @@ -234,9 +235,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { { auto param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - - auto infeed = - body_builder.AddInstruction(HloInstruction::CreateInfeed(shape, "")); + auto token = + body_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = body_builder.AddInstruction( + HloInstruction::CreateInfeed(shape, token, "")); body_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed)); } @@ -278,8 +280,10 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { { auto param = nested_callee_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); + auto token = nested_callee_builder.AddInstruction( + HloInstruction::CreateAfterAll({})); nested_callee_builder.AddInstruction( - HloInstruction::CreateOutfeed(shape, param, "")); + HloInstruction::CreateOutfeed(shape, param, token, "")); } auto nested_called_computation = module->AddEmbeddedComputation(nested_callee_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc new file mode 100644 index 0000000000000000000000000000000000000000..78955db0da02f16eb93689db947dc1190ab7049a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainIsolator::RunContext { + public: + RunContext(HloModule* module, HloDomainIsolator* isolator) + : module_(module), isolator_(isolator) {} + + StatusOr Run(); + + private: + // Inserts a kDomain instruction between parent and operand, in case + // the attribute (ie, sharding) values change between instruction and operand. + // Returns the newly inserted kDomain instruction, or nullptr if no kDomain + // instruction was necessary. + StatusOr CreateDomain(HloInstruction* instruction, + HloInstruction* parent, + HloInstruction* operand); + + HloModule* module_; + HloDomainIsolator* isolator_; +}; + +StatusOr HloDomainIsolator::RunContext::CreateDomain( + HloInstruction* instruction, HloInstruction* parent, + HloInstruction* operand) { + HloInstruction* domain = nullptr; + std::unique_ptr domain_instruction = + isolator_->creator_(instruction, operand); + if (domain_instruction != nullptr) { + domain = operand->parent()->AddInstruction(std::move(domain_instruction)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + } + return domain; +} + +StatusOr HloDomainIsolator::RunContext::Run() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); + + int64 added_domains = 0; + for (HloComputation* computation : module_->computations()) { + // Walk in post order and place all the required kDomain instructions. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + continue; + } + for (HloInstruction* operand : instruction->unique_operands()) { + // When applying multiple domains, we could end up stacking more than + // one in one edge, so here we want to build the effective + // (kDomain-less) instruction->operand edge. + HloInstruction* parent = instruction; + while (operand->opcode() == HloOpcode::kDomain) { + parent = operand; + operand = operand->mutable_operand(0); + } + // Check whether a kDomain is necessary between instruction and operand. + TF_ASSIGN_OR_RETURN(HloInstruction * domain, + CreateDomain(instruction, parent, operand)); + if (domain != nullptr) { + VLOG(4) << "New domain: " << domain->ToString(); + ++added_domains; + } + } + } + } + VLOG(3) << "Added " << added_domains << " kDomain instructions"; + if (added_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + } + return added_domains > 0; +} + +HloDomainIsolator::HloDomainIsolator(DomainCreator creator) + : creator_(std::move(creator)) {} + +StatusOr HloDomainIsolator::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h new file mode 100644 index 0000000000000000000000000000000000000000..eded3e78eead76c4564daee119034c5031eba409 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Domain isolation is the task of placing kDomain instructions between HLO +// instructions having different sharding. A kDomain instruction is essentially +// used to break an HLO graph edge connecting two instructions with different +// sharding. If a set of connected instructions have all the same sharding, no +// kDomain instruction will be placed. +class HloDomainIsolator : public HloPassInterface { + public: + // Creates a new kDomain instruction for the edge between the use instruction + // (the first HloInstruction argument), and the operand instruction (the + // second HloInstruction argument). + // Returns nullptr in case no domain separation is necessary. + using DomainCreator = std::function( + HloInstruction*, HloInstruction*)>; + + explicit HloDomainIsolator(DomainCreator creator); + + tensorflow::StringPiece name() const override { return "domain_isolator"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + DomainCreator creator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebd5adb5d573ce4b556046f85eb26a6ad59efcb9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +/* static */ StatusOr> HloDomainMap::Create( + HloComputation* computation, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + return std::move(domain_map); +} + +/* static */ StatusOr> HloDomainMap::Create( + HloModule* module, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + } + return std::move(domain_map); +} + +bool HloDomainMap::InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const { + int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); + int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + return domain_id1 >= 0 && domain_id1 == domain_id2; +} + +Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); + // We only check operands, so we are sure to not process the empty domain from + // both sides. + for (HloInstruction* operand : instruction->unique_operands()) { + if (IsDomainInstruction(operand)) { + auto domain = MakeUnique(); + domain->enter_domains.insert(operand); + domain->exit_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + } + return Status::OK(); +} + +Status HloDomainMap::Populate(HloComputation* computation) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsDomainInstruction(instruction)) { + // If this is a kDomain of the kind we are currently processing, check + // whether this is an "empty domain". + TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + continue; + } + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + if (domain_id >= 0) { + // We have already processed this instruction. + continue; + } + TF_ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction)); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + return Status::OK(); +} + +Status HloDomainMap::InsertDomain( + std::unique_ptr domain) { + int64 domain_id = instruction_domains_.size(); + instruction_domains_.push_back(std::move(domain)); + for (HloInstruction* instruction : instruction_domains_.back()->reach_set) { + instruction_to_domain_[instruction] = domain_id; + } + return Status::OK(); +} + +Status HloDomainMap::ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const { + std::vector in_queue; + in_queue.push_back(instruction); + while (!in_queue.empty()) { + HloInstruction* current_instruction = in_queue.back(); + in_queue.pop_back(); + if (domain->reach_set.insert(current_instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = + FindOrDefault(instruction_to_domain_, current_instruction, -1); + TF_RET_CHECK(domain_id < 0) + << "Instruction " << current_instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : current_instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + in_queue.push_back(operand); + } + } + for (HloInstruction* user : current_instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + in_queue.push_back(user); + } + } + } + } + return Status::OK(); +} + +StatusOr> HloDomainMap::CreateDomain( + HloInstruction* instruction) const { + auto domain = MakeUnique(); + TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + domain->instructions = MakeNonDomainInstructions(domain->reach_set); + return std::move(domain); +} + +bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { + if (instruction->opcode() != HloOpcode::kDomain) { + return false; + } + if (!domain_kind_.empty()) { + if (instruction->user_side_metadata().Kind() != domain_kind_) { + return false; + } + // Both user and operand side of the metadata must be of the same kind. + CHECK(instruction->operand_side_metadata().Kind() == domain_kind_) + << "Instruction " << instruction->ToString() + << " has mismatching metadata kinds"; + } + return true; +} + +/* static */ std::vector +HloDomainMap::MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set) { + std::vector instructions; + instructions.reserve(instruction_set.size()); + for (HloInstruction* instruction : instruction_set) { + if (instruction->opcode() != HloOpcode::kDomain) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [](HloInstruction* a, HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + return instructions; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h new file mode 100644 index 0000000000000000000000000000000000000000..e62ef763fb3881ab6030b1f6a66266ac80a3d84d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The HloDomainMap splits a set of instructions within a module or computation, +// into different domains, separated by kDomain instructions. +// A domain is composed by a set of instructions which can reach each other via +// operand/user edges, without crossing a kDomain insutrction of a given kind. +// A domain never crosses computation boundaries. +class HloDomainMap { + public: + // Creates a new HloDomainMap, creating all the domains within the input + // computation, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create( + HloComputation* computation, string domain_kind); + + // Creates a new HloDomainMap, creating all the domains within the input + // module, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create(HloModule* module, + string domain_kind); + + // Retrieves all the domains the input module or computation are composed by. + const std::vector>& GetDomains() + const { + return instruction_domains_; + } + + // Checks whether two instructions are within the same domain. + bool InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const; + + // Checks whether instruction is a kDomain instruction of the kind we are + // currently processing. + bool IsDomainInstruction(HloInstruction* instruction) const; + + private: + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} + + // Check if the kDomain instruction is facing (via its operand link) another + // kDomain instruction of the same kind, hence defining an empty domain. + // If that is the case, create the empty domain and call the proper + // normalizer. + Status TryProcessEmptyDomain(HloInstruction* instruction); + + Status Populate(HloComputation* computation); + + // Inserts the provided domain into the ones tracked by this object, + // creating a new domain ID. + Status InsertDomain(std::unique_ptr domain); + + // From the given instruction, epxands operand and user wise, the set of + // instructions which can be reached without crossing a kDomain instruction + // of the kind specified by domain_kind_. + // The domain data structure will be populated with all the reached + // instructions, and the boundaries of the domain, with the kDomain + // instructions encountered while expanding the reach. + Status ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const; + + // Creates a domain data structure using the ExpandDomain() API. + StatusOr> CreateDomain( + HloInstruction* instruction) const; + + // Out of an instruction set, returns a vector of all the ones which are not + // a kDomain kind. + static std::vector MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set); + + string domain_kind_; + std::vector> instruction_domains_; + tensorflow::gtl::FlatMap instruction_to_domain_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0308100a21f109579de75788fce7d242d6a6b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Cannot include hlo_instruction.h as this file is included from there. +class HloInstruction; + +// The DomainMetadata represents the base class for metadata which can be +// attached to kDomain HLO instructions. +class DomainMetadata { + public: + // A Domain data structure captures all the information about a kDomain + // bounded instruction set. + struct Domain { + // The set of instructions which are reachable from each other via + // operand/user pathways, without crossing a kDomain instruction of a given + // kind. The reach_set can contain kDomain instructions of other kinds, if + // two domains of different kind intersect each other. + tensorflow::gtl::FlatSet reach_set; + + // The same instructions in reach_set, but purged from kDomain instructions. + std::vector instructions; + + // If we consider a graph edge as an arrow oriented from the operand to the + // user, the enter_domains will contain the set of kDomain instructions + // whose dataflow enters the reach set (domain), while the exit_domains + // contains the set of kDomain instructions whose dataflow exit the reach + // set. + tensorflow::gtl::FlatSet enter_domains; + tensorflow::gtl::FlatSet exit_domains; + }; + + virtual ~DomainMetadata() = default; + + // Clones the metadata object. + virtual std::unique_ptr Clone() const = 0; + + // Returns the metadata type. A unique identifier which describes the real + // metadata type. + virtual tensorflow::StringPiece Kind() const = 0; + + // Compares the metadata object with another one and returns true if the + // two matches. + virtual bool Matches(const DomainMetadata& other) const = 0; + + // Returns a string representation of the metadata. + virtual string ToString() const = 0; + + // Given a reachable set (the set of instructions which are reachable from + // each other via user/operand pathways, without crossing a kDomain + // instruciton), makes sure that all of them have metadata attributes which + // are coherent with this metadata object. + virtual Status NormalizeInstructions(const Domain& domain) const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d06040b0e7c92b03f4cb5481bdee73a0f74f939 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainRemover::RunContext { + public: + RunContext(HloModule* module, HloDomainRemover* remover) + : module_(module), remover_(remover) {} + + StatusOr Run(); + + private: + // Verifies the consistency of the domain, and normalizes the instructions + // within it. + Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain); + + HloModule* module_; + HloDomainRemover* remover_; +}; + +Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( + const DomainMetadata::Domain& domain) { + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + if (ref_metadata != nullptr) { + VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); + TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + } else { + // No kDomain instruction was present within this domain, so call the + // generic normalization functions and have them apply their heuristic. + VLOG(2) << "Applying domain-less normalization"; + TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + } + return Status::OK(); +} + +StatusOr HloDomainRemover::RunContext::Run() { + VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); + + int64 removed_domains = 0; + for (HloComputation* computation : module_->computations()) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); + // Verify and normalize every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + } + + // Now remove all the kDomain instructions of the kind specified by the + // remover, that are within the currently processed computation from the + // graph. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + for (HloInstruction* operand : instruction->unique_operands()) { + if (domain_map->IsDomainInstruction(operand)) { + VLOG(5) << "Removing " << operand->name(); + TF_RETURN_IF_ERROR( + operand->ReplaceAllUsesWith(operand->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + ++removed_domains; + } + } + } + HloInstruction* root = computation->root_instruction(); + if (root != nullptr && domain_map->IsDomainInstruction(root)) { + VLOG(5) << "Removing " << root->name(); + computation->set_root_instruction(root->mutable_operand(0)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + ++removed_domains; + } + } + VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" + << remover_->kind_ << "' kind"; + if (removed_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); + } + return removed_domains > 0; +} + +StatusOr HloDomainRemover::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h new file mode 100644 index 0000000000000000000000000000000000000000..0c71dd34fd4d2944037dc965a2c9ad2c592d6e3e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Removes all the kDomain instructions of a given kind from the input module, +// and calls the normalizer to propagate the properties on the possibly new born +// instructions. +class HloDomainRemover : public HloPassInterface { + public: + // Creates a new HloDomainRemover object tasked at removing all the kDomain + // instructions of a given kind. + // In case a reachable set (the set of instructions within a computation, + // which are mutually reachable via operand/user pathways) has all the + // instructions in it with the same attributes (ie, sharding), a normalizer + // function is tasked at applying attribute normalization on the instructions + // within such domain. + HloDomainRemover( + tensorflow::StringPiece kind, + std::function normalizer) + : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + + tensorflow::StringPiece name() const override { return "domain_remover"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + string kind_; + std::function normalizer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3859e4cae6e15bdb783277093b80d7822b1f4670 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -0,0 +1,461 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloDomainTest : public HloVerifiedTestBase { + protected: + bool FindUserViaDomainPath(HloInstruction* instruction, + HloInstruction* operand) const { + for (HloInstruction* user : operand->users()) { + if (user == instruction) { + return true; + } + if (user->opcode() == HloOpcode::kDomain && + FindUserViaDomainPath(instruction, user)) { + return true; + } + } + return false; + } + + // Checks whether there is a kDomain instruction in the edge between the + // instruction and the operand. + bool HasDomainEdge(HloModule* module, + tensorflow::StringPiece instruction_name, + tensorflow::StringPiece operand_name) { + HloInstruction* instruction = FindInstruction(module, instruction_name); + HloInstruction* operand = FindInstruction(module, operand_name); + CHECK_NE(instruction, nullptr); + CHECK_NE(operand, nullptr); + if (!instruction->IsUserOf(operand)) { + // If instruction is not an immediate user, we must find a path from + // operand to instruction anyway, otherwise there is a corruption. + if (FindUserViaDomainPath(instruction, operand)) { + return true; + } + LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name + << "' and '" << operand_name << "' instructions:\n" + << module->ToString(); + } + return false; + } + + StatusOr ParseModule(tensorflow::StringPiece hlo_string) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + ParseAndVerifyModule(hlo_string, config); + return &module(); + } +}; + +// Dummy DomainMetadata implementation which create kDomain boundaries around +// HLO instructions with the same metadata().op_name() values. +class OpNameMetadata : public DomainMetadata { + public: + explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} + + std::unique_ptr Clone() const override { + return MakeUnique(opname_); + } + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override { + const OpNameMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a OpNameMetadata, then it is clearly a no match. + return false; + } + return opname_ == other_ptr->opname_; + } + + string ToString() const override { return opname_; } + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override { + // For the purposes of this test, nothing to do. + return Status::OK(); + } + + static tensorflow::StringPiece KindName() { return "opname"; } + + private: + string opname_; +}; + +// Creator function for OpNameMetadata domains. +std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* operand) { + if (instruction->metadata().op_name() == operand->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + MakeUnique(operand->metadata().op_name()); + std::unique_ptr user_side_metadata = + MakeUnique(instruction->metadata().op_name()); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { + // Nothing to do for the particular use this test make of the OpName domains. + return Status::OK(); +} + +TEST_F(HloDomainTest, CheckDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1} + d = f32[4] subtract(a, b), sharding={maximal device=1} + e = f32[4] multiply(c, d), sharding={maximal device=1} + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b) + d = f32[4] subtract(a, b) + e = f32[4] multiply(c, d) + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(!isolator_changed); +} + +TEST_F(HloDomainTest, CheckDomainAroundIO) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + token = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + c = token[] send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} + e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} + f = f32[4] add(a, e_element) + g = f32[4] subtract(a, e_element) + ROOT h = (f32[4], f32[4]) tuple(f, g) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "b", "a")); + EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "b", "a")); + EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + token = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} + c = f32[4] add(b_element, b_element), sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_FALSE(isolator_changed); +} + +TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + token = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} + c = f32[4] add(b_element, b_element) + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_FALSE(remover_changed); + + HloInstruction* add = FindInstruction(module, "c"); + ASSERT_NE(add, nullptr); + auto device = add->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, 0); +} + +TEST_F(HloDomainTest, CheckMultiDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(a, b), sharding={maximal device=1} + d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"} + e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"} + f = f32[4] add(e, c), sharding={maximal device=1} + ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator sharding_isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, + sharding_isolator.Run(module)); + EXPECT_TRUE(sharding_isolator_changed); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module)); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module)); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module)); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); +} + +TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + token = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token), + sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + gte0 = f32[4] get-tuple-element(infeed.data), index=0 + gte1 = f32[4] get-tuple-element(infeed.data), index=1 + copy0 = f32[4] copy(gte0) + copy1 = f32[4] copy(gte1) + ROOT add = f32[4] add(copy0, copy1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); + + // Inject unassigned tuple/gte within the infeed domain, to simulate the + // HLO passes adding unexpected instructions. + // + // infeed + // | + // infeed.data (tuple element 0 of infeed) + // / \ + // GTE0 GTE1 + // / \ + // COPY0 COPY1 + // \ / + // \ / + // TUPLE + // | + HloInstruction* infeed = FindInstruction(module, "infeed"); + ASSERT_NE(infeed, nullptr); + HloInstruction* infeed_data = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + + auto infeed_data_users = infeed_data->users(); + HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data, + 0)); + HloInstruction* new_copy0 = + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte0->shape(), HloOpcode::kCopy, new_gte0)); + HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data, + 1)); + HloInstruction* new_copy1 = + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte1->shape(), HloOpcode::kCopy, new_gte1)); + HloInstruction* new_tuple = infeed_data->parent()->AddInstruction( + HloInstruction::CreateTuple({new_copy0, new_copy1})); + for (HloInstruction* user : infeed_data_users) { + TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + struct Assignment { + HloInstruction* instruction; + int64 device; + } assignments[] = { + {new_gte0, 1}, + {new_copy0, 1}, + {new_gte1, 0}, + {new_copy1, 0}, + }; + for (auto& assignment : assignments) { + auto device = assignment.instruction->sharding_unique_device(); + ASSERT_TRUE(device.has_value()); + EXPECT_EQ(*device, assignment.device); + } + EXPECT_TRUE(new_tuple->has_sharding()); + EXPECT_EQ( + new_tuple->sharding(), + HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)})); +} + +// Tests that text dumps of domain instructions can be parsed back, in the +// specific case of null shardings. +TEST_F(HloDomainTest, DumpParseNullSharding) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto sharding_md_0 = MakeUnique(nullptr); + auto sharding_md_1 = MakeUnique(nullptr); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( + shape, param, std::move(sharding_md_0), std::move(sharding_md_1))); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto hlo_string = module->ToString(); + ASSERT_TRUE(ParseModule(hlo_string).status().ok()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index d236f83aeb9254b9c6e6d04629758ac2c8fd0da3..4ed1508d7067684a15d0fb7d86e69b055bc1333b 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -119,6 +119,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { return false; } + HloCloneContext context(module); bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { @@ -140,6 +141,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || @@ -180,7 +182,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); new_hlo = computation->AddInstruction( - hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + hlo->CloneWithNewOperands(shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); @@ -189,16 +191,16 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - new_shape, new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(new_shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - hlo->shape(), new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5c5a059e0fd895f03bc26a975609b57333237faf..c170e36c73ad2bef830e528de3ec72d38683d888 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,8 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - ROOT infeed = bf16[4]{0} infeed() - outfeed = () outfeed(infeed) + token = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token) + ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) } )"; auto module = CreateModuleFromHloString(hlo_string); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ff7d07ee16d09a81da99c46d85888296e44cf8a9..47da46bfad646e12b736ecb123f9d3db16ca1990 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -94,7 +94,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -124,7 +124,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -300,11 +300,32 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( instruction->CloneWithNewOperands(instruction->shape(), operands); auto result = Evaluate(cloned_instruction.get()); - // Clean up our cloned instructions before returning. - cloned_instruction->DetachFromOperands(); - for (auto& operand : owned_operands) { - operand->DetachFromOperands(); - } + return result; +} + +StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), + rhs_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + return result; +} + +StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand) { + std::unique_ptr operand_instr = + HloInstruction::CreateConstant(operand.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); + auto result = Evaluate(cloned_instruction.get()); return result; } @@ -343,7 +364,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); + CHECK(ShapeUtil::IsArray(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); @@ -354,7 +375,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); + CHECK(ShapeUtil::IsArray(operand_shape)); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += @@ -859,6 +880,33 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { return Status::OK(); } +Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { + const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand.shape().dimensions(i)); + } + + TF_ASSIGN_OR_RETURN( + evaluated_[broadcast], + operand.Broadcast(broadcast->shape(), broadcast->dimensions())); + + return Status::OK(); +} + +Status HloEvaluator::HandleAfterAll(HloInstruction* token) { + evaluated_[token] = Literal::CreateToken(); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -914,9 +962,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { // Attach cloned computation to an empty HLO module so the existing ones are // not modified. HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloCloneContext context(&empty_hlo_module); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( - /*suffix=*/"clone_with_layout", &empty_hlo_module); + /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); } @@ -951,8 +1000,8 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* true_computation = conditional->true_computation(); auto* false_computation = conditional->false_computation(); - auto result = Literal::CreateFromShape(conditional->shape()); HloEvaluator embedded_evaluator; + std::unique_ptr result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -975,8 +1024,6 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); // If predicate is of scalar type, no element-wise selection would be needed. - // This would also handle output array of tuple types as the DefaultAction - // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { evaluated_[select] = on_true.CloneToUnique(); @@ -989,6 +1036,19 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { return DefaultAction(select); } +Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { + const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0)); + const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1)); + const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); + + if (pred.Get({})) { + evaluated_[tuple_select] = on_true.CloneToUnique(); + } else { + evaluated_[tuple_select] = on_false.CloneToUnique(); + } + return Status::OK(); +} + Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); @@ -1019,6 +1079,107 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } +// Key-value sort is a special snowflake: it's templated on two different +// element types, one for the keys, and one for the values. Jump through some +// hoops to make this work. +namespace { +template +std::unique_ptr EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { + CHECK_EQ(sort->operand_count(), 2); + // We need to sort and array of keys and an array of values, where the + // sorted order of the values is determined by the keys. The simplest(?) + // way to do this is to go to an array-of-pairs representation, sort the + // array using the keys, and then go back to pair-of-arrays. + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); + const auto& keys_data = keys_literal.data(); + const auto& values_data = values_literal.data(); + using kv_pair = std::pair; + std::vector key_value_vector; + CHECK_EQ(keys_data.size(), values_data.size()); + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys; + std::vector result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + auto result_keys_literal = MakeUnique(sort->operand(0)->shape()); + result_keys_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_keys)); + auto result_values_literal = MakeUnique(sort->operand(1)->shape()); + result_values_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_values)); + auto result_tuple = Literal::MakeTuple( + {result_keys_literal.get(), result_values_literal.get()}); + VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + return result_tuple; +} + +template +StatusOr> EvaluateSortCurried( + HloInstruction* sort, const Literal& keys_literal, + const Literal& values_literal) { + switch (sort->operand(1)->shape().element_type()) { + case F32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case U32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case S32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case BF16: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + default: + return InvalidArgument("Unsupported type for Sort"); + } +} + +StatusOr> EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { + switch (sort->operand(0)->shape().element_type()) { + case F32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case U32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case S32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case BF16: + return EvaluateSortCurried(sort, keys_literal, values_literal); + default: + return InvalidArgument("Unsupported type for Sort"); + } +} +} // namespace + +Status HloEvaluator::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return DefaultAction(sort); + } else { + auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), + GetEvaluatedLiteralFor(sort->operand(1))); + if (result.ok()) { + evaluated_[sort] = std::move(result.ValueOrDie()); + return Status::OK(); + } else { + return result.status(); + } + } +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index cc5676ea7b05be6e0b7066bf703d8e48da0133ab..2850c5cb1a94de0dbab8ba5b27d7e21998794087 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -108,6 +109,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const std::unordered_map& substitutions); + StatusOr> EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs); + + StatusOr> EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -165,6 +172,36 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; + + Status HandleBroadcast(HloInstruction* broadcast) override; + + Status HandleAfterAll(HloInstruction* token) override; + + Status HandleSort(HloInstruction* sort) override; + + // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. + // Crash with log if the given instruction has not been evaluated previously. + const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } + auto it = evaluated_.find(hlo); + CHECK(it != evaluated_.end()) + << "could not find evaluated value for: " << hlo->ToString(); + return *(it->second); + } + + // Tracks the HLO instruction and its evaluated literal result. + // TODO(b/35950897): have better memory management here to free instructions + // that are no longer a parent for any other subsequent instruction in + // post-orderring. + // Must be cleared for each evaluation. + tensorflow::gtl::FlatMap> + evaluated_; + private: template static StatusOr> ElementWiseUnaryOpImpl( @@ -184,8 +221,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = Literal::CreateFromShape(shape); - + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { return unary_op(operand_literal.Get(multi_index)); @@ -193,20 +229,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return std::move(result); } - // Returns the already-evaluated literal result for the instruction. - // A Constant instruction is considered evaluated and its literal will be - // returned directly without looking up the cache. - // Crash with log if the given instruction has not been evaluated previously. - const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { - if (hlo->IsConstant()) { - return hlo->literal(); - } - auto it = evaluated_.find(hlo); - CHECK(it != evaluated_.end()) - << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); - } - // Map from a primitive type to its associated (templated) DfsHloVisitor. // Note: the hash function here is only needed because current gcc std::hash // does not specialize for enum types. This should however be fixed in the @@ -215,14 +237,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { std::hash> typed_visitors_; - // Tracks the HLO instruction and its evaluated literal result. - // TODO(b/35950897): have better memory management here to free instructions - // that are no longer a parent for any other subsequent instruction in - // post-orderring. - // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; - // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of // each invocation to the Evaluate* method. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index ae5b5e0412ef99db9b72d645a954759ca0b9eb8b..42770d848a83b2e27b87bc963d259e2b7af664a4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -206,6 +206,15 @@ TEST_P(HloEvaluatorTest, DoesOr) { std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise or with 2 operands. +TEST_P(HloEvaluatorTest, DoesXor) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{3, 4}, {-104, 0}}); + TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. TEST_P(HloEvaluatorTest, DoesMultiply) { auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); @@ -262,13 +271,13 @@ TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = @@ -333,7 +342,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); } @@ -567,7 +576,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5))); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -1248,7 +1257,7 @@ void BM_ReducePrecisely(int num_iters) { HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 5a459a4f16d3592c9b904fa805fa97b5a45d89ca..cdbac74ba4c9855ca586c8a6ef37b1e507eedea4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -34,6 +34,37 @@ using is_complex_t = std::is_same; template using is_complex64_t = std::is_same; +// It's UB to use std::sort with std::less, because of NaNs. Define +// "safe" less functions which are actually strict weak orders. +template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + return a < b; +} + +template ::value || + std::is_same::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (std::isnan(b)) { + return !std::isnan(a); + } else { + return a < b; + } +} + +template ::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (Eigen::half_impl::isnan(b)) { + return !Eigen::half_impl::isnan(a); + } else { + return a < b; + } +} + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -161,36 +192,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleRound(round); } - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); - } - template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -640,12 +641,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); + return InvalidArgument("Unsupported type for And"); } template < @@ -674,12 +670,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); + return InvalidArgument("Unsupported type for Or"); } template < @@ -693,6 +684,35 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleOr(or_); } + template ::value>::type* = + nullptr> + Status HandleXor(HloInstruction* xor_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[xor_], + ElementWiseBinaryOp(xor_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el ^ rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleXor(HloInstruction* xor_) { + return InvalidArgument("Unsupported type for Xor"); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleXor(HloInstruction* xor_) { + return InvalidArgument("Unsupported type for Xor"); + } + + Status HandleXor(HloInstruction* xor_) override { + return HandleXor(xor_); + } + template ::value && @@ -808,7 +828,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); + CHECK(ShapeUtil::IsArray(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -836,7 +856,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); + auto result = MakeUnique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice out_index) { @@ -993,7 +1013,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = Literal::CreateFromShape(result_shape); + auto result = MakeUnique(result_shape); TF_RETURN_IF_ERROR(result->PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); @@ -1033,87 +1053,50 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = Literal::CreateFromShape(dot->shape()); - CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - std::vector lhs_non_contracting_dims; + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + + // result_index_locations[i] contains one or two pointers to the locations + // in lhs_index or rhs_index where the i'th result index should go. + tensorflow::gtl::InlinedVector, kInlineRank> + result_index_locations; + result_index_locations.reserve(lhs_rank + rhs_rank - 2); + + // The first components in the output shape are the LHS and RHS batch + // dimensions: + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { + result_index_locations.push_back( + {&lhs_index[dnums.lhs_batch_dimensions(i)], + &rhs_index[dnums.rhs_batch_dimensions(i)]}); + } + + // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); + if (i != lhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + result_index_locations.push_back({&lhs_index[i], nullptr}); } } - - std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); + if (i != rhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + result_index_locations.push_back({&rhs_index[i], nullptr}); } } - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); + auto result = MakeUnique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + for (int64 i = 0; i < result_index.size(); i++) { + *result_index_locations[i].first = result_index[i]; + if (result_index_locations[i].second) { + *result_index_locations[i].second = result_index[i]; + } } // Accumulates resulting product along the contracted dimension. @@ -1134,7 +1117,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), @@ -1147,13 +1130,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*padding_config=*/pad->padding_config())); CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = Literal::CreateFromShape(pad->shape()); + auto result = MakeUnique(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( [&scalar](tensorflow::gtl::ArraySlice multi_index) { return scalar; @@ -1213,7 +1196,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1268,7 +1251,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape(), update->shape(), start_indices->shape())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1318,7 +1301,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = Literal::CreateFromShape(map->shape()); + auto result = MakeUnique(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR(result->Populate( @@ -1409,6 +1392,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleSort(HloInstruction* sort) { + auto keys = sort->operand(0); + TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1) + << "Sort is only supported for R1 shapes"; + TF_RET_CHECK(sort->operand_count() == 1) + << "Typed visitor does not support key-value sort"; + + const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data(); + + std::vector result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess(a, b); + }); + auto result_literal = MakeUnique(sort->shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + parent_->evaluated_[sort] = std::move(result_literal); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleSort(HloInstruction* sort) { + return InvalidArgument("Unsupported type for Sort"); + } + + Status HandleSort(HloInstruction* sort) override { + return HandleSort(sort); + } + Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); @@ -1424,7 +1447,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); @@ -1434,8 +1457,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = Literal::CreateFromShape(reduce->shape()); - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); std::vector arg_dim_steps(arg_dimensions.size()); std::vector arg_dim_counts(arg_dimensions.size()); @@ -1454,6 +1475,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1484,11 +1506,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. auto curr_val_literal = Literal::CreateR0(curr_val); auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {result_val_literal.get(), - curr_val_literal.get()}; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1532,7 +1555,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = Literal::CreateFromShape(select_and_scatter->shape()); + auto result = MakeUnique(select_and_scatter->shape()); // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( @@ -1644,7 +1667,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanStringWithLayout(inferred_return_shape); const Literal& operand_literal = @@ -1656,8 +1679,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = Literal::CreateFromShape(reduce_window->shape()); - // Creates a Shape object from window, for iteration below. std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1670,6 +1691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -1688,10 +1710,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Literal::CreateR0(curr_val); const auto result_val_literal = Literal::CreateR0(result_val); - const std::vector args = { - result_val_literal.get(), curr_val_literal.get()}; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -1989,17 +2012,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector start(start_indices_typed.begin(), start_indices_typed.end()); - std::vector operand_indices(start.size()); + // Clamp the start indices so the slice is in-bounds w.r.t the operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to officially document different behavior. + for (int64 i = 0; i < start.size(); ++i) { + start[i] = std::min( + std::max(int64{0}, start[i]), + operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); + } - auto result = Literal::CreateFromShape(result_shape); + std::vector operand_indices(start.size()); + auto result = MakeUnique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); + operand_indices[i] = multi_index[i] + start[i]; } auto result = operand_literal.Get(operand_indices); @@ -2016,23 +2046,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = operand_literal.CloneToUnique(); auto start_indices_typed = start_indices_literal.data(); const auto rank = ShapeUtil::Rank(result->shape()); - std::vector start(rank, 0); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + // Clamp the update start indices so the slice is in-bounds w.r.t the + // operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. for (int64 i = 0; i < rank; ++i) { - // All other implementations currently wrap-around the index, so this - // should do so as well. - start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); - start[i] += (start[i] < 0) * result->shape().dimensions(i); + start[i] = std::min( + std::max(0, start[i]), + result->shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); auto func = [&](tensorflow::gtl::ArraySlice update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - // Same as above, wrap-around only to match other implementations' - // semantics. - std::transform(result_index.begin(), result_index.end(), - result->shape().dimensions().begin(), result_index.begin(), - std::modulus()); result->Set(result_index, update_literal.Get(update_index)); return true; @@ -2083,7 +2114,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -2121,7 +2152,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 4900c813fdf037e65c6b42d027f1cbefb6ee9830..eba80c0f199f6224f4b46ac19af482c713585154 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -29,7 +29,7 @@ using ::testing::ContainsRegex; class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - auto hlo_module = tools::Parse(R"( + auto hlo_module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { lhs = f32[30,30]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 17e3c405f1e5269ddf2f03c031a1137f9bb14fcc..7a1372f929833a16de97c94e12b616e359b36950 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -321,13 +323,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - const DebugOptions& debug_options, bool show_metadata, - bool show_backend_config, const HloExecutionProfile* profile, - NodeFilter filter) + const DebugOptions& debug_options, bool show_backend_config, + const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(std::string(label)), debug_options_(debug_options), - show_metadata_(show_metadata), show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -395,7 +395,6 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_metadata_; const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -430,7 +429,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map sharding_colors_; + std::unordered_map + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -592,15 +592,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { VLOG(2) << "Dumping subcomputation " << subcomp->name(); - const char* computation_fmt = R"(subgraph %s { -%s -label = <%s>; -labelloc = t; -tooltip = " "; -%s -} // %s + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); + } -)"; + // Have we already dumped this subcomputation? If so, generating the edge + // linking it and parent_instr is all we want to do in this function. + if (cluster_ids_.find(subcomp) != cluster_ids_.end()) { + return ""; + } cluster_ids_[subcomp] = next_cluster_id_++; @@ -647,25 +658,16 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { - const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); - VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() - << " as " << next_edge_id_; - edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = - R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( - edge_fmt, InstructionId(from), InstructionId(parent_instr), - SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); - } - - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + const char* computation_fmt = R"(subgraph %s { +%s +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s - return computation; +)"; + return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -723,11 +725,25 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +static const HloConstantInstruction* TryGetFusionParameterConstant( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { + return nullptr; + } + const HloInstruction* fusion = instr->parent()->FusionInstruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + return DynCast(operand); +} + bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // If a node: // - // - is a tuple-shaped parameter, - // - is not a parameter to a fusion node, + // - is a parameter of a fusion node which is bound to a constant, + // + // or + // + // - is a tuple-shaped parameter, and + // - is not a parameter to a fusion node, and // - has at least kMinUsersToOmit users shown, and // - all of the shown users are get-tuple-elements, // @@ -735,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // // This helps us handle the common case where a while loop body has one big // tuple-shaped parameter. + if (TryGetFusionParameterConstant(instr) != nullptr) { + return true; + } const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && @@ -791,41 +810,41 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : {trivial_subcomputation, node_metadata, - node_backend_config, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_backend_config, + extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", - InstructionId(instr), node_body, node_shape, + InstructionId(instr), node_body, node_shape, node_metadata, NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); // If the shape has a dimension of size zero, print it as e.g. // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. optional elem_count; - if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + if (ShapeUtil::IsArray(shape)) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { + if (elem_count.has_value() && *elem_count <= 8) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -841,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( ShapeUtil::HumanString(constant->shape())); }; - // Special case: If instr is a parameter to a fusion node, check whether the - // corresponding operand to the fusion node is a constant. - if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->parent()->FusionInstruction(); - const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() != HloOpcode::kConstant) { - return ""; - } - return StrCat("constant ", stringify_constant(operand)); - } - std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast(operand); optional operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { - // Special case: If the operand is a parameter, use its parameter number - // rather than its name, because that's generally how people think of the - // node. + // Special case: If the operand is a parameter to a fusion node and it + // always has a constant value, display it like a regular constant. + // + // For other parameters, use the parameter number rather than the proper + // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + if (const HloConstantInstruction* constant = + TryGetFusionParameterConstant(operand)) { + operand_str = stringify_constant(constant); + } else { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } } else { operand_str = operand->name(); } @@ -885,24 +901,26 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it - // the same color as a parameter. + // the same color as a parameter. Unless the merged-in parameter is a + // parameter to a fusion node that is bound to a constant -- these aren't + // "real" parameters from the user's perspective. if (std::any_of(instr->operands().begin(), instr->operands().end(), [&](const HloInstruction* operand) { return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand); + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; })) { return kParameterColor; } @@ -942,11 +960,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -965,6 +985,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -976,13 +997,12 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kGreen; case HloOpcode::kConcatenate: - case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kTranspose: // De-emphasize scalar-shaped data movement ops and all data movement ops // inside fusion nodes, both of which are essentially free. @@ -998,6 +1018,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kCopy: + // Emphasize copy nodes, which are either physical transposes (and thus + // significant), or copies of read-only buffers (and thus dead weight). + return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: @@ -1013,6 +1037,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: return kPurple; + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: return kGray; @@ -1068,10 +1093,6 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { - if (!show_metadata_) { - return ""; - } - std::vector lines; if (!instr->metadata().op_name().empty()) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); @@ -1091,11 +1112,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->backend_config().empty()) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { return ""; } - return StrCat("backend_config=\"", instr->backend_config(), "\""); + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { @@ -1154,6 +1175,20 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return Join(lines, "
"); } +// Gets the total number of array elements in the given shape. For tuples, this +// is the sum of all the sizes of all of the array elements recursively in the +// tuple. +static int64 TotalElementsInShape(const Shape& shape) { + int64 elems = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + elems += ShapeUtil::ElementsIn(subshape); + } + }); + return elems; +} + void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1173,9 +1208,16 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } - const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; + + // We print "small" arrays using a hollow arrowhead and "large" arrays using + // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" + // means. + bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; + + const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - from->name(), to->name(), edge_label)); + (is_big_array ? "normal" : "empty"), from->name(), + to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1425,7 +1467,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata, bool show_backend_config) { + bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1436,8 +1478,8 @@ string DumpGraph(const HloComputation& computation, const string& label, graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { graph = - HloDotDumper(&computation, label, debug_options, show_metadata, - show_backend_config, hlo_execution_profile, NodeFilter()) + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1449,15 +1491,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata, bool show_backend_config) { + bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = HloDotDumper(node.parent(), label, debug_options, - show_metadata, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); + string graph = + HloDotDumper(node.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index fc8e1468aca9c2edbc22c30a41a1be8b32a1feca..0b11f34abb7f0d937a24d11f4dc5d2d6a0aae6e7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false, bool show_backend_config = false); + bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,6 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false, bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e52d926d85f1ce6fabeb2dedd2f8e0fe0c2051d..68f41a1cbb4db228f5dcf8b4a6130f05e81262a8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -121,7 +121,7 @@ TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); - instruction->set_name("i_am_a_constant_root_instruction"); + instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 31aff008a4c9271bd4542a2a3410dfaad808c2d4..6ea302f8b4170ea5043176a58b6f47003a79f5a5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include -#include #include #include #include @@ -27,19 +26,22 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -58,107 +60,352 @@ StatusOr> HloInstruction::CreateFromProto( TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); - auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; - instruction->AppendOperand(instruction_map.at(operand_id)); - } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } - - // In the proto, fused computations are held exclusively within the - // HloInstructionProto and do not appear as an HloComputationProto within the - // HloModuleProto. - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(!proto.fusion_kind().empty()); - TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, - StringToFusionKind(proto.fusion_kind())); - - // Find the fused computation and set its fusion instruction. - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " - << proto.called_computation_ids_size(); - const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); - TF_RET_CHECK(fused_computation != nullptr) - << "No fusion computation with id " << fusion_id; - fused_computation->SetFusionInstruction(instruction.get()); - instruction->called_computations_.push_back(fused_computation); - } else { - for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; - instruction->called_computations_.push_back( - computation_map.at(computation_id)); + std::unique_ptr instruction; + const auto operands = [&instruction_map, &proto](int index) { + return instruction_map.at(proto.operand_ids(index)); + }; + const auto all_operands = [&instruction_map, &proto]() { + std::vector result(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + result.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + return result; + }; + const auto computations = [&computation_map, &proto](int index) { + return computation_map.at(proto.called_computation_ids(index)); + }; + switch (opcode) { + // Ops migrated to subclasses. + case HloOpcode::kBatchNormTraining: + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "BatchNormTraining instruction should have 3 operands but sees " + << proto.operand_ids_size(); + instruction = CreateBatchNormTraining( + proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), + proto.feature_index()); + break; + case HloOpcode::kBatchNormInference: + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormInference instruction should have 5 operands but sees " + << proto.operand_ids_size(); + instruction = CreateBatchNormInference( + proto.shape(), operands(0), operands(1), operands(2), operands(3), + operands(4), proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kBatchNormGrad: + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormGrad instruction should have 5 operands but sees " + << proto.operand_ids_size(); + instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + operands(2), operands(3), operands(4), + proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kFft: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Fft instruction should have 1 operand but sees " + << proto.operand_ids_size(); + std::vector fft_length(proto.fft_length().begin(), + proto.fft_length().end()); + instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + tensorflow::gtl::ArraySlice(fft_length)); + break; + } + case HloOpcode::kSend: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Send instruction should have 2 operand but sees " + << proto.operand_ids_size(); + instruction = CreateSend(operands(0), operands(1), proto.channel_id()); + break; + case HloOpcode::kSendDone: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "SendDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateSendDone(operands(0)); + break; + case HloOpcode::kRecv: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Recv instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + proto.channel_id()); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "RecvDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateRecvDone(operands(0)); + break; + case HloOpcode::kReverse: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Reverse instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateReverse(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kConcatenate: + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Concatenate instruction should have 1 dimension but sees " + << proto.dimensions_size(); + instruction = + CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); + break; + case HloOpcode::kReduce: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Reduce instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Reduce instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + break; + case HloOpcode::kTranspose: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Transpose instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = + CreateTranspose(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kBroadcast: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Broadcast instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = + CreateBroadcast(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kMap: + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Map instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateMap(proto.shape(), all_operands(), computations(0)); + break; + case HloOpcode::kSlice: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Slice instruction should have 1 operand but sees " + << proto.operand_ids_size(); + std::vector slice_starts, slice_limits, slice_strides; + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + slice_starts.push_back(slice_dimensions.start()); + slice_limits.push_back(slice_dimensions.limit()); + slice_strides.push_back(slice_dimensions.stride()); + } + instruction = CreateSlice(proto.shape(), operands(0), slice_starts, + slice_limits, slice_strides); + break; + } + case HloOpcode::kConstant: { + // TODO(b/110214922): Revert this to CHECK(proto.has_literal()). + if (proto.has_literal()) { + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + } else { + instruction = MakeUnique(proto.shape()); + } + break; + } + case HloOpcode::kTrace: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Trace instruction should have 1 operand but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + break; + } + case HloOpcode::kFusion: { + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within + // the HloModuleProto. + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(FusionKind fusion_kind, + StringToFusionKind(proto.fusion_kind())); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), + fused_computation); + break; + } + case HloOpcode::kRng: + instruction = + CreateRng(proto.shape(), proto.distribution(), all_operands()); + break; + case HloOpcode::kParameter: + instruction = CreateParameter(proto.parameter_number(), proto.shape(), + proto.name()); + break; + case HloOpcode::kGetTupleElement: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "GetTupleElement instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateGetTupleElement(proto.shape(), operands(0), + proto.tuple_index()); + break; + case HloOpcode::kReducePrecision: + instruction = + CreateReducePrecision(proto.shape(), operands(0), + proto.exponent_bits(), proto.mantissa_bits()); + break; + case HloOpcode::kInfeed: { + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + if (proto.operand_ids_size() == 0) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = CreateInfeed(data_shape, proto.infeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); + } + } break; + case HloOpcode::kOutfeed: + if (proto.operand_ids_size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + proto.outfeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); + } + break; + case HloOpcode::kCrossReplicaSum: { + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "CrossReplicaSum should have 1 called computation but sees " + << proto.called_computation_ids_size(); + tensorflow::gtl::optional all_reduce_id; + if (proto.all_reduce_id() > 0) { + all_reduce_id = proto.all_reduce_id(); + } + instruction = CreateCrossReplicaSum( + proto.shape(), all_operands(), computations(0), + /*replica_group_ids=*/ + std::vector(proto.replica_group_ids().begin(), + proto.replica_group_ids().end()), + /*barrier=*/proto.cross_replica_sum_barrier(), + /*all_reduce_id=*/all_reduce_id); + break; + } + case HloOpcode::kConvolution: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Convolution instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_window()); + TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + instruction = + CreateConvolve(proto.shape(), operands(0), operands(1), + proto.window(), proto.convolution_dimension_numbers()); + break; + case HloOpcode::kReduceWindow: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "ReduceWindow instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "ReduceWindow should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + proto.window(), computations(0)); + break; + case HloOpcode::kSelectAndScatter: + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "SelectAndScatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "SelectAndScatter should have 2 called computations but sees " + << proto.called_computation_ids_size(); + instruction = CreateSelectAndScatter( + proto.shape(), operands(0), computations(0), proto.window(), + operands(1), operands(2), computations(1)); + break; + case HloOpcode::kCustomCall: + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target()); + if (proto.has_window()) { + static_cast(instruction.get()) + ->set_window(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + static_cast(instruction.get()) + ->set_convolution_dimension_numbers( + proto.convolution_dimension_numbers()); + } + break; + case HloOpcode::kHostCompute: + instruction = + CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), + proto.cost_estimate_ns()); + break; + case HloOpcode::kPad: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Pad instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_padding_config()); + instruction = CreatePad(proto.shape(), operands(0), operands(1), + proto.padding_config()); + break; + case HloOpcode::kDynamicSlice: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "DynamicSlice instruction should have 2 operands but sees " + << proto.operand_ids_size(); + std::vector slice_sizes(proto.dynamic_slice_sizes_size()); + c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), + slice_sizes); + break; + } + default: { + instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + if (instruction->opcode() != HloOpcode::kFusion) { + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; + instruction->called_computations_.push_back( + computation_map.at(computation_id)); + } + } + break; } - } - - if (instruction->opcode() == HloOpcode::kTrace) { - TF_RET_CHECK(instruction->operands().size() == 1) - << "Trace instruction should have 1 operand but sees " - << instruction->operands().size(); - instruction->mutable_operand(0)->set_tracing(instruction.get()); } TF_RET_CHECK(!proto.name().empty()); - instruction->name_ = proto.name(); - + instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); - instruction->set_backend_config(proto.backend_config()); - if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(instruction->literal_, - Literal::CreateFromProto(proto.literal())); - } - instruction->parameter_number_ = proto.parameter_number(); + instruction->backend_config_ = proto.backend_config(); - instruction->tuple_index_ = proto.tuple_index(); - for (int64 dimension : proto.dimensions()) { - instruction->dimensions_.push_back(dimension); - } - if (proto.has_window()) { - instruction->window_ = MakeUnique(proto.window()); - } - if (proto.has_convolution_dimension_numbers()) { - instruction->convolution_dimension_numbers_ = - MakeUnique( - proto.convolution_dimension_numbers()); - } if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (const HloInstructionProto::SliceDimensions& slice_dimensions : - proto.slice_dimensions()) { - instruction->slice_starts_.push_back(slice_dimensions.start()); - instruction->slice_limits_.push_back(slice_dimensions.limit()); - instruction->slice_strides_.push_back(slice_dimensions.stride()); - } - instruction->exponent_bits_ = proto.exponent_bits(); - instruction->mantissa_bits_ = proto.mantissa_bits(); - for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { - instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); - } - if (proto.has_padding_config()) { - instruction->padding_config_ = - MakeUnique(proto.padding_config()); - } - instruction->outfeed_config_ = proto.outfeed_config(); - instruction->distribution_ = proto.distribution(); - instruction->epsilon_ = proto.epsilon(); - instruction->feature_index_ = proto.feature_index(); - instruction->channel_id_ = proto.channel_id(); - instruction->infeed_config_ = proto.infeed_config(); - instruction->custom_call_target_ = proto.custom_call_target(); - instruction->outfeed_shape_ = proto.outfeed_shape(); - instruction->fft_type_ = proto.fft_type(); - for (int64 fft_len : proto.fft_length()) { - instruction->fft_length_.push_back(fft_len); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -173,61 +420,34 @@ StatusOr> HloInstruction::CreateFromProto( for (int64 bound : proto.gather_window_bounds()) { instruction->gather_window_bounds_.push_back(bound); } - - instruction->channel_name_ = proto.channel_name(); - instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); - return std::move(instruction); } /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); - instruction->parameter_number_ = parameter_number; - instruction->name_ = name; - return instruction; + return MakeUnique(parameter_number, shape, name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); - instruction->operands_.push_back(operand); - instruction->literal_ = Literal::CreateR1U8(tag); - operand->set_tracing(instruction.get()); - return instruction; + return MakeUnique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); - instruction->literal_ = std::move(literal); - return instruction; + return MakeUnique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); - instruction->tuple_index_ = index; - instruction->AppendOperand(operand); - return instruction; + return MakeUnique(shape, operand, index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape)); - instruction->distribution_ = distribution; - instruction->shape_ = shape; - for (HloInstruction* param : parameters) { - instruction->AppendOperand(param); - } - return instruction; + return MakeUnique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -256,6 +476,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -268,7 +489,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: break; default: @@ -303,6 +523,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -320,8 +541,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // Only certain opcodes are supported with CreateTernary: opcodes of ternary // instructions with no auxiliary fields. switch (opcode) { - case (HloOpcode::kClamp): - case (HloOpcode::kSelect): + case HloOpcode::kClamp: + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: break; default: LOG(FATAL) << "Invalid ternary instruction opcode " @@ -339,45 +561,22 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(map_computation); - return instruction; + HloComputation* map_computation) { + return MakeUnique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); - if (window_util::HasBaseDilation(window)) { - instruction->name_ = instruction->name() + "-base-dilated"; - } - if (window_util::HasWindowDilation(window)) { - instruction->name_ = instruction->name() + "-window-dilated"; - } - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->window_ = MakeUnique(window); - instruction->convolution_dimension_numbers_ = - MakeUnique(dimension_numbers); - return instruction; + return MakeUnique(shape, lhs, rhs, window, + dimension_numbers); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); - instruction->AppendOperand(operand); - instruction->fft_type_ = fft_type; - instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); - return instruction; + return MakeUnique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( @@ -410,93 +609,86 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); - instruction->AppendOperand(operand); - instruction->exponent_bits_ = exponent_bits; - instruction->mantissa_bits_ = mantissa_bits; - return instruction; + return MakeUnique( + shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) { + return MakeUnique( + shape, operands, reduce_computation, replica_group_ids, barrier, + all_reduce_id); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& shape, const string& config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); - instruction->set_infeed_config(config); - return instruction; + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return MakeUnique(infeed_shape, token_operand, config); +} + +/* static */ std::unique_ptr HloInstruction::CreateInfeed( + const Shape& infeed_shape, const string& config) { + return MakeUnique(infeed_shape, config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { + return MakeUnique(outfeed_shape, operand, + token_operand, outfeed_config); +} + +/* static */ std::unique_ptr HloInstruction::CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); - CHECK(ShapeUtil::Compatible(operand->shape(), shape)) - << "Outfeed shape " << shape << " must be compatible with operand shape " - << operand->shape(); - instruction->AppendOperand(operand); - instruction->outfeed_config_ = std::string(outfeed_config); - instruction->outfeed_shape_ = shape; - return instruction; + return MakeUnique(outfeed_shape, operand, + outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( - HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = channel_id; - return instruction; + HloInstruction* operand, HloInstruction* token, int64 channel_id) { + return MakeUnique(operand, token, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) + auto send_operand = DynCast(operand); + CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(send_operand); } /* static */ std::unique_ptr HloInstruction::CreateRecv( - const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); - instruction->channel_id_ = channel_id; - return instruction; + const Shape& shape, HloInstruction* token, int64 channel_id) { + return MakeUnique(shape, token, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) + auto recv_operand = DynCast(operand); + CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(recv_operand); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr HloInstruction::CreateAfterAll( + tensorflow::gtl::ArraySlice operands) { + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } return instruction; } @@ -533,30 +725,15 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); - instruction->AppendOperand(operand); - instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); - instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); - instruction->slice_strides_.assign(strides.begin(), strides.end()); - // For backward compatibility with old serialized computations: if there are - // no strides, assume all strides are 1. - // TODO(b/63317920): remove this code. - if (instruction->slice_strides_.empty()) { - instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); - } - return instruction; + return MakeUnique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(start_indices); - instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(), - slice_sizes.end()); - return instruction; + return MakeUnique(shape, operand, start_indices, + slice_sizes); } /* static */ std::unique_ptr @@ -575,13 +752,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->dimensions_.push_back(dimension); - return instruction; + return MakeUnique(shape, operands, dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( @@ -604,25 +775,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); - instruction->AppendOperand(arg); - instruction->AppendOperand(init_value); - instruction->dimensions_.assign(dimensions_to_reduce.begin(), - dimensions_to_reduce.end()); - instruction->called_computations_.push_back(reduce_computation); - return instruction; + return MakeUnique( + shape, arg, init_value, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(init_value); - instruction->called_computations_.push_back(reduce_computation); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique(shape, operand, init_value, + window, reduce_computation); } /* static */ std::unique_ptr @@ -631,14 +792,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, epsilon, feature_index); } /* static */ std::unique_ptr @@ -646,16 +801,8 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, mean, variance, epsilon, feature_index); } /* static */ std::unique_ptr @@ -664,16 +811,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->AppendOperand(grad_output); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique(shape, operand, scale, mean, + variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -681,27 +821,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(source); - instruction->AppendOperand(init_value); - // Select comes before scatter in the vector. - instruction->called_computations_.push_back(select); - instruction->called_computations_.push_back(scatter); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique( + shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(broadcast_dimensions.begin(), - broadcast_dimensions.end()); - return instruction; + return MakeUnique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -759,11 +887,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(padding_value); - instruction->padding_config_ = MakeUnique(padding_config); - return instruction; + return MakeUnique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -780,45 +905,39 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr HloInstruction::CreateSort( + const Shape& shape, HloInstruction* keys, HloInstruction* values) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSort, shape)); + instruction->AppendOperand(keys); + if (values) { + instruction->AppendOperand(values); + } return instruction; } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->set_parent(fused_root->parent()); - instruction->set_metadata(fused_root->metadata()); - instruction->CloneAndFuseInternal(fused_root); - return instruction; + return MakeUnique(shape, fusion_kind, fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); + return MakeUnique(shape, fusion_kind, operands, + fusion_computation); +} + +void HloInstruction::set_single_sharding(const HloSharding& sharding) { + CHECK(!sharding.IsTuple()) << sharding; + if (ShapeUtil::IsTuple(shape())) { + set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); + } else { + set_sharding(sharding); } - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->called_computations_.push_back(fusion_computation); - fusion_computation->SetFusionInstruction(instruction.get()); - return instruction; } void HloInstruction::SetupDerivedInstruction( @@ -831,358 +950,70 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->set_metadata(metadata_); } -HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { - CHECK_EQ(opcode(), HloOpcode::kFusion); - CHECK_EQ(operand_count(), - fused_instructions_computation()->parameter_instructions().size()); - const int64 param_no = operand_count(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); - HloInstruction* fused_parameter = - fused_instructions_computation()->AddParameter( - HloInstruction::CreateParameter(param_no, new_operand->shape(), - param_name)); - AppendOperand(new_operand); - return fused_parameter; +bool HloInstruction::HasSideEffectNoRecurse() const { + switch (opcode_) { + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kRng: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kTrace: + case HloOpcode::kHostCompute: + return true; + default: + return false; + } } -void HloInstruction::MergeFusionInstruction( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); - // Clone the instruction from which to merge fused instructions. - std::unique_ptr clone = instruction_to_merge->Clone(); - // Replace uses of fused parameters with the corresponding operand of the - // fusion. Add all non-parameter fused instructions to 'unfused_instructions' - // to be merged into 'this'. This is done in reverse post order. - std::vector unfused_instructions; - auto fused_instructions = - clone->fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_it = fused_instructions.rbegin(); - fused_it != fused_instructions.rend(); ++fused_it) { - auto fused_instruction = *fused_it; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith( - clone->mutable_operand(fused_instruction->parameter_number()))); - } else { - unfused_instructions.push_back(fused_instruction); - } +bool HloInstruction::HasSideEffect() const { + if (HasSideEffectNoRecurse()) { + return true; } - CHECK(unfused_instructions.front() == clone->fused_expression_root()); - // Replace instruction_to_merge use of 'this' with unfused_root. - TF_CHECK_OK( - instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); - // Fuse 'unfused_instructions' into 'this'. - for (auto& instruction : unfused_instructions) { - FuseInstruction(instruction); - instruction->DetachFromOperands(); + // Check if any of the called computations has a side effect. + for (const auto& computation : called_computations()) { + if (computation->HasSideEffect()) { + return true; + } } - CHECK_EQ(0, clone->user_count()); - clone->DetachFromOperands(); - TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( - clone->fused_instructions_computation())); + return false; } -void HloInstruction::MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - // Add all non-parameter fused instructions to 'unfused_instructions' to be - // merged into 'this'. `old_to_new' maps the instructions in the fused node - // to the disaseembled fusion instructions. - // Note that we add the unfused instructions to this->parent_ computation. - // This is necessary because the unique_id needs for an instruction and - // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; - std::vector unfused_instructions; - auto computation_to_merge = - instruction_to_merge->fused_instructions_computation(); - auto post_order = computation_to_merge->MakeInstructionPostOrder(); - for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { - auto fused_instruction = *rit; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&old_to_new, fused_instruction, - instruction_to_merge->mutable_operand( - fused_instruction->parameter_number())); - continue; - } - - // Here we clone the insertion and call FuseInstructionIntoMultiOutput() - // which clones again. This can be improved. - auto cloned_instruction = - parent_->AddInstruction(fused_instruction->Clone()); - unfused_instructions.push_back(cloned_instruction); - InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); - } - for (auto unfused_instruction : unfused_instructions) { - for (int64 index = 0; index < unfused_instruction->operand_count(); - index++) { - auto new_operand = - FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); - TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); - } +/* static */ std::unique_ptr HloInstruction::CreateCall( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* computation) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); } + instruction->called_computations_.push_back(computation); + return instruction; +} - HloInstruction* unfused_root = unfused_instructions.front(); - TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) { + return MakeUnique(shape, operands, + custom_call_target); +} - TF_CHECK_OK( - instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); - if (GetModule()) { - TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); - } +/* static */ std::unique_ptr HloInstruction::CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { + return MakeUnique(shape, operands, channel_name, + cost_estimate_ns); +} - // Fuse the root instruction and generate multiple outputs. - FuseInstructionIntoMultiOutput(unfused_root); - TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); - // The rest instructions are of normal fusing. - for (int64 i = 1; i < unfused_instructions.size(); i++) { - auto instruction = unfused_instructions[i]; - FuseInstruction(instruction); - TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); +/* static */ std::unique_ptr HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (auto element : elements) { + element_shapes.push_back(element->shape()); } -} - -HloInstruction* HloInstruction::FuseInstructionInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - // When add_output is false, this fusion instruction must be a user of - // instruction_to_fuse. - if (!add_output) { - CHECK(IsUserOf(instruction_to_fuse)); - } - HloInstruction* fused_instruction = - CloneAndFuseInternal(instruction_to_fuse, add_output); - return fused_instruction; -} - -HloInstruction* HloInstruction::CloneAndFuseInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); - VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); - HloInstruction* clone = nullptr; - if (called_computations_.empty()) { - // New fusion instruction. It should not be a multioutput instruction. - CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", this); - builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - called_computations_.push_back( - CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); - clone = fused_expression_root(); - } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); - // When add_output is false, instruction_to_fuse is necessarily an operand - // of the fusion instruction. After fusion this will no longer be the case. - // Remove the operand from the operand list and remove its corresponding - // fused parameter instruction. Renumber parameters as necessary to make - // parameter numbers consistent with their index in the - // fused_parameter_ vector. - bool in_operand_list = std::find(operands_.begin(), operands_.end(), - instruction_to_fuse) != operands_.end(); - CHECK(add_output || in_operand_list); - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - if (instruction_to_fuse == operands_[operand_num]) { - // replace the fused parameter instruction's uses with the clone. - HloInstruction* fused_parameter = fused_parameters[operand_num]; - TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); - - // Remove the corresponding fused parameter and operand from their - // respective vectors. - TF_CHECK_OK( - fused_instructions_computation()->RemoveParameter(operand_num)); - operands_.erase(operands_.begin() + operand_num); - break; - } - } - // We've cloned instruction_to_fuse into this fusion instruction, so this - // fusion instruction is no longer a use of instruction_to_fuse. - if (in_operand_list) { - instruction_to_fuse->RemoveUser(this); - // When the instruction_to_fuse does not have other users, we don't need - // to generate a multioutput fusion instruction. - if (instruction_to_fuse->user_count() == 0) { - add_output = false; - } - } - } - - // Reread the parameters in the computation. - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - - // Add each operand of the clone as an operand of the fusion instruction. A - // complication is that some clone operands may already be operands of the - // fusion instruction. - for (int64 operand_num = 0; operand_num < clone->operand_count(); - ++operand_num) { - HloInstruction* operand = clone->mutable_operand(operand_num); - - // See if this operand is already an operand of the fusion node. - CHECK_EQ(operands_.size(), fused_parameters.size()); - HloInstruction* fused_param = nullptr; - for (int64 i = 0; i < operands_.size(); ++i) { - if (operands_[i] == operand) { - fused_param = fused_parameters[i]; - break; - } - } - - if (fused_param == nullptr) { - // Clone's operand was not already an operand of the fusion - // instruction. Add it as an operand and add a corresponding fused - // parameter instruction. - fused_param = AddFusionOperand(operand); - } - TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); - } - - if (add_output) { - CHECK_GT(instruction_to_fuse->user_count(), 0); - // If this is already a multioutput fusion instruction, expand the root - // tuple by 1. - HloInstruction* fused_root = fused_expression_root(); - HloInstruction::InstructionVector tuple_elements; - bool newly_created_tuple_instr = false; - if (fused_root->opcode() == HloOpcode::kTuple) { - tuple_elements = fused_root->operands(); - } else { - tuple_elements.push_back(fused_root); - newly_created_tuple_instr = true; - } - if (clone->opcode() == HloOpcode::kTuple) { - for (auto inst : clone->operands()) { - tuple_elements.push_back(inst); - } - } else { - tuple_elements.push_back(clone); - } - HloInstruction* new_root = fused_instructions_computation()->AddInstruction( - HloInstruction::CreateTuple(tuple_elements)); - fused_instructions_computation()->set_root_instruction(new_root); - shape_ = new_root->shape(); - if (fused_root->opcode() == HloOpcode::kTuple) { - TF_CHECK_OK( - fused_instructions_computation()->RemoveInstruction(fused_root)); - } - - // If this is a newly created multioutput instruction, we need to update - // the use of the original fusion instruction. - if (newly_created_tuple_instr) { - HloInstruction* new_instr = parent_->AddInstruction( - HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); - } - int64 index = tuple_elements.size(); - if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); - std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { - CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); - int64 old_tuple_index = old_gte->tuple_index(); - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - old_gte->shape(), this, index + old_tuple_index)); - TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); - to_be_removed.push_back(old_gte); - } - for (auto old_gte : to_be_removed) { - TF_CHECK_OK(parent_->RemoveInstruction(old_gte)); - } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); - } else { - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - clone->shape(), this, index - 1)); - TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); - } - } - - VLOG(2) << "New clone:\n" << clone->ToString(); - return clone; -} - -RandomDistribution HloInstruction::random_distribution() const { - CHECK_EQ(opcode_, HloOpcode::kRng); - return distribution_; -} - -bool HloInstruction::HasSideEffect() const { - switch (opcode_) { - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kRng: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kTrace: - case HloOpcode::kHostCompute: - return true; - default: { - // Check if any of the called computations has a side effect. - for (const auto& computation : called_computations()) { - if (computation->HasSideEffect()) { - return true; - } - } - return false; - } - } -} - -/* static */ std::unique_ptr HloInstruction::CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(computation); - return instruction; -} - -/* static */ std::unique_ptr HloInstruction::CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->custom_call_target_ = std::string(custom_call_target); - return instruction; -} - -/* static */ std::unique_ptr HloInstruction::CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->channel_name_ = std::string(channel_name); - instruction->cost_estimate_ns_ = cost_estimate_ns; - return instruction; -} - -/* static */ std::unique_ptr HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice elements) { - std::vector element_shapes; - for (auto element : elements) { - element_shapes.push_back(element->shape()); - } - Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); - return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); + Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); + return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); } /* static */ std::unique_ptr HloInstruction::CreateGather( @@ -1219,25 +1050,68 @@ bool HloInstruction::HasSideEffect() const { return gather_dim_numbers; } +/* static */ std::unique_ptr HloInstruction::CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + instruction->operand_side_metadata_ = std::move(operand_side_metadata); + instruction->user_side_metadata_ = std::move(user_side_metadata); + instruction->AppendOperand(operand); + return instruction; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module, CloneMap* clone_map) const { + HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } - if (module == nullptr) { - module = GetModule(); - } std::unique_ptr clone; - // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. switch (opcode_) { + // Ops migrated to subclasses. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kHostCompute: + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: + clone = CloneWithNewOperandsImpl(shape, new_operands, context); + break; // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -1258,7 +1132,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1282,6 +1155,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1291,28 +1165,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: CHECK_EQ(new_operands.size(), 3); clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1], new_operands[2]); break; // Other supported ops. - case HloOpcode::kBroadcast: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcast(shape, new_operands[0], dimensions_); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; - case HloOpcode::kCustomCall: - clone = CreateCustomCall(shape, new_operands, custom_call_target_); - break; - case HloOpcode::kHostCompute: - clone = CreateHostCompute(shape, new_operands, channel_name_, - cost_estimate_ns_); - break; - case HloOpcode::kConcatenate: - clone = CreateConcatenate(shape, new_operands, dimensions(0)); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1321,85 +1182,20 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kReducePrecision: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); - break; - case HloOpcode::kConvolution: - CHECK_EQ(new_operands.size(), 2); - clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, - *convolution_dimension_numbers_); - break; case HloOpcode::kDot: CHECK_EQ(new_operands.size(), 2); clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kFft: - CHECK_EQ(new_operands.size(), 1); - clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); - break; - case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); - break; - case HloOpcode::kGetTupleElement: - CHECK_EQ(new_operands.size(), 1); - clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); - break; - case HloOpcode::kMap: - clone = CreateMap(shape, new_operands, to_apply()); - break; - case HloOpcode::kPad: - CHECK_EQ(new_operands.size(), 2); - clone = - CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); - break; - case HloOpcode::kReduce: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); - break; - case HloOpcode::kReduceWindow: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], - *window_, to_apply()); - break; - case HloOpcode::kSelectAndScatter: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateSelectAndScatter(shape, new_operands[0], select(), *window_, - new_operands[1], new_operands[2], scatter()); - break; - case HloOpcode::kReverse: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReverse(shape, new_operands[0], dimensions_); - break; - case HloOpcode::kRng: - clone = CreateRng(shape, distribution_, new_operands); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kSlice: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); - break; - case HloOpcode::kDynamicSlice: - clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], - dynamic_slice_sizes_); - break; case HloOpcode::kDynamicUpdateSlice: CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands[2]); break; - case HloOpcode::kTranspose: - CHECK_EQ(new_operands.size(), 1); - clone = CreateTranspose(shape, new_operands[0], dimensions_); - break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); *clone->mutable_shape() = shape; @@ -1409,94 +1205,77 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kConstant: - clone = CreateConstant(literal_->CloneToUnique()); - break; - case HloOpcode::kFusion: { - CHECK_NE(module, nullptr); - auto new_fused_computation = module->AddEmbeddedComputation( - fused_instructions_computation()->Clone("clone", module, clone_map)); - clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), - /*operands=*/new_operands, - /*fusion_computation=*/new_fused_computation); - break; - } - case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, name_); - break; - case HloOpcode::kBatchNormTraining: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), feature_index()); - break; - case HloOpcode::kBatchNormInference: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormInference( - shape, new_operands[0], new_operands[1], new_operands[2], - new_operands[3], new_operands[4], epsilon(), feature_index()); - break; - case HloOpcode::kInfeed: - CHECK_EQ(new_operands.size(), 0); - clone = CreateInfeed(shape, infeed_config()); - break; - case HloOpcode::kOutfeed: - CHECK_EQ(new_operands.size(), 1); - clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); - break; - case HloOpcode::kBatchNormGrad: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kSend: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSend(new_operands[0], channel_id()); + case HloOpcode::kGather: + CHECK_EQ(new_operands.size(), 2); + clone = CreateGather(shape, new_operands[0], new_operands[1], + *gather_dimension_numbers_, gather_window_bounds_); break; - case HloOpcode::kSendDone: + case HloOpcode::kDomain: CHECK_EQ(new_operands.size(), 1); - clone = CreateSendDone(new_operands[0]); - break; - case HloOpcode::kRecv: - CHECK_EQ(new_operands.size(), 0); - // The shape is a tuple, but CreateRecv() wants the raw data shape. clone = - CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); + CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); break; - case HloOpcode::kRecvDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateRecvDone(new_operands[0]); + case HloOpcode::kAfterAll: + clone = CreateAfterAll(new_operands); break; - case HloOpcode::kGather: - CHECK_EQ(new_operands.size(), 2); - clone = CreateGather(shape, new_operands[0], new_operands[1], - *gather_dimension_numbers_, gather_window_bounds_); + case HloOpcode::kSort: + CHECK(new_operands.size() == 1 || new_operands.size() == 2) + << "Too many operands for sort: " << new_operands.size(); + HloInstruction* keys = new_operands[0]; + HloInstruction* values = + new_operands.size() == 2 ? new_operands[1] : nullptr; + clone = CreateSort(shape, keys, values); break; - case HloOpcode::kTrace: - LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); - clone->set_backend_config(backend_config()); - if (clone_map != nullptr) { - InsertOrDie(clone_map, this, clone.get()); + clone->set_raw_backend_config_string(backend_config_); + if (context != nullptr) { + context->MapInstruction(this, clone.get()); + clone->ReplaceCalledComputations([&](HloComputation* callee) { + return callee->parent() != context->module() + ? context->module()->DeepCloneComputation(callee, context) + : callee; + }); } return clone; } -HloInstruction::~HloInstruction() {} +HloInstruction::~HloInstruction() { + // Detach from operands. An instruction may be repeated as an operand. To + // avoid calling RemoveUser twice on the same operand, check before remove. + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + HloInstruction* operand = operands_[operand_num]; + if (operand == nullptr) { + continue; + } + if (operand->user_set_.find(this) != operand->user_set_.end()) { + operand->RemoveUser(this); + } + operands_[operand_num] = nullptr; + } + + // Update users. Set `nullptr` to the correpsonding operand slot for users. + for (auto& user : this->users()) { + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operands_[i] == this) { + user->operands_[i] = nullptr; + } + } + } +} std::unique_ptr HloInstruction::Clone( - const string& suffix, HloModule* module, CloneMap* clone_map) const { + const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module, clone_map); + CloneWithNewOperands(shape_, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1556,40 +1335,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -const Literal& HloInstruction::literal() const { - CHECK_EQ(HloOpcode::kConstant, opcode_); - return *literal_; -} - -bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } - -bool HloInstruction::CanHaveDimensionsField() const { - return (opcode() == HloOpcode::kReverse || - opcode() == HloOpcode::kConcatenate || - opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || - opcode() == HloOpcode::kTranspose); -} - -const std::vector& HloInstruction::dimensions() const { - CHECK(CanHaveDimensionsField()); - return dimensions_; -} - -int64 HloInstruction::dimensions(int64 index) const { - return dimensions()[index]; -} - -int64 HloInstruction::concatenate_dimension() const { - CHECK(opcode() == HloOpcode::kConcatenate); - CHECK_EQ(1, dimensions_.size()); - return dimensions(0); -} - -int64 HloInstruction::tuple_index() const { - CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); - return tuple_index_; -} - const HloInstruction* HloInstruction::operand(int64 i) const { return operands_[i]; } @@ -1608,6 +1353,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +HloInstruction::InstructionVector HloInstruction::unique_operands() const { + InstructionVector unique; + tensorflow::gtl::FlatSet seen; + for (HloInstruction* operand : operands()) { + if (seen.insert(operand).second) { + unique.push_back(operand); + } + } + return unique; +} + Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); if (std::find(control_successors_.begin(), control_successors_.end(), @@ -1660,6 +1416,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { operand->AddUser(this); } +void HloInstruction::RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice ascending_indices) { + if (ascending_indices.empty()) { + return; + } + int next_index = 0; + int removed_count = 0; + for (int to_remove : ascending_indices) { + while (next_index < to_remove) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_LT(to_remove, operands_.size()); + ++removed_count; + ++next_index; + } + while (next_index < operands_.size()) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_EQ(removed_count, ascending_indices.size()); + operands_.resize(operands_.size() - removed_count); +} + void HloInstruction::AddUser(HloInstruction* user) { if (!ContainsKey(user_set_, user)) { user_set_.insert(user); @@ -1667,10 +1447,6 @@ void HloInstruction::AddUser(HloInstruction* user) { } } -bool HloInstruction::IsConstant() const { - return opcode_ == HloOpcode::kConstant; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -1683,24 +1459,25 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -1715,6 +1492,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1724,61 +1502,27 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: + case HloOpcode::kSort: case HloOpcode::kSin: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return true; - case HloOpcode::kFusion: - return fusion_kind() == other.fusion_kind() && - eq_computations(fused_instructions_computation(), - other.fused_instructions_computation()); - // These opcodes have complex or special behavior so just return false. - case HloOpcode::kRng: - case HloOpcode::kTrace: + case HloOpcode::kDomain: case HloOpcode::kWhile: + case HloOpcode::kAfterAll: return false; - case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); - - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kBatchNormGrad: - return feature_index() == other.feature_index() && - epsilon() == other.epsilon(); - - // A constant is defined by the value in the literal. - case HloOpcode::kConstant: - return literal() == other.literal(); - - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); - - // A reduce-precision operation is determined by the bit sizes. - case HloOpcode::kReducePrecision: - return exponent_bits() == other.exponent_bits() && - mantissa_bits() == other.mantissa_bits(); - - // Convolution has a window and dimensions. - case HloOpcode::kConvolution: - return protobuf_util::ProtobufEquals(window(), other.window()) && - protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()); // Check dot dimension numbers. case HloOpcode::kDot: return protobuf_util::ProtobufEquals(dot_dimension_numbers(), @@ -1789,89 +1533,52 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - // FFT has various types & lengths. - case HloOpcode::kFft: - return fft_type() == other.fft_type() && - fft_length() == other.fft_length(); - - // Reduction results are determined by the reduction dimension and the - // reduction computation. - case HloOpcode::kReduce: - return dimensions() == other.dimensions() && - eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kReduceWindow: - return eq_computations(to_apply(), other.to_apply()) && - protobuf_util::ProtobufEquals(window(), other.window()); - - // SelectAndScatter is determined by both select and scatter - // computation as well as the window configuration. - case HloOpcode::kSelectAndScatter: - return eq_computations(select(), other.select()) && - eq_computations(scatter(), other.scatter()) && - protobuf_util::ProtobufEquals(window(), other.window()); - - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - // Remaining instructions with special values. - case HloOpcode::kBitcast: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); - case HloOpcode::kGetTupleElement: - return tuple_index() == other.tuple_index(); - case HloOpcode::kPad: - return protobuf_util::ProtobufEquals(padding_config(), - other.padding_config()); - case HloOpcode::kSlice: - return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_ && - slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: - case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kCustomCall: - return custom_call_target_ == other.custom_call_target_; - case HloOpcode::kReverse: - return dimensions() == other.dimensions(); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); - // These opcodes are not yet supported. - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: + // Ops migrated to subclasses should never come to this line. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: case HloOpcode::kHostCompute: - return false; + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: + LOG(FATAL) << "Base class impl called for opcode with subclass: " + << opcode(); } } -bool HloInstruction::IsRank2Transpose() const { - return (opcode_ == HloOpcode::kTranspose) && - dimensions_ == std::vector({1, 0}) && - shape_.dimensions_size() == 2 && - std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), - operands_[0]->shape_.dimensions().rbegin()); -} - void HloInstruction::RemoveUser(HloInstruction* user) { auto set_it = user_set_.find(user); CHECK(set_it != user_set_.end()); @@ -1901,6 +1608,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast(user)->DeduplicateFusionOperands()); + } return Status::OK(); } @@ -1909,10 +1620,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); + if (old_operand == new_operand) { + return Status::OK(); + } + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), new_operand->shape())) - << old_operand->shape().ShortDebugString() << " is not compatible with " - << new_operand->shape().ShortDebugString(); + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " @@ -1939,34 +1654,22 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast(user)->DeduplicateFusionOperands()); + } } } users_.clear(); user_set_.clear(); if (new_producer_is_user) { - AddUser(new_producer); - } - if (parent_ && parent_->root_instruction() == this) { - parent_->set_root_instruction(new_producer); - } - - return Status::OK(); -} - -void HloInstruction::DetachFromOperands() { - VLOG(3) << "DetachFromOperands:\n " << ToString(); - CHECK_EQ(0, user_count()); - // An instruction may be repeated as an operand. To avoid calling RemoveUser - // twice on the same operand, keep a set of already detached operands. - std::set detached_operands; - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - HloInstruction* operand = operands_[operand_num]; - if (!ContainsKey(detached_operands, operand)) { - operand->RemoveUser(this); - detached_operands.insert(operand); - } - operands_[operand_num] = nullptr; + AddUser(new_producer); + } + if (parent_ && parent_->root_instruction() == this) { + parent_->set_root_instruction(new_producer); } + + return Status::OK(); } HloComputation* HloInstruction::to_apply() const { @@ -1975,6 +1678,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1992,6 +1696,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -2001,16 +1706,6 @@ void HloInstruction::set_to_apply(HloComputation* computation) { } } -const string& HloInstruction::custom_call_target() const { - CHECK_EQ(opcode_, HloOpcode::kCustomCall); - return custom_call_target_; -} - -const string& HloInstruction::outfeed_config() const { - CHECK_EQ(opcode_, HloOpcode::kOutfeed); - return outfeed_config_; -} - HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); return called_computations_[kConditionComputationIndex]; @@ -2037,32 +1732,6 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } -HloComputation* HloInstruction::select() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kSelectComputationIndex]; -} - -HloComputation* HloInstruction::scatter() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kScatterComputationIndex]; -} - -void HloInstruction::set_select(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kSelectComputationIndex] = computation; -} - -void HloInstruction::set_scatter(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kScatterComputationIndex] = computation; -} - HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -2110,6 +1779,74 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { return ToStringWithCanonicalNameMap(options, &new_map); } +bool HloInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + switch (opcode_) { + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); + return true; + + // Binary elementwise operations, the same as in IsElementwiseBinary(). + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); + return true; + + // Ternary elementwise operations. + case HloOpcode::kSelect: + case HloOpcode::kClamp: + return true; + + case HloOpcode::kDynamicUpdateSlice: + return operand_idx.has_value() && operand_idx.value() == 0; + + default: + return false; + } +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -2145,8 +1882,8 @@ string HloInstruction::ToStringWithCanonicalNameMap( !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } - if (options.print_backend_config() && !backend_config().empty()) { - StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); } return result; } @@ -2160,107 +1897,45 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - if (opcode() == HloOpcode::kConstant) { - // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { - // Literal::ToString emits multidimensional arrays over multiple - // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } - } else { - // Do not show large constants or tuples. - operands = "{...}"; + tensorflow::gtl::ArraySlice slice(operands_); + const int64 kMaxOperandsToShowIfCompact = 4; + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { + slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + } + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + // If operand is already been deleted, put `null` to the string output. + if (operand == nullptr) { + StrAppend(out, "null "); + return; } - } else if (opcode() == HloOpcode::kParameter) { - StrAppend(&operands, parameter_number_); - } else { - tensorflow::gtl::ArraySlice slice(operands_); - const int64 kMaxOperandsToShowIfCompact = 4; - if (options.compact_operands() && - slice.size() > kMaxOperandsToShowIfCompact) { - slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - std::vector str; - if (options.print_operand_shape()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); - } - // In a top-level HloInstruction::ToString() call, the operand name is not - // part of the canonical string. - if (options.canonicalize_instruction_names() && - options.is_in_nested_computation()) { - str.push_back(PrintName( - canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { - str.push_back(PrintName(operand->name(), options)); - } - StrAppend(out, Join(str, " ")); - }); - const int64 remaining = operands_.size() - slice.size(); - if (slice.size() != operands_.size()) { - StrAppend(&operands, ", ...(+", remaining, ")"); + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); + }); + const int64 remaining = operands_.size() - slice.size(); + if (slice.size() != operands_.size()) { + StrAppend(&operands, ", ...(+", remaining, ")"); } return operands; } std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { - std::vector extra; - if (opcode() == HloOpcode::kFusion) { - extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); - } - if (CanHaveDimensionsField()) { - extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); - } - if (window_ != nullptr && window_->dimensions_size() != 0) { - extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); - } - if (padding_config_ != nullptr) { - extra.push_back( - StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); - } - if (opcode() == HloOpcode::kSlice) { - std::vector bounds; - bounds.reserve(slice_starts_.size()); - const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); - for (int i = 0; i < slice_starts_.size(); ++i) { - string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); - bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i], - stride_str, "]")); - } - extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); - } - if (opcode() == HloOpcode::kDynamicSlice) { - extra.push_back( - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); - } - if (opcode() == HloOpcode::kBatchNormTraining || - opcode() == HloOpcode::kBatchNormInference || - opcode() == HloOpcode::kBatchNormGrad) { - extra.push_back(StrCat("epsilon=", epsilon())); - extra.push_back(StrCat("feature_index=", feature_index())); - } + std::vector extra = ExtraAttributesToStringImpl(options); - if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); - } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } @@ -2269,10 +1944,6 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); } - if (opcode() == HloOpcode::kFft) { - extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); - extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2292,7 +1963,8 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2327,6 +1999,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2342,14 +2015,7 @@ std::vector HloInstruction::ExtraAttributesToString( break; } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { - extra.push_back(StrCat("channel_id=", channel_id_)); - } - if (opcode() == HloOpcode::kGetTupleElement) { - extra.push_back(StrCat("index=", tuple_index())); - } if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } @@ -2362,28 +2028,10 @@ std::vector HloInstruction::ExtraAttributesToString( }), "}")); } - if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { - extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); - } - if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { - extra.push_back( - StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); - } - if (opcode() == HloOpcode::kRng) { - extra.push_back( - StrCat("distribution=", RandomDistributionToString(distribution_))); - } - if (opcode() == HloOpcode::kReducePrecision) { - extra.push_back(StrCat("exponent_bits=", exponent_bits_)); - extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); - } - - // By contract, we print the custom call target even if - // options.print_subcomputation_mode() == kOff, because the call target is not - // an HloComputation. - if (opcode() == HloOpcode::kCustomCall) { - extra.push_back( - StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")); } return extra; @@ -2415,32 +2063,13 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; - proto.set_backend_config(backend_config()); - if (literal_ != nullptr) { - *proto.mutable_literal() = literal_->ToProto(); - } - proto.set_parameter_number(parameter_number_); - if (opcode() == HloOpcode::kFusion) { - proto.set_fusion_kind(xla::ToString(fusion_kind())); - proto.add_called_computation_ids( - fused_instructions_computation()->unique_id()); - } else { + proto.set_backend_config(backend_config_); + if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - proto.set_tuple_index(tuple_index_); - for (int64 dimension : dimensions_) { - proto.add_dimensions(dimension); - } - if (window_ != nullptr) { - *proto.mutable_window() = *window_; - } - if (convolution_dimension_numbers_ != nullptr) { - *proto.mutable_convolution_dimension_numbers() = - *convolution_dimension_numbers_; - } if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } @@ -2452,42 +2081,11 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_gather_window_bounds(bound); } } - for (int i = 0; i < slice_starts_.size(); ++i) { - auto* slice_dimension = proto.add_slice_dimensions(); - slice_dimension->set_start(slice_starts_[i]); - slice_dimension->set_limit(slice_limits_[i]); - slice_dimension->set_stride(slice_strides_[i]); - } - proto.set_exponent_bits(exponent_bits_); - proto.set_mantissa_bits(mantissa_bits_); - for (int64 slice_size : dynamic_slice_sizes_) { - proto.add_dynamic_slice_sizes(slice_size); - } - if (padding_config_ != nullptr) { - *proto.mutable_padding_config() = *padding_config_; - } - proto.set_outfeed_config(outfeed_config_); - if (opcode() == HloOpcode::kRng) { - proto.set_distribution(distribution_); - } - proto.set_epsilon(epsilon_); - proto.set_feature_index(feature_index_); - proto.set_channel_id(channel_id_); - proto.set_infeed_config(infeed_config_); - proto.set_custom_call_target(custom_call_target_); - *proto.mutable_outfeed_shape() = outfeed_shape_; - proto.set_fft_type(fft_type_); - for (int64 fft_len : fft_length_) { - proto.add_fft_length(fft_len); - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - return proto; } @@ -2497,35 +2095,6 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - if (opcode() == HloOpcode::kConvolution) { - string category = "convolution"; - if (window_util::HasBaseDilation(window())) { - category += " base-dilated"; - } - if (window_util::HasWindowDilation(window())) { - category += " window-dilated"; - } - return category; - } - - // Give transpose-dot and backwards-conv fusions the categories "dot" and - // "convolution" so they match the categories of proper kDot and kConvolution - // ops. These fusion categories are really just a way of expressing a - // particular kind of dot or conv, so they should have the same category as a - // vanilla dot/conv. - if (opcode() == HloOpcode::kFusion) { - switch (fusion_kind()) { - case FusionKind::kLoop: - return "loop fusion"; - case FusionKind::kInput: - return "input fusion"; - case FusionKind::kOutput: - return "output fusion"; - case FusionKind::kCustom: - return "custom fusion"; - } - } - if (IsElementwise()) { return "non-fusion elementwise"; } @@ -2539,12 +2108,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -string HloInstruction::TracingTag() const { - CHECK_EQ(HloOpcode::kTrace, opcode()); - CHECK(literal_ != nullptr); - return literal_->GetR1U8AsString(); -} - bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { @@ -2554,6 +2117,7 @@ bool HloInstruction::IsFusable() const { } // Some kinds of instructions don't make sense to fuse. switch (opcode_) { + case HloOpcode::kDomain: case HloOpcode::kParameter: return false; // Side effecting instrutions cannot be fused. @@ -2562,49 +2126,6 @@ bool HloInstruction::IsFusable() const { } } -HloComputation* HloInstruction::fused_instructions_computation() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(!called_computations_.empty()); - auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()); - return fused_instructions_computation; -} - -HloInstruction* HloInstruction::fused_expression_root() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->root_instruction(); -} - -HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instruction( - parameter_number); -} - -const std::vector& HloInstruction::fused_parameters() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instructions(); -} - -const tensorflow::gtl::iterator_range>::const_iterator>> -HloInstruction::fused_instructions() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - const HloComputation* subcomp = fused_instructions_computation(); - return subcomp->instructions(); -} - -const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> -HloInstruction::fused_instructions() { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->instructions(); -} - -int64 HloInstruction::fused_instruction_count() const { - return fused_instructions_computation()->instruction_count(); -} - HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), @@ -2659,6 +2180,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAnd(this); case HloOpcode::kOr: return visitor->HandleOr(this); + case HloOpcode::kXor: + return visitor->HandleXor(this); case HloOpcode::kShiftLeft: return visitor->HandleShiftLeft(this); case HloOpcode::kShiftRightArithmetic: @@ -2683,6 +2206,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleRemainder(this); case HloOpcode::kSelect: return visitor->HandleSelect(this); + case HloOpcode::kTupleSelect: + return visitor->HandleTupleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); case HloOpcode::kFft: @@ -2781,6 +2306,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kDomain: + return visitor->HandleDomain(this); + case HloOpcode::kAfterAll: + return visitor->HandleAfterAll(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3018,117 +2547,31 @@ Status HloInstruction::AcceptOrdered( TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); } - return visitor->FinishVisit(this); -} - -const Shape& HloInstruction::outfeed_shape() const { - DCHECK_EQ(opcode_, HloOpcode::kOutfeed); - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); - return outfeed_shape_; -} - -const Shape& HloInstruction::shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); - return shape_; -} - -std::vector HloInstruction::OperandIndices( - const HloInstruction* operand) const { - std::vector result; - for (int64 i = 0; i < operand_count(); ++i) { - if (this->operand(i) == operand) { - result.push_back(i); - } - } - return result; -} - -bool HloInstruction::IsElementwiseBinary() const { - return IsElementwise() && operand_count() == 2; -} - -bool HloInstruction::IsElementwise() const { - switch (opcode_) { - // Nullary elementwise operations. - case HloOpcode::kConstant: - return true; - - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kCeil: - case HloOpcode::kClz: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kTanh: - CHECK_EQ(1, operand_count()); - return true; - - // Binary elementwise operations, the same as in IsElementwiseBinary(). - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - CHECK_EQ(2, operand_count()); - return true; - - // Ternary elementwise operations. - case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); - case HloOpcode::kClamp: - return true; - - // Other operations. - case HloOpcode::kRng: - case HloOpcode::kMap: - return true; - case HloOpcode::kFusion: - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - for (auto* fused : fused_instructions()) { - if (fused->opcode() != HloOpcode::kParameter && - !fused->IsElementwise()) { - return false; - } - } - return true; + return visitor->FinishVisit(this); +} - default: - return false; +const Shape& HloInstruction::shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); + return shape_; +} + +std::vector HloInstruction::OperandIndices( + const HloInstruction* operand) const { + std::vector result; + for (int64 i = 0; i < operand_count(); ++i) { + if (this->operand(i) == operand) { + result.push_back(i); + } } + return result; +} + +bool HloInstruction::IsElementwiseBinary() const { + return IsElementwise() && operand_count() == 2; +} + +bool HloInstruction::IsElementwise() const { + return IsElementwiseImpl(tensorflow::gtl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -3136,54 +2579,8 @@ bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } -namespace { -bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, - const HloInstruction* operand) { - std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); -} -} // namespace - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { - // For all instructions other than kFusion, being elementwise on one of the - // operands is equivalent to being elementwise on all the operands. - if (opcode() != HloOpcode::kFusion) { - return IsElementwise(); - } - - CHECK_EQ(HloOpcode::kFusion, opcode()); - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - - // A loop-fusion is elementwise on an operand if all operations (computed - // using BFS) between the operand and the fused root are elementwise. - std::deque worklist; - std::unordered_set visited; - worklist.push_back(fused_parameter(operand_idx)); - visited.insert(fused_parameter(operand_idx)); - while (!worklist.empty()) { - HloInstruction* operand = worklist.front(); - worklist.pop_front(); - for (HloInstruction* user : operand->users()) { - CHECK_GE(user->unique_id(), 0); - if (ContainsKey(visited, user)) { - continue; - } - if (user->IsElementwise() || - IsInstructionElementwiseOnOperand(user, operand)) { - worklist.push_back(user); - visited.insert(user); - } else { - return false; - } - } - } - return true; + return IsElementwiseImpl(operand_idx); } // A helper class for memoized, recursive computation of HloOpcode::kFusion @@ -3205,8 +2602,10 @@ class HloInstruction::FusionReusesParamElements { static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, tensorflow::gtl::FlatMap* cache) { - if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { - return UseKind::kUse; + if (auto hlo_param = DynCast(&hlo)) { + if (hlo_param->parameter_number() == i) { + return UseKind::kUse; + } } auto p = cache->emplace(&hlo, UseKind{}); @@ -3368,42 +2767,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3427,19 +2792,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3465,6 +2819,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = @@ -3496,6 +2872,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3513,21 +2914,264 @@ void HloInstruction::set_outer_dimension_partitions( outer_dimension_partitions_ = outer_dimension_partitions; } +// TODO(b/80131774): Remove these temporary methods after transition. +int64 HloInstruction::feature_index() const { + return Cast(this)->feature_index(); +} + +float HloInstruction::epsilon() const { + return Cast(this)->epsilon(); +} + +FftType HloInstruction::fft_type() const { + return Cast(this)->fft_type(); +} + +const std::vector& HloInstruction::fft_length() const { + return Cast(this)->fft_length(); +} + +int64 HloInstruction::channel_id() const { + return Cast(this)->channel_id(); +} + +int64 HloInstruction::concatenate_dimension() const { + return Cast(this)->concatenate_dimension(); +} + +bool HloInstruction::IsRank2Transpose() const { + auto transpose = DynCast(this); + return transpose != nullptr && transpose->IsRank2Transpose(); +} + +int64 HloInstruction::slice_starts(int64 dimension) const { + return Cast(this)->slice_starts(dimension); +} + +const std::vector& HloInstruction::slice_starts() const { + return Cast(this)->slice_starts(); +} + +int64 HloInstruction::slice_limits(int64 dimension) const { + return Cast(this)->slice_limits(dimension); +} + +const std::vector& HloInstruction::slice_limits() const { + return Cast(this)->slice_limits(); +} + +int64 HloInstruction::slice_strides(int64 dimension) const { + return Cast(this)->slice_strides(dimension); +} + +const std::vector& HloInstruction::slice_strides() const { + return Cast(this)->slice_strides(); +} + +bool HloInstruction::IsInPlaceSlice() const { + return Cast(this)->IsInPlaceSlice(); +} + +const Literal& HloInstruction::literal() const { + return Cast(this)->literal(); +} + +bool HloInstruction::IsConstant() const { + return DynCast(this) != nullptr; +} + void HloInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { - CHECK_EQ(opcode(), HloOpcode::kConstant); - Shape* mutable_array_subshape = - ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + Cast(this)->RelayoutConstant(new_layout, shape_index); +} + +string HloInstruction::TracingTag() const { + return Cast(this)->TracingTag(); +} + +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + return Cast(this)->AddFusionOperand(new_operand); +} + +// Delegates to HloFusionInstruction::MergeFusionInstruction. +void HloInstruction::MergeFusionInstruction( + HloInstruction* instruction_to_merge) { + return Cast(this)->MergeFusionInstruction( + Cast(instruction_to_merge)); +} + +// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. +void HloInstruction::MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge) { + return Cast(this) + ->MergeFusionInstructionIntoMultiOutput( + Cast(instruction_to_merge)); +} + +HloInstruction* HloInstruction::FuseInstruction( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstruction(instruction_to_fuse); +} + +HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstructionIntoMultiOutput( + instruction_to_fuse); +} + +HloComputation* HloInstruction::fused_instructions_computation() const { + return Cast(this)->fused_instructions_computation(); +} + +HloInstruction* HloInstruction::fused_expression_root() const { + return Cast(this)->fused_expression_root(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloInstruction::fused_instructions() const { + return Cast(this)->fused_instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloInstruction::fused_instructions() { + return Cast(this)->fused_instructions(); +} + +int64 HloInstruction::fused_instruction_count() const { + return Cast(this)->fused_instruction_count(); +} + +HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { + return Cast(this)->fused_parameter(parameter_number); +} + +const std::vector& HloInstruction::fused_parameters() const { + return Cast(this)->fused_parameters(); +} + +const bool HloInstruction::IsMultiOutputFusion() const { + const HloFusionInstruction* fusion = DynCast(this); + return fusion != nullptr && fusion->IsMultiOutputFusion(); +} + +HloInstruction::FusionKind HloInstruction::fusion_kind() const { + return Cast(this)->fusion_kind(); +} + +void HloInstruction::set_fusion_kind(FusionKind kind) { + return Cast(this)->set_fusion_kind(kind); +} + +RandomDistribution HloInstruction::random_distribution() const { + return Cast(this)->random_distribution(); +} + +int64 HloInstruction::parameter_number() const { + return Cast(this)->parameter_number(); +} + +int64 HloInstruction::tuple_index() const { + return Cast(this)->tuple_index(); +} + +int32 HloInstruction::exponent_bits() const { + return Cast(this)->exponent_bits(); +} + +int32 HloInstruction::mantissa_bits() const { + return Cast(this)->mantissa_bits(); +} + +string HloInstruction::infeed_config() const { + return Cast(this)->infeed_config(); +} + +void HloInstruction::set_infeed_config(const string& config) { + return Cast(this)->set_infeed_config(config); +} + +const Shape& HloInstruction::outfeed_shape() const { + return Cast(this)->outfeed_shape(); +} + +const string& HloInstruction::outfeed_config() const { + return Cast(this)->outfeed_config(); +} + +const std::vector& HloInstruction::replica_group_ids() const { + return Cast(this)->replica_group_ids(); +} - // Normally array_subshape will always have a layout, but this invariant is - // temporarily broken in LayoutAssignment::AssignLayouts. +string HloInstruction::cross_replica_sum_barrier() const { + return Cast(this)->cross_replica_sum_barrier(); +} + +void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); +} + +tensorflow::gtl::optional HloInstruction::all_reduce_id() const { + return Cast(this)->all_reduce_id(); +} + +const ConvolutionDimensionNumbers& +HloInstruction::convolution_dimension_numbers() const { + if (auto convolution = DynCast(this)) { + return convolution->convolution_dimension_numbers(); + } + if (auto custom_call = DynCast(this)) { + return custom_call->convolution_dimension_numbers(); + } + LOG(FATAL) << "Unimplemented method."; +} - if (!mutable_array_subshape->has_layout() || - !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); - *mutable_array_subshape->mutable_layout() = new_layout; +void HloInstruction::set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + if (auto convolution = DynCast(this)) { + convolution->set_convolution_dimension_numbers(dnums); + } else if (auto custom_call = DynCast(this)) { + custom_call->set_convolution_dimension_numbers(dnums); + } else { + LOG(FATAL) << "Unimplemented method."; } } +HloComputation* HloInstruction::select() const { + return Cast(this)->select(); +} + +HloComputation* HloInstruction::scatter() const { + return Cast(this)->scatter(); +} + +void HloInstruction::set_select(HloComputation* computation) { + return Cast(this)->set_select(computation); +} + +void HloInstruction::set_scatter(HloComputation* computation) { + return Cast(this)->set_scatter(computation); +} + +const string& HloInstruction::custom_call_target() const { + return Cast(this)->custom_call_target(); +} + +const string& HloInstruction::channel_name() const { + return Cast(this)->channel_name(); +} + +const PaddingConfig& HloInstruction::padding_config() const { + return Cast(this)->padding_config(); +} + +int64 HloInstruction::slice_sizes(int64 dimension) const { + return Cast(this)->slice_sizes(dimension); +} + +const std::vector& HloInstruction::dynamic_slice_sizes() const { + return Cast(this)->dynamic_slice_sizes(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 2b05a8825d151fb703c01452eacbafe3c3177493..34e7dcb43d43483f010f226f00bdf211722f2562 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -319,7 +322,7 @@ class HloInstruction { kCustom, }; - ~HloInstruction(); + virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // @@ -386,11 +389,10 @@ class HloInstruction { // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, - // at a given index) with the same `static_operands`. + // at a given index) static std::unique_ptr CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands = {}); + HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. @@ -423,10 +425,27 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `all_reduce_id`: for Allreduce nodes from different modules, if they have + // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -439,19 +458,36 @@ class HloInstruction { const Shape& shape, HloInstruction* operand); // Creates an infeed instruction, which reads data of the given shape from the - // Infeed interface of the device. - static std::unique_ptr CreateInfeed(const Shape& shape, + // Infeed interface of the device. infeed_shape is the shape of the data + // received from the infeed *not* the shape of the infeed instruction which + // is a tuple containing the infeed_shape and the TOKEN. + static std::unique_ptr CreateInfeed( + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of infeed are + // converted to take tokens. + static std::unique_ptr CreateInfeed(const Shape& infeed_shape, const string& config); - // Creates an outfeed instruction, which outputs data. + // Creates an outfeed instruction, which outputs data. outfeed_shape is the + // shape of the data being outfed *not* the shape of the outfeed instruction + // which is a TOKEN. static std::unique_ptr CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of outfeed are + // converted to take tokens. + static std::unique_ptr CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in // another computation that has the same channel id. static std::unique_ptr CreateSend(HloInstruction* operand, + HloInstruction* token, int64 channel_id); // Blocks until data transfer for the Send instruction (operand) is complete. @@ -463,6 +499,7 @@ class HloInstruction { // which allocates resources to receive data of the given shape from a unique // send instruction in another computation that has the same channel id. static std::unique_ptr CreateRecv(const Shape& shape, + HloInstruction* token, int64 channel_id); // Blocks until data transfer for the Recv instruction (operand) is complete @@ -576,6 +613,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a sort op, with a keys operand, and an optional values operand. + static std::unique_ptr CreateSort( + const Shape& shape, HloInstruction* keys, + HloInstruction* values = nullptr); + // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 @@ -597,6 +639,13 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Creates a kDomain instruction which delimits an HLO domain which have + // the provided user and operand side metadata. + static std::unique_ptr CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -638,6 +687,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a token instruction used for joining or creating new values of + // token type which thread through side-effecting operations. + static std::unique_ptr CreateAfterAll( + tensorflow::gtl::ArraySlice operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, @@ -648,6 +702,10 @@ class HloInstruction { // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + // Returns true if this instruction has a side effect, irrespective of whether + // any called computations may contain an instruction with side effects. + bool HasSideEffectNoRecurse() const; + // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. @@ -672,6 +730,10 @@ class HloInstruction { using InstructionVector = tensorflow::gtl::InlinedVector; const InstructionVector& operands() const { return operands_; } + // Returns the vector of unique operands, in the same order they are found + // within the operand vector. + InstructionVector unique_operands() const; + // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; @@ -742,10 +804,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -760,21 +820,28 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + if (backend_config_ != other.backend_config_) { + return false; + } + + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; - // Returns whether this instruction does a rank-2 transposition. - bool IsRank2Transpose() const; - // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. + // + // If user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. + // + // This function does NOT remove duplicated operands even if this instruction + // is a fusion, so that the existing operand numbers do not change. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If @@ -783,14 +850,10 @@ class HloInstruction { // // If this instruction is the root of its computation, sets the computation's // root to new_producer. - Status ReplaceAllUsesWith(HloInstruction* new_producer); - - // Detaches an instruction from its operands. That is, remove the instruction - // from each operand's user set. This should only be called prior to - // deallocating the instruction. // - // TODO(b/78305363): Make this automatic when deleting an instruction. - void DetachFromOperands(); + // If a user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. + Status ReplaceAllUsesWith(HloInstruction* new_producer); // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when @@ -837,38 +900,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the literal associated with this instruction. - // - // Note: only constant and parameter opcodes have an associated literal. - const Literal& literal() const; - - // Returns whether there is literal associated with this instruction. - bool HasLiteral() const; - - // Returns the parameter number associated with this instruction. - // - // Note: only parameter opcodes have an associated parameter number. - int64 parameter_number() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_number_; - } - - // Returns the dimension sizes or numbers associated with this instruction. - // - // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, - // and reverse. - const std::vector& dimensions() const; - int64 dimensions(int64 index) const; - - // Accessor for the dimension in which a concatenate HLO should occur. - // Precondition: opcode() == HloOpcode::kConcatenate - int64 concatenate_dimension() const; - - // Returns the tuple index associated with this instruction. - // - // Precondition: opcode() == HloOpcode::kGetTupleElement - int64 tuple_index() const; - // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. @@ -896,18 +927,6 @@ class HloInstruction { HloComputation* to_apply() const; void set_to_apply(HloComputation* to_apply); - // Returns the custom_call_target for CustomCall. - // Precondition: opcode() == HloOpcode::kCustomCall - const string& custom_call_target() const; - - // Returns the config for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const string& outfeed_config() const; - - // Returns the shape for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const Shape& outfeed_shape() const; - // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -917,15 +936,6 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); - // Gets/sets the select or scatter HloComputation for SelectAndScatter. The - // setters should only be called by HloModule or HloComputation methods. - // - // Precondition: opcode() == HloOpcode::kSelectAndScatter. - HloComputation* select() const; - HloComputation* scatter() const; - void set_select(HloComputation* select); - void set_scatter(HloComputation* scatter); - // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -963,11 +973,11 @@ class HloInstruction { string ToShortString() const; // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const; + virtual HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". - string ToCategory() const; + virtual string ToCategory() const; // Returns a logging instruction, if the output of this instruction is logged. // @@ -975,111 +985,14 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - // - // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv - int64 channel_id() const { return channel_id_; } - - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - // - // Precondition: opcode() == HloOpcode::kHostCompute - string channel_name() const { return channel_name_; } - - // Returns feature_index field associated with the instruction. The index - // represents the index of the feature dimension. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - int64 feature_index() const { return feature_index_; } - - // Returns a epsilon value associated with the instruction. The is a small - // number added to the variance to avoid divide-by-zero error. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - float epsilon() const { return epsilon_; } - - // Returns the infeed configuration string. The infeed configuration includes - // any metadata needed for the backend compiler (e.g., infeed buffer address) - // and is target-dependent. - string infeed_config() const { return infeed_config_; } - void set_infeed_config(const string& config) { infeed_config_ = config; } - - // Returns a tag to be used in tracing. - // - // Precondition: opcode() == HloOpcode::kTrace - string TracingTag() const; - - // Returns whether the instruction is a constant. - bool IsConstant() const; - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the root instruction of the fused expression contained within this - // fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_expression_root() const; - - // Returns the list of fused instructions inside this fusion instruction. The - // returned type is a range of HloInstruction*s. - // - // Precondition: opcode() == HloOpcode::kFusion - const tensorflow::gtl::iterator_range>::const_iterator>> - fused_instructions() const; - - const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> - fused_instructions(); - - // Gets the number of instructions inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - int64 fused_instruction_count() const; - - // Returns the fused parameter instruction in this fusion instruction - // corresponding to the given parameter number. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_parameter(int64 parameter_number) const; - - // Returns the vector of fused parameters inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - const std::vector& fused_parameters() const; - - // Returns true if this instruction is a fusion instruction that generates - // multiple outputs. - const bool IsMultiOutputFusion() const { - return opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple; - } - - FusionKind fusion_kind() const { - CHECK_EQ(HloOpcode::kFusion, opcode_); - return fusion_kind_; - } - - void set_fusion_kind(FusionKind kind) { - CHECK_EQ(HloOpcode::kFusion, opcode_); - fusion_kind_ = kind; - } - // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. const HloSharding& sharding() const { @@ -1092,20 +1005,44 @@ class HloInstruction { } // Returns the sharding unique device, if any. tensorflow::gtl::optional sharding_unique_device() const { - if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - return sharding_->UniqueDevice().ValueOrDie(); + auto device = sharding_->UniqueDevice(); + return device.ok() ? device.ValueOrDie() + : tensorflow::gtl::optional(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + void set_single_sharding(const HloSharding& sharding); + // Sets a sharding that assigns the current instruction to device. + void set_device_sharding(int64 device) { + set_single_sharding(HloSharding::AssignDevice(device)); + } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Checks whether the instruction has compatible sharding with the other + // instruction. + bool has_compatible_sharding(const HloInstruction* other) const { + if (!has_sharding()) { + return !other->has_sharding(); + } + return other->has_sharding() ? sharding() == other->sharding() : false; + } + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain @@ -1114,172 +1051,19 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Adds a new operand the fusion instruction. - HloInstruction* AddFusionOperand(HloInstruction* new_operand); - - // Merges the fused instructions from 'instruction_to_merge' into the - // fused instruction set of 'this', updating operands as necessary. - // - // Precondition: opcode() == HloOpcode::kFusion - // Predondition: 'instruction_to_merge' must be an operand of 'this'. - void MergeFusionInstruction(HloInstruction* instruction_to_merge); - - // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. - // All the users of instruction_to_merge will be redirected to 'this' - // instruction. instruction_to_merge will be removed from its parent - // computation. - // - // Precondition: opcode() == HloOpcode::kFusion - void MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge); - - // Fuses the given instruction in this fusion instruction. instruction_to_fuse - // is cloned and the clone is placed in the fusion - // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather - // than moved to cleanly handle the case where the instruction has a use - // outside the fusion instruction. Moving such an instruction into a fusion - // instruction would violate the single-result invariant of HLO instructions - // and significantly complicate code generation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse); - } - - // Fuses the given instruction in this fusion instruction and generate - // multioutput fusion instruction. A clone of the instruction_to_fuse will - // be part of the output of fusion instructions. The users of - // instruction_to_fuse will be redirected to this fusion instructions. - // instruction_to_fuse will be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionIntoMultiOutput( - HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); - } - - // Returns the start index in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_starts(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_starts_[dimension]; + // TODO(b/80249101): Remove these methods once HLO scheduling and copy + // insertion are integrated, and we don't need to run a separate pass + // of copy elision anymore. + bool CopyElisionAllowed() const { + CHECK_EQ(HloOpcode::kCopy, opcode_); + return copy_elision_allowed_; } - const std::vector& slice_starts() const { return slice_starts_; } - // Returns the (exclusive) limit index in the given dimension for a slice - // node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_limits(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_[dimension]; - } - const std::vector& slice_limits() const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_; + void SetCopyElisionAllowed(bool value) { + CHECK_EQ(HloOpcode::kCopy, opcode_); + copy_elision_allowed_ = value; } - // Returns the stride in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_strides(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_strides_[dimension]; - } - const std::vector& slice_strides() const { return slice_strides_; } - - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - - // Returns the size of the slice in the given dimension for a dynamic - // slice node. - // - // Precondition: opcode() == HloOpcode::kDynamicSlice - int64 slice_sizes(int64 dimension) const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_[dimension]; - } - const std::vector& dynamic_slice_sizes() const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_; - } - - // Returns the number of exponent bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 exponent_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return exponent_bits_; - } - - // Returns the number of mantissa bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 mantissa_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return mantissa_bits_; - } - - // Returns data on the window in a windowed operation such as - // convolution. - const Window& window() const { - CHECK(window_ != nullptr); - return *window_; - } - - // Sets the window data in a windowed operation such as convolution. - void set_window(const Window& window) { - window_ = MakeUnique(window); - } - - // Returns the padding configuration for a pad node. - // - // Precondition: opcode() == HloOpcode::kPad - const PaddingConfig& padding_config() const { - CHECK(padding_config_ != nullptr); - return *padding_config_; - } - - // Returns data on the dimension numbers used for a convolution operation, - // which may be a kConvolution instruction or a kCustomCall that implements a - // convolution. - const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { - CHECK(convolution_dimension_numbers_ != nullptr); - return *convolution_dimension_numbers_; - } - - // Sets the convolution dimension numbers on this instruction. In general you - // shouldn't need to call this; instead, specify the convolution dimension - // numbers when you create the instruction. - void set_convolution_dimension_numbers( - const ConvolutionDimensionNumbers& dnums) { - convolution_dimension_numbers_ = - MakeUnique(dnums); - } - - FftType fft_type() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_type_; - } - - const std::vector& fft_length() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_length_; - } - - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1302,35 +1086,19 @@ class HloInstruction { // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; - // Returns the random distribution for this rng node. - // - // Precondition: opcode() == HloOpcode::kRng - RandomDistribution random_distribution() const; - - // See documentation for Clone(). - using CloneMap = std::unordered_map; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. Ignores the - // control predecessors and successors of this HLO instruction. - // - // If the module pointer is not nullptr, then any cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the instruction is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr, - CloneMap* clone_map = nullptr) const; + // the instruction to form the name of the cloned instruction. + // Ignores the control predecessors and successors of this HLO instruction. + std::unique_ptr Clone( + const string& suffix = "clone", HloCloneContext* context = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1400,9 +1168,14 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Gets/sets the string identifier for this instruction. + // Gets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } + + // Sets the string identifier for this instruction. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + void SetAndSanitizeName(const string& name) { + name_ = NameUniquer::GetSanitizedName(name); + } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1425,13 +1198,34 @@ class HloInstruction { // this field and they cannot interpret it due to its meaning being backend // specific. // - // TODO(b/78194644): Introduce structured configuration format as per - // go/xla-heuristics. - const string& backend_config() const { return backend_config_; } - void set_backend_config(string backend_config) { - backend_config_ = std::move(backend_config); + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); + } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); } + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1462,13 +1256,273 @@ class HloInstruction { void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); - // Change the layout for an Constant Hlo instruction to match new_layout. For - // tuple shaped constants shape_index is the path to the internal array - // subshape whose layout needs to be changed. + // Old methods kept for smooth subclassing transition BEGIN. + // TODO(b/80131774): Remove this code. + + // Delegates to HloBatchNormInstruction::feature_index. + int64 feature_index() const; + + // Delegates to HloBatchNormInstruction::epsilon. + float epsilon() const; + + // Delegates to HloFftInstruction::fft_type. + FftType fft_type() const; + + // Delegates to HloFftInstruction::fft_length. + const std::vector& fft_length() const; + + // Delegates to HloSendRecvInstruction::channel_id. + int64 channel_id() const; + + // Returns the dimension sizes or numbers associated with this instruction. + virtual const std::vector& dimensions() const { + LOG(FATAL) << "Unimplemented method."; + } + virtual int64 dimensions(int64 index) const { + LOG(FATAL) << "Unimplemented method."; + } + + // Delegates to HloConcatenateInstruction::concatenate_dimension. + int64 concatenate_dimension() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Delegates to HloSliceInstruction::slice_start. + int64 slice_starts(int64 dimension) const; + const std::vector& slice_starts() const; + + // Delegates to HloSliceInstruction::slice_limits. + int64 slice_limits(int64 dimension) const; + const std::vector& slice_limits() const; + + // Delegates to HloSliceInstruction::slice_strides. + int64 slice_strides(int64 dimension) const; + const std::vector& slice_strides() const; + + // Delegates to HloSliceInstruction::IsInPlaceSlice. + bool IsInPlaceSlice() const; + + // Returns the literal associated with this instruction. + const Literal& literal() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + // Delegates to HloTraceInstruction::TracingTag. + string TracingTag() const; + + // Delegates to HloFusionInstruction::AddFusionOperand. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Delegates to HloFusionInstruction::MergeFusionInstruction. + void MergeFusionInstruction(HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. + void MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::FuseInstruction. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::fused_instruction. + HloComputation* fused_instructions_computation() const; + + // Delegates to HloFusionInstruction::fused_expression_root. + HloInstruction* fused_expression_root() const; + + // Delegates to HloFusionInstruction::fused_instructions. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Delegates to HloFusionInstruction::fused_instruction_count. + int64 fused_instruction_count() const; + + // Delegates to HloFusionInstruction::fused_parameter. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Delegates to HloFusionInstruction::fused_parameters. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const; + + // Delegates to HloFusionInstruction::fusion_kind. + FusionKind fusion_kind() const; + + // Delegates to HloFusionInstruction::set_fusion_kind. + void set_fusion_kind(FusionKind kind); + + // Delegates to HloRngInstruction::random_distribution. + RandomDistribution random_distribution() const; + + // Delegates to HloParameterInstruction::parameter_number. + int64 parameter_number() const; + + // Delegates to HloGetTupleElementInstruction::tuple_index. + int64 tuple_index() const; + + // Delegates to HloReducePrecisionInstruction::exponent_bits. + int32 exponent_bits() const; + + // Delegates to HloReducePrecisionInstruction::mantissa_bits. + int32 mantissa_bits() const; + + // Delegates to HloInfeedInstruction::infeed_config. + string infeed_config() const; + + // Delegates to HloInfeedInstruction::set_infeed_config. + void set_infeed_config(const string& config); + + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const; + + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const; + + // Delegates to HloAllReduceInstruction::replica_group_ids. + const std::vector& replica_group_ids() const; + + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. + string cross_replica_sum_barrier() const; + void set_cross_replica_sum_barrier(const string& barrier); + + // Delegates to HloAllReduceInstruction::all_reduce_id. + tensorflow::gtl::optional all_reduce_id() const; + + // Returns data on the window in a windowed operation such as + // convolution. + virtual const Window& window() const { + LOG(FATAL) << "Unimplemented method."; + } + + // Sets the window data in a windowed operation such as convolution. + virtual void set_window(const Window& window) { + LOG(FATAL) << "Unimplemented method."; + } + + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; + + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums); + + // Delegates to HloSelectAndScatterInstruction::select. + HloComputation* select() const; + + // Delegates to HloSelectAndScatterInstruction::scatter. + HloComputation* scatter() const; + + // Delegates to HloSelectAndScatterInstruction::set_select. + void set_select(HloComputation* computation); + + // Delegates to HloSelectAndScatterInstruction::set_scatter. + void set_scatter(HloComputation* computation); + + // Delegates to HloCustomCallInstruction::custom_call_target. + const string& custom_call_target() const; + + // Delegates to HloHostComputeInstruction::channel_name. + const string& channel_name() const; + + // Delegates to HloPadInstruction::padding_config. + const PaddingConfig& padding_config() const; + + // Delegates to HloDynamicSliceInstruction::slice_sizes. + int64 slice_sizes(int64 dimension) const; + + // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. + const std::vector& dynamic_slice_sizes() const; + // Old methods kept for smooth subclassing transition END. + + protected: + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Helper class for computing OperandElementUse for kFusion. + class FusionReusesParamElements; + + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + + void RemoveOperandAt(int index) { + operands_.erase(operands_.begin() + index); + } + + // Removes a list of operands with the given indices in ascending order. + void RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice ascending_indices); + + void AppendComputation(HloComputation* computation) { + called_computations_.push_back(computation); + } + + void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + + void set_called_computation(int index, HloComputation* computation) { + called_computations_[index] = computation; + } + // Indices of computations in called_computations_ for instructions which call + // multiple computations. + enum { + // kWhile computations. + kBodyComputationIndex = 0, + kConditionComputationIndex = 1, + + // kSelectAndScatter computations. + kSelectComputationIndex = 0, + kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, + }; + private: + // Implementation for non-common logic of CloneWithNewOperands. + virtual std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + // TODO(b/80131774): This should be pure virtual. + LOG(FATAL) << "Unimplemented method."; + } + + // Implementation for non-common logic of ExtraAttributesToString. + virtual std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {}; + } + + // Implementation for IsElementwise if operand_idx is nullopt and for + // IsElementwiseOnOperand if otherwise. + // + // NOTE: For all instructions other than kFusion, being elementwise on one of + // the operands is equivalent to being elementwise on all the operands. + virtual bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1479,7 +1533,7 @@ class HloInstruction { CanonicalNameMap* canonical_name_map) const; // Prints an operand to a string. - string OperandsToStringWithCanonicalNameMap( + virtual string OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; @@ -1487,75 +1541,30 @@ class HloInstruction { // OperandsToStringWithCanonicalNameMap() functions. friend class HloComputation; - enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; - - // Helper class for computing OperandElementUse for kFusion. - class FusionReusesParamElements; - // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. - bool IdenticalSlowPath( + virtual bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); - // Appends operand to the list of operands and adds this instruction as a user - // of the operand. - void AppendOperand(HloInstruction* operand); - // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. - HloInstruction(HloOpcode opcode, const Shape& shape); - - // Fuses the given instruction into this fusion instruction. When add_output - // is false (which is the default), instruction_to_fuse is cloned and the - // clone is placed in the fusion instruction. instruction_to_fuse is - // unchanged. - // - // When add_output is true, a clone of the instruction_to_fuse will be part - // of the output of fusion instructions. The users of instruction_to_fuse - // will be redirected to this fusion instructions. instruction_to_fuse will - // be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones the given instruction_to_fuse and insert the clone into this fusion - // instruction. If add_output is true, a clone of instruction_to_fuse will - // be in the output of the this fusion instruction (part of the tuple of the - // fusion root). - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones a fusion instruction with a new shape and operands. - std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; - - // Returns true if this instruction can legally have the dimensions field - // set. Used for checking precondition of dimensions field accessors. - bool CanHaveDimensionsField() const; - // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. @@ -1580,125 +1589,34 @@ class HloInstruction { // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; - // Shape of outfeed request. - Shape outfeed_shape_; - // Result shape of this instruction. Shape shape_; - // Literal, only present for kConstant. - std::unique_ptr literal_; - - // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = -1; - - // Dimensions present for some operations that require reshaping or - // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. - std::vector dimensions_; - - // Describes the window in a windowed operation such as convolution. - std::unique_ptr window_; - - // Describes the dimension numbers used for a convolution. - std::unique_ptr convolution_dimension_numbers_; - // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // Describes FFT type for an FFT instruction. - FftType fft_type_ = FftType::FFT; - - // Indicates the FFT length for an FFT instruction. - std::vector fft_length_; - - // Describes the [begin, end) index range for a slice. - std::vector slice_starts_; - std::vector slice_limits_; - std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; - - // The bit sizes for a reduce-precision operation. - int32 exponent_bits_ = 0; - int32 mantissa_bits_ = 0; - - // Describes the [start, start + size) range size for a dynamic slice - // ('start' is specified dynamically in the second operand of the operation). - std::vector dynamic_slice_sizes_; - - // The padding configuration that describes the edge padding and interior - // padding of this pad instruction. Only set for pad instructions. - std::unique_ptr padding_config_; - - // The type of the fusion. Used by kFusion only. - FusionKind fusion_kind_; + // Used to tag kCopy instructions that are eligible for copy elision. + bool copy_elision_allowed_ = true; // The sharding, if one exists. std::unique_ptr sharding_; - // For parameter instructions this field holds the parameter number. - int64 parameter_number_ = 0; - - // Name of a global symbol to call, only present for kCustomCall. - string custom_call_target_; - - // Name to use for host send/recv channels, only present for kHostCompute. - string channel_name_; - - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; + // Fields used by the kDomain instruction. + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; // Computations called by this instruction. std::vector called_computations_; - // Indices of computations in called_computations_ for instructions which call - // multiple computations. - enum { - // kWhile computations. - kBodyComputationIndex = 0, - kConditionComputationIndex = 1, - - // kSelectAndScatter computations. - kSelectComputationIndex = 0, - kScatterComputationIndex = 1, - - // kConditional computations. - kTrueComputationIndex = 0, - kFalseComputationIndex = 1, - }; - - // Outfeed configuration information, only present for kOutfeed. - string outfeed_config_; - // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as // an operand. HloInstruction* trace_instruction_ = nullptr; - // The distribution requested for random number generation. - // Only present for kRng. - RandomDistribution distribution_; - - // A small float number added to the variance to avoid divide-by-zero error. - // Only present for kBatchNormTraining. - float epsilon_ = 0.0f; - - // An integer value representing the index of the feature dimension. - // Only present for kBatchNormTraining. - int64 feature_index_ = -1; - - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. - int64 channel_id_ = -1; - - // The string representation of the infeed configuration. - string infeed_config_; - // The backend-specific configuration for how a backend should compile this // HLO. See the documentation on backend_config(). string backend_config_; @@ -1725,6 +1643,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index a61c472c72804b077d21274d2e866a69c5e73157..d8ca99dfd12ef95ab5e1ea61093d8bf3ea97a5e2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,11 +24,13 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -340,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); module->AddEntryComputation(builder.Build()); @@ -379,7 +381,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( @@ -714,10 +716,11 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { }))); auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto outfeed10 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape10, constant, "")); + HloInstruction::CreateOutfeed(shape10, constant, token, "")); auto outfeed01 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape01, constant, "")); + HloInstruction::CreateOutfeed(shape01, constant, token, "")); auto clone01 = builder.AddInstruction(outfeed01->Clone()); auto clone10 = builder.AddInstruction(outfeed10->Clone()); @@ -761,12 +764,12 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); - auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {constant}, computation_x, /*static_operands=*/{})); - auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); - auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); + auto map_1_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {constant}, computation_x)); + auto map_2_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x)); + auto map_3_y = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y)); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( @@ -921,6 +924,40 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); } +TEST_F(HloInstructionTest, IdenticalCallInstructions) { + const char* const hlo_string = R"( +HloModule Module + +subcomp1 (x: f32[]) -> f32[] { + x = f32[] parameter(0) + ROOT n = f32[] sine(x) +} + +subcomp2 (x: f32[]) -> f32[] { + x = f32[] parameter(0) + ROOT n = f32[] cosine(x) +} + +ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) { + p = f32[] parameter(0) + t1 = f32[] call(p), to_apply=subcomp1 + t2 = f32[] call(p), to_apply=subcomp1 + t3 = f32[] call(p), to_apply=subcomp2 + ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto* root = module->entry_computation()->root_instruction(); + auto* t1 = root->operand(0); + auto* t2 = root->operand(1); + auto* t3 = root->operand(2); + + EXPECT_TRUE(StructuralEqual(*t1, *t2)); + EXPECT_FALSE(StructuralEqual(*t1, *t3)); +} + TEST_F(HloInstructionTest, FunctionVisitor) { // Verify the function visitor HloInstruction::Accept visits all instructions // from a root properly given the following graph: @@ -978,6 +1015,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); @@ -1117,6 +1171,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } +TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + const Shape s = ShapeUtil::MakeShape(F32, {10, 10}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok()); + + EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y)); + EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1); +} + TEST_F(HloInstructionTest, FusionEquality) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -1494,5 +1582,117 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { })"); } +TEST_F(HloInstructionTest, CheckDeepClone) { + const char* const hlo_string = R"( +HloModule Module + +addy (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT zadd = s32[] add(lhs, rhs) +} + +calla (x: s32[]) -> s32[] { + x = s32[] parameter(0) + reduce = s32[] reduce-window(x, x), to_apply=addy + ROOT xadd = s32[] add(x, reduce) +} + +body (bparam: s32[]) -> s32[] { + constant = s32[] constant(1) + bparam = s32[] parameter(0) + v = s32[] call(bparam), to_apply=calla + ROOT add = s32[] add(constant, bparam) +} + +condition (cparam: s32[]) -> pred[] { + xconstant = s32[] constant(5) + cparam = s32[] parameter(0) + ROOT greater-than = pred[] greater-than(xconstant, cparam) +} + +ENTRY entry (param: s32[]) -> s32[] { + eparam = s32[] parameter(0) + ROOT while = s32[] while(eparam), condition=condition, body=body + } +)"; + // Check that deep clones really deep clones every instruction and + // computations, without leaving dangling pointers to the old module. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + std::unique_ptr clone = module->Clone(); + for (HloComputation* computation : clone->computations()) { + EXPECT_EQ(computation->parent(), clone.get()); + for (HloInstruction* instruction : computation->instructions()) { + EXPECT_EQ(instruction->parent()->parent(), clone.get()); + } + } +} + +TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { + const Shape shape = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder("test"); + HloInstruction* p = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + + EXPECT_TRUE(add1->Identical(*add2)); + add1->set_raw_backend_config_string("abc"); + EXPECT_FALSE(add1->Identical(*add2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc new file mode 100644 index 0000000000000000000000000000000000000000..7052e236cdab534864d8d4791bcdcfa162a2851d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -0,0 +1,1866 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace { + +using ::tensorflow::str_util::CEscape; +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, + const HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + return std::all_of( + operand_indices.begin(), operand_indices.end(), + [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); +} +} // namespace + +HloBatchNormInstruction::HloBatchNormInstruction( + HloOpcode opcode, const Shape& shape, HloInstruction* operand, + HloInstruction* scale, float epsilon, int64 feature_index) + : HloInstruction(opcode, shape), + epsilon_(epsilon), + feature_index_(feature_index) { + AppendOperand(operand); + AppendOperand(scale); +} + +bool HloBatchNormInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return feature_index() == casted_other.feature_index() && + epsilon() == casted_other.epsilon(); +} + +HloInstructionProto HloBatchNormInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + return proto; +} + +std::vector HloBatchNormInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("epsilon=", epsilon()), + StrCat("feature_index=", feature_index())}; +} + +HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); +} + +std::unique_ptr +HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), + feature_index()); +} + +HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); + AppendOperand(mean); + AppendOperand(variance); +} + +std::unique_ptr +HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloBatchNormGradInstruction::HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale, + epsilon, feature_index) { + AppendOperand(mean); + AppendOperand(variance); + AppendOperand(grad_output); +} + +std::unique_ptr +HloBatchNormGradInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloFftInstruction::HloFftInstruction( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) + : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { + fft_length_.assign(fft_length.begin(), fft_length.end()); + AppendOperand(operand); +} + +HloInstructionProto HloFftInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } + return proto; +} + +std::vector HloFftInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("fft_type=", FftType_Name(fft_type())), + StrCat("fft_length={", Join(fft_length(), ","), "}")}; +} + +bool HloFftInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return fft_type() == casted_other.fft_type() && + fft_length() == casted_other.fft_length(); +} + +std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], fft_type_, + fft_length_); +} + +HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, + const Shape& shape, + int64 channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +HloInstructionProto HloSendRecvInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_id(channel_id_); + return proto; +} + +std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("channel_id=", channel_id_)}; +} + +bool HloSendRecvInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +// Send instruction produces a tuple of {aliased operand, U32 context}. +HloSendInstruction::HloSendInstruction(HloInstruction* operand, + HloInstruction* token, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kSend, + ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), + channel_id) { + AppendOperand(operand); + AppendOperand(token); +} + +std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(new_operands[0], new_operands[1], + channel_id()); +} + +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloSendDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +// Recv instruction produces a tuple of {receive buffer, U32 context}. +HloRecvInstruction::HloRecvInstruction(const Shape& shape, + HloInstruction* token, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kRecv, + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), + channel_id) { + AppendOperand(token); +} + +std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id()); +} + +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTokenShape()}), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloRecvDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) + : HloInstruction(HloOpcode::kCrossReplicaSum, shape), + replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + all_reduce_id_(all_reduce_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(!all_reduce_id_); + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 i : replica_group_ids_) { + proto.add_replica_group_ids(i); + } + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + +std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result = { + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + if (all_reduce_id_) { + result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); + } + return result; +} + +bool HloAllReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return replica_group_ids() == casted_other.replica_group_ids() && + eq_computations(to_apply(), casted_other.to_apply()) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier() && + all_reduce_id() == casted_other.all_reduce_id(); +} + +std::unique_ptr +HloAllReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, to_apply(), replica_group_ids(), + cross_replica_sum_barrier(), all_reduce_id()); +} + +HloReverseInstruction::HloReverseInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kReverse, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloReverseInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReverseInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReverseInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloConcatenateInstruction::HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) + : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloConcatenateInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloConcatenateInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloConcatenateInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, + dimensions(0)); +} + +HloReduceInstruction::HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduce, shape), + dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { + AppendOperand(arg); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + // Reduction results are determined by the reduction dimension and the + // reduction computation. + return dimensions() == casted_other.dimensions() && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dimensions(), to_apply()); +} + +HloTransposeInstruction::HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kTranspose, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; + AppendOperand(operand); +} + +bool HloTransposeInstruction::IsRank2Transpose() const { + return dimensions() == std::vector({1, 0}) && + shape().dimensions_size() == 2 && + std::equal(shape().dimensions().begin(), shape().dimensions().end(), + operand(0)->shape().dimensions().rbegin()); +} + +HloInstructionProto HloTransposeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloTransposeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloTransposeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloBroadcastInstruction::HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension) + : HloInstruction(HloOpcode::kBroadcast, shape), + dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloBroadcastInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloBroadcastInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloMapInstruction::HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation) + : HloInstruction(HloOpcode::kMap, shape) { + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(map_computation); + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + dimensions_.resize(ShapeUtil::Rank(shape)); + std::iota(dimensions_.begin(), dimensions_.end(), 0); +} + +HloInstructionProto HloMapInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +bool HloMapInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (!dimensions().empty()) { + // Check that the map is executed in elementwise compatible dimensions. + if (dimensions().size() != shape().dimensions_size()) { + return false; + } + for (int i = 0; i < dimensions().size(); ++i) { + if (dimensions()[i] != i) { + return false; + } + } + } + return true; +} + +std::vector HloMapInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloMapInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return eq_computations(to_apply(), other.to_apply()); +} + +std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, to_apply()); +} + +HloSliceInstruction::HloSliceInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) + : HloInstruction(HloOpcode::kSlice, shape), + slice_starts_(start_indices.begin(), start_indices.end()), + slice_limits_(limit_indices.begin(), limit_indices.end()), + slice_strides_(strides.begin(), strides.end()) { + AppendOperand(operand); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (slice_strides_.empty()) { + slice_strides_ = std::vector(start_indices.size(), 1LL); + } +} + +HloInstructionProto HloSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + return proto; +} + +std::vector HloSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector bounds; + bounds.reserve(slice_starts_.size()); + const bool omit_stride = + std::all_of(slice_strides_.begin(), slice_strides_.end(), + [](int64 stride) { return stride == 1; }); + for (int i = 0; i < slice_starts_.size(); ++i) { + string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); + } + return {StrCat("slice={", Join(bounds, ", "), "}")}; +} + +bool HloSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return slice_starts_ == other_slice.slice_starts_ && + slice_limits_ == other_slice.slice_limits_ && + slice_strides_ == other_slice.slice_strides_; +} + +std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], slice_starts_, + slice_limits_, slice_strides_); +} + +HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) + : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), + literal_(std::move(literal)) {} + +HloConstantInstruction::HloConstantInstruction(const Shape& shape) + : HloInstruction(HloOpcode::kConstant, shape) {} + +HloInstructionProto HloConstantInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } + return proto; +} + +bool HloConstantInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Shape* mutable_array_subshape = + ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); + CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + + // Normally array_subshape will always have a layout, but this invariant is + // temporarily broken in LayoutAssignment::AssignLayouts. + + if (!mutable_array_subshape->has_layout() || + !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { + literal_ = literal_->Relayout(new_layout, shape_index); + *mutable_array_subshape->mutable_layout() = new_layout; + } +} + +bool HloConstantInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return literal() == other_slice.literal(); +} + +std::unique_ptr +HloConstantInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(literal_->CloneToUnique()); +} + +string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string operands; + // For constants, show the actual value in place of an empty operand list. + if (literal_ != nullptr && + ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { + // Literal::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = literal().ToString(); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Do not show large constants or tuples. + operands = "{...}"; + } + return operands; +} + +HloTraceInstruction::HloTraceInstruction(const string& tag, + HloInstruction* operand) + : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), + literal_(Literal::CreateR1U8(tag)) { + AppendOperand(operand); + operand->set_tracing(this); +} + +HloInstructionProto HloTraceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloTraceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); +} + +HloFusionInstruction::HloFusionInstruction(const Shape& shape, + FusionKind fusion_kind, + HloInstruction* fused_root) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + CHECK(fused_root != nullptr); + SetAndSanitizeName("fusion"); + set_parent(fused_root->parent()); + set_metadata(fused_root->metadata()); + CloneAndFuseInternal(fused_root); +} + +HloFusionInstruction::HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + for (auto operand : operands) { + AppendOperand(operand); + } + SetAndSanitizeName("fusion"); + AppendComputation(fusion_computation); + fusion_computation->SetFusionInstruction(this); +} + +string HloFusionInstruction::ToCategory() const { + switch (fusion_kind()) { + case FusionKind::kLoop: + return "loop fusion"; + case FusionKind::kInput: + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; + case FusionKind::kCustom: + return "custom fusion"; + } +} + +HloInstructionProto HloFusionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fusion_kind(xla::ToString(fusion_kind())); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); + return proto; +} + +bool HloFusionInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (!operand_idx.has_value()) { + for (auto* fused : fused_instructions()) { + if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { + return false; + } + } + return true; + } + // A loop-fusion is elementwise on an operand if all operations (computed + // using BFS) between the operand and the fused root are elementwise. + std::deque worklist; + std::unordered_set visited; + worklist.push_back(fused_parameter(operand_idx.value())); + visited.insert(fused_parameter(operand_idx.value())); + while (!worklist.empty()) { + HloInstruction* operand = worklist.front(); + worklist.pop_front(); + for (HloInstruction* user : operand->users()) { + CHECK_GE(user->unique_id(), 0); + if (ContainsKey(visited, user)) { + continue; + } + if (user->IsElementwise() || + IsInstructionElementwiseOnOperand(user, operand)) { + worklist.push_back(user); + visited.insert(user); + } else { + return false; + } + } + } + return true; +} + +HloInstruction* HloFusionInstruction::AddFusionOperand( + HloInstruction* new_operand) { + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + +void HloFusionInstruction::MergeFusionInstruction( + HloFusionInstruction* instruction_to_merge) { + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); + // Clone the instruction from which to merge fused instructions. + std::unique_ptr cloned = instruction_to_merge->Clone(); + HloFusionInstruction* cloned_fusion = + static_cast(cloned.get()); + // Replace uses of fused parameters with the corresponding operand of the + // fusion. Add all non-parameter fused instructions to + // 'unfused_instructions' to be merged into 'this'. This is done in reverse + // post order. + std::vector unfused_instructions; + auto fused_instructions = cloned_fusion->fused_instructions_computation() + ->MakeInstructionPostOrder(); + for (auto fused_it = fused_instructions.rbegin(); + fused_it != fused_instructions.rend(); ++fused_it) { + auto fused_instruction = *fused_it; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + TF_CHECK_OK( + fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand( + fused_instruction->parameter_number()))); + } else { + unfused_instructions.push_back(fused_instruction); + } + } + CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root()); + // Replace instruction_to_merge use of 'this' with unfused_root. + TF_CHECK_OK( + instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); + // Fuse 'unfused_instructions' into 'this'. + for (auto& instruction : unfused_instructions) { + FuseInstruction(instruction); + } + CHECK_EQ(0, cloned_fusion->user_count()); + TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( + cloned_fusion->fused_instructions_computation())); +} + +void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge) { + // Add all non-parameter fused instructions to 'unfused_instructions' to be + // merged into 'this'. `old_to_new' maps the instructions in the fused node + // to the disaseembled fusion instructions. + // Note that we add the unfused instructions to this->parent_ computation. + // This is necessary because the unique_id needs for an instruction and + // it's only added when inserting to the computation. + tensorflow::gtl::FlatMap old_to_new; + std::vector unfused_instructions; + auto computation_to_merge = + instruction_to_merge->fused_instructions_computation(); + auto post_order = computation_to_merge->MakeInstructionPostOrder(); + for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { + auto fused_instruction = *rit; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + InsertOrDie(&old_to_new, fused_instruction, + instruction_to_merge->mutable_operand( + fused_instruction->parameter_number())); + continue; + } + + // Here we clone the insertion and call FuseInstructionIntoMultiOutput() + // which clones again. This can be improved. + auto cloned_instruction = + parent()->AddInstruction(fused_instruction->Clone()); + unfused_instructions.push_back(cloned_instruction); + InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); + } + for (auto unfused_instruction : unfused_instructions) { + for (int64 index = 0; index < unfused_instruction->operand_count(); + index++) { + auto new_operand = + FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); + TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); + } + } + + HloInstruction* unfused_root = unfused_instructions.front(); + TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); + + TF_CHECK_OK( + instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); + if (GetModule()) { + TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); + } + + // Fuse the root instruction and generate multiple outputs. + FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); + // The rest instructions are of normal fusing. + for (int64 i = 1; i < unfused_instructions.size(); i++) { + auto instruction = unfused_instructions[i]; + FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); + } +} + +HloComputation* HloFusionInstruction::fused_instructions_computation() const { + CHECK(!called_computations().empty()); + auto* fused_instructions_computation = called_computations().front(); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; + return fused_instructions_computation; +} + +HloInstruction* HloFusionInstruction::fused_expression_root() const { + return fused_instructions_computation()->root_instruction(); +} + +HloInstruction* HloFusionInstruction::fused_parameter( + int64 parameter_number) const { + return fused_instructions_computation()->parameter_instruction( + parameter_number); +} + +const std::vector& HloFusionInstruction::fused_parameters() + const { + return fused_instructions_computation()->parameter_instructions(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloFusionInstruction::fused_instructions() const { + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloFusionInstruction::fused_instructions() { + return fused_instructions_computation()->instructions(); +} + +int64 HloFusionInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + +HloInstruction* HloFusionInstruction::FuseInstructionInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + // When add_output is false, this fusion instruction must be a user of + // instruction_to_fuse. + if (!add_output) { + CHECK(IsUserOf(instruction_to_fuse)); + } + HloInstruction* fused_instruction = + CloneAndFuseInternal(instruction_to_fuse, add_output); + return fused_instruction; +} + +HloInstruction* HloFusionInstruction::CloneAndFuseInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); + HloInstruction* clone = nullptr; + if (called_computations().empty()) { + // New fusion instruction. It should not be a multioutput instruction. + CHECK(!add_output); + auto builder = HloComputation::Builder("fused_computation", this); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + AppendComputation( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); + clone = fused_expression_root(); + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + // When add_output is false, instruction_to_fuse is necessarily an operand + // of the fusion instruction. After fusion this will no longer be the + // case. Remove the operand from the operand list and remove its + // corresponding fused parameter instruction. Renumber parameters as + // necessary to make parameter numbers consistent with their index in the + // fused_parameter_ vector. + bool in_operand_list = std::find(operands().begin(), operands().end(), + instruction_to_fuse) != operands().end(); + CHECK(add_output || in_operand_list); + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + if (instruction_to_fuse == operand(operand_num)) { + // replace the fused parameter instruction's uses with the clone. + HloInstruction* fused_parameter = fused_parameters[operand_num]; + TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); + + // Remove the corresponding fused parameter and operand from their + // respective vectors. + TF_CHECK_OK( + fused_instructions_computation()->RemoveParameter(operand_num)); + RemoveOperandAt(operand_num); + break; + } + } + // We've cloned instruction_to_fuse into this fusion instruction, so this + // fusion instruction is no longer a use of instruction_to_fuse. + if (in_operand_list) { + DetachFrom(instruction_to_fuse); + // When the instruction_to_fuse does not have other users, we don't need + // to generate a multioutput fusion instruction. + if (instruction_to_fuse->user_count() == 0) { + add_output = false; + } + } + } + + // Reread the parameters in the computation. + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + + // Add each operand of the clone as an operand of the fusion instruction. A + // complication is that some clone operands may already be operands of the + // fusion instruction. + for (int64 operand_num = 0; operand_num < clone->operand_count(); + ++operand_num) { + HloInstruction* operand = clone->mutable_operand(operand_num); + + // See if this operand is already an operand of the fusion node. + CHECK_EQ(operands().size(), fused_parameters.size()); + HloInstruction* fused_param = nullptr; + for (int64 i = 0; i < operands().size(); ++i) { + if (this->operand(i) == operand) { + fused_param = fused_parameters[i]; + break; + } + } + + if (fused_param == nullptr) { + // Clone's operand was not already an operand of the fusion + // instruction. Add it as an operand and add a corresponding fused + // parameter instruction. + fused_param = AddFusionOperand(operand); + } + TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); + } + + if (add_output) { + CHECK_GT(instruction_to_fuse->user_count(), 0); + // If this is already a multioutput fusion instruction, expand the root + // tuple by 1. + HloInstruction* fused_root = fused_expression_root(); + HloInstruction::InstructionVector tuple_elements; + bool newly_created_tuple_instr = false; + if (fused_root->opcode() == HloOpcode::kTuple) { + tuple_elements = fused_root->operands(); + } else { + tuple_elements.push_back(fused_root); + newly_created_tuple_instr = true; + } + if (clone->opcode() == HloOpcode::kTuple) { + for (auto inst : clone->operands()) { + tuple_elements.push_back(inst); + } + } else { + tuple_elements.push_back(clone); + } + HloInstruction* new_root = fused_instructions_computation()->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + fused_instructions_computation()->set_root_instruction(new_root); + *mutable_shape() = new_root->shape(); + if (fused_root->opcode() == HloOpcode::kTuple) { + TF_CHECK_OK( + fused_instructions_computation()->RemoveInstruction(fused_root)); + } + + // If this is a newly created multioutput instruction, we need to update + // the use of the original fusion instruction. + if (newly_created_tuple_instr) { + HloInstruction* new_instr = parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); + TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + } + int64 index = tuple_elements.size(); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + index -= instruction_to_fuse->operand_count(); + std::vector to_be_removed; + for (auto old_gte : instruction_to_fuse->users()) { + CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); + int64 old_tuple_index = old_gte->tuple_index(); + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + old_gte->shape(), this, index + old_tuple_index)); + TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); + to_be_removed.push_back(old_gte); + } + for (auto old_gte : to_be_removed) { + TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); + } + TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); + } else { + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + clone->shape(), this, index - 1)); + TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); + } + } + + VLOG(2) << "New clone:\n" << clone->ToString(); + return clone; +} + +std::vector HloFusionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("kind=", xla::ToString(fusion_kind()))}; +} + +bool HloFusionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); +} + +std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } + return MakeUnique(shape, fusion_kind(), new_operands, + new_fused_computation); +} + +Status HloFusionInstruction::DeduplicateFusionOperands() { + tensorflow::gtl::FlatMap operand_indices; + std::vector operands_to_remove; + for (int i = 0; i < operand_count(); ++i) { + auto emplace_result = operand_indices.emplace(operand(i), i); + if (!emplace_result.second) { + TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith( + fused_parameter(emplace_result.first->second))); + operands_to_remove.push_back(i); + } + } + if (operands_to_remove.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + fused_instructions_computation()->RemoveUnusedParameters()); + RemoveOperandsAtAscendingIndices(operands_to_remove); + return Status::OK(); +} + +HloRngInstruction::HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters) + : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { + for (HloInstruction* param : parameters) { + AppendOperand(param); + } +} + +HloInstructionProto HloRngInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_distribution(distribution_); + return proto; +} + +std::vector HloRngInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("distribution=", RandomDistributionToString(distribution_))}; +} + +bool HloRngInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +bool HloRngInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, distribution_, new_operands); +} + +HloParameterInstruction::HloParameterInstruction(int64 parameter_number, + const Shape& shape, + const string& name) + : HloInstruction(HloOpcode::kParameter, shape), + parameter_number_(parameter_number) { + SetAndSanitizeName(name); +} + +HloInstructionProto HloParameterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_parameter_number(parameter_number_); + return proto; +} + +string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + return StrCat(parameter_number_); +} + +bool HloParameterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return parameter_number() == casted_other.parameter_number(); +} + +std::unique_ptr +HloParameterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(parameter_number_, shape, name()); +} + +HloGetTupleElementInstruction::HloGetTupleElementInstruction( + const Shape& shape, HloInstruction* operand, int64 index) + : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); + AppendOperand(operand); +} + +HloInstructionProto HloGetTupleElementInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_tuple_index(tuple_index_); + return proto; +} + +std::vector HloGetTupleElementInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("index=", tuple_index())}; +} + +bool HloGetTupleElementInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return tuple_index() == casted_other.tuple_index(); +} + +std::unique_ptr +HloGetTupleElementInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + tuple_index()); +} + +HloReducePrecisionInstruction::HloReducePrecisionInstruction( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits) + : HloInstruction(HloOpcode::kReducePrecision, shape), + exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits) { + AppendOperand(operand); +} + +HloInstructionProto HloReducePrecisionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + return proto; +} + +std::vector HloReducePrecisionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("exponent_bits=", exponent_bits_), + StrCat("mantissa_bits=", mantissa_bits_)}; +} + +bool HloReducePrecisionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + // A reduce-precision operation is determined by the bit sizes. + return exponent_bits() == casted_other.exponent_bits() && + mantissa_bits() == casted_other.mantissa_bits(); +} + +std::unique_ptr +HloReducePrecisionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + shape, new_operands[0], exponent_bits(), mantissa_bits()); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, + const string& config) + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) { + AppendOperand(token_operand); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + const string& config) + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) {} + +HloInstructionProto HloInfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_infeed_config(infeed_config_); + return proto; +} + +std::vector HloInfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (infeed_config_.empty()) { + return {}; + } + return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")}; +} + +bool HloInfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + if (new_operands.empty()) { + return MakeUnique(infeed_shape(), infeed_config()); + } else { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(infeed_shape(), new_operands[0], + infeed_config()); + } +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); + AppendOperand(operand); + AppendOperand(token_operand); +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& outfeed_shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); + AppendOperand(operand); +} + +HloInstructionProto HloOutfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_outfeed_config(outfeed_config()); + *proto.mutable_outfeed_shape() = outfeed_shape(); + return proto; +} + +std::vector HloOutfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (outfeed_config_.empty()) { + return {}; + } + return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")}; +} + +bool HloOutfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + if (new_operands.size() == 1) { + return MakeUnique(outfeed_shape(), new_operands[0], + outfeed_config()); + } else { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(outfeed_shape(), new_operands[0], + new_operands[1], outfeed_config()); + } +} + +HloConvolutionInstruction::HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + : HloInstruction(HloOpcode::kConvolution, shape), + window_(window), + convolution_dimension_numbers_(dimension_numbers) { + if (window_util::HasBaseDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-base-dilated")); + } + if (window_util::HasWindowDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-window-dilated")); + } + AppendOperand(lhs); + AppendOperand(rhs); +} + +string HloConvolutionInstruction::ToCategory() const { + string category = "convolution"; + if (window_util::HasBaseDilation(window())) { + category += " base-dilated"; + } + if (window_util::HasWindowDilation(window())) { + category += " window-dilated"; + } + return category; +} + +HloInstructionProto HloConvolutionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + *proto.mutable_convolution_dimension_numbers() = + convolution_dimension_numbers_; + return proto; +} + +std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( + convolution_dimension_numbers_))); + return extra; +} + +bool HloConvolutionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return protobuf_util::ProtobufEquals(window(), casted_other.window()) && + protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()); +} + +std::unique_ptr +HloConvolutionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], + new_operands[1], window(), + convolution_dimension_numbers_); +} + +HloReduceWindowInstruction::HloReduceWindowInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + const Window& window, HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) { + AppendOperand(operand); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceWindowInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloReduceWindowInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloReduceWindowInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(to_apply(), casted_other.to_apply()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloReduceWindowInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], window(), to_apply()); +} + +HloSelectAndScatterInstruction::HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter) + : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) { + AppendOperand(operand); + AppendOperand(source); + AppendOperand(init_value); + // Select comes before scatter in the vector. + AppendComputation(select); + AppendComputation(scatter); +} + +HloInstructionProto HloSelectAndScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloSelectAndScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloSelectAndScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(select(), casted_other.select()) && + eq_computations(scatter(), casted_other.scatter()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], select(), window(), new_operands[1], + new_operands[2], scatter()); +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), + custom_call_target.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloCustomCallInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + proto.set_custom_call_target(custom_call_target_); + return proto; +} + +std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_ != nullptr && window_->dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); + } + if (convolution_dimension_numbers_ != nullptr) { + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); + } + // By contract, we print the custom call target even if + // options.print_subcomputation_mode() == kOff, because the call target is not + // an HloComputation. + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + return extra; +} + +bool HloCustomCallInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + if ((window_ == nullptr) != (casted_other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (casted_other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()))) { + return false; + } + return custom_call_target_ == casted_other.custom_call_target_; +} + +std::unique_ptr +HloCustomCallInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + auto cloned = MakeUnique(shape, new_operands, + custom_call_target()); + if (window_ != nullptr) { + cloned->set_window(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); + } + return std::move(cloned); +} + +HloHostComputeInstruction::HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) + : HloInstruction(HloOpcode::kHostCompute, shape), + channel_name_(channel_name.begin(), channel_name.end()), + cost_estimate_ns_(cost_estimate_ns) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloHostComputeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); + return proto; +} + +bool HloHostComputeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr +HloHostComputeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique( + shape, new_operands, channel_name_, cost_estimate_ns_); +} + +HloPadInstruction::HloPadInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) + : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) { + AppendOperand(operand); + AppendOperand(padding_value); +} + +HloInstructionProto HloPadInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_padding_config() = padding_config_; + return proto; +} + +std::vector HloPadInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))}; +} + +bool HloPadInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(padding_config(), + casted_other.padding_config()); +} + +std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], new_operands[1], + padding_config_); +} + +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes) + : HloInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + AppendOperand(start_indices); +} + +HloInstructionProto HloDynamicSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + return proto; +} + +std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return { + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; +} + +bool HloDynamicSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return true; +} + +std::unique_ptr +HloDynamicSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..df6969c410a7742a9abfff56c3d41864232a8bff --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -0,0 +1,1124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// All HloInstruction subclasses are put in this file. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +class HloBatchNormInstruction : public HloInstruction { + public: + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + float epsilon() const { return epsilon_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, float epsilon, + int64 feature_index); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // A small float number added to the variance to avoid divide-by-zero error. + float epsilon_ = 0.0f; + + // An integer value representing the index of the feature dimension. + int64 feature_index_ = -1; +}; + +class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormTrainingInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormGradInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloFftInstruction : public HloInstruction { + public: + explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + FftType fft_type() const { return fft_type_; } + + const std::vector& fft_length() const { return fft_length_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; +}; + +class HloSendRecvInstruction : public HloInstruction { + public: + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + int64 channel_id() const { return channel_id_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, + int64 channel_id); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Represents a unique identifier for each Send/Recv instruction pair. + int64 channel_id_; +}; + +class HloSendInstruction : public HloSendRecvInstruction { + public: + explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, + int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloSendDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloSendDoneInstruction(HloSendInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, + int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloAllReduceInstruction : public HloInstruction { + public: + explicit HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id = + tensorflow::gtl::nullopt); + + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + + tensorflow::gtl::optional all_reduce_id() const { + return all_reduce_id_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The group id of each replica for CrossReplicaSum. + std::vector replica_group_ids_; + + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; + + // For Allreduce nodes from different modules, if they have the same + // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross modules. + tensorflow::gtl::optional all_reduce_id_; +}; + +class HloReverseInstruction : public HloInstruction { + public: + explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloConcatenateInstruction : public HloInstruction { + public: + explicit HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Accessor for the dimension in which a concatenate HLO should occur. + int64 concatenate_dimension() const { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloReduceInstruction : public HloInstruction { + public: + explicit HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloTransposeInstruction : public HloInstruction { + public: + explicit HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloBroadcastInstruction : public HloInstruction { + public: + explicit HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloMapInstruction : public HloInstruction { + public: + explicit HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSliceInstruction : public HloInstruction { + public: + explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + HloInstructionProto ToProto() const override; + + // Returns the start index in the given dimension for a slice node. + int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } + const std::vector& slice_limits() const { return slice_limits_; } + + // Returns the stride in the given dimension for a slice node. + int64 slice_strides(int64 dimension) const { + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + std::vector slice_strides_; + + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; +}; + +class HloConstantInstruction : public HloInstruction { + public: + explicit HloConstantInstruction(std::unique_ptr literal); + // Used when the literal is too large and dropped. + explicit HloConstantInstruction(const Shape& shape); + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Change the layout for an Constant Hlo instruction to match new_layout. For + // tuple shaped constants shape_index is the path to the internal array + // subshape whose layout needs to be changed. + void RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index = {}); + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloTraceInstruction : public HloInstruction { + public: + explicit HloTraceInstruction(const string& tag, HloInstruction* operand); + // Returns a tag to be used in tracing. + string TracingTag() const { return literal_->GetR1U8AsString(); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloFusionInstruction : public HloInstruction { + public: + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + HloInstruction* fused_root); + + explicit HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + + string ToCategory() const override; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Merges the fused instructions from 'instruction_to_merge' into the + // fused instruction set of 'this', updating operands as necessary. + // + // Predondition: 'instruction_to_merge' must be an operand of 'this'. + void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); + + // Merges the fused instructions from instruction_to_merge into the fused + // instruction set of 'this' and generates multioutput fusion instructions. + // All the users of instruction_to_merge will be redirected to 'this' + // instruction. instruction_to_merge will be removed from its parent + // computation. + void MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge); + + // Fuses the given instruction in this fusion instruction. instruction_to_fuse + // is cloned and the clone is placed in the fusion + // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather + // than moved to cleanly handle the case where the instruction has a use + // outside the fusion instruction. Moving such an instruction into a fusion + // instruction would violate the single-result invariant of HLO instructions + // and significantly complicate code generation. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse); + } + + // Fuses the given instruction in this fusion instruction and generate + // multioutput fusion instruction. A clone of the instruction_to_fuse will + // be part of the output of fusion instructions. The users of + // instruction_to_fuse will be redirected to this fusion instructions. + // instruction_to_fuse will be removed from its parent computation. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); + } + + // Returns the computation for this fused instruction. + HloComputation* fused_instructions_computation() const; + + // Returns the root instruction of the fused expression contained within this + // fusion instruction. + HloInstruction* fused_expression_root() const; + + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. + int64 fused_instruction_count() const; + + // Returns the fused parameter instruction in this fusion instruction + // corresponding to the given parameter number. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Returns the vector of fused parameters inside this fusion instruction. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const { + return fused_expression_root()->opcode() == HloOpcode::kTuple; + } + + FusionKind fusion_kind() const { return fusion_kind_; } + + void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + + // If multiple operands are the same instruction, keeps only one of them. + Status DeduplicateFusionOperands(); + + private: + // Fuses the given instruction into this fusion instruction. When add_output + // is false (which is the default), instruction_to_fuse is cloned and the + // clone is placed in the fusion instruction. instruction_to_fuse is + // unchanged. + // + // When add_output is true, a clone of the instruction_to_fuse will be part + // of the output of fusion instructions. The users of instruction_to_fuse + // will be redirected to this fusion instructions. instruction_to_fuse will + // be removed from its parent computation. + HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + // Clones the given instruction_to_fuse and insert the clone into this fusion + // instruction. If add_output is true, a clone of instruction_to_fuse will + // be in the output of the this fusion instruction (part of the tuple of the + // fusion root). + HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The type of the fusion. Used by kFusion only. + FusionKind fusion_kind_; +}; + +class HloRngInstruction : public HloInstruction { + public: + explicit HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters); + // Returns the random distribution for this rng node. + RandomDistribution random_distribution() const { return distribution_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The distribution requested for random number generation. + RandomDistribution distribution_; +}; + +class HloParameterInstruction : public HloInstruction { + public: + explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, + const string& name); + int64 parameter_number() const { return parameter_number_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 parameter_number_ = 0; +}; + +class HloGetTupleElementInstruction : public HloInstruction { + public: + explicit HloGetTupleElementInstruction(const Shape& shape, + HloInstruction* operand, int64 index); + // Returns the tuple index associated with this instruction. + int64 tuple_index() const { return tuple_index_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 tuple_index_ = -1; +}; + +class HloReducePrecisionInstruction : public HloInstruction { + public: + explicit HloReducePrecisionInstruction(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits); + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const { return exponent_bits_; } + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const { return mantissa_bits_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; +}; + +class HloInfeedInstruction : public HloInstruction { + public: + explicit HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, + const string& config); + // TODO(b/80000000): Remove this constructor when all uses of infeed are + // converted to take tokens. + explicit HloInfeedInstruction(const Shape& infeed_shape, + const string& config); + // Returns the infeed configuration string. The infeed configuration includes + // any metadata needed for the backend compiler (e.g., infeed buffer address) + // and is target-dependent. + string infeed_config() const { return infeed_config_; } + void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns the shape of the data received by the infeed. This is not the same + // as the shape of the infeed instruction which produces a tuple containing + // the infeed data shape and a TOKEN. + const Shape& infeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return ShapeUtil::GetSubshape(shape(), {0}); + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The string representation of the infeed configuration. + string infeed_config_; +}; + +class HloOutfeedInstruction : public HloInstruction { + public: + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + tensorflow::StringPiece outfeed_config); + // TODO(b/80000000): Remove this constructor when all uses of outfeed are + // converted to take tokens. + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); + return outfeed_shape_; + } + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const { return outfeed_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Shape of outfeed request. + Shape outfeed_shape_; + // Outfeed configuration information, only present for kOutfeed. + string outfeed_config_; +}; + +class HloConvolutionInstruction : public HloInstruction { + public: + explicit HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + return convolution_dimension_numbers_; + } + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = dnums; + } + string ToCategory() const override; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; +}; + +class HloReduceWindowInstruction : public HloInstruction { + public: + explicit HloReduceWindowInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* init_value, + const Window& window, + HloComputation* reduce_computation); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloSelectAndScatterInstruction : public HloInstruction { + public: + explicit HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Gets/sets the select or scatter HloComputation for SelectAndScatter. The + // setters should only be called by HloModule or HloComputation methods. + HloComputation* select() const { + return called_computations()[kSelectComputationIndex]; + } + + HloComputation* scatter() const { + return called_computations()[kScatterComputationIndex]; + } + + void set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kSelectComputationIndex, computation); + } + + void set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kScatterComputationIndex, computation); + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloCustomCallInstruction : public HloInstruction { + public: + explicit HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target); + const Window& window() const override { + CHECK(window_ != nullptr); + return *window_; + } + + void set_window(const Window& window) override { + window_ = MakeUnique(window); + } + + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + CHECK(convolution_dimension_numbers_ != nullptr); + return *convolution_dimension_numbers_; + } + + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + const string& custom_call_target() const { return custom_call_target_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target_; + // Describes the window in a windowed operation such as convolution. + std::unique_ptr window_; + // Describes the dimension numbers used for a convolution. + std::unique_ptr convolution_dimension_numbers_; +}; + +class HloHostComputeInstruction : public HloInstruction { + public: + explicit HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. + const string& channel_name() const { return channel_name_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name to use for host send/recv channels. + string channel_name_; + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_ = 0; +}; + +class HloPadInstruction : public HloInstruction { + public: + explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + // Returns the padding configuration for a pad node. + const PaddingConfig& padding_config() const { return padding_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. + PaddingConfig padding_config_; +}; + +class HloDynamicSliceInstruction : public HloInstruction { + public: + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + // Old methods kept for smooth subclassing transition END. + // Returns the size of the slice in the given dimension for a dynamic + // slice node. + int64 slice_sizes(int64 dimension) const { + return dynamic_slice_sizes_[dimension]; + } + const std::vector& dynamic_slice_sizes() const { + return dynamic_slice_sizes_; + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + std::vector dynamic_slice_sizes_; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc similarity index 95% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc rename to tensorflow/compiler/xla/service/hlo_lexer.cc index 350db126535e418cbfa914edd958f47ba90a3ee5..f0d9fdbc8f86da0bb9d7f9235239df677c9506bc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include @@ -26,9 +26,8 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace tools { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; namespace { @@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { +tensorflow::StringPiece HloLexer::StringPieceFromPointers( + const char* begin, const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); + return tensorflow::StringPiece(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + tensorflow::StringPiece identifier = + StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. #define KEYWORD(STR) \ @@ -332,23 +332,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == StringPiece::npos) { + if (line_offset == tensorflow::StringPiece::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -StringPiece HloLexer::GetLine(LocTy loc) const { +tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == StringPiece::npos + const char* start = line_start == tensorflow::StringPiece::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); - const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + const char* end = + line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -370,7 +371,7 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - StringPiece raw = + tensorflow::StringPiece raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { @@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) { } } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h similarity index 90% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.h rename to tensorflow/compiler/xla/service/hlo_lexer.h index 27880b9b8afbfa58abfedc3b2cecd5236b78a6d6..ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ #include -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. class HloLexer { public: explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { @@ -57,7 +59,7 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - int64 GetInt64Val() const { + tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; } @@ -114,7 +116,7 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - int64 int64_val_; + tensorflow::int64 int64_val_; double decimal_val_; struct LineNoCacheTy { @@ -125,7 +127,6 @@ class HloLexer { mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using Worklist = std::deque; +using Workset = std::unordered_set; + +namespace { + +void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, + Workset* workset) { + if (workset->count(instruction) == 0) { + worklist->push_back(instruction); + workset->insert(instruction); + VLOG(3) << "ADD instruction: " << instruction->name(); + } +} + +using VisitorFunction = std::function; + +void ForEachLiveIndex(const ShapeTree& index_tree, + const VisitorFunction& func) { + index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) { + if (live) { + func(shape_index); + } + }); +} + +// Marks 'instruction' output live at 'shape_index'. +// Adds to 'worklist' iff: +// *) 'instruction' is not already on worklist. +// *) 'shape_index' has not yet been visited. +void MarkLiveAtIndex(const HloInstruction* instruction, + const ShapeIndex& shape_index, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + auto it_added = live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + if (it->second.element(shape_index) == false) { + AddToWorklist(instruction, worklist, workset); + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } +} + +// Marks 'instruction' live at all shape indices in its output. +void MarkLiveAtAllIndices(const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + bool add_to_worklist = false; + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/true)); + add_to_worklist = true; + } else { + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& shape_index) { + if (it->second.element(shape_index) == false) { + add_to_worklist = true; + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } + }); + } + if (add_to_worklist) { + AddToWorklist(instruction, worklist, workset); + } +} + +// Propagates liveness through Tuple instructions. +// *) For each tuple operand: +// *) For tuple output shape index associated with operand: +// *) Propgate live shape indices to tuple operand at the associated +// shape index in the operands output, and add to worklist. +void PropagateLivenessThroughTuple( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); + for (int64 operand_index = 0; operand_index < instruction->operand_count(); + ++operand_index) { + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + if (shape_index.empty() || shape_index[0] != operand_index) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index; + for (int i = 1; i < shape_index.size(); ++i) { + operand_shape_index.push_back(shape_index[i]); + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); + } +} + +// Propagates liveness through GetTupleElement instructions. +// *) For each live index in GetTupleElement output, mark output of GTE operand +// at associated shape index in its output, and add to worklist. +void PropagateLivenessThroughGTE( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // Mark operand top-level index. + MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist, + workset); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + // Propagate live shape indices along GTE -> Tuple edge. + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + ShapeIndex operand_shape_index(shape_index); + operand_shape_index.push_front(instruction->tuple_index()); + MarkLiveAtIndex(instruction->operand(0), operand_shape_index, + live_index_map, worklist, workset); + }); +} + +// Propagates liveness through While instructions. +// *) For each live index in While output, mark shape index of while.body.root +// and while.operand (adding each to worklist). +// *) Mark while.cond.root and add to worklist. +void PropagateLivenessThroughWhile( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kWhile); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while body computation root instruction. + MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index, + live_index_map, worklist, workset); + // Propagate liveness to tuple-shaped operand. + MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map, + worklist, workset); + }); + + // Propagate liveness to while condition computation root instruction. + MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); +} + +// Propagates liveness out of Parameter instructions to callers and aliasing +// positions. This can occur if liveness propagates to a parameter in the +// while.condition computation, requiring liveness to propagate out to caller +// callsite while (and while.body.root). +void PropagateLivenessToParameterCallers( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + auto* xla_while = callsite.instruction(); + const ShapeTree& index_tree = + FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while result{shape_index} + MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist, + workset); + // Propagate liveness to while body root{shape_index}. + MarkLiveAtIndex(xla_while->while_body()->root_instruction(), + shape_index, live_index_map, worklist, workset); + // Propagate liveness to operand(0){shape_index}. + MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map, + worklist, workset); + }); + } + } + } +} + +} // namespace + +HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) + : module_(module), call_graph_(CallGraph::Build(&module)) {} + +// Runs liveness analysis on 'module_'. +// Initializes worklist with entry root instruction (and any instruction with +// side-effects), marking all of their output shape indices live. +// Visits elements on worklist, propagating liveness from an instructions +// live output shape indices to its called computations and operands. +void HloLivenessAnalysis::RunAnalysis() { + Worklist worklist; + Workset workset; + // Add entry compuation root instruction. + MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(), + &live_index_map_, &worklist, &workset); + for (auto* computation : module_.computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->HasSideEffectNoRecurse()) { + // Add instructions with side effects. + MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist, + &workset); + } + } + } + + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop_front(); + workset.erase(workset.find(instruction)); + VLOG(1) << "VISIT instruction: " << instruction->name(); + + if (instruction->opcode() == HloOpcode::kTuple) { + PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { + PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kWhile && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessToParameterCallers(instruction, &live_index_map_, + &worklist, &workset, + call_graph_.get()); + } else { + // Propagate liveness to called computations. + for (auto* called_computation : instruction->called_computations()) { + MarkLiveAtAllIndices(called_computation->root_instruction(), + &live_index_map_, &worklist, &workset); + } + // Propagate liveness to operands. + for (HloInstruction* operand : instruction->operands()) { + MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); + } + } + } +} + +bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const { + if (ContainsKey(live_index_map_, instruction)) { + return FindOrDie(live_index_map_, instruction).element(shape_index); + } + return false; +} + +/* static */ +StatusOr> HloLivenessAnalysis::Run( + const HloModule& module) { + VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); + + auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + + liveness_analysis->RunAnalysis(); + + return std::move(liveness_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..fe55a8070a42a3d68836dd32cf7ce5823dd77951 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = + std::unordered_map>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0275294a1a86cef13e5b267ad578f30cc18858dc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -0,0 +1,402 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloLivenessAnalysisTest : public HloTestBase { + protected: + HloLivenessAnalysisTest() {} + + // Run liveness analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloLivenessAnalysis& RunLiveness(HloModule* module) { + liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie(); + return *liveness_; + } + + HloInstruction* GetInstruction(HloModule* module, const string& name) { + HloInstruction* to_return = nullptr; + for (auto* comp : module->computations()) { + for (auto* inst : comp->instructions()) { + if (inst->name() == name) { + to_return = inst; + break; + } + } + } + return CHECK_NOTNULL(to_return); + } + + std::unique_ptr liveness_; +}; + +// Test that add instruction at entry root is live at all output shape indices. +TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT add = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Test that a dead add instruction is marked as dead by analysis. +TEST_F(HloLivenessAnalysisTest, DeadAdd) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + add.1 = s32[] add(constant.1, constant.2) + ROOT add.2 = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); +} + +// Test that all output shape indices of entry root tuple (and defining +// instruction in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that all outputs of nested tuple and entry root (and defining +// instruction values appearing in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(1) + constant.2 = s32[] constant(2) + constant.3 = s32[] constant(3) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE at entry root of Tuple instruction only propgates liveness +// to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that GTE at entry root of nested Tuple instruction only propgates +// liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE of GTE (at entry root) of nested Tuple instruction only +// propgates liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {})); + + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_FALSE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_FALSE( + liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Test that live/dead while tuple elements are marked live/dead correctly. +TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a tuple element live in while.cond computation, propagates +// liveness to while.body.root/while.result/while.operand (where it is unused). +TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 + add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(add.1, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a use of while.result{0} propagates liveness to +// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. +TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2 + multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3) + ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.1 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + } + ENTRY SimpleLoop { + constant.2 = s32[] constant(0) + constant.3 = s32[] constant(1) + constant.4 = s32[] constant(2) + tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4) + while.1 = (s32[], s32[], s32[]) while(tuple.2), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0 + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2})); + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2})); + // While body root. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2})); + // While body param. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c33bdadf1c7145bf2aff09b01423c6c21382da0c..b57c940238f0672692e3b65827f43e2f5499502d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -186,6 +187,7 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); +HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); @@ -194,6 +196,7 @@ HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); +HLO_MATCHER(Xor); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); @@ -324,6 +327,12 @@ inline ::testing::Matcher Sharding( return ::testing::MakeMatcher( new ::xla::testing::HloShardingMatcher(sharding)); } +// Matcher for Sharding from sharding string +inline ::testing::Matcher Sharding( + tensorflow::StringPiece sharding) { + return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( + ParseSharding(sharding).ValueOrDie())); +} // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { return ::testing::MakeMatcher( diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 016cc01e33840aa195dfc0a21e8ac8f3d24a3e06..9a3010cf1ff75e840130d8442bbe26d6041cef25 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -147,6 +147,18 @@ TEST(HloMatchersTest, ShardingMatcher) { "param.1"); p1->set_sharding(HloSharding::AssignDevice(1)); + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}), + ShapeUtil::MakeShape(F32, {11})}); + auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2"); + Array assignment({2}); + assignment.SetValues({0, 1}); + auto sharding = HloSharding::Tuple( + tuple_shape, + {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), + HloSharding::AssignDevice(1), HloSharding::Replicate()}); + p2->set_sharding(sharding); + EXPECT_THAT(p0.get(), op::NoSharding()); EXPECT_THAT(p0.get(), ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); @@ -155,6 +167,11 @@ TEST(HloMatchersTest, ShardingMatcher) { ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + EXPECT_THAT( + p2.get(), + op::Sharding( + "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " "{maximal device=1})"); @@ -178,7 +195,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fbf1d58007e318a8a08aa9e11d9d54811533703e..39bc25ba42c2cb6a9f77e2726405311ba13b3edc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -32,15 +32,6 @@ limitations under the License. namespace xla { -HloModule::HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), - config_(config), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - unique_id_(next_unique_module_id_++) {} - HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), @@ -67,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_host_entry_computation_layout()) { + if (!config_.has_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -234,21 +225,17 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle) { + const HloModuleProto& proto, const HloModuleConfig& module_config) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK( - expected_program_shape.parameters_size() == - module_config.device_entry_computation_layout().parameter_count()); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.device_entry_computation_layout() - .parameter_layout(i) - .shape(); + module_config.entry_computation_layout().parameter_layout(i).shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -258,7 +245,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.device_entry_computation_layout().result_layout().shape(); + module_config.entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -287,8 +274,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); + auto module = MakeUnique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -338,7 +324,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_host_entry_computation_layout(); + module_config.mutable_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -346,9 +332,6 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); - *module_config.mutable_device_entry_computation_layout() = - module_config.host_entry_computation_layout(); - return module_config; } @@ -401,7 +384,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( - parameter_count, old_operand->shape(), "")); + parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( @@ -462,7 +445,7 @@ int64 HloModule::instruction_count() const { return n; } -std::list HloModule::MakeComputationPostOrder() const { +std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). @@ -480,7 +463,7 @@ std::list HloModule::MakeComputationPostOrder() const { // order. This prevents duplication as an embedded computation may be called // from two different root computations. std::set added_computations; - std::list post_order; + std::vector post_order; for (auto& computation : computations_) { if (nonroot_computations.count(computation.get()) == 0) { for (HloComputation* embedded_computation : @@ -496,7 +479,18 @@ std::list HloModule::MakeComputationPostOrder() const { added_computations.insert(computation.get()); } } - CHECK_EQ(post_order.size(), computations_.size()); + if (post_order.size() != computations_.size()) { + for (HloComputation* computation : post_order) { + LOG(ERROR) << "Post Order: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + for (auto& computation : computations_) { + LOG(ERROR) << "Computations: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() + << " computation_count=" << computations_.size(); + } return post_order; } @@ -514,57 +508,26 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = MakeUnique(name_ + "-" + suffix, config_); - module->entry_computation_handle_ = entry_computation_handle_; - module->has_entry_computation_handle_ = has_entry_computation_handle_; - std::unordered_map clone_map; - for (auto& computation : computations_) { - if (computation->IsFusionComputation()) { - // Cloning of a fused computation is handled by its fusion instruction. - continue; - } - - // When cloning a computation, pass in the new module, so that for any - // fusion instruction in this computation, the fused computation will be - // deep cloned to the new module. - auto cloned_computation = computation->Clone(suffix, module.get()); - InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); - - if (entry_computation_ == computation.get()) { - module->AddEntryComputation(std::move(cloned_computation)); - } else { - module->AddEmbeddedComputation(std::move(cloned_computation)); - } - } - - for (auto& cloned_computation : module->computations_) { - for (auto* instruction : cloned_computation->instructions()) { - // Rewrite instruction's called_computation to point to the cloned - // computations. - instruction->ReplaceCalledComputations([&](HloComputation* hlo) { - if (hlo->IsFusionComputation()) { - // Cloning of a fused computation has already been handled when its - // fusion instruction is cloned. So this hlo computation is already - // the cloned one. - return hlo; - } - return FindOrDie(clone_map, hlo); - }); - } - } + HloCloneContext context(module.get(), suffix); + auto cloned_computation = entry_computation_->Clone(suffix, &context); + module->AddEntryComputation(std::move(cloned_computation)); return module; } -HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { - HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); - TF_CHECK_OK( - clone->root_instruction()->Accept([this](HloInstruction* instruction) { - instruction->ReplaceCalledComputations([this](HloComputation* callee) { - return DeepCloneComputation(callee); - }); - return Status::OK(); - })); - return clone; +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, + HloCloneContext* context) { + HloComputation* new_computation; + if (context != nullptr) { + if ((new_computation = context->FindComputation(computation)) != nullptr) { + return new_computation; + } + new_computation = + AddEmbeddedComputation(computation->Clone(context->suffix(), context)); + } else { + new_computation = AddEmbeddedComputation(computation->Clone("")); + } + return new_computation; } uint64 HloModule::RandomNew64() const { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 02918c377776b73f2086fe41afc406567a12af4c..d2e726a0db63f622cd5092d56b4f746232d04aad 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -56,10 +56,6 @@ namespace xla { // attached to. class HloModule { public: - HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config); - // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation @@ -94,8 +90,10 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all - // the called computations as well. - HloComputation* DeepCloneComputation(HloComputation* computation); + // the called computations as well. If the clone context is specified, it + // will be populated with the cloned object mappings. + HloComputation* DeepCloneComputation(HloComputation* computation, + HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { @@ -107,24 +105,19 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_host_entry_computation_layout() { - return config_.mutable_host_entry_computation_layout(); - } - - const ComputationLayout& host_entry_computation_layout() const { - return config_.host_entry_computation_layout(); + // Creates the ComputationLayout which describes the current status of the HLO + // module entry computation. + ComputationLayout compute_computation_layout() const { + return ComputationLayout(entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); } - ComputationLayout* mutable_device_entry_computation_layout() { - return config_.mutable_device_entry_computation_layout(); + ComputationLayout* mutable_entry_computation_layout() { + return config_.mutable_entry_computation_layout(); } - const ComputationLayout& device_entry_computation_layout() const { - return config_.device_entry_computation_layout(); - } - - const VersionedComputationHandle& entry_computation_handle() const { - return entry_computation_handle_; + const ComputationLayout& entry_computation_layout() const { + return config_.entry_computation_layout(); } // Gets the computations in this module. @@ -160,7 +153,7 @@ class HloModule { // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. - std::list MakeComputationPostOrder() const; + std::vector MakeComputationPostOrder() const; // Gets the computations in this module which aren't for fusion nodes. // @@ -185,9 +178,7 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle = - VersionedComputationHandle()); + const HloModuleProto& proto, const HloModuleConfig& module_config); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. @@ -261,10 +252,6 @@ class HloModule { mutable std::mt19937_64 rng_{42}; mutable tensorflow::mutex rng_mutex_; - // Versioned handle of the entry computation of the module. - bool has_entry_computation_handle_ = false; - VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/*separator=*/"."}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index dae5578a3158fecb8219e518841dec1020b2ca98..07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -28,16 +28,14 @@ namespace xla { using tensorflow::strings::StrAppend; -HloModuleConfig::HloModuleConfig() {} - -HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : host_entry_computation_layout_(program_shape), - device_entry_computation_layout_(program_shape) {} +HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts) + : entry_computation_layout_( + ComputationLayout(program_shape, ignore_layouts)) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - host_entry_computation_layout_ = ComputationLayout(program_shape); - device_entry_computation_layout_ = ComputationLayout(program_shape); + entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - host_entry_computation_layout_->parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - host_entry_computation_layout_->result_shape().SerializeAsString()); - for (const ShapeLayout& param_layout : - device_entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); - } - StrAppend( - &key, tensorflow::str_util::Join(params, ", "), ") => ", - device_entry_computation_layout_->result_shape().SerializeAsString()); + entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index cdb0b29a2399b387bc617262032e9083ba079625..074e9c90705d432b8344aebaf3c15aeb41a59fa3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -37,48 +37,34 @@ class HloModuleConfig { // ComputationLayout. The default ctor creates it without -- in this case // accessing entry_computation_layout will CHECK-fail. The ctor accepting a // ProgramShape creates a computation layout using this shape. - HloModuleConfig(); - explicit HloModuleConfig(const ProgramShape& program_shape); + // The layouts in the ProgramShape will be reset to default unless + // ignore_layouts is set to false. + HloModuleConfig() = default; - // Checks if this config has an entry computation layout already. - bool has_host_entry_computation_layout() const { - return host_entry_computation_layout_.has_value(); - } + explicit HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts = true); - bool has_device_entry_computation_layout() const { - return device_entry_computation_layout_.has_value(); + // Checks if this config has an entry computation layout already. + bool has_entry_computation_layout() const { + return entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the on-host layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& host_entry_computation_layout() const { - CHECK(host_entry_computation_layout_.has_value()); - return *host_entry_computation_layout_; - } - - // Returns a mutable pointer to the layout of the on-host entry computation. + // Returns a constant reference to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_host_entry_computation_layout() { - CHECK(host_entry_computation_layout_.has_value()); - return &(*host_entry_computation_layout_); - } - - // Returns a constant reference to the on-device layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& device_entry_computation_layout() const { - CHECK(device_entry_computation_layout_.has_value()); - return *device_entry_computation_layout_; + const ComputationLayout& entry_computation_layout() const { + CHECK(entry_computation_layout_.has_value()); + return *entry_computation_layout_; } - // Returns a mutable pointer to the layout of the on-device entry computation. + // Returns a mutable pointer to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_device_entry_computation_layout() { - CHECK(device_entry_computation_layout_.has_value()); - return &(*device_entry_computation_layout_); + ComputationLayout* mutable_entry_computation_layout() { + CHECK(entry_computation_layout_.has_value()); + return &(*entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -127,8 +113,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional host_entry_computation_layout_; - tensorflow::gtl::optional device_entry_computation_layout_; + tensorflow::gtl::optional entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc new file mode 100644 index 0000000000000000000000000000000000000000..98d20315e399c6b1a3979b5d11a89ef93869f4d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool HasSendRecv(HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecv || + instruction->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (auto* sub_computation : instruction->called_computations()) { + if (HasSendRecv(sub_computation)) { + return true; + } + } + } + return false; +} + +StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + + const auto* xla_while = instruction; + auto* while_body_comp = xla_while->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + + if (!ShapeUtil::IsTuple(xla_while->shape()) || + while_body_root->opcode() != HloOpcode::kTuple || + HasSendRecv(while_body_comp)) { + // Only run DCE on tuple-shaped while loops where body root is Tuple, + // with no send/recv instructions. + VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); + continue; + } + + // Remove dead tuple elements. + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(xla_while->shape()); + for (int64 i = 0; i < tuple_element_count; ++i) { + if (liveness->IsLive(xla_while, {i})) { + continue; + } + VLOG(1) << "WhileDCE Dead while tuple element." + << " while: " << xla_while->name() << " tuple_index: " << i; + // Transform while.body computation to make tuple element at + // 'shape_index' as simple pass-through parameter (which candidate + // be removed later by simplification pass). + HloInstruction* pass_thru_gte = while_body_comp->AddInstruction( + HloInstruction::CreateGetTupleElement( + while_body_param->shape().tuple_shapes(i), while_body_param, + i)); + // Replace while.body.root Tuple operand at 'tuple_index' with + // 'pass_thru_gte', making prior operand a dead root (to be cleaned + // up with a subsequent DCE pass). + TF_RETURN_IF_ERROR( + while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + changed = true; + } + } + } + return changed; +} + +} // namespace + +StatusOr HloModuleDCE::Run(HloModule* module) { + VLOG(2) << "Before HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + std::unique_ptr liveness; + TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + + // Sweep through while instructions, transforming dead while tuple element + // computations to pass through tuple values (creating dead roots in while + // body computation in the process). + TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get())); + + // Run HloDCE to clean up any dead code created during HloModuleDCE. + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module)); + + VLOG(2) << "After HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + return hlo_module_dce_changed | hlo_dce_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h new file mode 100644 index 0000000000000000000000000000000000000000..29024085c1038961ef2b3721de1ce0e8a55ccf45 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes dead code from computations in the module using +// HloModule-scoped analysis (HloLivenessAnalysis). +// +// Sweeps through live instructions which cross computation boundaries (kWhile), +// and removes code at dead shape indices. +// +class HloModuleDCE : public HloPassInterface { + public: + ~HloModuleDCE() override {} + tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..363862e4905fc13a4ef07aeaac255259fc6b86ba --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloModuleDceTest : public HloTestBase { + protected: + HloModuleDceTest() {} + + // Returns whether the given instruction exists in the given computation. + bool HasInstruction(const HloComputation& computation, + const HloInstruction* instruction) { + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); + } + + // Returns whether the while instruction with name 'while_name' in + // 'computation' passes through its tuple element at 'tuple_index' from + // parameter to root instruction. + bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation, + const string& while_name, + const int64 tuple_index) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile && + instruction->name() == while_name) { + auto* while_body_comp = instruction->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + auto* operand = while_body_root->operand(tuple_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->tuple_index() == tuple_index && + operand->operand(0) == while_body_param) { + return true; + } + return false; + } + } + return false; + } +}; + +// Tests that a while with all outputs live is unmodified. +TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests a while loop with one unused output (which is used in the while loop +// body by an instruction with side-effects: rng) is unmodified. +TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], f32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 + constant.2 = f32[] constant(1.0) + rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform + add.1 = s32[] add(get-tuple-element.2, constant.2) + ROOT tuple = (s32[], f32[]) tuple(add, add.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.3 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + } + ENTRY SimpleLoop { + constant.4 = s32[] constant(0) + constant.5 = f32[] constant(0.0) + tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5) + while = (s32[], f32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a while loop with one dead tuple element at {1} has its while +// loop body modified to make that tuple element pass-through the while body. +TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} should now be pass-through after ModuleDCE. + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a tuple element {1} used by condition computation (which appears +// dead in while.body{1} and at while.result{1}) propgates liveness of this +// tuple element to while.body{1} and at while.result{1}. +TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[]) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4) + while = (s32[], s32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} still be pass-through after ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at index {1} between +// two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6) + while.1 = (s32[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1 and while.2 should have pass-thru elements, + // after being modified to pass through unused tuple element {1}. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and +// while.2{1}, between two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5) + while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index b4cd3c730e323b8459312edbebc564e08f9d6840..6bcd7b042dfddfea6ac86365b82f8077be2a6101 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -87,6 +87,7 @@ Status HloModuleGroupMetadata::Build() { << "Peer instruction does not match the computation kind"; TF_RETURN_IF_ERROR( AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); } // Add the parents of companion instructions (they must be all of the same @@ -112,27 +113,43 @@ Status HloModuleGroupMetadata::Build() { } } TF_RETURN_IF_ERROR(VerifyCompanionSets()); + if (VLOG_IS_ON(4)) { + DumpCollectedStats(); + } return Status::OK(); } Status HloModuleGroupMetadata::VerifyCompanionSets() const { - // TODO(dlibenzi): Migrate this to use the device instead of module ID, once - // the kDomain CL goes in. for (const auto& companions : companion_sets_) { // A companion set must be composed at most of an instruction per // device/module. std::unordered_set devices; for (HloInstruction* instruction : *companions) { - int64 device = GetModuleId(instruction->parent()->parent()); - if (!devices.insert(device).second) { - std::stringstream ss; - ss << "Companion set:" << std::endl; - for (HloInstruction* hlo : *companions) { - ss << " " << hlo->name() << " (" - << GetModuleId(hlo->parent()->parent()) << ")" << std::endl; + // Go through all the communicating instructions (send, recv) of the given + // companion, and record their device. + auto it = tracked_instructions_comms_.find(instruction); + if (it == tracked_instructions_comms_.end()) { + // Companions can be added even if they have no communicating + // instructions, if they are parent of companions. + continue; + } + std::unordered_set comm_devices; + for (HloInstruction* comm_instruction : it->second) { + auto device = GetInstructionDevice(*comm_instruction); + TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() + << " does not have a device"; + comm_devices.insert(*device); + } + for (int64 device : comm_devices) { + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); } - ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); } } } @@ -223,6 +240,28 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } +tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( + const HloInstruction& instruction) const { + // The module group metadata can be created in both "single module, multiple + // devices" and "multiple modules, no explicit devices" fashions. + // The API returns an optional even though the current implementation always + // returns a device, to account for cases where we cannot guess a device. + // In such cases the VerifyChannelInstructions() will return proper errors. + tensorflow::gtl::optional device = + instruction.sharding_unique_device(); + if (!device) { + device = GetModuleId(instruction.parent()->parent()); + } + return device; +} + +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { @@ -284,6 +323,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + VLOG(2) << "Created " << channels_.size() << " channels"; return Status::OK(); } @@ -342,30 +382,43 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { // Check if the shapes match for each channel. for (const Channel& channel : channels_) { const Shape& send_shape = channel.send->operand(0)->shape(); - const Shape& recv_shape = channel.recv_done->shape(); + const Shape& recv_shape = + ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0); if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } - const HloModule* send_module = channel.send->parent()->parent(); - const HloModule* send_done_module = channel.send_done->parent()->parent(); - if (send_module != send_done_module) { + auto send_device = GetInstructionDevice(*channel.send); + auto send_done_device = GetInstructionDevice(*channel.send_done); + if (!send_device) { + return FailedPrecondition("send instruction must have a device: %s", + channel.send->ToString().c_str()); + } + if (!send_done_device) { + return FailedPrecondition("send_done instruction must have a device: %s", + channel.send_done->ToString().c_str()); + } + if (*send_device != *send_done_device) { return FailedPrecondition( "send and send-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + channel.id, *send_device, *send_done_device); } - const HloModule* recv_module = channel.recv->parent()->parent(); - const HloModule* recv_done_module = channel.recv_done->parent()->parent(); - if (recv_module != recv_done_module) { + auto recv_device = GetInstructionDevice(*channel.recv); + auto recv_done_device = GetInstructionDevice(*channel.recv_done); + if (!recv_done_device) { + return FailedPrecondition("recv_done instruction must have a device: %s", + channel.recv_done->ToString().c_str()); + } + if (*recv_device != *recv_done_device) { return FailedPrecondition( "recv and recv-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + channel.id, *recv_device, *recv_done_device); } - if (send_module == recv_module) { + if (*send_device == *recv_device) { return FailedPrecondition( "send and recv (channel=%lld) must be on different devices: %lld", - channel.id, GetModuleId(send_module)); + channel.id, *send_device); } } @@ -402,4 +455,36 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( return FailedPrecondition("channel is used in disallowed computation"); } +void HloModuleGroupMetadata::DumpCollectedStats() const { + std::map, int64> communication_histogram; + for (auto& channel : channels_) { + auto from_device = GetInstructionDevice(*channel.send); + auto to_device = GetInstructionDevice(*channel.recv); + LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device + << " to_device=" << *to_device << " send=" << channel.send->name() + << " send_done=" << channel.send_done->name() + << " recv=" << channel.recv->name() + << " recv_done=" << channel.recv_done->name(); + communication_histogram[std::pair(*from_device, + *to_device)] += 1; + } + for (auto& fromto_count : communication_histogram) { + LOG(INFO) << "From " << fromto_count.first.first << " to " + << fromto_count.first.second << ": " << fromto_count.second; + } + for (auto& companion_set : companion_sets_) { + LOG(INFO) << "Companion set:"; + for (HloInstruction* instruction : *companion_set) { + LOG(INFO) << " " << instruction->name(); + } + } + for (auto& instruction_comm : tracked_instructions_comms_) { + LOG(INFO) << "Communicating instruction " << instruction_comm.first->name(); + for (HloInstruction* instruction : instruction_comm.second) { + auto device = GetInstructionDevice(*instruction); + LOG(INFO) << " " << instruction->name() << " on device " << *device; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 3ef4542f9129632de4975688ae7e9e2c5f43a7ee..ffde3a332dfc141ca928a44cfdf4686900e9f47b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -148,6 +149,15 @@ class HloModuleGroupMetadata { // the module in the module vector. int64 GetModuleId(const HloModule* module) const; + // Retrieves the device an instruction is assigned to. Either from the + // sharding information, or from the ordinal of the module the instruction + // is in. + tensorflow::gtl::optional GetInstructionDevice( + const HloInstruction& instruction) const; + + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. @@ -220,6 +230,9 @@ class HloModuleGroupMetadata { return it != tracked_instructions_.end() ? &it->second : nullptr; } + // Dump all the collected module group statistics to the logs. + void DumpCollectedStats() const; + // List of all companion instructions sets in the module. std::vector>> companion_sets_; @@ -231,6 +244,11 @@ class HloModuleGroupMetadata { tensorflow::gtl::FlatMap tracked_instructions_; + // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of + // communicating instructions within the proper called computation(s). + tensorflow::gtl::FlatMap> + tracked_instructions_comms_; + // All channels in the module. std::vector channels_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 5a0d1e264eb5095ff53721416ebcf4842a063f97..21a9b7291acc9e0066a9061facd13ab5acbf0bac 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( tensorflow::gtl::ArraySlice computations) { - std::list post_order; + std::vector post_order; auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ac7cd2f2f517cf8831416d9265fc48bbf9fce340..39e12c48157992410a5d3b733720d677a1191611 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -69,6 +69,7 @@ namespace xla { V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ + V(kDomain, "domain") \ V(kDot, "dot") \ V(kDynamicSlice, "dynamic-slice") \ V(kDynamicUpdateSlice, "dynamic-update-slice") \ @@ -80,6 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ @@ -92,6 +94,7 @@ namespace xla { V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ + V(kXor, "xor") \ V(kLt, "less-than", kHloOpcodeIsComparison) \ V(kMap, "map", kHloOpcodeIsVariadic) \ V(kMaximum, "maximum") \ @@ -130,6 +133,7 @@ namespace xla { V(kTrace, "trace") \ V(kTranspose, "transpose") \ V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kTupleSelect, "tuple-select") \ V(kWhile, "while") enum class HloOpcode { diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f030c65b889d67e082a3677b8739ddb..6f3f83f63a05fafaa3f3ddcff8a7cac7cb7b06d5 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index dcd4725fe78e8b9b5d14437e964cb5aaf1664117..6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition( << " and def is in FALSE computation"; return true; } + if (value.defining_instruction() == use.instruction) { + VLOG(4) << " use is conditional " << use << " and def is " + << value.ToShortString(); + return true; + } } VLOG(4) << " use is not before value"; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index ee526d8dd7f7e81b3a846741d3e452935f486bd2..985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -183,6 +183,10 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: + // TODO(dimvar): HloModuleSequence is not a good name because it sounds like + // a sequence of modules, instead of a map of schedules for all computations + // in a module. We should change it at some point. + // // A sequence of instructions for each computation in the module. using HloModuleSequence = tensorflow::gtl::FlatMap module, - tools::Parse(module_str)); + ParseHloString(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } @@ -347,7 +347,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc similarity index 83% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.cc rename to tensorflow/compiler/xla/service/hlo_parser.cc index d0e7af8844203da93dac5b45cb7e13916448dd47..f192debc9c75e49d0be09c1a069a20343685a134 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -24,18 +26,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -namespace tools { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Join; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using ::tensorflow::StringPiece; +using ::tensorflow::gtl::optional; +using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsInts; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -56,6 +57,11 @@ class HloParser { // Returns the error information. string GetError() const { return Join(error_, "\n"); } + // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); + private: // ParseXXX returns false if an error occurred. bool ParseHloModule(); @@ -78,11 +84,15 @@ class HloParser { // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. - bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + Literal* literal); + bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + Literal* literal); template - bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + bool SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal); bool ParseOperands(std::vector* operands); @@ -94,9 +104,15 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr entry_metadata; + std::unique_ptr exit_metadata; }; // Types of attributes. @@ -117,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -164,21 +181,27 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector* result); + const TokKind delim, + std::vector* result); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -190,7 +213,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParseInt64(int64* result); + bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -304,22 +327,15 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(p) - ->CopyLayoutFromShape(param_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(result_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } - return true; } @@ -384,6 +400,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } + instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } @@ -447,7 +464,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - int64 parameter_number; + tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || @@ -492,7 +509,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -521,6 +537,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { @@ -534,7 +551,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Ternary ops. case HloOpcode::kClamp: - case HloOpcode::kSelect: { + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: { if (!ParseOperands(&operands, /*expected_size=*/3) || !ParseAttributes(attrs)) { return false; @@ -563,11 +581,31 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional to_apply; + optional> replica_group_ids; + optional barrier; + optional all_reduce_id; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + attrs["replica_group_ids"] = { + /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, + &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + if (replica_group_ids) { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, *replica_group_ids, + barrier ? *barrier : "", all_reduce_id)); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, {}, barrier ? *barrier : "", + all_reduce_id)); + } break; } case HloOpcode::kReshape: { @@ -579,6 +617,35 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kAfterAll: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); + break; + } + case HloOpcode::kSort: { + auto loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + switch (operands.size()) { + case 1: + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, /*keys=*/operands[0])); + break; + case 2: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, + /*keys=*/operands[0], /*values=*/operands[1])); + break; + default: + return Error(loc, StrCat("expects either 1 or 2 operands, but has ", + operands.size(), " operands")); + } + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; @@ -602,18 +669,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/0) || + if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + instruction = builder->AddInstruction(HloInstruction::CreateRecv( + shape.tuple_shapes(0), operands[0], *channel_id)); break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -627,18 +694,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/1) || + if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], *channel_id)); + HloInstruction::CreateSend(operands[0], operands[1], *channel_id)); break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -652,7 +719,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -710,7 +777,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -723,7 +790,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -735,7 +802,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -750,6 +817,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -761,7 +831,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -774,7 +844,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -818,7 +888,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -842,7 +912,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -856,7 +926,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -872,7 +942,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -889,7 +959,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -929,23 +999,53 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateInfeed(shape, config ? *config : "")); + // We need to know the infeed data shape to construct the infeed + // instruction. This is the zero-th element of the tuple-shaped output of + // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail + // if the shape is not a non-empty tuple, so add guard so an error message + // can be emitted instead of a check fail + if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + return Error(lexer_.GetLoc(), + "infeed must have a non-empty tuple shape"); + } + + if (operands.empty()) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); + } else if (operands.size() == 1) { + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "infeed must have exactly zero or one operands"); + } break; } case HloOpcode::kOutfeed: { optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); + if (operands.size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + operands[0]->shape(), operands[0], config ? *config : "")); + } else if (operands.size() == 2) { + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "outfeed must have exactly one or two operands"); + } break; } case HloOpcode::kRng: { @@ -960,8 +1060,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1006,7 +1106,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kHostCompute: { optional channel_name; - optional cost_estimate_ns; + optional cost_estimate_ns; attrs["channel_name"] = {/*required=*/true, AttrTy::kString, &channel_name}; attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, @@ -1019,16 +1119,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; @@ -1060,20 +1160,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; + optional> output_window_dims; attrs["output_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; + optional> elided_window_dims; attrs["elided_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; + optional> gather_dims_to_operand_dims; attrs["gather_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &gather_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; + optional> window_bounds; attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, &window_bounds}; @@ -1093,12 +1193,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.exit_metadata), + std::move(domain.entry_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - instruction->set_name(name); + instruction->SetAndSanitizeName(name); + if (instruction->name() != name) { + return Error(name_loc, + StrCat("illegal instruction name: ", name, + "; suggest renaming to: ", instruction->name())); + } // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { @@ -1118,7 +1235,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_metadata(*metadata); } if (backend_config) { - instruction->set_backend_config(std::move(*backend_config)); + instruction->set_raw_backend_config_string(std::move(*backend_config)); } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1169,8 +1286,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { @@ -1197,7 +1314,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - int64 dim; + tensorflow::int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1209,7 +1326,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - int64 device; + tensorflow::int64 device; if (!ParseInt64(&device)) { return false; } @@ -1268,10 +1385,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); *sharding->mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment_dimensions) { + for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (int64 device : devices) { + for (tensorflow::int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1280,6 +1397,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; + optional entry_sharding; + optional exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector* instructions) { @@ -1306,40 +1451,50 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, +bool HloParser::SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, int64 linear_index, +bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1350,7 +1505,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, } } -bool HloParser::SetValueInLiteral(bool value, int64 linear_index, +bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -1363,7 +1518,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index, } template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal) { // Check that linear_index is in range. if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { @@ -1475,7 +1631,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; } @@ -1483,8 +1639,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // Create a literal with the given shape in default layout. *literal = Literal::CreateFromDimensions(shape.element_type(), AsInt64Slice(shape.dimensions())); - int64 nest_level = 0; - int64 linear_index = 0; + tensorflow::int64 nest_level = 0; + tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -1492,14 +1648,14 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), - elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim( + elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", Join(elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); }), "]"); }; @@ -1575,7 +1731,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { LocTy loc = lexer_.GetLoc(); - int64 value; + tensorflow::int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -1615,29 +1771,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, switch (shape.element_type()) { case PRED: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F32: return ParseSparseLiteralHelper(literal, shape); case BF16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F64: return ParseSparseLiteralHelper(literal, shape); default: @@ -1650,9 +1806,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, template bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, const Shape& shape) { - std::vector index; + std::vector index; - int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = ShapeUtil::Rank(shape); *literal = MakeUnique(shape); @@ -1670,7 +1826,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, LocTy index_loc = lexer_.GetLoc(); index.clear(); if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); + tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); if (rank != 1) { return Error( @@ -1703,7 +1859,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, value = static_cast(lexer_.GetKind() == TokKind::kw_true); lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value_s64; + tensorflow::int64 value_s64; if (!ParseInt64(&value_s64)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -1876,23 +2032,24 @@ bool HloParser::ParseAttributeHelper( LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(result); return true; } case AttrTy::kInt32: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -1926,7 +2083,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -1968,12 +2125,12 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2018,6 +2175,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast(attr_out_ptr)); + } } }(); if (!success) { @@ -2044,9 +2204,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2056,7 +2217,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2120,7 +2283,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2144,7 +2308,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( << str; } - const int64 rank = lhs_rhs_out[0].length(); + const tensorflow::int64 rank = lhs_rhs_out[0].length(); if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2258,7 +2422,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2292,7 +2456,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { if (!ParseToken(start, StrCat("expects an int64 list starting with ", TokKindToString(start)))) { return false; @@ -2301,7 +2465,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // empty } else { do { - int64 i; + tensorflow::int64 i; if (!ParseInt64(&i)) { return false; } @@ -2418,7 +2582,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParser::ParseDxD(const string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, @@ -2426,7 +2591,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 1D if (lexer_.GetKind() == TokKind::kInt) { - int64 number; + tensorflow::int64 number; if (!ParseInt64(&number)) { return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } @@ -2446,7 +2611,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParser::ParseWindowPad( + std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -2457,7 +2623,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (int i = 0; i < padding_str.size(); i++) { - std::vector low_high; + std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -2481,7 +2647,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -2503,7 +2669,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -2590,7 +2756,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -2673,10 +2839,48 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShardingOnly() { + lexer_.Lex(); + OpSharding op_sharding; + if (!ParseSharding(&op_sharding)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after sharding"); + } + return HloSharding::FromProto(op_sharding); +} + +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace -StatusOr> Parse(StringPiece str, - const HloModuleConfig& config) { +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -2684,10 +2888,29 @@ StatusOr> Parse(StringPiece str, return parser.ConsumeHloModule(); } -StatusOr> Parse(StringPiece str) { +StatusOr> ParseHloString( + tensorflow::StringPiece str) { + HloModuleConfig config; + return ParseHloString(str, config); +} + +StatusOr ParseSharding(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseShardingOnly(); +} + +StatusOr ParseWindow(tensorflow::StringPiece str) { HloModuleConfig config; - return Parse(str, config); + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h similarity index 52% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.h rename to tensorflow/compiler/xla/service/hlo_parser.h index 2f97a2b9b19d0cdb64a2869913da62c55e14c1d5..3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -13,30 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace tools { + +// For details about the syntax accepted by this parser, see +// g3doc/hlo_parser.md. // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. -StatusOr> Parse(tensorflow::StringPiece str, - const HloModuleConfig& config); +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> Parse(tensorflow::StringPiece str); +StatusOr> ParseHloString( + tensorflow::StringPiece str); + +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +StatusOr ParseSharding(tensorflow::StringPiece str); + +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +StatusOr ParseSharding(tensorflow::StringPiece str); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc similarity index 82% rename from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc rename to tensorflow/compiler/xla/service/hlo_parser_test.cc index 131aded95ab04c4327c275ed8cd18b8fc7ac1bd6..88f3309baa4150c14f44a3db0b412fe80e22293c 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { -namespace tools { + namespace { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; struct TestData { string test_name; @@ -233,6 +234,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} } +)" +}, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + )" }, // int32 result = 0; @@ -265,12 +277,13 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { "SendRecv", R"(HloModule TwoSendRecvBothWayRecvFist_module -ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} - ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } )" @@ -753,7 +766,7 @@ add_F32.v3 { ENTRY MapBinaryAdder.v3 { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) - ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 + ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3 } )" @@ -783,10 +796,14 @@ ENTRY ReduceR3ToR2.v3 { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - infeed = (u32[3]{0}, pred[]) infeed() - outfeed = () outfeed(infeed) - ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() - outfeed.1 = () outfeed(infeed.1) + token = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) } )" @@ -814,6 +831,31 @@ ENTRY ReducePrecision { ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 } +)" +}, +// Sort (Key) +{ +"SortKey", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x) +} + +)" +}, +// Sort (Key, Value) +{ +"SortKeyValue", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values) +} + )" }, // Conditional @@ -888,6 +930,42 @@ ENTRY Gather { )" }, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add +} + +)" +}, +// cross-replica-sum with subgroups +{ +"CrossReplicaSumWithSubgroups", +R"(HloModule CRS_Subgroups + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CrossReplicaSumWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add +} + +)" +} }); // clang-format on } @@ -900,12 +978,12 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - // Expects "ToString(Parse(string)) == string", that is, parses the string, - // asserts that it succeeded, stringifies the parsed module, and checks that - // the it equals the original string. + // Expects "ToString(ParseHloString(string)) == string", that is, parses the + // string, asserts that it succeeded, stringifies the parsed module, and + // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString( HloPrintOptions().set_print_large_constants(true))); @@ -916,7 +994,7 @@ class HloParserShortTest : public HloParserTest { protected: void ExpectEqualShort() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); @@ -937,13 +1015,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -957,7 +1035,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -969,7 +1047,7 @@ ENTRY %blabla (x: g32[]) -> g32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -982,7 +1060,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -993,7 +1071,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -1008,7 +1086,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // Constant instructions have no name. The string will be parsed successfully // but the constant names will not be exactly the same. @@ -1019,12 +1097,12 @@ TEST_F(HloParserTest, ConfigurationField) { ENTRY %configuration_test() -> s32[] { %constant = s32[] constant(42), backend_config="foo bar" })"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ("foo bar", result.ValueOrDie() ->entry_computation() ->root_instruction() - ->backend_config()); + ->raw_backend_config_string()); } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { @@ -1035,7 +1113,7 @@ ENTRY %some_2 () -> f32[2] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); @@ -1049,7 +1127,7 @@ ENTRY %some_2x3 () -> f32[2,3] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); @@ -1063,7 +1141,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); @@ -1078,7 +1156,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); @@ -1092,7 +1170,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // The string will be parsed successfully but the output strings are not // exactly the same, because "3e2" is parsed into value 300 and will be @@ -1110,7 +1188,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, InvalidDimLabels) { @@ -1126,32 +1204,34 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; + ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + ExpectHasSubstr( - Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), - "expects dim labels pattern"); - - ExpectHasSubstr(Parse(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) - .status() - .error_message(), - "must have the same rank"); + "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "unexpected attribute \"calls\""); } @@ -1159,15 +1239,16 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[]) send(f32[] %constant) - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "attribute channel_id is expected but not seen"); } @@ -1175,15 +1256,16 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "'done' is not defined"); } @@ -1196,7 +1278,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { @@ -1210,7 +1292,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects padding_low and padding_high separated by '_'"); } @@ -1222,7 +1304,7 @@ ENTRY %test_comma.v4 () -> f32[] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { @@ -1232,7 +1314,7 @@ ENTRY %CustomCall () -> f32[1] { %constant = f32[1]{0} constant({12345}) ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "Shape of computation CustomCall, f32[1], is not compatible " "with that of its root instruction foo, f32[1,2,3]"); } @@ -1251,9 +1333,9 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); @@ -1274,7 +1356,7 @@ c1 { c2 { const2 = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); } @@ -1285,7 +1367,7 @@ ENTRY consts { first = f32[1]{0} constant({12345}) last = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ( module.ValueOrDie()->entry_computation()->root_instruction()->name(), @@ -1300,7 +1382,7 @@ ENTRY c1 { ENTRY c2 { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects only one ENTRY"); } @@ -1310,25 +1392,10 @@ ENTRY consts { ROOT const1 = f32[1]{0} constant({12345}) ROOT const2 = f32[1]{0} constant({12345}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "one computation should have only one ROOT"); } -TEST_F(HloParserTest, InstructionExists) { - const string original = R"(HloModule comp_exists -c1 { - instr = f32[1]{0} constant({12345}) -} -c2 { - instr = f32[1]{0} constant({67890}) -})"; - - ExpectHasSubstr(Parse(original).status().error_message(), - R"(was parsing 3:3: error: instruction previously defined here - instr = f32[1]{0} constant({12345}) - ^)"); -} - TEST_F(HloParserTest, ComputationExists) { const string original = R"(HloModule comp_exists comp { @@ -1337,12 +1404,62 @@ comp { comp { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), R"(was parsing 2:1: error: computation previously defined here comp { ^)"); } +TEST_F(HloParserTest, CrossComputationLookup) { + const string original = R"(HloModule cross_computation_lookup: +tcalla (a: (s32[], s32[])) -> (s32[], s32[]) { + ROOT aparam = (s32[], s32[]) parameter(0) +} + +tcallb (b: (s32[], s32[])) -> s32[] { + rparam = (s32[], s32[]) parameter(0) + ROOT gte0 = s32[] get-tuple-element(aparam), index=0 +} + +ENTRY entry { + param = (s32[], s32[]) parameter(0) + call0 = (s32[], s32[]) call(param), to_apply=tcalla + ROOT call1 = s32[] call(param), to_apply=tcallb +})"; + ExpectHasSubstr( + ParseHloString(original).status().error_message(), + "was parsing 8:39: error: instruction does not exist: aparam"); +} + +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + +TEST_F(HloParserTest, NontupleInfeed) { + const string original = R"(HloModule nontuple_infeed: +ENTRY nontuple_infeed { + token = token[] after-all() + ROOT infeed = pred[] infeed(token) +})"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "infeed must have a non-empty tuple shape"); +} + } // namespace -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index d45038f1f4a2e4aa19234eec93fdc9a068a902e1..2418c19f3de7b036d7ef52d3a6db11de6316203b 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -61,7 +61,7 @@ bool AllOperandsAreConstants(const HloInstruction& instruction) { } HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction) { for (HloInstruction* op : instruction->operands()) { if (matcher(op)) { @@ -72,7 +72,7 @@ HloInstruction* GetMatchingOperand( } bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand) { CHECK_EQ(instruction->operand_count(), 2); diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c79347bbf9d6146943b7b787f713369cb37fadee..c0826a6aee1f693484207a86ec258c6604d92318 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -45,7 +45,7 @@ bool IsScalarConstant(const HloInstruction* instruction); // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction); // Returns whether a binary instruction has a matching operand. Sets @@ -53,7 +53,7 @@ HloInstruction* GetMatchingOperand( // other_operand. Note: in the case where both operands match, the first operand // of the instruction is returned. bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand); diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 8e167633bb13476301fa0c4afa0b123c9b47e40d..01b088a957554821e65db7bf9cedf334db49728f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - const std::list& instructions) + tensorflow::gtl::ArraySlice instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { @@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion( const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; + SetReachabilityToUnionHelper(inputs, instruction, &bit_vector); + return bit_vector != tmp_bit_vector_; +} +void HloReachabilityMap::FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); +} + +void HloReachabilityMap::SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { - bit_vector.SetToZero(); + bit_vector->SetToZero(); } - bit_vector.Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector.OrWith(GetBitVector(input)); + bit_vector->OrWith(GetBitVector(input)); } - - return bit_vector != tmp_bit_vector_; } void HloReachabilityMap::SetReachable(const HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 553ec11f6f9a2997ab7113f9b8241e04c7fe20d5..48215d32a8284919cce6beb1663e6a723eefc1c4 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -41,7 +41,8 @@ class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given // instructions. - explicit HloReachabilityMap(const std::list& instructions); + explicit HloReachabilityMap( + tensorflow::gtl::ArraySlice instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where @@ -57,6 +58,11 @@ class HloReachabilityMap { tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency @@ -133,6 +139,11 @@ class HloReachabilityMap { return bit_vectors_[GetIndex(instruction)]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector); + // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 39b85de0f12024f5e20ddd37618987c6d06bc307..0b222f43483405cf1d3f711bab3e8390903f8ded 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -71,6 +72,20 @@ bool IsRematerializable(const HloInstruction* instruction) { } } +// Checks whether an instruction can be rematerialized, by looking up the +// cache before, and eventually calling the IsRematerializable() API. +bool CanBeRematerialized( + const HloInstruction* instruction, + tensorflow::gtl::FlatMap* remat_able) { + auto it = remat_able->find(instruction); + if (it != remat_able->end()) { + return it->second; + } + bool rematerializable = IsRematerializable(instruction); + (*remat_able)[instruction] = rematerializable; + return rematerializable; +} + // Type holding a unique identifier for each Buffer object. using BufferId = int64; using BufferIdList = tensorflow::gtl::InlinedVector; @@ -843,9 +858,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, - const InstructionList& instruction_list, - int64 memory_limit_bytes) { +Item* PickRematerializationCandidate( + const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, int64 memory_limit_bytes, + tensorflow::gtl::FlatMap* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -869,8 +885,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, << " is excluded from rematerialization"; continue; } - - if (!IsRematerializable(candidate)) { + if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; continue; @@ -974,6 +989,9 @@ StatusOr HloRematerialization::RematerializeComputation( // blacklist. tensorflow::gtl::FlatSet remat_move_instructions; + // The map from instructions to their rematerializable status. + tensorflow::gtl::FlatMap remat_able; + // The peak memory of the computation at any point in the instruction // sequence. int64 peak_memory = memory_tracker.memory_usage(); @@ -1011,7 +1029,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes); + memory_tracker, instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1191,6 +1209,10 @@ StatusOr HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); + if (copy_insertion_) { + TF_RETURN_IF_ERROR(copy_insertion_->Run(module).status()); + } + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1213,12 +1235,20 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( *module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); + if (copy_insertion_) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR( + copy_insertion_->RemoveUnnecessaryCopies(ordering, module)); + } + // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1321,8 +1351,9 @@ StatusOr HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes) { - HloRematerialization remat(scheduler_algorithm, size_function); + RematerializationSizes* sizes, CopyInsertion* copy_insertion) { + HloRematerialization remat(std::move(scheduler_algorithm), size_function, + copy_insertion); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ee2dd0571ae8c6604e4ca722351fd48a913bda5..1c72f42b8c6085de43af766ce8084ca059620e53 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,6 +17,7 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -57,6 +58,9 @@ class HloRematerialization { // sizes: Optional outparam that indicates the peak memory usage of the HLO // module before/after rematerialization. // + // copy_insertion: If non-null, run the provided copy insertion pass + // before HLO scheduling. + // // Returns whether any instructions were rematerialized. If memory use is // already below the given limit then no instructions are rematerialized and // false is returned. @@ -68,13 +72,15 @@ class HloRematerialization { const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes = nullptr); + RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) + const ShapeSizeFunction& size_function, + CopyInsertion* copy_insertion) : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + size_function_(size_function), + copy_insertion_(copy_insertion) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -139,6 +145,9 @@ class HloRematerialization { // uses of the original instruction and the original instruction is // dead. Hence, no net instructions were added. int64 net_instructions_added_ = 0; + + // Copy insertion pass that runs before HLO scheduling. + CopyInsertion* copy_insertion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83de54f3fa56ee660b79d8c366dbc0b52f9fde87..fc137c839fc18d01b8aad9e073ad24dc166ebb4a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase { // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[] %param = {...} + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, /*limit_indices=*/{1}, @@ -135,6 +141,16 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence, + CopyInsertion* copy_insertion = nullptr) { + TF_EXPECT_OK(verifier().Run(module).status()); + return HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, + sequence, /*sizes=*/nullptr, copy_insertion); + } + // Various shapes used in the canned computations. const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); @@ -158,11 +174,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -188,18 +202,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -225,23 +237,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_EQ(body_computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -264,20 +274,53 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // Both computations should have a rematerialized instruction added. + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); +} + +// Similar to RematerializeEntryAndWhileBody, except with copy insertion run +// after HLO scheduling. +TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBodyWithCopies) { + auto module = CreateNewModule(); + + auto cond_builder = HloComputation::Builder(TestName() + ".cond"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloComputation* while_cond = + module->AddEmbeddedComputation(cond_builder.Build()); + + HloComputation* body_computation = module->AddEmbeddedComputation( + MakeRematerializableComputation(/*suffix=*/".body")); + HloComputation* entry_computation = + module->AddEntryComputation(MakeRematerializableWhileComputation( + while_cond, /*while_body=*/body_computation)); + EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); + + SequentialHloOrdering::HloModuleSequence sequence; + CopyInsertion copy_insertion; + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module.get(), + &sequence, ©_insertion)); + EXPECT_TRUE(changed); + + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); } // Test rematerialization of a doubly nested computation. All computations @@ -303,24 +346,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/middle_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(middle_computation->instruction_count(), 6); - EXPECT_EQ(inner_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // All computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(middle_computation->instruction_count(), 7); - EXPECT_EQ(inner_computation->instruction_count(), 8); + // All computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(middle_computation->instruction_count(), 9); + EXPECT_EQ(inner_computation->instruction_count(), 9); } TEST_F(HloRematerializationTest, RngNotRematerialized) { @@ -382,10 +423,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, + bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), DefaultMemoryScheduler, &sequence)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,11 +516,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,11 +611,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 2a601ec3d183023954b6f1b6bca7594384378169..b2725e2918ce76248d9f2cdbb2a6e5a63226bf9a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -22,9 +22,9 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -36,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } namespace { @@ -80,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } HloRunner::HloRunner(se::Platform* platform) { @@ -92,53 +92,116 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CreateExecutable(std::move(module), run_hlo_passes)); - se::Stream stream(backend().default_stream_executor()); - stream.Init(); +StatusOr HloRunner::TransferLiteralToDevice( + const Literal& literal) { + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), backend().memory_allocator(), + backend().default_device_ordinal())); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.get(), literal, buffer)); + return std::move(buffer); +} - ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice( - backend().default_device_ordinal(), &stream, nullptr)); - const ExecutableRunOptions& run_options = service_run_options.run_options(); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals) { + std::vector buffers; + for (const Literal* literal : literals) { + CHECK(literal != nullptr); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + TransferLiteralToDevice(*literal)); + buffers.push_back(std::move(buffer)); + } + return std::move(buffers); +} - // Copy arguments to device. - std::vector argument_buffers; - for (Literal* argument : arguments) { - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer argument_buffer, - backend().transfer_manager()->AllocateScopedShapedBuffer( - argument->shape(), run_options.allocator(), - run_options.device_ordinal())); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, argument_buffer)); - argument_buffers.push_back(std::move(argument_buffer)); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals) { + std::vector literal_pointers; + literal_pointers.reserve(literals.size()); + for (const auto& literal : literals) { + literal_pointers.push_back(literal.get()); } + return TransferLiteralsToDevice(literal_pointers); +} + +StatusOr> HloRunner::TransferLiteralFromDevice( + const ShapedBuffer& buffer) { + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), + buffer); +} + +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_buffers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} - std::vector argument_buffer_ptrs; - argument_buffer_ptrs.reserve(argument_buffers.size()); - for (const auto& buf : argument_buffers) { - argument_buffer_ptrs.push_back(&buf); +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(argument.get()); } + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result, - executable->ExecuteOnStreamWrapper( - &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); - - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), result); - if (result_literal.ok()) { - VLOG(4) << "Executed binary and got result: " - << result_literal.ValueOrDie()->ToString(); - } else { - VLOG(4) << "Executed binary and got status: " - << result_literal.status().ToString(); + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); } - return result_literal; + return ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); } StatusOr>> HloRunner::ExecuteReplicated( @@ -182,7 +245,7 @@ StatusOr>> HloRunner::ExecuteReplicated( backend().transfer_manager()->AllocateScopedShapedBuffer( argument->shape(), backend().memory_allocator(), device)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, *argument, argument_buffer)); + streams.back().get(), *argument, argument_buffer)); argument_buffers.push_back(std::move(argument_buffer)); argument_buffer_ptrs[index++] = &argument_buffers.back(); } @@ -250,9 +313,10 @@ StatusOr>> HloRunner::ExecuteReplicated( std::vector> exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { + TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, backend().transfer_manager()->TransferLiteralFromDevice( - streams[i]->parent(), results[i])); + streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); } return std::move(exec_results); @@ -295,4 +359,8 @@ Backend& HloRunner::backend() { return *backend_; } +const Backend& HloRunner::backend() const { + return const_cast(this)->backend(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 53f7c6fe4a09111c5ee24f2290f0f4aeed0a4401..65537f07f56e74b7fe2c2f9792af21efc7229573 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -102,6 +102,15 @@ class HloRunner { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Transfers data between the host and device. + StatusOr TransferLiteralToDevice(const Literal& literal); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals); + StatusOr> TransferLiteralFromDevice( + const ShapedBuffer& buffer); + // Executes the given module with given literals as input and returns the // result as a Literal. // @@ -109,20 +118,25 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( std::unique_ptr module, const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes = true) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - c_transform( - arguments, std::back_inserter(argument_pointers), - [](const std::unique_ptr& literal) { return literal.get(); }); - return Execute(std::move(module), argument_pointers, run_hlo_passes); - } + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + // As Execute(), but accepts and returns device buffers instead of host + // buffers. + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as @@ -137,6 +151,7 @@ class HloRunner { // This creates the backend lazily so it's possible to instantiate an // HloRunner in a program without any backends linked in. Backend& backend(); + const Backend& backend() const; private: // Creates an executable object given an HLO module. If run_hlo_passes is diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 02545811f7a6b15daf2e14602a3f76a56118f379..c6d3909af6103949daf4b0ab6be9b74724461e30 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -36,29 +36,6 @@ using ::tensorflow::strings::HumanReadableNumBytes; namespace xla { -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - namespace { // Class implementing a list scheduler of HLO instructions which produces a @@ -299,6 +276,8 @@ class ListScheduler { auto best_it = ready_queue.end(); --best_it; const HloInstruction* best = best_it->second.instruction; + VLOG(2) << "Schedule instruction: " << best->ToShortString() + << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); @@ -396,7 +375,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -414,28 +393,15 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - -StatusOr> DFSMemorySchedulerImpl( +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, bool reverse_heuristics) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; + int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -444,6 +410,11 @@ StatusOr> DFSMemorySchedulerImpl( total_sizes[hlo] = 0; continue; } + // This ordering is based on DFS post-order, with a heuristic to decide + // which operand to visit first. The heuristic is based on 'extra_users', + // which is simply users-1 for each instruction. By subtracting 1, we're + // saying that instructions with no users or a single user don't count; + // instructions with lots of fan-out will be visited earlier. extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); @@ -455,7 +426,17 @@ StatusOr> DFSMemorySchedulerImpl( extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + // total_sizes[hlo] transitively includes the sizes of all nodes that + // lead to it. But computation is a DAG, so we are double-counting nodes, + // which can lead to overflows for large programs. + // cumulative_total_size caps the size to prevent overflows. + // Same for total_hlos: it prevents overflows on very large and branchy + // models, where the number of paths is exponential to the number of nodes. + // NOTE(dimvar): this is quite ugly and should be changed. It's unclear + // why we care about transitive sizes; when scheduling a node, its input + // and output buffers should be all that matters, not its "history". total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); + extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -469,16 +450,15 @@ StatusOr> DFSMemorySchedulerImpl( return Status::OK(); }); TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( - &visitor, [&extra_users, &total_sizes, reverse_heuristics]( - const HloInstruction* a, const HloInstruction* b) { - auto lhs = std::tuple(extra_users[a], - total_sizes[a], b->name()); - auto rhs = std::tuple(extra_users[b], - total_sizes[b], a->name()); - - // Reverse heuristics. This helps some cases as a different starting - // point of gradient descent, see b/78906799 for more context. - return reverse_heuristics ? rhs > lhs : lhs > rhs; + &visitor, [&extra_users, &total_sizes](const HloInstruction* a, + const HloInstruction* b) { + if (extra_users[a] != extra_users[b]) { + return extra_users[a] > extra_users[b]; + } + if (total_sizes[a] != total_sizes[b]) { + return total_sizes[a] > total_sizes[b]; + } + return a->name() < b->name(); })); CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; @@ -505,81 +485,51 @@ StatusOr> PostOrderMemoryScheduler( post_order.end()}; } -StatusOr> DFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function, - /*reverse_heuristics=*/false); -} - -StatusOr> DFSMemorySchedulerReverse( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function, - /*reverse_heuristics=*/true); -} - StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. TF_ASSIGN_OR_RETURN( std::vector list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 list_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, list_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); + TF_ASSIGN_OR_RETURN(const int64 dfs_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, dfs_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( std::vector post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 post_order_memory, - MinimumMemoryForComputation(computation, post_order_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 post_order_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, post_order_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); - TF_ASSIGN_OR_RETURN( - std::vector reverse_dfs, - DFSMemorySchedulerReverse(computation, points_to_analysis, size_function, - memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 reverse_dfs_memory, - MinimumMemoryForComputation(computation, reverse_dfs, points_to_analysis, - size_function)); - VLOG(2) << "Min-memory reverse_dfs sequence: " - << HumanReadableNumBytes(reverse_dfs_memory); - auto min_memory = std::min( - {dfs_memory, post_order_memory, reverse_dfs_memory, list_memory}); + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); if (min_memory == list_memory) { VLOG(2) << "Chose min-memory list sequence: " @@ -589,10 +539,6 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Chose min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); return dfs_sequence; - } else if (min_memory == reverse_dfs_memory) { - VLOG(2) << "Chose min-memory reverse_dfs memory: " - << HumanReadableNumBytes(reverse_dfs_memory); - return reverse_dfs; } else { VLOG(2) << "Chose min-memory post_order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -600,10 +546,9 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -611,12 +556,13 @@ CreateMemoryMinimizingSequence(const HloModule& module, for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(auto one_computation_sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = - MinimumMemoryForComputation(*computation, one_computation_sequence, - *points_to_analysis, size_function) + HeapSimulator::MinimumMemoryForComputation( + *computation, one_computation_sequence, *points_to_analysis, + size_function, &memory_by_computation) .ValueOrDie(); sequence[computation] = std::move(one_computation_sequence); } @@ -624,15 +570,15 @@ CreateMemoryMinimizingSequence(const HloModule& module, return sequence; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); tensorflow::gtl::FlatMap empty_map; - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, nullptr, empty_map); + return ScheduleComputationHelper(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 0e5ac2022db2059b830fb8edc10d6c4a70857415..2b33ccc8bfb895286bb3747aab0a16cf25e2cfae 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -28,20 +28,6 @@ limitations under the License. namespace xla { -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns the minimum memory required to compute the given computation, -// assuming no fragmentation. -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); - // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function @@ -76,15 +62,6 @@ StatusOr> PostOrderMemoryScheduler( const tensorflow::gtl::FlatMap& memory_by_computation); -// DFS-order scheduler with reversed heuristics. This helps some cases (see -// b/78906799). -StatusOr> DFSMemorySchedulerReverse( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. @@ -98,14 +75,13 @@ StatusOr> DefaultMemoryScheduler( // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); -// Overload of above that computes the sequence for a single computation. +// Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index c018ba2ffc404d0c6a0d08b8f5c63a9f90888b70..73f22f81f4e9cf597db8b184642acff2fdaaf2b0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,78 +18,20 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); -} - class HloSchedulingTest : public HloTestBase {}; TEST_F(HloSchedulingTest, LastUseScheduledFirst) { @@ -124,7 +66,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -158,14 +100,14 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.at(module->entry_computation()).size()); @@ -203,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // ROOT %subtract = f32[4]{0} subtract( // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) // } - // %SubcomputationsNotAccounted () -> f32[2,4] { + // %ListAccountsForSubcomputations () -> f32[2,4] { // %constant.3 = f32[2,4]{1,0} constant( // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) // %transpose = f32[2,4]{1,0} transpose( @@ -269,16 +211,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }, - ListMemoryScheduler)); + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); SequentialHloOrdering ordering(module.get(), sequence); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been @@ -287,6 +229,184 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. The max mem doesn't change + // because the while body isn't live during the peak. + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); +} + +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations + EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7f7e3f7dab03ce0ad64bd0fcfe4ddd020d31bf56..268b4727bcbed42ba71526f1d5ef5c887e941930 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -39,6 +39,34 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(tile_shape, assignment); } +HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve(sub_shardings.leaf_count()); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + if (flattened_list.empty()) { + // Empty tuple sharding ends up having no leaves, but we want to allow + // empty tuple HLO instruction results to have sharding, so we fetch the + // root ({}) sharding value from the ShapeTree. + // A ShapeTree created with ShapeTree(shape, init) will have + // init as value at its root. + flattened_list.push_back(sub_shardings.element(ShapeIndex({}))); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Tuple( + const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) + << "Flat list has " << flattened_list.size() << ", required " + << RequiredLeaves(tuple_shape); + return HloSharding(flattened_list); +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -49,9 +77,6 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { @@ -75,6 +100,29 @@ bool HloSharding::UsesDevice(int64 device) const { std::find(devices.begin(), devices.end(), device) != devices.end(); } +std::map HloSharding::UsedDevices(int64* count) const { + int64 element_count = 1; + std::map device_map; + if (IsTuple()) { + for (auto& tuple_element_sharding : tuple_elements()) { + auto unique_device = tuple_element_sharding.UniqueDevice(); + if (unique_device.ok()) { + device_map[unique_device.ValueOrDie()] += 1; + } + } + element_count = tuple_elements().size(); + } else { + auto unique_device = UniqueDevice(); + if (unique_device.ok()) { + device_map[unique_device.ValueOrDie()] += 1; + } + } + if (count != nullptr) { + *count = element_count; + } + return device_map; +} + std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); @@ -126,6 +174,49 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +int64 HloSharding::RequiredLeaves(const Shape& shape) { + // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are + // concerned, but they do have a single tuple_elements_ entry since we want + // to allow empty tuple results to have sharding. + return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); +} + +Status HloSharding::CheckLeafCount(const Shape& shape) const { + int64 shape_leaves = RequiredLeaves(shape); + TF_RET_CHECK(shape_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + return Status::OK(); +} + +StatusOr> HloSharding::AsShapeTree( + const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + if (ShapeUtil::IsEmptyTuple(shape)) { + // Empty tuples have no leaves, but we want to assign them a sharding + // anyway, so we use the root element sharding. + *result.mutable_element(ShapeIndex({})) = *it; + } + return std::move(result); + } else { + return ShapeTree(shape, *this); + } +} + +StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { + if (IsTuple()) { + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + return *this; + } + return Tuple(ShapeTree(shape, *this)); +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -167,28 +258,12 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } - // The easiest way to get the number of elements in a nested tuple is just to - // create a shape tree. We could call GetAsShapeTree, but that will try and - // apply our tuple_shardings_ to the shape tree, and that might cause a crash - // at this point as we haven't validated them. - ShapeTree bool_shape_tree(shape, false); - int64 num_leaves = - std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); - if (num_leaves != tuple_elements_.size()) { - return tensorflow::errors::InvalidArgument( - StrCat("Validation tuple shape has ", num_leaves, - " leaf elements, but this sharding contains ", - tuple_elements_.size(), " elements.")); - } + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); // Now we've validated the number of tuple elements, it's safe to request a // shape tree. ShapeTree shape_tree = GetAsShapeTree(shape); for (const auto& index_to_sharding : shape_tree.leaves()) { - if (index_to_sharding.first.empty()) { - // An empty tuple has a ShapeTree with a single leaf at the empty index. - continue; - } Status status = index_to_sharding.second.ValidateNonTuple( ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); if (!status.ok()) { @@ -370,11 +445,42 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, Shape sub_shape = ShapeUtil::GetSubshape(shape, index); ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - if (ShapeUtil::IsTuple(sub_shape)) { - return Tuple(sub_shape_tree); - } else { - return sub_shape_tree.element({}); + return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) + : sub_shape_tree.element(ShapeIndex({})); +} + +tensorflow::gtl::optional HloSharding::ExtractSingleSharding() + const { + if (!IsTuple()) { + return *this; + } + for (int64 i = 1; i < tuple_elements_.size(); ++i) { + if (tuple_elements_[0] != tuple_elements_[i]) { + return tensorflow::gtl::optional(); + } + } + return tuple_elements_.front(); +} + +size_t HloSharding::Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } + if (replicated_) { + return 0; + } + size_t h = 0; + for (uint32 v : tile_assignment_) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + for (uint32 v : tile_shape_.dimensions()) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); } + return h; } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 2b8e757f42991f697df37d3d34bfdff6a36bc509..34324d2058efe804cda486600dabd8a62cb84fda 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -19,7 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ +#include #include +#include #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -70,26 +72,13 @@ class HloSharding { // Creates a new sharding for a tuple type. The given ShapeTree must have // elements for every leaf shape contained in the tuple. - static HloSharding Tuple(const ShapeTree& sub_shardings) { - std::vector flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); - for (const auto& index_to_sharding : sub_shardings.leaves()) { - flattened_list.push_back(index_to_sharding.second); - } - return HloSharding(flattened_list); - } + static HloSharding Tuple(const ShapeTree& sub_shardings); - // Creates a new sharding for a tuple type. The requested tuple shape must not - // be nested. For nested tuples, use the ShapeTree overload. + // Creates a new sharding for a tuple type. The number of elements in + // shardings must match the number of leaf nodes in tuple_shape. For + // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)); - CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); - std::vector flattened_list(shardings.begin(), shardings.end()); - CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); - return HloSharding(flattened_list); - } + tensorflow::gtl::ArraySlice shardings); // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -99,6 +88,9 @@ class HloSharding { static bool IsReservedDevice(int64 device) { return device < 0; } OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -128,6 +120,14 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; + // Retrieves an histogram of the devices used by the sharding. The returned + // map has the device number as key, and the occurrence count as value. + // If a sharding does not have a device, it will not be incuded in the + // histogram. The count argument, if not nullptr, will receive the total + // number of elements this sharding is made of (one for array, N leaves for + // tuples). + std::map UsedDevices(int64* count) const; + // Returns the tile that should be executed on the given device. // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; @@ -160,25 +160,27 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. + StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { - if (IsTuple()) { - ShapeTree result(shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } else { - return ShapeTree(shape, *this); - } + return AsShapeTree(shape).ValueOrDie(); } // Retrieves the sub sharding at a given index, out of a tuple sharding. // REQUIRES: IsTuple() HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + // If the current sharding is a tuple sharding, return itself as result. + // Otherwise returns a tuple sharding for the input shape, with all the leaves + // having this object sharding. + StatusOr GetTupleSharding(const Shape& shape) const; + + // Extracts the sharding that is common within the current sharding. + // If the current sharding is not a tuple sharding, the current sharding will + // be returned. If it is a tuple, and all the tuple elements are common, the + // common element will be returned. Otherwise the optional will contain no + // value. + tensorflow::gtl::optional ExtractSingleSharding() const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && @@ -187,26 +189,13 @@ class HloSharding { } bool operator!=(const HloSharding& other) const { return !(*this == other); } - size_t Hash() const { - if (!tuple_) { - size_t h = 0; - for (const auto& element : tuple_elements_) { - h = tensorflow::Hash64Combine(h, element.Hash()); - } - return h; - } - if (replicated_) { - return 0; - } - size_t h = 0; - for (uint32 v : tile_assignment_) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); + size_t Hash() const; + + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); } - return h; - } + }; // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() @@ -242,6 +231,12 @@ class HloSharding { tuple_(false), tile_shape_(), tile_assignment_({0}) {} + // device_id values: + // -2: magic number to mean unassigned device, used by spatial partitioning + // -1: the id of the host + // 0 or positive: the id of a device + // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once + // we have fully switched to the side-effect tokens. explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), @@ -261,11 +256,19 @@ class HloSharding { tile_assignment_({0}), tuple_elements_(tuple_shardings) {} + // Checks that the number of elements in tuple_elements_ is consistent with + // the tuple shape passes as argument. + Status CheckLeafCount(const Shape& shape) const; + // Internal helper to validate a tuple sharding. Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; + // Returns the number of tuple_elements_ entries to fit the shape. + static int64 RequiredLeaves(const Shape& shape); + bool replicated_; bool maximal_; bool tuple_; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..39036e205e76979e7da08246cd030ebd17e52f76 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -0,0 +1,411 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +struct PassThrough { + PassThrough(HloInstruction* user, HloInstruction* operand) + : user(user), operand(operand) {} + + HloInstruction* user = nullptr; + HloInstruction* operand = nullptr; +}; + +void SetSingleSharding(HloInstruction* instruction, + const HloSharding& sharding) { + VLOG(4) << " " << instruction->name() << " to " << sharding; + instruction->set_single_sharding(sharding); +} + +bool ShardingMatches(const HloSharding& sharding1, + const HloSharding& sharding2) { + auto single_sharding1 = sharding1.ExtractSingleSharding(); + if (single_sharding1) { + auto single_sharding2 = sharding2.ExtractSingleSharding(); + if (single_sharding2) { + return *single_sharding1 == single_sharding2; + } + } + // Anything which is not unique across all elements, gets a full sharding + // compare. + return sharding1 == sharding2; +} + +// When we create domains, they are never "empty", where with empty we mean +// that a kDomain instruction has as operand another kDomain instruction of the +// same kind. +// But when the HLO optimizations are run, empty domains can be created. +// For example: +// +// Domain(device=None, device=0) -> +// Tuple(device=0) -> +// GTE(device=0) -> +// Domain(device=0, device=None) +// +// In that case the tuple simplifier could create something like: +// +// Domain(device=None, device=0) -> Domain(device=0, device=None) +// +// Which is a so called empty domain. +// In the case above, crossing an empty domain which was transiting through +// device 0, requires the normalization phase to fixup the empty domain by +// adding back a Tuple+GTE pair with the proper device. +// One particular case where this can create problems is the result of the +// entry computation, where the GTE assignments are used by TF to tell the +// XLA where the results should be sent. +std::vector LocatePassThroughDomainLinks( + const DomainMetadata::Domain& domain) { + std::vector pass_through; + for (HloInstruction* instruction : domain.enter_domains) { + CHECK(instruction->opcode() == HloOpcode::kDomain) + << "Instruction is not a kDomain: " << instruction->ToString(); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(user) != 0) { + pass_through.emplace_back(user, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " " << user->ToString(); + VLOG(2) << " " << instruction->ToString(); + } + } + } + return pass_through; +} + +Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + for (auto& pass_through : LocatePassThroughDomainLinks(domain)) { + HloInstruction* tuple = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateTuple({pass_through.operand})); + HloInstruction* gte = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), + tuple, 0)); + gte->set_sharding(sharding); + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } + return Status::OK(); +} + +std::unique_ptr CloneShardingForDomain( + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (!single_sharding) { + return MakeUnique(sharding); + } + return MakeUnique(*single_sharding); +} + +Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + VLOG(4) << "Applying " << sharding << " sharding"; + for (HloInstruction* instruction : domain.instructions) { + // We only change instructions without sharding, since otherwise we might + // mess up with eventual HLO passes which has knowledge of it. + if (!instruction->has_sharding()) { + SetSingleSharding(instruction, sharding); + } else { + VLOG(4) << " " << instruction->name() << " already has sharding " + << instruction->sharding(); + } + } + return Status::OK(); +} + +// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. +// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() +// sharding will be returned. +ShapeTree GetTupleSharding(HloInstruction* tuple) { + if (tuple->has_sharding()) { + return tuple->sharding().GetAsShapeTree(tuple->shape()); + } + return ShapeTree(tuple->shape(), HloSharding::Replicate()); +} + +// Retrieves the sharding of operand, asked from a user instruction which is +// within domain. If operand is a kDomain, it means that sharding argument is +// the operand sharding, otherwise the operand's own sharding will be returned. +const HloSharding* GetOperandSharding(const HloInstruction* operand, + const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); + // Here the user of operand is within the domain instruction set, and since it + // is user of operand, we need to look into the enter_domains set. If this is + // not a kDomain within the user domains set, then return the operand + // sharding, if any. + if (operand->opcode() != HloOpcode::kDomain || + domain.enter_domains.count(const_cast(operand)) == 0) { + return operand->has_sharding() ? &operand->sharding() : nullptr; + } + // At this point operand is a kDomain of the currently processed domain, so we + // can refer to sharding as the domain sharding. + return &sharding; +} + +// Tries to propagate the sharding information into the instructions that are +// part of the domain, in a post order manner (operand propagate to user). +StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + int64 assigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (instruction->has_sharding()) { + continue; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* tuple = instruction->mutable_operand(0); + const HloSharding* tuple_sharding = + GetOperandSharding(tuple, domain, sharding); + if (tuple_sharding != nullptr) { + if (tuple_sharding->IsTuple()) { + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + } else { + SetSingleSharding(instruction, *tuple_sharding); + } + ++assigned; + } + } else if (instruction->opcode() == HloOpcode::kTuple) { + int64 tuple_assigned = 0; + ShapeTree shape_tree = GetTupleSharding(instruction); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloSharding* operand_sharding = + GetOperandSharding(instruction->operand(i), domain, sharding); + if (operand_sharding != nullptr && + shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } + } + if (tuple_assigned > 0) { + HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); + VLOG(4) << " " << instruction->name() << " to sharding " + << tuple_sharding; + instruction->set_sharding(tuple_sharding); + ++assigned; + } + } else { + // If all the operand of the given instruction has the same single device + // assignment, assign that device to this instruction as well. + const HloSharding* common_sharding = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + const HloSharding* operand_sharding = + GetOperandSharding(operand, domain, sharding); + if (operand_sharding != nullptr) { + if (common_sharding != nullptr && + *common_sharding != *operand_sharding) { + common_sharding = nullptr; + break; + } + common_sharding = operand_sharding; + } + } + if (common_sharding != nullptr) { + VLOG(4) << " " << instruction->name() << " to sharding " + << *common_sharding; + instruction->set_sharding(*common_sharding); + ++assigned; + } + } + } + return assigned; +} + +Status ApplyDomainSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + // Here is the place to call external sharding normalizers, which are + // implemented in other modules (ie, spatial partitioning). + // The signature of the external normalizer function should be something + // like: + // + // StatusOr Normalizer(const DomainMetadata::Domain&, + // const HloSharding& sharding); + // + // The function should return true if it has processed the domain + // normalization, false if domain was not one recognized by it, or an error. + // We will call the functions in order below, and fall back to local code if + // none of the external normalizers acted on the domain. + // External normalizers should not handle the cases that are already handled + // locally. + + // None of the external normalizers handled the domain sharding, try to see + // whether this is a single sharding first. + auto single_sharding = sharding.ExtractSingleSharding(); + if (single_sharding) { + // Shortcut the simple case. We have a unique sharding, so we call + // the ApplyDomainSingleSharding() API which will apply array or tuple + // shaped sharding to the domain instructions. + return ApplyDomainSingleSharding(domain, *single_sharding); + } + VLOG(1) << "Assigning non-trivial sharding " << sharding; + for (;;) { + TF_ASSIGN_OR_RETURN(int64 assigned, + ApplyDomainShardingPass(domain, sharding)); + if (assigned == 0) { + break; + } + } + int64 unassigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (!instruction->has_sharding()) { + LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); + ++unassigned; + } + } + // Should we error out if unassigned > 0? + return Status::OK(); +} + +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +std::unique_ptr CreateDomain(HloInstruction* instruction, + HloInstruction* operand) { + const HloSharding* instruction_sharding = + instruction->has_sharding() ? &instruction->sharding() : nullptr; + const HloSharding* operand_sharding = + operand->has_sharding() ? &operand->sharding() : nullptr; + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && operand_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && operand_sharding != nullptr && + ShardingMatches(*instruction_sharding, *operand_sharding)) { + return nullptr; + } + std::unique_ptr real_instruction_sharding; + std::unique_ptr real_operand_sharding; + if (instruction_sharding != nullptr) { + real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); + } + if (operand_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*operand_sharding); + } + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (real_instruction_sharding != nullptr + ? real_instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (real_operand_sharding != nullptr + ? real_operand_sharding->ToString() + : "None"); + + std::unique_ptr operand_side_metadata = + MakeUnique(std::move(real_operand_sharding)); + std::unique_ptr user_side_metadata = + MakeUnique(std::move(real_instruction_sharding)); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +StatusOr> ExtractOriginalCommonSharding( + tensorflow::gtl::ArraySlice instructions) { + // If we are here, all the instructions being passed had the same sharding + // (or no sharding), by the means of the ShardingMatches() API. + // As such, no kDomain was inserted, and here we are asked to extract the + // original common sharding. + // All the instructions passed to this API are part of the same computation. + const HloSharding* sharding = nullptr; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + if (sharding == nullptr) { + sharding = &instruction->sharding(); + } else { + TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) + << "Sharding " << *sharding << " does not match the one in " + << instruction->ToString(); + } + } + } + if (sharding == nullptr) { + return std::unique_ptr(); + } + VLOG(4) << "Extracted sharding is " << *sharding; + return CloneShardingForDomain(*sharding); +} + +} // namespace + +std::unique_ptr ShardingMetadata::Clone() const { + std::unique_ptr sharding; + if (sharding_ != nullptr) { + sharding = MakeUnique(*sharding_); + } + return MakeUnique(std::move(sharding)); +} + +bool ShardingMetadata::Matches(const DomainMetadata& other) const { + const ShardingMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a ShardingMetadata, then it is clearly a no match. + return false; + } + if (sharding_ == nullptr) { + return other_ptr->sharding_ == nullptr; + } + return other_ptr->sharding_ != nullptr + ? ShardingMatches(*sharding_, *other_ptr->sharding_) + : false; +} + +string ShardingMetadata::ToString() const { + return sharding_ != nullptr ? sharding_->ToString() : "{}"; +} + +Status ShardingMetadata::NormalizeInstructions( + const DomainMetadata::Domain& domain) const { + if (sharding_ != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); + } + return Status::OK(); +} + +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() + << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } + return Status::OK(); +} + +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand) { + return CreateDomain(instruction, operand); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..ec162c34904ee2dfac3daeeee37133282a9c9698 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// A DomainMetadata implementation that internally wraps a sharding attribute. +class ShardingMetadata : public DomainMetadata { + public: + explicit ShardingMetadata(std::unique_ptr sharding) + : sharding_(std::move(sharding)) {} + + std::unique_ptr Clone() const override; + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override; + + string ToString() const override; + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override; + + static tensorflow::StringPiece KindName() { return "sharding"; } + + private: + std::unique_ptr sharding_; +}; + +// Within a set of instructions which had common sharding attributes before +// entring the HLO passes pipeline, apply sharding heuristics and normalize the +// instructions whose sharding deviates from the one which is inferred as to be +// the original one. +// Policy wise, HLO passes are allowed to create new unassigned instructions, +// but if they do create assigned ones, they have to conform to the ones around. +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); + +// Given an HLO graph edge between instruction and one of its operands, creates +// a ShardingMetadata based kDomain instruction if the sharding between +// instruction and operand changes. Returns nullptr if there is no need for a +// domain separation. +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 3bf0d25efb7fad78aeccdd9269c289950b2171ab..54b7402b866361748d9eb35182b0bf486c4c9bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_sharding.h" - #include #include #include #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -312,5 +311,50 @@ TEST_F(HloShardingTest, OstreamTest) { EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); } +TEST_F(HloShardingTest, ParseHloString) { + auto check = [](const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding, + ParseSharding(sharding.ToString())); + EXPECT_EQ(sharding, parsed_sharding); + }; + check(HloSharding::Replicate()); + check(HloSharding::AssignDevice(2)); + check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}}))); + // Empty tuple. One sharding is required for empty tuples, as we need to be + // able to assign sharding to them, even though they have no leaves. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), + {HloSharding::Replicate()})); + { + // Non-nested tuple. + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})}); + check(HloSharding::Tuple( + tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)})); + } + { + // Nested tuple. + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})})}); + std::vector leaf_shardings = { + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)}; + ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); + // Assign leaf_shardings to sharding_tree leaves. + auto it = leaf_shardings.begin(); + for (auto& index_to_sharding : sharding_tree.leaves()) { + index_to_sharding.second = *it++; + } + check(HloSharding::Tuple(sharding_tree)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h similarity index 84% rename from tensorflow/compiler/xla/tools/parser/hlo_token.h rename to tensorflow/compiler/xla/service/hlo_token.h index 7928bee5c2097f353b182095a555c334d7b69c95..533429608bc2e13626a3e746fbe465398e1f4bb4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ #include @@ -22,9 +22,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Defines different kinds of tokens in a hlo module string. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. enum class TokKind { // Markers kEof, @@ -72,7 +74,6 @@ enum class TokKind { string TokKindToString(TokKind kind); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7b27dbfec376b8ba16d00285f10e2cc291e07a61..4e3c9df3a036890ce25f5b14603d275263e8659b 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -125,7 +125,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, // transparently. CHECK_EQ(operand_number, 0); return index.empty(); - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: // Select does not use any nested elements of its selected-from operands // (operand 1 and 2) CHECK_GE(operand_number, 0); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 7d6d0d9eaf70969c1a3762959233b561706398c2..f89677372944f2708aa678d2a6a53665ae1752ab 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,6 +41,10 @@ Status ShapeVerifier::HandleSelect(HloInstruction* select) { return CheckTernaryShape(select); } +Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) { + return CheckTernaryShape(tuple_select); +} + Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { std::vector operand_shapes; for (const HloInstruction* operand : concatenate->operands()) { @@ -106,22 +112,57 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } +namespace { + +Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { + const HloInstruction* token = instruction->operand(operand_no); + if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected operand %lld to be token-shaped, actual shape is" + "%s:\n%s", + operand_no, ShapeUtil::HumanString(token->shape()).c_str(), + instruction->ToString().c_str()); + } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast(instruction); + // Infeed has an optional single token operand. + // TODO(b/80000000): Update when token is not optional. + if (infeed->operand_count() == 1) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); + } + + // The output of infeed is a tuple containing the data value and a token. + return CheckShape(infeed, + ShapeUtil::MakeTupleShape( + {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()})); +} + +Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + HloOutfeedInstruction* outfeed = Cast(instruction); + // Outfeed has an optional token operand (operand 1). + // TODO(b/80000000): Update when token is not optional. + if (outfeed->operand_count() == 2) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); + } -Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the - // host. The shape of the instruction itself is always nil because the outfeed - // produces no HLO value in the graph. + // host. The shape of the instruction itself is always a token. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed to have shape compatible with operand's shape %s, " + "Expected outfeed shape to be compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), outfeed->ToString().c_str()); } - return CheckShape(outfeed, ShapeUtil::MakeNil()); + return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { @@ -137,7 +178,16 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - return CheckUnaryShape(sort); + if (sort->operand_count() == 2 && + !ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(1)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys and " + "the values. Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + } + return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { @@ -299,9 +349,11 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { const HloInstruction* send_done = send->users().front(); TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(send, 1)); + return CheckShape(send, + ShapeUtil::MakeTupleShape({send->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { @@ -309,7 +361,8 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { const HloInstruction* send = send_done->operand(0); TF_RET_CHECK(send->opcode() == HloOpcode::kSend); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); + + return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { @@ -317,9 +370,11 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { const HloInstruction* recv_done = recv->users().front(); TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(recv, 0)); + return CheckShape( + recv, ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv_done->shape(), 0), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { @@ -327,7 +382,9 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { const HloInstruction* recv = recv_done->operand(0); TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); + return CheckShape(recv_done, + ShapeUtil::MakeTupleShape({recv->shape().tuple_shapes(0), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleBatchNormTraining( @@ -376,6 +433,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: @@ -385,6 +443,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kTuple: @@ -425,6 +484,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { + std::vector operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -439,16 +506,10 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. switch (instruction->opcode()) { - case HloOpcode::kSelect: - if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { - // Select only defines the top-level buffer, which in this case is the - // tuple, so we cannot allow mixed precision. - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } + case HloOpcode::kTupleSelect: + // TupleSelect only defines the top-level buffer, which in this case is + // the tuple, so we cannot allow mixed precision. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); break; case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: @@ -776,8 +837,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); - if (!ShapeUtil::IsScalar(operand_shape) && - !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { + if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." "Found non-compatible shapes for instruction %s.\n" @@ -790,6 +850,39 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. +Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + return Status::OK(); +} + +} // namespace + StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -831,7 +924,9 @@ StatusOr HloVerifier::Run(HloModule* module) { << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->IsElementwise()) { + } else if (instruction->opcode() != + HloOpcode::kRng /* Rng operands are always scalar. */ + && instruction->IsElementwise()) { TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); } @@ -850,6 +945,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1392a78097aa026b2f7cffa2b0135402d3ca7ae5..12c047850ef7299f24c3f004613df3e66e0af8d6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -35,6 +35,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConvert(HloInstruction* convert) override; Status HandleBitcastConvert(HloInstruction* convert) override; @@ -81,6 +82,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index dc3bfce0c495bc40a2df7b985cab67e02a3e15ce..d7458c338e9f1df9fac90270845aae0b8f779ee2 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -169,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const { StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); } } + + if (total_bytes > 0) { + MetricTableReport table; + table.SetMetricName("MiB read+written"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : op_infos_) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = static_cast(op.bytes_accessed) / (1 << 20); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, + table.MakeReport(static_cast(total_bytes) / (1 << 20))); + } return s; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..1985d20578677ae68b244023c4640454b004bf49 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -0,0 +1,986 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gtl = ::tensorflow::gtl; + +namespace { +using Analysis = IndexedArrayAnalysis; +using UnknownArray = Analysis::UnknownArray; +using ConstantArray = Analysis::ConstantArray; +using ReshapedArray = Analysis::ReshapedArray; +using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using tensorflow::gtl::ArraySlice; +using tensorflow::str_util::Join; +} // namespace + +string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { + switch (root->kind()) { + case Array::kUnknown: { + auto* unknown_tensor = root->as(); + return tensorflow::strings::StrCat("%", + unknown_tensor->instruction().name()); + } + + case Array::kConstant: { + if (print_constants) { + string contents = root->as()->literal()->ToString(); + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, + ")"); + } + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + } + + case Array::kReshaped: { + ReshapedArray* reshaped_array = root->as(); + return tensorflow::strings::StrCat( + "(reshape ", ToString(reshaped_array->operand(), print_constants), + " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); + } + + case Array::kScalarIndexedConstant: + case Array::kScalarIndexed: { + auto* indexed_array = root->as(); + string name = root->kind() == Array::kScalarIndexedConstant + ? "scalar-indexed-const" + : "scalar-indexed"; + return tensorflow::strings::StrCat( + "(", name, " ", ToString(indexed_array->source(), print_constants), + " ", ToString(indexed_array->indices(), print_constants), " ", + indexed_array->source_dim(), "->[", + Join(indexed_array->output_dims(), ","), "])"); + } + } +} + +StatusOr IndexedArrayAnalysis::GetArrayFor( + const HloInstruction* instr) { + auto it = cache_.find(instr); + if (it != cache_.end()) { + return it->second; + } + + TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); + return FindOrDie(cache_, instr); +} + +Status IndexedArrayAnalysis::TraverseAndPopulateCache( + const HloInstruction* root) { + // Depth first search over the DAG, invoking ComputeArrayFor in post order. + // The HLO instructions already in the cache are considered leaves. + + gtl::InlinedVector stack; + + enum DfsState { kDiscovered, kVisited }; + gtl::FlatMap dfs_state_map; + + stack.push_back(root); + InsertOrDie(&dfs_state_map, root, kDiscovered); + + do { + const HloInstruction* instr = stack.back(); + if (cache_.count(instr)) { + stack.pop_back(); + continue; + } + + switch (FindOrDie(dfs_state_map, instr)) { + case kDiscovered: { + for (const HloInstruction* operand : instr->operands()) { + if (!cache_.count(operand)) { + stack.push_back(operand); + CHECK(!dfs_state_map.count(operand) || + dfs_state_map[operand] == kDiscovered); + dfs_state_map[operand] = kDiscovered; + } + } + dfs_state_map[instr] = kVisited; + break; + } + + case kVisited: + stack.pop_back(); + TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); + InsertOrDie(&cache_, instr, array); + break; + } + } while (!stack.empty()); + + return Status::OK(); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayFor( + const HloInstruction* instr) { + Array* computed_array; + if (instr->IsElementwise() && instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseUnaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->IsElementwise() && instr->operand_count() == 2) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseBinaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForConstant(instr->literal())); + } else if (instr->opcode() == HloOpcode::kGather) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForReshape(instr->shape(), + FindOrDie(cache_, instr->operand(0)))); + } else { + computed_array = nullptr; + } + + if (!computed_array) { + computed_array = Construct(instr); + } + + return computed_array; +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( + const Literal& literal) { + return Construct(&literal); +} + +StatusOr IndexedArrayAnalysis::FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape) { + // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). + // `source` is the inner Gather(A, X). + + Array* a = source->source(); + Array* x = source->indices(); + Array* y = indices; + + // This bit is slightly tricky, so we do a naive "simulation" of the two + // consecutive gather operations to infer what the composed gather should look + // like. + + enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; + + std::vector simulated_index(a->shape().dimensions_size(), + IndexComponent::Ungathered); + + // Simulate the first gather. + EraseAt(&simulated_index, source->source_dim()); + for (int64 gather_dim : source->output_dims()) { + simulated_index.insert(simulated_index.begin() + gather_dim, + IndexComponent::GatheredFirst); + } + + // Simulate the second gather. + EraseAt(&simulated_index, source_dim); + for (int64 output_dim : output_dims) { + simulated_index.insert(simulated_index.begin() + output_dim, + IndexComponent::GatheredSecond); + } + + int64 source_dim_for_index_array = + FindIndex(source->output_dims(), source_dim); + CHECK_NE(source_dim_for_index_array, source->output_dims().size()); + + std::vector output_dims_for_index_array; + int64 gathered_index_components_seen = 0; + for (IndexComponent simulation_dim : simulated_index) { + if (simulation_dim == IndexComponent::GatheredSecond) { + output_dims_for_index_array.push_back(gathered_index_components_seen); + } + if (simulation_dim != IndexComponent::Ungathered) { + gathered_index_components_seen++; + } + } + + std::vector dim_sizes_for_composed_index; + std::vector output_dims_for_new_gather; + for (int64 i = 0, e = simulated_index.size(); i < e; i++) { + if (simulated_index[i] != IndexComponent::Ungathered) { + dim_sizes_for_composed_index.push_back(shape.dimensions(i)); + output_dims_for_new_gather.push_back(i); + } + } + + Array* inner_indices = ConstructScalarIndexedArray( + x, y, source_dim_for_index_array, output_dims_for_index_array, + ShapeUtil::MakeShape(x->shape().element_type(), + dim_sizes_for_composed_index)); + return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), + output_dims_for_new_gather, + std::move(shape)); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + VLOG(3) << "ComputeArrayForGather: indices are not scalar"; + return nullptr; + } + + CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + + // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should + // it become relevant. + + if (dim_numbers.elided_window_dims_size() != 1 || + dim_numbers.elided_window_dims(0) != + dim_numbers.gather_dims_to_operand_dims(0)) { + VLOG(3) << "ComputeArrayForGather: gather operations must elide " + "gather_dims_to_operand_dims[0] and " + "gather_dims_to_operand_dims[0] only"; + return nullptr; + } + + // ScalarIndexedArray cannot represent gathers that "slice" along some + // dimensions -- for instance it cannot represent a gather that picks 5 [2,3] + // arrays from an array of size [7,4,6]. We check that condition down below: + + for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { + if (i != dim_numbers.elided_window_dims(0) && + source->shape().dimensions(i) != window_bounds[i]) { + VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + << "] != source->shape().dimensions(" << i << ") -- " + << source->shape().dimensions(i) << " vs. " << window_bounds[i] + << " with dim_numbers.elided_window_dims(0) = " + << dim_numbers.elided_window_dims(0); + return nullptr; + } + } + + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + std::vector output_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + output_dims.push_back(i); + } + } + + if (auto* indexed = dynamic_cast(source)) { + auto it = c_find(indexed->output_dims(), source_dim); + if (it != indexed->output_dims().end()) { + return FoldGatherOfGather(indexed, indices, source_dim, output_dims, + shape); + } + } else if (auto* constant = dynamic_cast(source)) { + return Construct(constant, indices, source_dim, + output_dims, shape); + } + + return Construct(source, indices, source_dim, output_dims, + shape); +} + +namespace { +// Returns an index into `values` such that the product of the range +// [values.begin()+index, values.end()) is equal to `product`. If there is no +// such index, return -1. All integers in `values` must be positive. +int64 FindSuffixWithProduct(ArraySlice values, int64 product) { + DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + + int64 current_product = 1; + int64 i; + for (i = values.size() - 1; i >= 0 && product > current_product; --i) { + current_product *= values[i]; + } + + if (product == current_product) { + return i + 1; + } + + return -1; +} + +struct ReshapePassthroughDimPair { + int64 result_dim; + int64 operand_dim; +}; + +// Returns a set of dimension pairs such for all (result_dim, operand_dim) in +// the set: +// +// output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim] +// +// The returned vector of pairs is sorted in both the result_dim and the +// operand_dim components. +std::vector ComputeReshapePassthroughDimPairs( + ArraySlice operand_shape, ArraySlice result_shape) { + // A reshape can be seen as an index mapping from output index to input index: + // + // (i_0, ..., i_n) = f(o_0, ..., o_m) + // + // This function returns the pairs (j, k) for which the following invariant + // holds for all indices in the shape: + // + // o_j == i_k + // + // And this occurs when: + // + // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m + // + // (where O_x are the sizes of the output shape and I_x are the sizes of the + // input shape) and the size of the dimension j of the result is the same as + // the size of dimension k in the operand. + // + // These conditions are sufficient because the Reshape HLO is spec'ed such + // that the rightmost dimensions are always minor in the flattening and refine + // operation. + + std::vector result; + int64 result_subarray_size = 1; + for (int64 result_dim = result_shape.size() - 1; result_dim >= 0; + --result_dim) { + int64 candidate_operand_dim = + FindSuffixWithProduct(operand_shape, result_subarray_size); + + // result_subarray_size does not include the elements in the current + // `result_dim` dimension (we multiply in result_shape[result_dim] at the + // end of loop body) so candidate_operand_dim can never be zero. + CHECK_NE(candidate_operand_dim, 0) + << "result_dim = " << result_dim + << ", result_subarray_size = " << result_subarray_size + << ", result_shape = [" << Join(result_shape, ",") << "]" + << ", operand_shape = [" << Join(operand_shape, ",") << "]"; + + if (candidate_operand_dim != -1 && + result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { + result.push_back({/*result_dim=*/result_dim, + /*operand_dim=*/candidate_operand_dim - 1}); + } + result_subarray_size *= result_shape[result_dim]; + } + + c_reverse(result); + + if (VLOG_IS_ON(3)) { + std::vector result_strings; + c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" + << Join(result_shape, ",") << "] passthrough indices are [" + << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; + } + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.result_dim < rhs.result_dim; + })); + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.operand_dim < rhs.operand_dim; + })); + + return result; +} + +// Return true if `dim` is stated as an passthrough operand dim in +// `passthrough_dims`. +bool IsReshapePassthroughOperandDim( + ArraySlice passthrough_dims, int64 dim) { + return c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); +} + +// Maps `operand_dim` which must be an passthrough operand dimension to its +// corresponding passthrough result dimension based on `passthrough_dims`. +int64 MapPassthroughOperandDimToResultDim( + ArraySlice passthrough_dims, int64 operand_dim) { + auto it = c_find_if(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); + CHECK(it != passthrough_dims.end()); + return it->result_dim; +} + +int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, + ArraySlice result_shape, + int64 source_passthrough_dim) { + VLOG(3) << "FindSourcePositionForPassthroughResultDim([" + << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << "], " << source_passthrough_dim << ")"; + + int64 indexed_source_subarray_size = + std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, + operand_shape.end(), 1, std::multiplies()); + + return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); +} + +Shape StripDegenerateDimensions(const Shape& shape) { + DimensionVector new_dims; + c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); + return ShapeUtil::MakeShape(shape.element_type(), new_dims); +} +}; // namespace + +StatusOr +IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand) { + const Shape& shape = operand->shape(); + if (!ShapeUtil::HasDegenerateDimensions(shape)) { + return operand; + } + + // We only need to reshape out the degenerate dims from the indices and the + // source (except the source dim). + + const Shape& source_shape = operand->source()->shape(); + DimensionVector new_source_shape_dims; + for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) { + if (i == operand->source_dim() || source_shape.dimensions(i) != 1) { + new_source_shape_dims.push_back(source_shape.dimensions(i)); + } + } + + Shape new_source_shape = + ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims); + Shape new_indices_shape = + StripDegenerateDimensions(operand->indices()->shape()); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + TF_ASSIGN_OR_RETURN( + Array* const new_indices, + ComputeArrayForReshape(new_indices_shape, operand->indices())); + + // Build the new output dims while keeping track of the degenerate dims that + // will no longer be present. + DimensionVector new_output_dims; + int64 degenerate_dims_seen = 0; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_dims_seen++; + } else if (ArrayContains(operand->output_dims(), i)) { + new_output_dims.push_back(i - degenerate_dims_seen); + } + } + + // Similarly, build the new source dim while keeping track of the degenerate + // dims that will no longer be present. + int64 degenerate_dims_before_source_dim = + std::count(source_shape.dimensions().begin(), + source_shape.dimensions().begin() + operand->source_dim(), 1); + int64 new_source_dim = + operand->source_dim() - degenerate_dims_before_source_dim; + + return ConstructScalarIndexedArray( + new_source, new_indices, new_source_dim, + InlinedVectorToVector(new_output_dims), + StripDegenerateDimensions(operand->shape())); +} + +StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims) { + if (degenerate_dims.empty()) { + return operand; + } + + CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape())); + + DimensionVector new_output_dims = [&]() { + // To make things easy we use a "scratch" buffer of bools where the i'th + // element is true iff the i'th component of the result index is an output + // index. + + gtl::InlinedVector output_dims_bitvector( + operand->shape().dimensions_size()); + for (int64 output_dim : operand->output_dims()) { + output_dims_bitvector[output_dim] = true; + } + + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&output_dims_bitvector, degenerate_dim, false); + } + + DimensionVector result; + result.reserve(operand->output_dims().size()); + for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) { + if (output_dims_bitvector[i]) { + result.push_back(i); + } + } + + return result; + }(); + + DimensionVector new_result_shape_dims; + c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&new_result_shape_dims, degenerate_dim, 1); + } + + DimensionVector new_source_shape_dims = new_result_shape_dims; + for (int64 output_dim : new_output_dims) { + EraseAt(&new_source_shape_dims, output_dim); + } + + int64 new_source_dim = [&]() { + for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) { + int64 non_degenerate_dims_seen = 0; + if (non_degenerate_dims_seen == operand->source_dim()) { + return i; + } + if (new_source_shape_dims[new_source_dim] != 1) { + non_degenerate_dims_seen++; + } + } + LOG(FATAL) << "Did not find source dim in " << ToString(operand); + }(); + + int64 source_dim_size = + operand->source()->shape().dimensions(operand->source_dim()); + InsertAt(&new_source_shape_dims, /*index=*/new_source_dim, + /*value=*/source_dim_size); + + Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_source_shape_dims); + Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_result_shape_dims); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + return ConstructScalarIndexedArray( + new_source, operand->indices(), new_source_dim, + InlinedVectorToVector(new_output_dims), new_result_shape); +} + +StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand) { + VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")"; + + // To make things easier on ourselves, instead of directly trying to fold the + // reshape of `operand` to `shape`, we call + // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and + // handle the degenerate dimensions here by inserting reshapes. + + TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims, + ReshapeToRemoveDegenerateDims(operand)); + + Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape); + TF_ASSIGN_OR_RETURN( + ScalarIndexedArray* const folded_reshape_without_degenerate_dims, + FoldReshapeOfGatherNoDegenerateDims( + output_shape_without_degenerate_dims, + operand_without_degenerate_dims->as())); + + if (folded_reshape_without_degenerate_dims == nullptr) { + return nullptr; + } + + DimensionVector degenerate_result_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_result_dims.push_back(i); + } + } + + return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims, + degenerate_result_dims); +} + +StatusOr +IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) { + VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed) + << ")"; + CHECK(!ShapeUtil::HasDegenerateDimensions(shape)); + CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape())); + + // Try to fold Reshape(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(Const', Indices) + // + // We can view the reshape and the scalar-indexed operations as functions that + // map an output index (i.e. an index into the result) to an input index + // (i.e. an index into the operand). The key idea used here is that the + // output-to-input mapping for some reshape operations may "pass through" some + // output dimensions into the input space unchanged -- i.e. there may exist + // output dimension "O" and input dimension "I" such that OutputIndex[O] is + // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through + // dimensions in the input space of the reshape happen to be include all the + // output dimensions for the scalar-indexed node then, roughly, the following + // holds: + // + // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx)) + // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs)) + // + // Where Ps are the set of the pass-through components of Idx that are + // also the output dims of the scalar-indexed node, and Qs are the rest. + // For brevity, we're playing fast and loose with the notation here -- we + // don't literally require Idx to be a concatenation of Ps and Qs, as + // suggested by the "++". + // + // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs)) + // + // Again, we're playing fast and loose with the notation around "++". + // Generally this ++ will be a different function that the ++ in the + // previous step. + // + // If the scalar-indexed node has a constant as the source then the + // SourceIndexOfReshape function can be "folded into" the constant itself by + // reshaping it, leaving us with: + // + // == SourceIndexOfScalarIndexed(Ps ++ Qs) + // == SourceIndexOfScalarIndexed(Idx) + // + // which is just a scalar-indexed node (with parameters different from the + // scalar-indexed node we started with) with a reshaped constant as the + // source. + // + // We can't fold SourceIndexOfReshape into the constant without introducing + // another precondition: since the new scalar-indexed node will have a + // reshaped (constant) array as its source it will, in general, have a + // different source dimension than the original scalar-indexed node. This + // source dimension will have to be a passthrough dimension of the + // SourceIndexOfReshape indexing function that is folded into the source. And + // such a dimension need not exist so this is a non-trivial precondition. + + std::vector reshape_passthrough_dims = + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()), + /*result_shape=*/AsInt64Slice(shape.dimensions())); + + auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { + return IsReshapePassthroughOperandDim(reshape_passthrough_dims, + operand_dim); + }; + + if (!c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { + VLOG(3) << "Not all output dims are passthrough dims " + << ToString(scalar_indexed); + return nullptr; + } + + // To compute the shape of the source for the new scalar-indexed node we're + // going to create, we first "undo" the scalar-indexed operation. + std::vector new_scalar_indexed_source_shape(shape.dimensions().begin(), + shape.dimensions().end()); + for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) { + int64 output_dim = scalar_indexed->output_dims()[i]; + int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim( + reshape_passthrough_dims, output_dim); + EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape); + } + + // After this, we need to add in the dimension that will be the source + // dimension for the new scalar-indexed node. A scalar-indexed node "removes" + // the source dimensions and "adds" the output dimensions, so to get back to + // the shape for the *source* of the scalar-indexed node we need to remove the + // output dims (which we did above) and then add back the source dim (which we + // are about to do below): + + const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape(); + + int64 source_dim_for_new_scalar_indexed_node = + FindSourcePositionForPassthroughResultDim( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape, + scalar_indexed->source_dim()); + + // We may not be able to find a source dim for the new scalar-indexed node. + // For instance consider: + // + // operand = s32[3,5,2] constant({...}) + // indices = s32[7] parameter(0) + // gather = s32[3,2,7] gather(operand, indices), + // output_window_dims={0,1}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3,1,2} + // reshape = s32[6,7] reshape(gather) + // + // In this case the gather maps to: + // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2]) + // + // and the reshape passes through dimension 2 from its input into dimension 1 + // in its output. However, we can't rewrite the reshape as a scalar-indexed + // node because then we'd have to reshape the [3,5,2] `operand` array to + // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently + // (a.k.a. isn't pass-through) than the [3,5,2] array. + + if (source_dim_for_new_scalar_indexed_node == -1) { + VLOG(3) << "Could not compute the source dim for the new scalar indexed " + "node: scalar_indexed_source_shape = [" + << Join(scalar_indexed_source_shape.dimensions(), ",") + << "] and new_scalar_indexed_source_shape = [" + << Join(new_scalar_indexed_source_shape, ",") << "]"; + return nullptr; + } + + InsertAt( + &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, + scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + + CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, + std::multiplies()), + ShapeUtil::ElementsIn(scalar_indexed_source_shape)); + + CHECK(IsReshapePassthroughOperandDim( + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape), + scalar_indexed->source_dim())); + + auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) { + return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims, + result_dim); + }; + + std::vector output_dims_for_new_scalar_indexed_node; + c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); + + TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, + TakeOwnership(scalar_indexed->literal().Reshape( + new_scalar_indexed_source_shape))); + TF_ASSIGN_OR_RETURN( + Array * new_scalar_indexed_source, + ComputeArrayForConstant(*new_scalar_indexed_source_literal)); + + return ConstructScalarIndexedArray( + new_scalar_indexed_source, scalar_indexed->indices(), + source_dim_for_new_scalar_indexed_node, + output_dims_for_new_scalar_indexed_node, shape); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + if (ShapeUtil::Compatible(operand->shape(), shape)) { + return operand; + } + + if (auto* scalar_indexed = + dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather, + FoldReshapeOfGather(shape, scalar_indexed)); + if (reshape_folded_into_gather) { + return reshape_folded_into_gather; + } + } + + if (auto* constant_array = dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Literal* const new_literal, + TakeOwnership(constant_array->literal()->Reshape( + AsInt64Slice(shape.dimensions())))); + return Construct(new_literal); + } + + return Construct(operand, shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs) { + // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) + // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) + // + // We can do this if every output dimension from the scalar-indexed node is a + // broadcasted dimension for the broadcast node. Informally, the precondition + // means Broadcast(Const0)[IDX] is solely a function of the components of IDX + // that are not output-dims for the scalar-indexed node. In other words, for + // every assignment to the non-output dims in IDX we have a "constant" LHS to + // the BinaryOp. This transform propagates this "constant" to the source for + // the scalar-indexed node. + + ScalarIndexedConstantArray* lhs_scalar_indexed_const = + dynamic_cast(lhs); + ScalarIndexedConstantArray* rhs_scalar_indexed_const = + dynamic_cast(rhs); + + bool lhs_is_indexed; + + // One of the operands must be scalar-indexed and the other must be a + // broadcast of a constant. + if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) { + lhs_is_indexed = true; + } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) { + lhs_is_indexed = false; + } else { + return nullptr; + } + + ScalarIndexedConstantArray* scalar_indexed_const = + lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const; + UnknownArray* candidate_broadcast_array = + dynamic_cast(lhs_is_indexed ? rhs : lhs); + if (!candidate_broadcast_array || + candidate_broadcast_array->instruction().opcode() != + HloOpcode::kBroadcast) { + return nullptr; + } + + const HloInstruction* broadcast_instr = + &candidate_broadcast_array->instruction(); + const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0); + if (broadcast_const_operand->opcode() != HloOpcode::kConstant) { + return nullptr; + } + + ArraySlice broadcast_dims = broadcast_instr->dimensions(); + auto is_broadcasted_dim = [&](int64 output_dim) { + return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + }; + + // All of the output dims must be "broadcasted" dims for the other operand. + if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + return nullptr; + } + + // To figure out the broadcast dimensions for the (constant) source for the + // scalar-indexed node, we "simulate" the index transformation done by the + // existing broadcsat: + enum class IndexComponent { Broadcasted, NotBroadcasted }; + std::vector simulated_index( + broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); + for (int64 broadcast_dim : broadcast_dims) { + simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; + } + + // The scalar-indexed node "removes" the source dim and "inserts" the output + // dims. We do the opposite here to undo the scalar-indexed operation. + ArraySlice output_dims = scalar_indexed_const->output_dims(); + for (int64 i = output_dims.size() - 1; i >= 0; --i) { + CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); + EraseAt(&simulated_index, output_dims[i]); + } + + InsertAt(&simulated_index, scalar_indexed_const->source_dim(), + IndexComponent::Broadcasted); + + // new_inner_broadcast_dims holds the broadcast dimensions for the inner + // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to + // new_inner_broadcast_dims. + std::vector new_inner_broadcast_dims; + for (int64 i = 0; i < simulated_index.size(); i++) { + if (simulated_index[i] == IndexComponent::NotBroadcasted) { + new_inner_broadcast_dims.push_back(i); + } + } + + // inner_broadcast_result is the Broadcast'(Const0) bit in + // BinaryOp(Broadcast'(Const0), Const1) + TF_ASSIGN_OR_RETURN( + std::unique_ptr inner_broadcast_result, + broadcast_const_operand->literal().Broadcast( + scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); + + // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) + const Literal* literal_for_new_source; + if (lhs_is_indexed) { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + } else { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + } + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand) { + auto* scalar_indexed_const = + dynamic_cast(operand); + if (scalar_indexed_const == nullptr) { + return nullptr; + } + + // Fold UnaryOp(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(UnaryOp(Const), Indices) + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( + opcode, scalar_indexed_const->literal()))); + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { + return "indexed-array-analysis-printer-pass"; +} + +StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { + if (!VLOG_IS_ON(2)) { + return false; + } + + IndexedArrayAnalysis analysis; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instr : computation->instructions()) { + TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); + if (!dynamic_cast(t) && !dynamic_cast(t)) { + VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..8684430231c1929f82508e3675f1c275c42b6149 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -0,0 +1,368 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { + kUnknown, + kConstant, + kReshaped, + kScalarIndexedConstant, + kScalarIndexed + }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // Represents an Array that is a reshape of another Array. + class ReshapedArray : public Array { + public: + Kind kind() const override { return kReshaped; } + + // The array to reshape. + Array* operand() const { return operand_; } + + // The output shape. + const Shape& shape() const override { return shape_; } + + private: + explicit ReshapedArray(Array* operand, Shape shape) + : operand_(operand), shape_(shape) {} + + Array* operand_; + const Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the + // input index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + + // `source_dim` is the dimension in the source array that is being indexed + // over using indices from the `indices` array. See the class documentation + // and the overview for more details. + int64 source_dim() const { return source_dim_; } + + // `output_dims` are the dimensions in the output array that are being used + // to compute an index into the `indices` array. See the class + // documentation and the overview for more details. + tensorflow::gtl::ArraySlice output_dims() const { + return output_dims_; + } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64 source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64 source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + StatusOr GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + string ToString(Array* root, bool print_constants = false); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + Status TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + StatusOr ComputeArrayFor(const HloInstruction* instr); + + StatusOr ComputeArrayForConstant(const Literal& literal); + + StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + StatusOr FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape); + + // Reshapes a scalar-indexed node to remove the degenerate dimensions in its + // output. The result is always a scalar-indexed node. + StatusOr ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand); + + // Reshapes a scalar-indexed node such that the result has the degenerate + // dimensions `degenerate_dims`. The result is always a scalar-indexed node. + StatusOr ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims); + + StatusOr FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand); + StatusOr FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); + StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); + + StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, Array* rhs); + StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + Literal* TakeOwnership(std::unique_ptr literal) { + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + StatusOr TakeOwnership( + StatusOr> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + std::vector> owned_tensors_; + std::vector> owned_literals_; + tensorflow::gtl::FlatMap cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloPassInterface { + public: + tensorflow::StringPiece name() const override; + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc2befe05b18651502c42b9892e766145d85f2e8 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -0,0 +1,803 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace { +class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + protected: + void AssertArrayForRootExpressionIs(const string& hlo_text, + const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/false); + } + + void AssertArrayWithConstantsForRootExpressionIs( + const string& hlo_text, const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/true); + } + + private: + // Replaces seqences of whitespace with a single space. This makes the + // strings being matched against "whitespace insensitive" which lets us indent + // them for readability. + string CanonicalizeWhitespace(const string& text) { + string result; + + for (char c : text) { + if (!isspace(c)) { + result.push_back(c); + } else if (!result.empty() && result.back() != ' ') { + result.push_back(' '); + } + } + + while (!result.empty() && result.back() == ' ') { + result.pop_back(); + } + + return result; + } + + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, + const string& root_expression, + bool print_constants) { + IndexedArrayAnalysis indexed_tensor_analysis; + ParseAndVerifyModule(hlo_text); + + TF_ASSERT_OK_AND_ASSIGN( + IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + module().entry_computation()->root_instruction())); + string string_result = CanonicalizeWhitespace( + indexed_tensor_analysis.ToString(array_result, print_constants)); + LOG(INFO) << string_result; + ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression)); + } +}; + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5] parameter(0) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5,2] parameter(0) + ROOT gather = s32[5] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2,3] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={2,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,2} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices_a = s32[5] parameter(0) + indices_b = s32[2] parameter(1) + gather_a = s32[5,3] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT gather_b = s32[2,3] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a " + "%indices_b 0->[0]) 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[2] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 1->[1]) 1->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,6] parameter(0) + indices_a = s32[2] parameter(1) + indices_b = s32[5,7] parameter(2) + gather_a = s32[2,6] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 0->[0,1]) 0->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[4,8] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=2, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " + "1->[0,2]) 1->[0,1,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5] parameter(0) + gather = s32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT reshape = s32[5,2,2] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,7] parameter(0) + gather = s32[5,4,7] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,2,6] constant(s32[3,2,6]{ + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[5,7] parameter(0) + gather = s32[5,2,6,7] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,2,6} + ROOT reshape = s32[5,3,4,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + + i.0 = s64[1,3]{1,0} parameter(0) + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, + elided_window_dims={0}, gather_dims_to_operand_dims={0}, + index_vector_dim=2, window_bounds={1,3} + + i.1 = s64[1] parameter(1) + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, + elided_window_dims={1}, gather_dims_to_operand_dims={1}, + index_vector_dim=1, window_bounds={1,1,3} + + ROOT reshape = s32[1,3]{1,0} reshape(g.1) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,3]) + (reshape + (scalar-indexed %i.0 %i.1 1->[1]) + to s64[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[1,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,2,6] constant(s32[1,2,6]{{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[1] parameter(0) + gather = s32[1,1,6] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={1,1,6} + ROOT reshape = s32[1,1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1,5] parameter(0) + gather = s32[1,5,6] gather(operand, indices), + output_window_dims={2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} + ROOT reshape = s32[1,1,5,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6] s32[2,1,1,6] { + { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, + { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + (reshape %indices to s32[5]) + 0->[2]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,2,3] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,5,2] constant(s32[3,5,2]{ + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}}) + indices = s32[7] parameter(0) + gather = s32[3,2,7] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1,2} + ROOT reshape = s32[6,7] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,5,2]) + %indices + 1->[2]) + to s32[6,7]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4,1] constant(s32[3,4,1]{ + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6,1] gather(operand, indices), + output_window_dims={1,3}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4,1} + ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4,1]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3,1]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { + string hlo_text = R"( +HloModule UnaryOpOfGather + +ENTRY main { + operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + indices = s32[5] parameter(0) + gather = f32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT tanh = f32[5,4] tanh(gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant f32[3,4] f32[3,4] { + { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, + { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, + { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 6, 7, 8, 9 }, + { 6, 8, 7, 9 }, + { 9, 8, 7, 6 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsLhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { -4, -3, -2, -1 }, + { -4, -2, -3, -1 }, + { -1, -2, -3, -4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsRhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 4, 3, 2, 1 }, + { 4, 2, 3, 1 }, + { 1, 2, 3, 4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[4] constant({10,11,12,13}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 11, 13, 15, 17 }, + { 11, 14, 14, 17 }, + { 14, 14, 14, 14 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[5] constant({10,11,12,13,14}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input = f32[100] parameter(0) + ROOT tanh = f32[100] tanh(input) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%tanh"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input0 = f32[100] parameter(0) + input1 = f32[100] parameter(1) + ROOT add = f32[100] add(input0, input1) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index cb6c98c48171a06539499b723a8d8b7aa0ccc96a..da91262130933b6d47fd95fb30bf89574b9469d6 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -83,6 +83,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReal: @@ -96,8 +97,10 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return false; // Cheap instructions for reals, but expensive for complex. @@ -118,6 +121,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kDivide: + case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -178,8 +182,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloReachabilityMap& reachability_map, - const DoNotFuseSet& do_not_fuse) { + const HloInstructionSet& do_not_duplicate) { if (consumer == producer) { return true; } @@ -190,10 +193,11 @@ bool InstructionFusion::CanFuseOnAllPaths( auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter // whether it's fusable. - if (!reachability_map.IsReachable(producer, consumer_operand)) { + if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + if (do_not_duplicate.count(consumer_operand) > 0 || + !ShouldFuse(consumer, i)) { return false; } // The producer is reachable from consumer_operand which means we need @@ -201,18 +205,16 @@ bool InstructionFusion::CanFuseOnAllPaths( // producer to be fusable into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map, - do_not_fuse)) { + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { return false; } } return true; } -InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( +InstructionFusion::HloInstructionSet +InstructionFusion::ComputeGloballyUnfusable( tensorflow::gtl::ArraySlice post_order) { - auto reachability = computation_->ComputeReachability(); - // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. @@ -222,10 +224,10 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( // Note that if we allow fusion by these global rules, we may still forbid // fusing operations that require duplication later depending on // is_expensive_(). - DoNotFuseSet do_not_fuse; + HloInstructionSet do_not_duplicate; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { - if (do_not_fuse.count(producer) > 0) { + if (do_not_duplicate.count(producer) > 0) { continue; } @@ -237,6 +239,30 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( if (EffectivelyAtMostUnary(producer)) { continue; } + + // If the total size of the inputs is less than or equal to the total size + // of the outputs for the producer then duplicating it won't increase the + // memory traffic. In that case, we do not forbid fusion of the operation + // here. + auto total_size = [](const Shape& shape) { + int64 size = 0; + ShapeUtil::ForEachSubshape( + shape, + [&size](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + size += ShapeUtil::ElementsIn(subshape); + } + }); + return size; + }; + int64 operands_size = 0; + for (const HloInstruction* op : producer->operands()) { + operands_size += total_size(op->shape()); + } + if (operands_size <= total_size(producer->shape())) { + continue; + } + // Otherwise we will forbid fusing the op unless we can fuse it into // all of its consumers on all paths. // @@ -254,14 +280,14 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( // A will be not allowed to be fused into B, as it cannot be fused via // all paths. if (producer->IsFusable() && - CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) { + CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } - do_not_fuse.insert(producer); + do_not_duplicate.insert(producer); } } - return do_not_fuse; + return do_not_duplicate; } StatusOr InstructionFusion::Run(HloModule* module) { @@ -273,6 +299,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; + reachability_ = computation_->ComputeReachability(); // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make @@ -280,17 +307,15 @@ StatusOr InstructionFusion::Run(HloModule* module) { // map from HloInstruction* to the instruction's index in the vector. An // instruction is "removed" from the vector by setting it's element to // nullptr. - std::list post_order_list = + std::vector post_order = computation_->MakeInstructionPostOrder(); - std::vector post_order(post_order_list.begin(), - post_order_list.end()); tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { InsertOrDie(&post_order_index, post_order[i], i); } - DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -358,9 +383,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { // ensures that B will be considered before A. // // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers(instruction->operands().size()); - std::iota(std::begin(sorted_operand_numbers), - std::end(sorted_operand_numbers), 0); + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the post-order index. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (post_order_index.find(instruction->mutable_operand(i)) == + post_order_index.end()) { + continue; + } + sorted_operand_numbers.push_back(i); + } std::sort( sorted_operand_numbers.begin(), sorted_operand_numbers.end(), [&](int64 i, int64 j) { @@ -377,13 +413,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { if (!operand->IsFusable()) { continue; } - if (!ShouldFuse(instruction, i)) { - continue; - } - if (do_not_fuse.count(operand) > 0) { + + HloInstruction* fusion_instruction; + // Try "regular" fusion if the operand may be duplicated. Otherwise, + // perform multi-output fusion, unless this creates a cycle. + // TODO(tjoerg): Consider making multi-output fusion the default. + if (ShouldFuse(instruction, i) && + do_not_duplicate.count(operand) == 0) { + fusion_instruction = Fuse(operand, instruction); + } else if (ShouldFuseIntoMultiOutput(instruction, i) && + !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_instruction = FuseIntoMultiOutput(operand, instruction); + } else { continue; } - HloInstruction* fusion_instruction = Fuse(operand, instruction); // Fusing an instruction into a fusion instruction can change the // operand set of the fusion instruction. For simplicity just push the @@ -449,6 +492,19 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( return fusion_instruction; } +bool InstructionFusion::MultiOutputFusionCreatesCycle( + HloInstruction* producer, HloInstruction* consumer) { + return c_any_of( + consumer->operands(), [&](const HloInstruction* consumer_operand) { + // The fusion algorithm traverses the HLO graph in reverse post order. + // Thus `cosumers` is visited before its operands (including + // `producer`). Therefore, consumer operands cannot have been fused yet. + // It is thus safe to use the pre-computed reachability map. + return consumer_operand != producer && + reachability_->IsReachable(producer, consumer_operand); + }); +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index c3c2ed0aaa81d6f346ec6e70d9c8b3b923e0a3d2..f73ca9adf768ed26f9ec9f162e01b7b160f50daf 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -61,6 +61,14 @@ class InstructionFusion : public HloPassInterface { // Subtypes can override this with target-specific heuristics. virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index); + // Returns whether multi-output fusion can be applied to fuse `producer` into + // `consumer`. In contrast to "regular" fusion, the `producer` is not + // duplicated by multi-output fusion. + virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + return false; + } + // Chooses a fusion kind for `producer` and `consumer`. // Default method chooses `kLoop`. virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, @@ -97,10 +105,12 @@ class InstructionFusion : public HloPassInterface { // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; + // Reachability information for the current computation. + std::unique_ptr reachability_; private: // The set of producers whose consumers we cannot fuse into. - using DoNotFuseSet = std::unordered_set; + using HloInstructionSet = std::unordered_set; HloInstruction* AddFusionInstruction(HloInstruction* producer, HloInstruction* consumer); @@ -108,18 +118,21 @@ class InstructionFusion : public HloPassInterface { // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloReachabilityMap& reachability_map, - const DoNotFuseSet& do_not_fuse); + const HloInstructionSet& do_not_fuse); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - DoNotFuseSet ComputeGloballyUnfusable( + HloInstructionSet ComputeGloballyUnfusable( tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index df109df7877eefe4c337f93cc5a3a7a48e2e76c7..bb7231c8c868ff2fefa3e88c4be036a89ed29118 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { @@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion { }; TEST_F(InstructionFusionTest, FuseInstructions) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) { } TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module fused_computation { p1 = f32[4,3] parameter(0) @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { } TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -167,7 +167,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); @@ -195,7 +196,7 @@ static int Count(const HloModule& module, HloOpcode op) { } TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -220,7 +221,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // // p0 -> add -------------------------> sub // \-> abs1 -> rng -> abs2 -/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -251,14 +252,15 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // p0 -> add -------------------------> sub // \-> abs1 -> log -> abs2 -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -282,13 +284,14 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \ \-> add2 -/ // \-> log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -314,14 +317,15 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \------> sub1 // log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) @@ -352,7 +356,8 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); HloInstruction* unary1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); - builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0)); HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); @@ -375,7 +380,8 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); @@ -390,7 +396,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY Test { p0 = f16[100] parameter(0) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3ff15512fb0bba22b1fa86269f5a7db993f0cf8c..9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,8 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->device_entry_computation_layout()); - + hlo_module->mutable_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } @@ -70,7 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module)); + xla::MakeUnique(std::move(hlo_module), + xla::MakeUnique()); return std::move(executable); } @@ -100,17 +100,14 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() return InterpreterExecutable::ShapeSizeBytes; } -static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); -} - static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { return xla::MakeUnique(); }); xla::ComputationPlacer::RegisterComputationPlacer( - se::interpreter::kXlaInterpreterPlatformId, &CreateComputationPlacer); + se::interpreter::kXlaInterpreterPlatformId, + []() { return xla::MakeUnique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 61f199bc9e8f4f95a2f097af4abf9395a1e05f64..9816acf6507a0ed5391cf4f1c94ccd0f27f5227a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -32,16 +31,17 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module, + std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + /*hlo_profile_index_map=*/nullptr), + evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} @@ -75,17 +75,20 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // consumes. std::vector> arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr arg_literal, - transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice( + run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, - evaluator.Evaluate>(*computation, arg_literals)); + std::unique_ptr result_literal; + { + tensorflow::mutex_lock lock(evaluator_lock_); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate>( + *computation, arg_literals)); + } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -93,7 +96,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( result_literal->shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - executor, *result_literal, result)); + run_options->stream(), *result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index b0b797ca7d6f449a11c662ffba7c2a0a0040e47e..91d8148d26dc8eddbafdaf4870d9efbb73a12816 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -40,13 +42,15 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module, + std::unique_ptr evaluator); ~InterpreterExecutable() override; StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override + LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, @@ -54,6 +58,11 @@ class InterpreterExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); + protected: + // The interpreter interprets executables with an HloEvaluator. + std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + mutable tensorflow::mutex evaluator_lock_; + private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 97e9fa2c8e8ecd918ffe3df2fd4e731f3b91e6db..4fb67bd0b72fc591c1ffa76ebb0513bf14ed3737 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -53,6 +53,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } @@ -61,6 +62,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 92e069a8c67c1d441ba9d396dee503c9b3bde0df..42c2c28997d5f3b02f1fe4effca164c893e4071d 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -31,13 +30,13 @@ limitations under the License. namespace stream_executor { namespace interpreter { -XlaInterpreterPlatform::XlaInterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, + const Platform::Id& id) + : name_(name), id_(id) {} XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id XlaInterpreterPlatform::id() const { - return kXlaInterpreterPlatformId; -} +Platform::Id XlaInterpreterPlatform::id() const { return id_; } int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } @@ -106,8 +105,6 @@ REGISTER_MODULE_INITIALIZER( interpreter_platform, stream_executor::interpreter::InitializeXlaInterpreterPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); - // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index d68c5aa20dda7ac246ed4aa667851e385a604c04..0187f6d473b19f50136e214708e56f833627d9d1 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -28,7 +29,8 @@ namespace interpreter { class XlaInterpreterPlatform : public Platform { public: - XlaInterpreterPlatform(); + XlaInterpreterPlatform(const string& name = "Interpreter", + const Platform::Id& id = kXlaInterpreterPlatformId); ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -55,6 +57,8 @@ class XlaInterpreterPlatform : public Platform { private: // This platform's name. string name_; + // This platform's id. + Platform::Id id_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index cfa7ba5e81ddd003978a2bd763384581c55b5c83..fedc83c8f8384a75beba7081e7e9c6094249178f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -173,41 +175,32 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - const BufferLayoutConstraint* curr_constraint = - GetBufferLayoutConstraint(buffer); - if (curr_constraint != nullptr) { - if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { + auto iter = buffer_constraints_.find(&buffer); + if (iter != buffer_constraints_.end()) { + const BufferLayoutConstraint& curr_constraint = iter->second; + if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } - if (curr_constraint->mandatory()) { + if (curr_constraint.mandatory()) { return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint->layout()).c_str(), + LayoutUtil::HumanString(curr_constraint.layout()).c_str(), LayoutUtil::HumanString(layout).c_str()); } - } - - auto iter = buffer_constraints_.find(&buffer); - bool overwrite = iter != buffer_constraints_.end(); - if (!overwrite) { + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); + } else { + TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1) + << buffer.ToString(); iter = buffer_constraints_ .insert(std::make_pair( &buffer, BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; - } else { - iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); - - // Remove buffer from the set of unconstrained buffers. - TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == - static_cast(!overwrite)); - unconstrained_buffer_ids_.erase(buffer.id()); - return Status::OK(); } @@ -400,9 +393,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints) { + const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, HloComputation* computation, + LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -424,11 +417,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { - // Parameter layouts must match the respective layout in - // ComputationLayout. - shape_with_layout = - &computation_layout.parameter_layout(instruction->parameter_number()) - .shape(); + if (computation_layout != nullptr) { + const ShapeLayout& parameter_layout = + computation_layout->parameter_layout( + instruction->parameter_number()); + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + shape_with_layout = ¶meter_layout.shape(); + } + } } if (shape_with_layout != nullptr) { TF_RETURN_IF_ERROR( @@ -493,9 +491,8 @@ Status LayoutAssignment::AddMandatoryConstraints( HloComputation* body = instruction->while_body(); HloComputation* condition = instruction->while_condition(); const HloInstruction* init = instruction->operand(0); - const ComputationLayout& body_layout = - FindOrDie(computation_layouts_, body); - const ComputationLayout& condition_layout = + ComputationLayout& body_layout = FindOrDie(computation_layouts_, body); + ComputationLayout& condition_layout = FindOrDie(computation_layouts_, condition); // Check a few invariants irrespective of layout. @@ -508,26 +505,19 @@ Status LayoutAssignment::AddMandatoryConstraints( condition_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); - // Return error if earlier layout assignment of the embedded computations - // has produced conflicting layouts. - if (!ShapeUtil::Equal(body_layout.result_shape(), - body_layout.parameter_shape(0))) { - return InternalError( - "Parameter and result of body computation %s of while instruction " - "%s have different layouts: %s vs %s", - body->name().c_str(), instruction->name().c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str(), - ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + if (body_layout.result_layout() != body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while body parameter layout: body=" << body->name() + << " while=" << instruction->name() + << " shape=" << body_layout.result_layout().ToString(); + *body_layout.mutable_parameter_layout(0) = body_layout.result_layout(); } - if (!ShapeUtil::Equal(body->root_instruction()->shape(), - condition->parameter_instruction(0)->shape())) { - return InternalError( - "Parameter of condition computation %s of while instruction " - "%s does not match body computation %s result: %s vs %s", - condition->name().c_str(), instruction->name().c_str(), - body->name().c_str(), - ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + if (condition_layout.parameter_layout(0) != + body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while condition parameter layout: cond=" + << condition->name() << " while=" << instruction->name() + << " shape=" << body_layout.parameter_layout(0).ToString(); + *condition_layout.mutable_parameter_layout(0) = + body_layout.parameter_layout(0); } // Constrain the output and the operand of the while instruction to match @@ -557,7 +547,20 @@ Status LayoutAssignment::AddMandatoryConstraints( true_computation_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible( false_operand->shape(), false_computation_layout.parameter_shape(0))); - + if (true_computation_layout.result_layout() != + false_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the true and false computations + // might have negotiated a different layout. But for the conditional + // instruction POV the layout must match, so we run again on the false + // computation, this time with proper computation layout. + VLOG(2) << "Reset %conditional false computation result layout: " + "false_computation=" + << false_computation->name() + << " conditional=" << instruction->name() << " shape=" + << true_computation_layout.result_layout().ToString(); + *false_computation_layout.mutable_result_layout() = + true_computation_layout.result_layout(); + } TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( true_computation_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( @@ -593,10 +596,14 @@ Status LayoutAssignment::AddMandatoryConstraints( } } } - - // Finally set the result layout to match ComputationLayout. - return constraints->SetResultLayout( - computation_layout.result_layout().shape()); + // Finally set the result layout to match ComputationLayout, if there is one. + if (computation_layout != nullptr) { + const ShapeLayout& result_layout = computation_layout->result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); + } + } + return Status::OK(); } namespace { @@ -700,7 +707,8 @@ Status CheckParameterLayout(HloInstruction* parameter, const ComputationLayout& computation_layout) { const ShapeLayout& parameter_layout = computation_layout.parameter_layout(parameter->parameter_number()); - if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) { + if (parameter_layout.LayoutIsSet() && + !parameter_layout.MatchesLayoutInShape(parameter->shape())) { return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", @@ -760,6 +768,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); + RegisterAddedCopy(copy); SetupCopiedInstruction(*instruction, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( @@ -783,13 +792,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + VLOG(5) << "Operand " << operand->ToString() << " layout matches in " + << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. return Status::OK(); } + VLOG(4) << "Operand " << operand->ToString() << " layout does not match " + << operand_layout.ToString() << " in " << instruction->ToString(); TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, CreateCopyWithNewLayout(operand_layout.shape(), operand)); + VLOG(4) << "New copy of " << operand->ToString() << " is " + << operand_copy->ToString(); return instruction->ReplaceOperandWith(operand_no, operand_copy); } @@ -896,32 +911,32 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } } } - - // Finally verify the result layout matches the layout of the entry + // Finally verify the result layout, if set, matches the layout of the entry // computation root. - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), + const ShapeLayout& result_layout = FindOrDie(computation_layouts_, module->entry_computation()) - .result_layout() - .shape())); - + .result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + result_layout.shape())); + } return Status::OK(); } LayoutAssignment::LayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), channel_layout_constraints_(channel_constraints) { - VLOG(1) << "entry computation layout given to layout assignment: " - << entry_computation_layout_.ToString(); - // Layouts of all parameter instructions must be set. - for (const ShapeLayout& parameter_layout : - entry_computation_layout_.parameter_layouts()) { - CHECK(parameter_layout.LayoutIsSet()); + if (channel_layout_constraints_ != nullptr) { + // Save a copy of the input ChannelLayoutConstraints so that we can reset it + // if we have to undo previous operations (ClearPreviousPassSideEffects()). + channel_constraints_ = *channel_layout_constraints_; } - // TODO(b/29118294): Choose a better layout if the result layout is not set. - CHECK(entry_computation_layout_.result_layout().LayoutIsSet()); + VLOG(1) << "Entry computation layout given to layout assignment: " + << entry_computation_layout_->ToString(); } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1481,16 +1496,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); } +Status LayoutAssignment::CalculateComputationLayout( + HloComputation* computation) { + ComputationLayout computation_layout(computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << " Calculated ComputationLayout = " + << computation_layout.ToString(); + return Status::OK(); +} + +Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + return Status::OK(); +} + Status LayoutAssignment::RunOnComputation( - const ComputationLayout& computation_layout, + ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { - DCHECK(computation_layout.LayoutIsSet()); - InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; - VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); + if (computation_layout != nullptr) { + auto it = computation_layouts_.find(computation); + if (it == computation_layouts_.end()) { + VLOG(2) << " New ComputationLayout = " << computation_layout->ToString(); + computation_layouts_.emplace(computation, *computation_layout); + } else { + TF_RET_CHECK(computation_layout == &it->second || + computation_layout == entry_computation_layout_); + VLOG(2) << " Existing ComputationLayout = " + << computation_layout->ToString(); + } + } else { + VLOG(2) << " No ComputationLayout specified (will be calculated)"; + } // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1506,6 +1565,13 @@ Status LayoutAssignment::RunOnComputation( // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); + // Prior to applying default layouts, we take note of all HLO instructions + // which lack a layout constraint. + for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) { + unconstrained_layout_instructions_.insert( + points_to_analysis.GetBuffer(buffer_id).instruction()); + } + // While any unconstrained buffers remain, pick an arbitrary buffer, give it a // layout and propagate the change. while (!constraints.unconstrained_buffer_ids().empty()) { @@ -1533,26 +1599,106 @@ Status LayoutAssignment::RunOnComputation( CHECK_LT(constraints.unconstrained_buffer_ids().size(), unconstrained_count); } - // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + // If the computation layout wasn't specified, now it is the time to compute + // it according to the parameters and root instruction layouts. + // This allows the first pass through this API to record the best flowing + // layout to parameters and root instruction. + if (computation_layout == nullptr) { + TF_RETURN_IF_ERROR(CalculateComputationLayout(computation)); + } + // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. + if (channel_constraints != nullptr) { + TF_RETURN_IF_ERROR( + ConstrainChannelLayouts(computation, channel_constraints)); + } + return Status::OK(); +} + +Status LayoutAssignment::ConstrainChannelLayouts( + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints) { + // We go through the kRecvDone before. These must either impose their layout, + // of find a matching one already existing (ConstrainChannel() returns + // nullptr). for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRecvDone) { + const Layout* layout = channel_constraints->ConstrainChannel( + instruction->channel_id(), + ShapeUtil::GetSubshape(instruction->shape(), {0}).layout()); + TF_RET_CHECK(layout == nullptr) + << instruction->ToString() + << " cannot constrain layout as it was set to " + << LayoutUtil::HumanString(*layout); + } + } + // After that we go through the kSend. These are likely going to have a kCopy + // as operand (otherwise we add it), so in case the constrained layout does + // not match, we can change the kCopy layout (and the kSend one as well). + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kSend) { - channel_constraints->ConstrainChannel( - instruction->channel_id(), instruction->operand(0)->shape().layout()); - } else if (instruction->opcode() == HloOpcode::kRecvDone) { - channel_constraints->ConstrainChannel(instruction->channel_id(), - instruction->shape().layout()); + HloInstruction* operand = instruction->mutable_operand(0); + const Layout* layout = channel_constraints->ConstrainChannel( + instruction->channel_id(), operand->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the kSend wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + Shape* send_shape = + ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); + *send_shape = shape; + } } } return Status::OK(); } +Status LayoutAssignment::PropagateComputationLayouts( + HloComputation* computation, ComputationLayout* computation_layout) { + ComputationLayout computed_computation_layout( + computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) { + ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); + if (!param_layout->LayoutIsSet()) { + VLOG(4) << "Assigning layout to parameter " << i << " of computation " + << computation->name() << ": " + << computed_computation_layout.parameter_layout(i).ToString(); + *param_layout = computed_computation_layout.parameter_layout(i); + } else { + TF_RET_CHECK(computed_computation_layout.parameter_layout(i) == + *param_layout); + } + } + ShapeLayout* result_layout = computation_layout->mutable_result_layout(); + if (!result_layout->LayoutIsSet()) { + VLOG(4) << "Assigning result layout of computation " << computation->name() + << ": " << computed_computation_layout.result_layout().ToString(); + *result_layout = computed_computation_layout.result_layout(); + } else { + TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + } + return Status::OK(); +} + StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -1561,52 +1707,46 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "before layout assignment", module->config().debug_options()); } - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Assign layouts to computations in an order such that a callee computation - // is handled before its caller computation. This ensures that the layout of - // all callers of a computation will agree. - std::list computation_post_order = - module->MakeComputationPostOrder(); - for (auto* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + TF_RETURN_IF_ERROR(Init()); + + // We do two passes. The first one we pass a nullptr ComputationLayout to + // the RunOnComputation() calls (for non entry computations), and we register + // the ComputationLayout which are naturally flowing in DFS fashion to the + // parameters and root instruction. + // Walking in DFS mode though, means that we can end up with incorrect layouts + // when seen from an outer instruction, which has across-computation + // constraints to impose. + // For example, the kWhile instruction needs to enforce the same layouts for + // the parameters and root of the body, as well as the condition parameters. + // Similarly, the kConditional instruction needs to enforce the same layouts + // for the root of the true and false computations. + // So in the first pass, while allowing the layouts to flow to parameters and + // root, we also fix up the eventually inconsistent ComputationLayout, which + // will be then made mandatory by the second pass. + for (int64 i = 0; i < 2; ++i) { + VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; + TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; } - if (instruction->opcode() != HloOpcode::kInfeed) { - LayoutUtil::ClearLayout(instruction->mutable_shape()); + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); + } else { + ComputationLayout* computation_layout = + (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, + *points_to_analysis, computation, + channel_layout_constraints_)); } } - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); - } else { - ComputationLayout computation_layout(computation->ComputeProgramShape()); - // Setting all embedded computations to the default layout is potentially - // suboptimal. - computation_layout.SetToDefaultLayout(); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, - channel_layout_constraints_)); - } } - + TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; @@ -1616,9 +1756,58 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "after layout assignment", module->config().debug_options()); } - // All layouts are reset then reassigned by this pass. return true; } +Status LayoutAssignment::Init() { + computation_layouts_.clear(); + *entry_computation_layout_ = saved_entry_computation_layout_; + return Status::OK(); +} + +Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + VLOG(5) << "Clearing previous side effects"; + // Clear all the copies which have been added, and all the related + // instructions (like GTE and tuples). + int64 removed_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCopy && + added_copies_.count(instruction) > 0) { + VLOG(5) << "Removing added copy: " << instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + ++removed_copies; + } + } + } + added_copies_.clear(); + unconstrained_layout_instructions_.clear(); + if (removed_copies > 0) { + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + ResetChannelConstraints(); + return Status::OK(); +} + +Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, + int64 operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + SetupCopiedInstruction(*operand, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + } + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 9663a793fdd7d4968700707a1003319e89ea19a3..b75ecb311a07b996562460fc5d6fbd8e70ac056b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -248,25 +249,30 @@ class ChannelLayoutConstraints { // Given `shape`, apply the layout for `channel_id`. `channel_id` must already // be constrained. Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - *shape.mutable_layout() = constraints_.at(channel_id); + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + *shape.mutable_layout() = it->second; return shape; } // Returns the layout constraint for `channel_id`, which must already be // constrained. - Layout LayoutForChannel(int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - return constraints_.at(channel_id); + const Layout& LayoutForChannel(int64 channel_id) const { + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + return it->second; } // Adds a new layout constraint for `channel_id`. If a constraint for - // `channel_id` already exists, this operation requires that the new layout is - // the same as the previously constrained layout. - void ConstrainChannel(int64 channel_id, const Layout& layout) { - CHECK(!IsChannelConstrained(channel_id) || - LayoutUtil::Equal(layout, constraints_[channel_id])); - constraints_[channel_id] = layout; + // `channel_id` has been added, this API returns nullptr, otherwise returns + // the layout which has already been set for the channel. + const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) { + auto it = constraints_.emplace(std::make_pair(channel_id, layout)); + if (it.second) { + return nullptr; + } + return LayoutUtil::Equal(layout, it.first->second) ? nullptr + : &it.first->second; } private: @@ -288,7 +294,7 @@ class LayoutAssignment : public HloPassInterface { // If channel_constraints is nullptr, no kSend or kRecvs must be contained // within any module passed to `Run`. explicit LayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} tensorflow::StringPiece name() const override { return "layout-assignment"; } @@ -362,12 +368,15 @@ class LayoutAssignment : public HloPassInterface { int64 operand_no); private: + // Initializes the layout assignment object for a new Run() call. + Status Init(); + // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints); + Status AddMandatoryConstraints(const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, + LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -378,10 +387,12 @@ class LayoutAssignment : public HloPassInterface { } // Construct contraints and assign layouts to all instructions in the - // computation satisfying the given ComputationLayout. Layouts constraints are - // added, then propagated until all LogicalBuffers in the computation are - // constrained. - Status RunOnComputation(const ComputationLayout& computation_layout, + // computation satisfying the given ComputationLayout, if not nullptr. + // Otherwise the ComputationLayout will be calculated by propagating the + // computation instruction contraints. + // Layouts constraints are added, then propagated until all LogicalBuffers in + // the computation are constrained. + Status RunOnComputation(ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -402,7 +413,31 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); - const ComputationLayout& entry_computation_layout_; + // Computes the ComputationLayout of the given computation based of the + // layouts assigned to parameters and root instruction, and inserts it to the + // computation_layouts_ map. + Status CalculateComputationLayout(HloComputation* computation); + + // Clears all the layouts which can be cleared within a computation. + Status ClearComputationLayouts(HloComputation* computation); + + // Clears the side effects of a previous pass, like added copy instructions. + Status ClearPreviousPassSideEffects(HloModule* module); + + // Propagates the layouts computed by the layout assignment pass on the given + // computation, to the computation layout passed in to this API. + // This API propagates missing layout, and also checks that the caller + // specified have been respected, by comparing those with the parameters and + // root computation instruction. + Status PropagateComputationLayouts(HloComputation* computation, + ComputationLayout* computation_layout); + + // The pointer to the ComputationLayout passed as constructor parameter. + ComputationLayout* entry_computation_layout_; + + // A copy of entry_computation_layout_ used to reset it to the initial values + // during the multiple passes done by the layout assignment operation. + ComputationLayout saved_entry_computation_layout_; protected: // Sets up the copy instruction according to the characteristic (sharding, @@ -418,22 +453,64 @@ class LayoutAssignment : public HloPassInterface { // Creates and returns a copy of the given instruction with a different // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple // instruction producing the copy is returned. - static StatusOr CreateCopyWithNewLayout( + StatusOr CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction); // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + + // Registers a copy instruction added by the layout assignment pass. + void RegisterAddedCopy(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + added_copies_.insert(copy); + } + + // Adds a copy for the operand of an instruction, unless such operand is + // already a copy, and has a single user (which is forcibly the instruction + // itself). + Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); + + // Apply the channel layout constraints by populating the channel_constraints + // data structure passed in at constructor time. Eventually adds copies in + // case two ends of a channel ended up with a different leyout. + Status ConstrainChannelLayouts(HloComputation* computation, + ChannelLayoutConstraints* channel_constraints); + + // Resets the input ChannelLayoutConstraints to the original copy received + // from the constructor input. + void ResetChannelConstraints() { + if (channel_layout_constraints_ != nullptr) { + *channel_layout_constraints_ = channel_constraints_; + } + } // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; - ChannelLayoutConstraints* channel_layout_constraints_; + + // Every copy added to the module by the layout assignment pass is registered + // here. + tensorflow::gtl::FlatSet added_copies_; + + // The pointer to the channel layout constraints passed in with the + // constructor. If not nullptr, this is an input/output argument. + ChannelLayoutConstraints* channel_layout_constraints_ = nullptr; + + // A copy of the input layout constraints used to reset the above pointer in + // case we have to undo operations due to the multiple passes over the + // computations/instructions. + ChannelLayoutConstraints channel_constraints_; + + // The set of HLO instructions which lacked any layout constraint, thus + // receiving propagated default layouts. + tensorflow::gtl::FlatSet + unconstrained_layout_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 986e177406b634598fea9a1f850fcbfbae1728dc..a673901c756950802884187248f4f0c66aee55ce 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -52,10 +52,18 @@ using ::testing::ElementsAre; class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(*entry_computation_layout); + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints = nullptr) { + LayoutAssignment layout_assignment( + entry_computation_layout, /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } + + std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + auto minor_to_major = + FindInstruction(module, name)->shape().layout().minor_to_major(); + return std::vector(minor_to_major.begin(), minor_to_major.end()); + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -285,7 +293,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - LayoutAssignment layout_assignment(computation_layout); + LayoutAssignment layout_assignment(&computation_layout); AssignLayouts(module.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to @@ -488,7 +496,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { public: explicit OperandsMustBeTheSameLayoutAssignment( ComputationLayout* entry_computation_layout) - : LayoutAssignment(*entry_computation_layout) {} + : LayoutAssignment(entry_computation_layout) {} protected: Status PropagateBufferConstraint( @@ -651,7 +659,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); module = backend() @@ -691,7 +699,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( @@ -707,17 +715,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - auto layout_of = [&](tensorflow::StringPiece name) { - return FindInstruction(module.get(), name) - ->shape() - .layout() - .minor_to_major(); - }; - - EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); EXPECT_THAT(FindInstruction(module.get(), "gte1") ->shape() .tuple_shapes(0) @@ -769,9 +770,13 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction( HloInstruction::CreateParameter(0, tshape, "param")); // Using infeed as layout assignment does not mess up with it. - auto infeed = - false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); - false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + auto token = + false_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = false_builder.AddInstruction( + HloInstruction::CreateInfeed(xshape, token, "")); + auto infeed_data = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xshape, infeed, 0)); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = module->AddEmbeddedComputation(false_builder.Build()); @@ -807,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(computation_layout); + LayoutAssignment layout_assignment(&computation_layout); Status error_status = layout_assignment.Run(module.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( @@ -816,5 +821,46 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { "Unexpected bitcast operation seen during layout assignment")); } +TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + token = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, + sharding={maximal device=1} + ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 + send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + sharding={maximal device=0} + send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::GetSubshape( + FindInstruction(module.get(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 21bca1d6beff5b2804531724b94b123d4523c173..f200a08a3cd7e33351ec4607d67d40e7ab28f3b9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -32,7 +32,8 @@ static const BufferAllocation* kParameterAllocation = new BufferAllocation( LogicalBuffer::Color(0)); void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array) { + llvm_ir::IrArray* array, + const ShapeIndex& index) { BufferAllocation::Slice buffer_slice; if (hlo.opcode() == HloOpcode::kParameter) { // Parameters may alias with each other but may not alias with our temporary @@ -40,7 +41,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { const std::set slices = - assignment_.GetAllSlices(&hlo, /*index=*/{}); + assignment_.GetAllSlices(&hlo, index); if (slices.empty() || slices.size() > 1) { // Skip HLOs which don't have a buffer assigned or for which the // buffer can't be determined statically. We cannot determine their @@ -137,16 +138,18 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( // 2. Operands of users of the given hlo. // 3. Operands of the given hlo. // - // This set can be increased as we need. For now only consider top-level - // buffers (index = {}) not buffers nested within the instruction's - // operands/output which are not typically touched. + // This set can be increased as we need. std::vector worklist; auto add_buffers_to_worklist = [&worklist, &assignment](const HloInstruction* instruction) { - for (const LogicalBuffer* buffer : - assignment.GetSourceBuffers(instruction, /*index=*/{})) { - worklist.push_back(buffer); - } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& /*shape*/, const ShapeIndex& index) { + for (const LogicalBuffer* buffer : + assignment.GetSourceBuffers(instruction, index)) { + worklist.push_back(buffer); + } + }); }; for (HloInstruction* user : hlo.users()) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 5244ac61e56307857aca659854647bd6c3e991d7..fe9eab93aae95557e3ee27a64c09b78f37ac2348 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -38,7 +38,8 @@ class AliasAnalysis { // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array); + llvm_ir::IrArray* array, + const ShapeIndex& index = {}); private: // Returns a unique alias domain for this emitter. diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f172b1d87c870270436f7301ed200b47d08431a7..d909845a3a21fc55e44b0037371fca30e577980f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7323abeb2077154f82828bcda3e90eb45a67138a..ea10cef49a4a9aa048b3e0ea443f052645c4912a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,9 +29,9 @@ limitations under the License. namespace xla { namespace llvm_ir { -static void Delinearize(std::vector* multidim, - llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) { +void IrArray::Index::Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) const { int64 divisor = 1; const Layout& layout = shape.layout(); for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { @@ -48,10 +48,11 @@ static void Delinearize(std::vector* multidim, // useful because cuda-memcheck can't help us much in XLA: Most of our // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. - auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); + auto* quot = + ir_builder->CreateUDiv(linear, GetConstantWithIndexType(divisor)); if (i < layout.minor_to_major_size() - 1) { (*multidim)[dimension] = ir_builder->CreateURem( - quot, ir_builder->getInt64(size_of_current_dimension)); + quot, GetConstantWithIndexType(size_of_current_dimension)); } else { (*multidim)[dimension] = quot; } @@ -65,6 +66,8 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_NE(linear, nullptr); + index_type_ = linear->getType(); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; @@ -77,6 +80,13 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + if (size()) { + index_type_ = multidim_[0]->getType(); + } else { + CHECK_NE(linear_, nullptr); + index_type_ = linear_->getType(); + } + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) @@ -88,6 +98,9 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_GT(multidim_.size(), 0); + index_type_ = multidim[0]->getType(); + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); } @@ -130,15 +143,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), - llvm::UndefValue::get(builder->getInt64Ty())); + ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = Index(tensorflow::gtl::ArraySlice( multidim_, common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second)) + common_factors[k + 1].second - common_factors[k].second), + index_type_) .Linearize( tensorflow::gtl::ArraySlice( AsInt64Slice(output_shape.dimensions()), @@ -150,9 +163,10 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64 i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { - llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i)); + llvm::Value* divisor = + GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { - source_multidim_index[i] = builder->getInt64(0); + source_multidim_index[i] = GetConstantWithIndexType(0); } else if (i == common_factors[k].first) { source_multidim_index[i] = logical_linear_index; } else { @@ -168,14 +182,14 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } - return Index(source_multidim_index); + return Index(source_multidim_index, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfSlice( const Shape& shape, tensorflow::gtl::ArraySlice starts, tensorflow::gtl::ArraySlice strides, llvm::IRBuilder<>* builder) const { - Index source_index(multidim_.size()); + Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); @@ -224,11 +238,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( // the physical index of the element in the buffer. This is like Linearize, // but takes the layout into account. int64 scale = 1; - llvm::Value* linear_index = builder->getInt64(0); + llvm::Value* linear_index = GetConstantWithIndexType(0); for (auto dimension : LayoutUtil::MinorToMajor(shape)) { linear_index = builder->CreateAdd( linear_index, - builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + builder->CreateMul(multidim_[dimension], + GetConstantWithIndexType(scale), "", /*HasNUW=*/true, /*HasNSW=*/true), "", /*HasNUW=*/true, /*HasNSW=*/true); scale *= shape.dimensions(dimension); @@ -252,7 +267,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( } if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || !LayoutUtil::HasLayout(shape)) { - return Index(source_index); + return Index(source_index, index_type_); } // High-level idea: we can reuse the linear index if the broadcasted // dimensions are contiguous, and this part of the operation is a bitcast. @@ -274,7 +289,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( bool contiguous_broadcast_dimensions = max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; if (!contiguous_broadcast_dimensions) { - return Index(source_index); + return Index(source_index, index_type_); } // Check if the mapped dimensions are a bitcast. std::vector operand_logical_to_physical = @@ -282,7 +297,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( for (int64 i = 0; i < rank; ++i) { if (operand_logical_to_physical[i] != logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { - return Index(source_index); + return Index(source_index, index_type_); } } llvm::Value* linear = linear_; @@ -291,7 +306,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } if (divisor > 1) { - linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + linear = builder->CreateUDiv( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor)); } if (min_broadcasted_dimension > 0) { int64 mod = 1; @@ -299,7 +316,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( ++i) { mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - linear = builder->CreateURem(linear, builder->getInt64(mod)); + linear = builder->CreateURem( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(mod)); } return Index(source_index, linear, operand_shape); } @@ -309,12 +328,13 @@ llvm::Value* IrArray::Index::Linearize( llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. - llvm::Value* logical_linear_index = builder->getInt64(0); + llvm::Value* logical_linear_index = GetConstantWithIndexType(0); int64 multiplier = 1; for (ssize_t i = size() - 1; i >= 0; --i) { llvm::Value* addend = - builder->CreateMul((*this)[i], builder->getInt64(multiplier), "", + builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "", /*HasNUW=*/true, /*HasNSW=*/true); + addend = builder->CreateZExtOrTrunc(addend, index_type_); logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", /*HasNUW=*/true, /*HasNSW=*/true); multiplier *= dimensions[i]; @@ -349,7 +369,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to // produce better code in some cases. auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + actual_index.push_back( + dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]); } // "base_ptr_" has the type of "*" @@ -357,7 +378,9 @@ llvm::Value* IrArray::EmitArrayElementAddress( // should be computed by // // getelementptr base_ptr_, 0, most major index, ..., most minor index - std::vector gep_indices(1, ir_builder->getInt64(0)); + CHECK_GT(index.size(), 0); + std::vector gep_indices( + 1, llvm::ConstantInt::get(index[0]->getType(), 0)); for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); @@ -410,7 +433,9 @@ IrArray IrArray::CastToShape(const Shape& new_shape, llvm::IRBuilder<>* ir_builder) { Index new_index = index; new_index[which_dimension] = ir_builder->CreateAdd( - index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true, + index[which_dimension], + llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", + /*HasNUW=*/true, /*HasNSW=*/true); return new_index; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 4c3195c29c859c9eef08e3f6531b059edbebfc47..4648c6d7ac089dbea7e660dd9889d557c8ad7318 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -53,18 +53,38 @@ class IrArray { // multidimensional index, which LLVM DCE can delete. class Index { public: - // Constructs an empty zero-dimensional index. - Index() {} - // Constructs an index of rank "size". Each dimension of the index is // initialized to "value". - explicit Index(size_t size, llvm::Value* value = nullptr) - : multidim_(size, value) {} + explicit Index(size_t size, llvm::Value* value) + : multidim_(size, value), index_type_(value->getType()) { + CHECK_NE(index_type_, nullptr); + } + + // Constructs an index of rank "size". Each dimension of the index is + // initialized to nullptr. + explicit Index(llvm::Type* index_ty, size_t size = 0) + : multidim_(size, nullptr), index_type_(index_ty) { + CHECK(index_ty->isIntegerTy()); + } // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim) - : multidim_(multidim.begin(), multidim.end()) {} + explicit Index(tensorflow::gtl::ArraySlice multidim, + llvm::Type* index_ty = nullptr) + : multidim_(multidim.begin(), multidim.end()) { + if (size() == 0) { + index_type_ = index_ty; + } else { + index_type_ = (*this)[0]->getType(); + if (index_ty != nullptr) { + CHECK_EQ(index_type_, index_ty); + } + } + CHECK_NE(index_type_, nullptr); + CHECK(c_all_of(multidim, [&](llvm::Value* v) { + return index_type_ == v->getType(); + })); + } // Constructs an index from linear index "linear" and computes the // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR @@ -154,6 +174,15 @@ class IrArray { llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const; + llvm::Type* GetType() const { return index_type_; } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + // The LLVM function makes sure that the value can be represented by the + // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const + // APInt &V). + return llvm::ConstantInt::get(index_type_, c); + } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& multidim() { @@ -161,6 +190,9 @@ class IrArray { return multidim_; } + void Delinearize(std::vector* multidim, llvm::Value* linear, + const Shape& shape, llvm::IRBuilder<>* ir_builder) const; + std::vector multidim_; // These values are purely for efficiency; `multidim_` is enough to find the @@ -177,6 +209,8 @@ class IrArray { llvm::Value* linear_ = nullptr; Layout layout_; std::vector dims_; + + llvm::Type* index_type_; }; // Default constructor. Constructs an IrArray in a null status. diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6..1f6e3c829f890d68aa251b101f0402c120a19d61 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -15,53 +15,57 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, - const std::function& for_body_generator) { - If(ir_builder_->CreateICmpSLT(start, end), [&]() { - for_body_generator(start, /*is_first_iteration=*/true); - For(name, ir_builder_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { for_body_generator(iv, false); }); + const std::function& for_body_generator) { + return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + return For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, - const std::function& for_body_generator) { + const std::function& + for_body_generator) { if (peel_first_iteration) { - For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) { - for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); - }); + return For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator( + indvar, ir_builder_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, ir_builder_, - /*prevent_unrolling=*/prevent_unrolling_, + /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( - loop->GetIndVarValue(), start)); + TF_RETURN_IF_ERROR( + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start))); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + return Status::OK(); } } -void KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, - const std::function& false_block_generator) { +Status KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, "", ir_builder_); ir_builder_->SetInsertPoint(&if_data.true_block->back()); - true_block_generator(); + TF_RETURN_IF_ERROR(true_block_generator()); ir_builder_->SetInsertPoint(&if_data.false_block->back()); - false_block_generator(); + TF_RETURN_IF_ERROR(false_block_generator()); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 1c00b2aabd182da72e78d2c9c01cbe70cfd8e33c..6f7a9d94e3b9e59b2dfe12b9673335a904ae78b6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,13 +31,14 @@ namespace xla { class KernelSupportLibrary { public: // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. - // If `prevent_unrolling` is true then unrolling is explicitly disabled on - // every loop generated by this instance of KernelSupportLibrary. - explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = true, - bool prevent_vectorization = true) + // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop + // generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary( + llvm::IRBuilder<>* ir_builder, + llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, + bool prevent_vectorization = true) : ir_builder_(ir_builder), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} // Generates the following control flow structure: @@ -46,19 +48,41 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator); + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& - for_body_generator); + for_body_generator) { + CHECK_EQ(Status::OK(), + For(name, start, end, step, + [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -75,37 +99,102 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); - - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, llvm::Value* step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(For( + name, start, end, step, peel_first_iteration, + [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + return For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); + } + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + ForReturnVoid(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + return For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); + } + + void ForReturnVoid( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + ForReturnVoid(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + for_body_generator); } - void For( + Status For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -114,9 +203,25 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - void If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() {}); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }); + + void IfReturnVoid(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); + } using ArgumentVector = tensorflow::gtl::ArraySlice; @@ -174,7 +279,7 @@ class KernelSupportLibrary { private: llvm::IRBuilder<>* ir_builder_; - bool prevent_unrolling_; + llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 497b48ff227d7d1f158080529372df44b6932b24..c9ae7d3afd5cdc21157732f6d0dfa824268e86bd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,7 +34,7 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, + llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) : prefix_(std::string(prefix)), suffix_(std::string(suffix)), @@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { + UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, + end_index, step, unroll_mode, prevent_vectorization)); loop->Emit(ir_builder); return loop; @@ -97,7 +97,7 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); llvm::Value* indvar_address = - ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr, + ir_builder->CreateAlloca(start_index_->getType(), nullptr, AsStringRef(GetQualifiedName("invar_address"))); // Preheader basic block. @@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { std::vector ForLoop::GetLoopMetadata( llvm::IRBuilder<>* ir_builder) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; llvm::LLVMContext* ctx = &start_index_->getContext(); std::vector result; - if (prevent_unrolling_) { + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); } @@ -162,6 +163,10 @@ std::vector ForLoop::GetLoopMetadata( llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); } + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)})); + } return result; } @@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { - return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,23 +220,23 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), + GetConstantWithIndexType(stride), unroll_mode, prevent_vectorization); } @@ -245,7 +250,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice dimensions, tensorflow::StringPiece suffix) { - llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr); + llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index d915f95db134918a173a9711936bb1e2f1ea0d95..0dd5b9d3b2656af68f76c2adfcb1f3a1385eeb91 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -34,6 +34,12 @@ limitations under the License. namespace xla { namespace llvm_ir { +enum class UnrollMode { + kDefaultUnroll, + kFullyUnroll, + kNoUnroll, +}; + // A class for constructing a for-loop in LLVM IR. class ForLoop { public: @@ -69,12 +75,13 @@ class ForLoop { // LLVM IR. If non-empty, it is prepended to the name of the induction // variable value and each basic block created for the loop. // - // If `prevent_unrolling` is true then emit metadata that directs LLVM to not - // unroll the generated loop. + // `unroll_mode` specifies the desired LLVM unrolling behavior for generated + // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false, bool prevent_vectorization = false); + UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -128,7 +135,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling, bool prevent_vectorization); + UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -161,7 +168,7 @@ class ForLoop { llvm::BasicBlock* body_bb_; llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; - bool prevent_unrolling_; + UnrollMode unroll_mode_; bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); @@ -170,46 +177,52 @@ class ForLoop { // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* ir_builder) - : ForLoopNest(/*name=*/"", ir_builder) {} + explicit ForLoopNest(llvm::IRBuilder<>* ir_builder, + llvm::Type* index_ty = nullptr) + : ForLoopNest(/*name=*/"", ir_builder) { + SetIndexType(index_ty); + } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) + ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder, + llvm::Type* index_ty = nullptr) : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), - ir_builder_(ir_builder) {} + ir_builder_(ir_builder) { + SetIndexType(index_ty); + } // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have - // been added then emit loop inside the body of the last added loop. If - // prevent_unrolling is true, then metadata is emitting directing LLVM to not - // unroll this loop. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + // been added then emit loop inside the body of the last added loop. + // unroll_mode is used to emit metadata that controls LLVM unrolling. + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, int64 stride, + tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the @@ -245,6 +258,14 @@ class ForLoopNest { llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; } private: + void SetIndexType(llvm::Type* index_ty) { + index_type_ = index_ty == nullptr ? ir_builder_->getInt64Ty() : index_ty; + } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + return llvm::ConstantInt::get(index_type_, c); + } + // Human-friendly name of the loop nest. string name_; @@ -259,6 +280,8 @@ class ForLoopNest { llvm::IRBuilder<>* ir_builder_; + llvm::Type* index_type_; + TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f9112134ba876fdfbb3905a3baf1f72..97bacc34b59118e60100e4749638d469a1ef1378 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -87,18 +88,10 @@ llvm::Value* EmitCallToIntrinsic( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, llvm::IRBuilder<>* ir_builder) { - std::vector types; - for (auto type : overloaded_types) { - types.push_back(type); - } llvm::Module* module = ModuleFromIRBuilder(ir_builder); - llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); - std::vector operands_vec; - for (auto operand : operands) { - operands_vec.push_back(operand); - } - return ir_builder->CreateCall(intrinsic, operands_vec); + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, intrinsic_id, AsArrayRef(overloaded_types)); + return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -201,6 +194,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // An Opaque is like a void*, use i8*. case OPAQUE: return llvm::Type::getInt8PtrTy(module->getContext()); + case TOKEN: + // Tokens do not have a physical representation, but the compiler needs + // some placeholder type, so use int8*. + return llvm::Type::getInt8PtrTy(module->getContext()); default: LOG(FATAL) << "unsupported type " << element_type; } @@ -253,130 +250,14 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, return shape; } -namespace { - -// Recursively construct a multidimensional LLVM constant which represents the -// given literal. The minor-to-major dimension ordering in the constant matches -// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 -// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a -// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type -// [4 x [3 x [2 x float]] will be returned. -// -// multi_index is a multidimensional index into the array. dimension_index is an -// index into the minor_to_major field in the literal shape. This determines -// which dimension is iterated over in this level of the recursion. Dimensions -// are iterated from most major down to most minor (highest dimension_index -// value down to zero). -llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, - std::vector* multi_index, - llvm::Module* module) { - const Shape& shape = literal.shape(); - llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); - if (dimension_index == -1) { - // Base case of the recursion. Index into the data field of the protobuf - // with the multi index. - llvm::Constant* value; - switch (shape.element_type()) { - case PRED: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U8: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case F32: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case BF16: - value = llvm::ConstantInt::get( - ir_element_type, - tensorflow::bit_cast(literal.Get(*multi_index))); - break; - case F16: - value = llvm::ConstantFP::get( - ir_element_type, - static_cast(literal.Get(*multi_index))); - break; - case F64: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case C64: { - complex64 x = literal.Get(*multi_index); - value = llvm::ConstantStruct::get( - static_cast(ir_element_type), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.real()), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.imag())); - break; - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } - return value; - } - - // The dimension index starts at the one less than the rank of the array and - // decrements with each recursive call. We want to iterate through the - // dimensions in major-to-minor order as we recurse so just index into - // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); - - // Recursively call LiteralToConstant to construct subarrays for the - // more-minor dimensions. Gather the subarrays into a vector for bundling into - // a new (higher-dimensional) ConstantArray. - std::vector elements; - for (int64 i = 0; i < shape.dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - elements.push_back( - LiteralToConstant(literal, dimension_index - 1, multi_index, module)); - } - - llvm::Type* element_type; - if (elements.empty()) { - element_type = ir_element_type; - for (int i = 0; i < dimension_index; ++i) { - int64 index = LayoutUtil::Minor(shape.layout(), i); - element_type = - llvm::ArrayType::get(element_type, shape.dimensions(index)); - } - } else { - element_type = elements[0]->getType(); - } - llvm::ArrayType* aggregate_type = - llvm::ArrayType::get(element_type, shape.dimensions(dimension)); - return llvm::ConstantArray::get(aggregate_type, elements); -} - -} // namespace - llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const char* data = static_cast(literal.untyped_data()); + CHECK_EQ(module->getDataLayout().isLittleEndian(), + tensorflow::port::kLittleEndian); + return llvm::ConstantDataArray::getString( + module->getContext(), llvm::StringRef(data, literal.size_bytes()), + /*AddNull=*/false); } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 0728ccfff7b85e3751f33bc5272a5f22d4e5411a..e8b0605b9d75677b34f0973d88d269a5795b7629 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -83,16 +83,19 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { - CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())) + << ": '" << shape_.ShortDebugString() << "' does not match '" + << array.GetShape().ShortDebugString() << "'"; } } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; - return {IrArray::Index()}; + return {IrArray::Index(index_type)}; } // Create loop nest with one for-loop for each dimension of the target shape. @@ -100,7 +103,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). ForLoopNest loop_nest(loop_name, ir_builder_); - IrArray::Index array_index(shape_.dimensions_size()); + IrArray::Index array_index(index_type, shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( @@ -123,9 +126,14 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = ir_builder_->getInt64Ty(); + } + for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index b70d28ecd3033eb26629718e50ce48f39b162273..6be1c2fba2cbd78a02865901ef8c5b7e2b2a74e6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -65,13 +65,16 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/""); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", + ir_builder_->getInt64Ty()); } + virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name); + tensorflow::StringPiece loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = "", + llvm::Type* index_type = nullptr); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 34899b7400464e4f4f97d301f35ed3b7b083bca1..3b298f4746d6177da52ba0227705d07fbeba5c19 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -45,26 +45,45 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Read start indices from start_indices_generator. const int64 rank = ShapeUtil::Rank(output_shape); - IrArray::Index start_index(rank); + IrArray::Index start_index(ir_builder->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), update_shape.dimensions(i)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + llvm::Value* max_bound = + ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, + start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // output_index[dim] = start_index[dim] + update_index[dim] // - IrArray::Index output_index(rank); + IrArray::Index output_index(start_index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - update_index[i]->getType(), output_shape.dimensions(i)); - llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateURem( - ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 3a21eda35757aa706565ee4a5286eee1acea117b..5fc08aab916e377b245b6221108956c06da70767 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -24,15 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -47,30 +46,27 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; llvm::Value* on_true_element_address = ir_builder->CreateInBoundsGEP(on_true, element_index); llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = ir_builder->CreateInBoundsGEP(on_false, element_index); llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), output_element_address); } } -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index dbf9a140068b60505f6798360438f709bfd3feba..352d34ebf839c6c2465abade7c3d3eb3b7a34506 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -59,13 +59,13 @@ namespace llvm_ir { // of the address from the corresponding element in either // tuple_on_true or tuple_on_false: // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module); +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0fa4061738612df76c72a18a9353f16bf6a42677..53efc30c3653879709fceae3dcdd4f679740f622 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -24,15 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -110,6 +107,11 @@ ExecutionOptions CreateExecutionOptions( ->set_xla_dump_optimized_hlo_proto_to( build_options.dump_optimized_hlo_proto_to().value()); } + if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_unoptimized_hlo_proto_to( + build_options.dump_unoptimized_hlo_proto_to().value()); + } if (build_options.dump_per_pass_hlo_proto_to().has_value()) { execution_options.mutable_debug_options() ->set_xla_dump_per_pass_hlo_proto_to( @@ -124,75 +126,17 @@ ExecutionOptions CreateExecutionOptions( LayoutUtil::SetToDefaultLayout( execution_options.mutable_shape_with_output_layout()); } - return execution_options; -} - -} // namespace - -StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { - return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); - } - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { - tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); - auto metadata_string = [&metadata]() -> string { - if (!metadata.has_value()) { - return ""; - } - CHECK(metadata.value() != nullptr); - const OpMetadata& m = *metadata.value(); - if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); - } - return ""; - }; - return InvalidArgument( - "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); + for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { + execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( + disabled_pass); } - ExecutionOptions execution_options = - CreateExecutionOptions(build_options, program_shape.get()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(build_options.device_ordinal())); - - return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); + return execution_options; } +} // namespace + StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -210,7 +154,8 @@ StatusOr> LocalService::CompileExecutable( for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { tensorflow::gtl::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); @@ -234,8 +179,8 @@ StatusOr> LocalService::CompileExecutable( } } if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape.result())); + TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(), + program_shape.result())); } ExecutionOptions execution_options = @@ -245,6 +190,9 @@ StatusOr> LocalService::CompileExecutable( std::unique_ptr module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options)); + VLOG(3) << "Computation Layout: " + << module_config->entry_computation_layout().ToString(); + TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); @@ -260,4 +208,15 @@ StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { /*computation_count=*/1); } +StatusOr LocalService::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + if (replica_number >= buffers.size()) { + return InvalidArgument( + "replica_number %d out of range; must be less than num_replicas = %zu.", + replica_number, buffers.size()); + } + return buffers[replica_number]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 06567cabd6eb28aae53881613cd6beb78e25e222..39d6734c3fc06df6832cf67edddbc7c14c815cd1 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -41,23 +41,11 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -70,6 +58,11 @@ class LocalService : public Service { // the "easy" case where a single replica is a single device. StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 6aca6ba38572c5311797fbb91acbbcd6610a3410..d631fb5ee42df6525681a5cd1fe1a8241824121d 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -125,18 +125,29 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { - // RecvDone doesn't create a new buffer but rather aliases its input (Recv) - // tuple element at {0} to its output. +Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand. + return Status::OK(); +} + +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone produces a two-element tuple containing the data value (which + // aliases part of its operand) and a token. Only the tuple index table and + // the token are defined by the RecvDone. + NewLogicalBuffer(recv_done, /*index=*/{}); + NewLogicalBuffer(recv_done, /*index=*/{1}); return Status::OK(); } Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { - // Send creates new buffers for the top-level tuple and the context (tuple - // element at {1}). Tuple element at {0} is an alias of the Send operand, so - // we don't need to create a new Logical Buffer for that. + // Send creates new buffers for the top-level tuple, the context (tuple + // element at {1}), and the token (tuple element at {2}). Tuple element at {0} + // is an alias of the Send operand, so we don't need to create a new Logical + // Buffer for that. NewLogicalBuffer(send, /*index=*/{}); NewLogicalBuffer(send, /*index=*/{1}); + NewLogicalBuffer(send, /*index=*/{2}); return Status::OK(); } @@ -146,10 +157,10 @@ Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleSelect(HloInstruction* select) { +Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. - NewLogicalBuffer(select, /*index=*/{}); + NewLogicalBuffer(tuple_select, /*index=*/{}); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index f4c63dd86b4d8a6f598d46047012e4e5bc7b3d7e..81f524d84a8091e1fff13dc7c55b401143a02753 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -59,10 +59,11 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; - Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..4166ef5baf9c891968b584a0c498005e9ae87784 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -0,0 +1,338 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr MultiOutputFusion::Run(HloModule* module) { + bool changed = false; + + for (auto* computation : module->MakeNonfusionComputations()) { + computation_ = computation; + RecomputeReachability(); + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + + int64 index = 0; + for (auto it : computation_->MakeInstructionPostOrder()) { + candidates_.emplace_back(it); + InsertOrDie(&candidates_index_, it, index++); + } + + // Create the initial candidate list for each Node. + for (auto& node : candidates_) { + HloInstruction* instruction = node.hlo; + int64 instruction_id = get_candidate_id(instruction); + FusionCandidate& instr_node = candidates_[instruction_id]; + if (!IsFusible(instruction)) { + continue; + } + all_fusion_candidates_.push_back(instruction); + + std::vector candidates; + tensorflow::gtl::FlatSet candidates_set; + VLOG(10) << "Looking at instruction: " << instruction->name(); + for (auto operand : instruction->operands()) { + // Filter out the non-interesting instructions -- they + // will not generate the savings. + if (!IsProfitableOperand(operand)) { + VLOG(10) << "Operand not profitable: " << operand->name(); + continue; + } + VLOG(10) << "Operand profitable: " << operand->name(); + for (auto user : operand->users()) { + VLOG(10) << "User: " << user->name(); + if (user == instruction || !IsFusible(user)) { + VLOG(10) << "User is not fusible, or is the instruction itself: " + << user->name(); + continue; + } + int64 user_id = get_candidate_id(user); + if (is_connected(instruction, user)) { + VLOG(10) << "User is connected: " << user->name(); + continue; + } + if (instruction_id < user_id && + user->opcode() == HloOpcode::kFusion) { + VLOG(10) << "User ID for user: " << user->name() << " is " + << user_id << " which is higher than " << instruction_id; + continue; + } + if (!LegalToFuse(instruction, user)) { + VLOG(10) << "User not legal to fuse: " << user->name(); + continue; + } + if (candidates_set.insert(user).second) { + VLOG(10) << "User added to candidate list: " << user->name(); + candidates.push_back(user); + } + } + } + + // Iterate over candidates rather than candidates_set to avoid + // nondeterminism. + for (auto candidate : candidates) { + int64 profit = GetProfit(instruction, candidate); + if (profit > 0) { + FusionCandidate& candidate_node = + candidates_[get_candidate_id(candidate)]; + instr_node.fusibles.emplace_back(candidate, profit); + candidate_node.fusibles.emplace_back(instruction, profit); + worklist_.emplace(instruction, candidate, profit); + } + } + } + if (Perform()) { + changed = true; + } + } + return changed; +} + +HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, + HloInstruction* instr2) { + HloInstruction* remaining = instr1; + HloInstruction* fused = instr2; + // Make sure that if only one of the instructions is a fusion, or if only one + // of the instructions is a multi-output fusion, it's what will be fused into. + if (fused->opcode() == HloOpcode::kFusion) { + std::swap(remaining, fused); + } + if (fused->IsMultiOutputFusion()) { + std::swap(remaining, fused); + } + + if (fused->opcode() == HloOpcode::kFusion) { + remaining->MergeFusionInstructionIntoMultiOutput(fused); + } else { + remaining->FuseInstructionIntoMultiOutput(fused); + } + return remaining; +} + +bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; + FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + // Update the fusible list for fusion. Variable new_fusibles keeps + // track of the new or changed entries. + std::vector> new_fusibles; + tensorflow::gtl::FlatSet in_list; + auto it = fusion_node.fusibles.begin(); + while (it != fusion_node.fusibles.end()) { + HloInstruction* instr = it->first; + if (is_fused(instr) || is_connected(fusion, instr)) { + it = fusion_node.fusibles.erase(it); + continue; + } + in_list.insert(instr); + int64 profit = GetProfit(instr, fusion); + if (profit > it->second) { + it->second = profit; + new_fusibles.emplace_back(instr, profit); + } + ++it; + } + + // Fused_node has been fused into fusion_node. Take the fusion candidates + // (fusibles) from fused_nodes and add them to the fusion_node's. Filter + // out those fusibles that no longer valid (or already in the list). + for (const auto& it : fused_node.fusibles) { + HloInstruction* instr = it.first; + if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { + continue; + } + if (in_list.count(instr) > 0) { + continue; + } + int64 profit = GetProfit(instr, fusion); + fusion_node.fusibles.emplace_back(instr, profit); + new_fusibles.emplace_back(instr, profit); + } + fused_node.fusibles.clear(); + + // Update the worklist_. + for (auto it : new_fusibles) { + worklist_.emplace(fusion, it.first, it.second); + } +} + +bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { + return false; + } + if (instr1->opcode() != HloOpcode::kFusion) { + return false; + } + + // Fusing nodes with 0 user makes no sense and the rest of the implementation + // doesn't support it either. + if (instr1->user_count() == 0 || instr2->user_count() == 0) { + return false; + } + + // Check if the users of multioutput fusion is not a get-tuple-element. + // If this is the case, we bail out because the transformation assumes + // the users are get-tuple-element. + auto multioutput_user_is_not_gte = [](HloInstruction* instr) { + if (!instr->IsMultiOutputFusion()) { + return false; + } + for (auto user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return true; + } + } + return false; + }; + if (multioutput_user_is_not_gte(instr1) || + multioutput_user_is_not_gte(instr2)) { + return false; + } + + if (is_connected(instr1, instr2)) { + return false; + } + if (!ShapesCompatibleForFusion(instr1, instr2)) { + return false; + } + + return true; +} + +void MultiOutputFusion::RecomputeReachability() { + reachability_ = computation_->ComputeReachability(); +} + +void MultiOutputFusion::UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip) { + for (auto instr : instrs_to_update) { + if (skip != nullptr && skip(instr)) { + continue; + } + if (reachability_->IsReachable(instr2, instr) && + reachability_->IsReachable(instr1, instr)) { + // If a candidate was already reachable by both, no update needed. + continue; + } + if (reachability_->IsReachable(instr2, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); + } + if (reachability_->IsReachable(instr1, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); + } + } +} + +bool MultiOutputFusion::Perform() { + int changed = false; + // Pick the top candidate from queue and try to merge. + while (!worklist_.empty()) { + if (fuel_ <= 0) { + VLOG(2) << "No fusing: run out of fuel."; + break; + } + ToBeFused candidate = worklist_.top(); + worklist_.pop(); + + HloInstruction* instr1 = candidate.instr1; + HloInstruction* instr2 = candidate.instr2; + + if (is_fused(instr1) || is_fused(instr2)) { + continue; + } + + VLOG(1) << "Considering candidate profit_score=" << candidate.score + << "\n\t\tinstr1 = " << instr1->ToString() + << "\n\t\tinstr2 = " << instr2->ToString(); + + if (LegalToFuse(instr1, instr2)) { + VLOG(1) << "Fuse!"; + VLOG(2) << "Before multi_output_fusion:"; + VLOG(2) << "instr1: " << instr1->ToString(); + VLOG(2) << "\n" + << instr1->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + VLOG(2) << "instr2: " << instr2->ToString(); + if (instr2->opcode() == HloOpcode::kFusion) { + VLOG(2) << "\n" + << instr2->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + } + HloInstruction* ret = Fuse(instr1, instr2); + set_is_fused(ret == instr1 ? instr2 : instr1); + Update(instr1, instr2); + changed = true; + VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" + << ret->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + auto users = ret->users(); + --fuel_; + } + } + if (DoProducerConsumerMultiOutputFusion()) { + changed = true; + } + return changed; +} + +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; } +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..0019cd725417d81900974b462c3b05075ce3e893 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// This class implements the fusing of sibling fusion instructions that sharing +// common operands. +// It constructs the following associated data structures. +// (1) candidates_: stores the instruction and the set of instructions it can +// fuse to. +// (2) candidates_index_: maps instruction to id. +// (3) reachability_: reachability map in this computation. +// (4) all_fusion_candidates_: the vector of candidate instructions. +// (5) worklist_: a priority queue that contains pairs of instructions to be +// fused and their fusion profit scores. +// +// Function Perform() applies the optimization. It picks up the most profitable +// pair in the worklist_, check if it's legal to fuse and fuse the pair. +// After fusion, it updates the associated structure such as reachability_, +// candidates_ and worklist_. +// Note that the reachability map is updated based on the original computation. +// This works because the reachability is monotonically increasing with +// instruction fusion. +class MultiOutputFusion : public HloPassInterface { + public: + MultiOutputFusion(int64 fuel) : fuel_(fuel) {} + + tensorflow::StringPiece name() const override { + return "multi_output_fusion"; + } + + // Run multi-output fusion on the given module. Returns whether the module + // was changed. + StatusOr Run(HloModule* module) override; + + protected: + // Main entry for the optimization. Returns true if the optimization happens. + bool Perform(); + + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + virtual bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) = 0; + + // Whether the instruction is a candidate for fusion. + virtual bool IsFusible(HloInstruction* instr) = 0; + + // This function estimates the savings by merging instr1 and instr2 into one + // multi-output fusion instruction. + virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; + + // Whether fusing the instruction can reduce memory reads. + virtual bool IsProfitableOperand(HloInstruction* instr); + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Recompute reachability for the current computation. + void RecomputeReachability(); + + // Returns the reachability map for the current computation. + HloReachabilityMap* reachability() const { return reachability_.get(); } + + // Returns the computation for the pass. + HloComputation* computation() const { return computation_; } + + // Update the reachability map after fusing instr1 and instr2. + void UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip = nullptr); + + // Hook for multi-output fusion along producer-consumer edges. + // Returns whether any instructions were fused. + // + // TODO(b/80420762): Perform producer-consumer multi-output fusion in + // InstructionFusion instead. + virtual bool DoProducerConsumerMultiOutputFusion(); + + private: + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + + // Optimization fuel is a compiler debugging technique that makes an + // optimization pass stop what it is doing after having made N changes to the + // program, where N is the fuel. By varying N, this can be used to find the + // first single change that makes a test fail. + int64 fuel_; + + // Computation for the pass. + HloComputation* computation_; + + // An internal data structure for each instruction in current computation. + // When an instruction is removed, member 'hlo' is set to nullptr. + struct FusionCandidate { + HloInstruction* hlo; + std::list> fusibles; + explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} + }; + std::vector candidates_; + + // A map that maps an instruction to the index_. + tensorflow::gtl::FlatMap candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // The pair of candidates to be fused and the profit score. + struct ToBeFused { + HloInstruction* instr1; + HloInstruction* instr2; + int64 score; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) + : instr1(instr1), instr2(instr2), score(score) {} + bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + }; + std::priority_queue worklist_; + + int64 get_candidate_id(HloInstruction* instr) { + return FindOrDie(candidates_index_, instr); + } + + bool is_fused(HloInstruction* instr) { + return candidates_[get_candidate_id(instr)].hlo == nullptr; + } + + void set_is_fused(HloInstruction* instr) { + candidates_[get_candidate_id(instr)].hlo = nullptr; + } + + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 3a6a7c25f4b727c7112dbcbcb4f3d892679a0011..f6e7578a89551ec2f23d4d8c8b488c3c10e0bf1c 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -67,22 +67,17 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); - // Update count to at least the numeric suffix value to avoid future - // colisions with this name. - generated_names_[root] = std::max(generated_names_[root], numeric_suffix); } } - int64* count = &(generated_names_[root]); - if (*count == 0) { - *count = 1; + + SequentialIdGenerator& id_generator = generated_names_[root]; + numeric_suffix = id_generator.RegisterId(numeric_suffix); + if (numeric_suffix == 0) { return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) : root; - } else { - tensorflow::strings::StrAppend(&root, separator_, *count); - // Increment lookup under old 'root' name. - (*count)++; - return root; } + tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + return root; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4139c2700b25e8600182a034a8ac6f4f041c12e6..4423d6106920eaeab830bd9dc08529ff409a5161 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ #include -#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -44,13 +45,40 @@ class NameUniquer { static string GetSanitizedName(const string& name); private: + // Used to track and generate new identifiers for the same instruction name + // root. + class SequentialIdGenerator { + public: + SequentialIdGenerator() = default; + + // Tries to register id as used identifier. If id is not already used, the + // id itself will be returned. Otherwise a new one will be generated, and + // returned. + int64 RegisterId(int64 id) { + if (used_.insert(id).second) { + return id; + } + while (!used_.insert(next_).second) { + ++next_; + } + return next_++; + } + + private: + // The next identifier to be tried. + int64 next_ = 0; + + // Set of all the identifiers which has been used. + tensorflow::gtl::FlatSet used_; + }; + // The string to use to separate the prefix of the name from the uniquing // integer value. string separator_; - // Map from name prefix to the number of names generated using that prefix - // so far. - std::unordered_map generated_names_; + // Map from name prefix to the generator data structure which tracks used + // identifiers and generates new ones. + tensorflow::gtl::FlatMap generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 2ec255558c4ed3695ec6c824458cbedac44dc297..3e2592c6ac626143f1421e545a31d9be91e376bc 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -54,12 +54,13 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); - EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000")); - EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); - EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + EXPECT_EQ("foo.55.0", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar.1000", uniquer.GetUniqueName("bar.1000")); + EXPECT_EQ("bar.2000", uniquer.GetUniqueName("bar.2000")); + EXPECT_EQ("bar.-2000", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.1")); } TEST_F(NameUniquerTest, PrefixHasSuffix) { @@ -77,12 +78,12 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); - EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_2", uniquer.GetUniqueName("foo")); // Invalid characters will be replaced with '_'. - EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000")); - EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); - EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); + EXPECT_EQ("bar_1000", uniquer.GetUniqueName("bar<1000")); + EXPECT_EQ("bar_2000", uniquer.GetUniqueName("bar<2000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. EXPECT_EQ("_10", uniquer.GetUniqueName( @@ -93,5 +94,15 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } +TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11")); + EXPECT_EQ("foo.10", uniquer.GetUniqueName("foo.10")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo.1")); + EXPECT_EQ("foo.12", uniquer.GetUniqueName("foo.12")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index d3bc47e61e0e75fa2ef181988700f88cec9c1d76..ac6ea4c72f61a47726b3ae7dd000837d3fba1b93 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -86,8 +86,8 @@ namespace xla { // are provided below. // // Example nullary instruction: -// Recv() == Op().WithOpcode(HloOpcode::kRecv) -// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv) +// Param() == Op().WithOpcode(HloOpcode::kParam) +// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) // // Example unary instruction: // Abs() == Op().WithOpcode(HloOpcode::kAbs) @@ -204,7 +204,7 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr LayoutPattern> EqualTo( - const Layout* layout) const { + const ::xla::Layout* layout) const { return LayoutPattern>( LayoutPatternEqualImpl(impl_, layout), matched_layout_); } @@ -726,6 +726,32 @@ class HloInstructionPatternFusionKindImpl { ::xla::HloInstruction::FusionKind kind_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a kGetTupleElement with a particular tuple index. +template +class HloInstructionPatternTupleIndexImpl { + public: + explicit constexpr HloInstructionPatternTupleIndexImpl( + const Previous& previous, int64 tuple_index) + : previous_(previous), tuple_index_(tuple_index) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && + inst->opcode() == HloOpcode::kGetTupleElement && + inst->tuple_index() == tuple_index_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && + inst->opcode() == HloOpcode::kGetTupleElement && + inst->tuple_index() == tuple_index_; + } + + private: + Previous previous_; + int64 tuple_index_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -841,6 +867,17 @@ class HloInstructionPattern { HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); } + // Modifies the pattern to match only if the instruction is a + // get-tuple-element with the given tuple index. + constexpr HloInstructionPattern> + WithTupleIndex(int64 tuple_index) const { + return HloInstructionPattern>( + HloInstructionPatternTupleIndexImpl(impl_, tuple_index), + matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; @@ -880,9 +917,7 @@ Op(::xla::HloInstruction** matched_inst) { return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ } XLA_NULLOP_PATTERN(Constant) -XLA_NULLOP_PATTERN(Infeed) XLA_NULLOP_PATTERN(Parameter) -XLA_NULLOP_PATTERN(Recv) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -919,18 +954,21 @@ XLA_UNOP_PATTERN(Cos) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) +XLA_UNOP_PATTERN(GetTupleElement) XLA_UNOP_PATTERN(Imag) +XLA_UNOP_PATTERN(Infeed) XLA_UNOP_PATTERN(IsFinite) XLA_UNOP_PATTERN(Log) XLA_UNOP_PATTERN(Not) XLA_UNOP_PATTERN(Negate) -XLA_UNOP_PATTERN(Outfeed) XLA_UNOP_PATTERN(Real) +XLA_UNOP_PATTERN(Recv) +XLA_UNOP_PATTERN(RecvDone) XLA_UNOP_PATTERN(Reduce) XLA_UNOP_PATTERN(ReducePrecision) XLA_UNOP_PATTERN(Reshape) XLA_UNOP_PATTERN(Reverse) -XLA_UNOP_PATTERN(Send) +XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Sort) @@ -981,8 +1019,10 @@ XLA_BINOP_PATTERN(Maximum) XLA_BINOP_PATTERN(Minimum) XLA_BINOP_PATTERN(Multiply) XLA_BINOP_PATTERN(Ne) +XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) +XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) XLA_BINOP_PATTERN(And) XLA_BINOP_PATTERN(Or) @@ -1040,6 +1080,32 @@ inline auto NonConstant(HloInstructionType** matched_inst) return Op(matched_inst).IsNonConstant(); } +// Add overloads for GetTupleElement which take a int64 specifying which tuple +// element is selected. +template +inline auto GetTupleElement(Arg&& arg, int64 tuple_index) + -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index)) { + return Op() + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index); +} + +template +inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, + int64 tuple_index) + -> decltype(Op(matched_inst) + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index)) { + return Op(matched_inst) + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 204e8c99209fa95adb868a676bb9e5144fed432c..a530581c34bf1d699eae3c53203c197f7943cc53 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) { ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); const HloInstruction* matched_inst; HloInstruction* matched_operand; @@ -182,7 +182,7 @@ TEST(PatternMatcherTest, FusionKind) { p0 = f32[] parameter(0) ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE(Match( @@ -193,5 +193,23 @@ TEST(PatternMatcherTest, FusionKind) { HloInstruction::FusionKind::kLoop))); } +TEST(PatternMatcherTest, GetTupleElement) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + ENTRY while.v11 { + p0 = (f32[], f32[], f32[]) parameter(0) + ROOT gte = f32[] get-tuple-element(p0), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0))); + EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1))); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2))); + EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0))); + EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0f26a025bf125f70199637894741540f89eae7e5..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -155,20 +155,15 @@ HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, case HloOpcode::kConstant: { if (first_reshape_operand->opcode() == HloOpcode::kReshape) { VLOG(5) << "Adding reshape to kConstant operand"; - HloInstruction* reshape = computation->AddInstruction( + return computation->AddInstruction( HloInstruction::CreateReshape(new_shape, operand)); - operand->SetupDerivedInstruction(reshape); - return reshape; } else { CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); VLOG(5) << "Adding transpose to kConstant operand"; std::vector inverse_permutation = InversePermutation(first_reshape_operand->dimensions()); - HloInstruction* transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - new_shape, operand, inverse_permutation)); - operand->SetupDerivedInstruction(transpose); - return transpose; + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); } } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 047cadb3d9d5a312f13c1fd03e9574177a7050c9..da3b622bfae8ac5132f9f95070ee41674e79b5b8 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -62,55 +61,28 @@ namespace xla { namespace { -// Records the arguments used to invoke a computation in a SessionModule -// proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, - SessionModule* module) { - module->clear_arguments(); - for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); - *module->add_arguments() = literal->ToProto(); - } - return Status::OK(); -} - -// Records the result of a computation in a SessionModule proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, - TransferManager* transfer_manager, SessionModule* module) { - module->clear_result(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); - *module->mutable_result() = literal->ToProto(); - return Status::OK(); -} - // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, + se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); + transfer_manager->TransferLiteralFromDevice(stream, *argument)); *module->add_arguments() = literal->ToProto(); } return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, +Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); + transfer_manager->TransferLiteralFromDevice(stream, result)); *module->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -195,20 +167,6 @@ Service::Service(const ServiceOptions& options, } } -Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return Status::OK(); -} - Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); @@ -233,21 +191,17 @@ Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, return Status::OK(); } -Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const { - if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { +Status Service::ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const { + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + if (!ShapeUtil::Compatible(client_shape, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(client_shape).c_str(), ShapeUtil::HumanString(result_shape).c_str()); } - if (!LayoutUtil::HasLayout(shape_with_layout)) { - return InvalidArgument( - "Shape used to set computation result layout %s does not have layout", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); - } - return ShapeUtil::ValidateShape(shape_with_layout); + return Status::OK(); } StatusOr>> @@ -288,13 +242,10 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - ComputationLayout* host_computation_layout = - config->mutable_host_entry_computation_layout(); - ComputationLayout* device_computation_layout = - config->mutable_device_entry_computation_layout(); + ComputationLayout* computation_layout = + config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -305,43 +256,28 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - if (user_computation == nullptr) { - return InvalidArgument( - "Argument does not match shape of computation parameter %d: want " - "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); - } - return InvalidParameterArgument( - *user_computation->ParameterMetadata(i).value(), - "Argument does not match shape of computation parameter %d: want %s, " - "got %s", + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); - TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + *argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { const auto& shape_with_output_layout = execution_options->shape_with_output_layout(); - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, - program_shape.result())); TF_RETURN_IF_ERROR( - host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( - shape_with_output_layout)); + ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { // If the result layout is not set, then choose the default. - // TODO(b/29118294): Allow the compiler to choose a better layout in this - // case. - host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); - device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -363,76 +299,12 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation) { + const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options, - user_computation); -} - -StatusOr>> Service::BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); - - // Dump computation proto state if flag is set. - std::vector> session_modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const string& directory_path = - module_configs[i]->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && other_directory_path.empty()) { - continue; - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); - } - } - - VLOG(1) << "Computation handles:"; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle; - } - - CHECK_EQ(versioned_handles.size(), module_configs.size()); - std::vector> modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const VersionedComputationHandle& versioned_handle = versioned_handles[i]; - const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - computation_tracker_.BuildHloModule( - versioned_handle, config, - /*include_unreachable_instructions=*/true)); - modules.push_back(std::move(module)); - } - - TF_ASSIGN_OR_RETURN( - std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); - - for (size_t i = 0; i < versioned_handles.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { - executables[i]->set_session_module(std::move(session_modules[i])); - } - } - - return std::move(executables); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( @@ -460,8 +332,8 @@ StatusOr>> Service::BuildExecutables( module_protos[i]->entry_computation_name().c_str()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - hlo_snapshots.push_back(std::move(hlo_snapshot)); } + hlo_snapshots.push_back(std::move(hlo_snapshot)); } VLOG(1) << "Computations:"; @@ -493,114 +365,6 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } -Status Service::ValidateEntryComputationLayout(HloModule* module) { - const ComputationLayout& on_device = - module->device_entry_computation_layout(); - for (int64 i = 0; i < on_device.parameter_count(); ++i) { - TF_RET_CHECK(ShapeUtil::Equal( - on_device.parameter_shape(i), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().parameter_shape(i)))); - } - TF_RET_CHECK(ShapeUtil::Equal( - module->device_entry_computation_layout().result_shape(), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().result_shape()))); - return Status::OK(); -} - -StatusOr> Service::BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, - versioned_handle.ToString().c_str()); - - // Dump computation proto state if flag is set. - std::unique_ptr session_module; - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !other_directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - true)); - - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); - - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); - - if (!other_directory_path.empty()) { - executable->set_session_module(std::move(session_module)); - } - - return std::move(executable); -} - -StatusOr> Service::BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator) { - std::shared_ptr executable = - compilation_cache_.LookUp(versioned_handle, *module_config); - - if (executable != nullptr) { - // Executable found in the computation cache. - if (profile != nullptr) { - profile->set_compilation_cache_hit(true); - } - return executable; - } - - uint64 start_micros = - // Avoid reading the clock if we don't want timing info - (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; - - // Take a copy of the module config, as compilation introduces layouts where - // layouts were optional before. - HloModuleConfig original_module_config = *module_config; - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), backend, - executor, device_allocator)); - - if (profile != nullptr) { - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - uint64 milliseconds = (end_micros - start_micros) / 1000; - profile->set_compilation_cache_hit(false); - profile->set_compile_time_ms(milliseconds); - } - - // Insert executable into the cache. - return compilation_cache_.Insert(std::move(executable_unique_ptr), - original_module_config); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -621,9 +385,16 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. @@ -695,7 +466,7 @@ Service::ExecuteParallelAndRegisterResult( HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( - executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); + executable->PopulateExecutionProfile(&hlo_profile, stream)); XLA_LOG_LINES( tensorflow::INFO, hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); @@ -796,13 +567,6 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); -} - StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const { @@ -844,117 +608,6 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - - std::vector>> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; - std::vector> module_configs; - std::vector computation_names; - std::vector device_handles; - - int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { - return a + r.execution_options().device_handles_size(); - }); - if (num_requested_devices * options_.number_of_replicas() > - execute_backend_->device_count()) { - return FailedPrecondition( - "there are not enough stream executors to execute %d computations", - num_requested_devices); - } - - for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor for the i'th computation. This stream executor - // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - - // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); - - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); - - // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. Here, we care only about the - // shapes of the arguments, so, it is sufficient to use the arguments of - // replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(replicated_arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - versioned_handles.push_back(versioned_handle); - module_configs.push_back(std::move(module_config)); - computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); - all_executors.push_back(executors); - device_handles.insert(device_handles.end(), - execution_options.device_handles().begin(), - execution_options.device_handles().end()); - } - - // Build the user computations into HloModules and compile to generate the - // executables. - // - // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); - std::vector executable_ptrs; - executable_ptrs.reserve(executables.size()); - for (const auto& executable : executables) { - executable_ptrs.push_back(executable.get()); - } - - // Execute the generated executables in parallel and return the device - // handles for each computation's output. - ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; - } - - VLOG(1) << "successfully completed 'execute-parallel' request"; - return Status::OK(); -} - Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; @@ -1004,11 +657,10 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, std::unique_ptr module_config, CreateModuleConfig(request.computation().program_shape(), replicated_arguments.front(), - request.execution_options(), - /*user_computation=*/nullptr)); + request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); + << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -1037,6 +689,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, executable_ptrs.push_back(executable.get()); } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + all_executors[i][0]->device_ordinal())); + TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(), + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + } + } + // Execute the generated executables in parallel and return the device // handles for each computation's output. ExecutionProfile profile; @@ -1052,6 +715,20 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, *result->add_responses() = response; } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(outputs[i], 0)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream(all_executors[i][0])); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + // Dump out the ith snapshot. + TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot()); + } + } + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; return Status::OK(); } @@ -1080,15 +757,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; @@ -1121,80 +789,6 @@ Status Service::PickParallelResponse( return Status::OK(); } -Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - // Since we care only about the shapes of the arguments, it is sufficient to - // use the arguments of replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } - - TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + user_computation->name(), result->mutable_profile())); - - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } - - VLOG(1) << "successfully completed 'execute' request"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -1227,13 +821,15 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( std::move(module), executor, device_allocator)); + if (!execution_directory_path.empty()) { + executable->set_hlo_snapshot(std::move(hlo_snapshot)); + } + return std::move(executable); } @@ -1271,12 +867,14 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + execute_backend_->default_stream_executor())); if (executable->dumping_snapshot()) { executable->hlo_snapshot()->set_execution_platform( execute_backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), + replicated_arguments.front(), stream.get(), execute_backend_->transfer_manager(), executable->hlo_snapshot())); } @@ -1290,9 +888,9 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable->hlo_snapshot())); TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } @@ -1300,86 +898,6 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return Status::OK(); } -Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); - - // Set up streams. - std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - std::vector result_buffers; - for (size_t i = 0; i < streams.size(); ++i) { - const auto& stream = streams[i]; - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); - - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, - executable->ExecuteAsyncOnStream( - &service_options, replicated_arguments[i])); - - result_buffers.emplace_back(std::move(this_result_buffer)); - } - - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(result_buffers), "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return Status::OK(); -} - Status Service::WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, @@ -1410,14 +928,13 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, return_shape = &shaped_buffer->on_host_shape(); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( + shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( std::unique_ptr result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( - executor, *shaped_buffer)); + stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal->shape())) { @@ -1467,9 +984,10 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( shape, execute_backend_->memory_allocator(), executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, shaped_buffer)); + stream.get(), *literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1546,117 +1064,6 @@ Status Service::ResetDevice(const ResetDeviceRequest* arg, return execute_backend_->ResetDevices(); } -Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return Status::OK(); -} - -Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } - return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); - } - - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - - TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); - if (arg->has_output_layout()) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - user_computation)); - - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); - - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. - // - // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); - } - *result->mutable_literal() = result_literal->ToProto(); - - return Status::OK(); -} - Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { if (!arg->has_computation()) { @@ -1706,60 +1113,6 @@ Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { return Status::OK(); } -Status Service::GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return Status::OK(); -} - -Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return Status::OK(); -} - -Status Service::GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - HloModuleConfig config; - config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); - - // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis( - execute_backend_->compiler()->ShapeSizeBytesFunction()); - - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); - - ComputationStats stats; - stats.set_flop_count(analysis.flop_count()); - stats.set_transcendental_count(analysis.transcendental_count()); - *result->mutable_stats() = stats; - return Status::OK(); -} - Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { @@ -1790,262 +1143,6 @@ Status Service::GetComputationGraphStats( return Status::OK(); } -template -Status Service::AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return Status::OK(); -} - -Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - // Send does not return a value, but we need a handle to be able to - // set OpMetadata and OpSharding (device assignment). - handle_status = computation->AddSendInstruction(arg->send_request()); - break; - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return Status::OK(); -} - -Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return Status::OK(); -} - -Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return Status::OK(); -} - DeviceHandle Service::SingleComputationDeviceHandle() const { DeviceHandle device_handle; device_handle.set_handle(0); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 81fbd41957887aec763e1cfe165ad0d1d2ac2269..47d196fb2aaee897ce1fd3745129af10bf5b2d2d 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -26,17 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -83,11 +78,6 @@ class Service : public ServiceInterface { static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is @@ -100,35 +90,15 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - - // Executes one or more computations in parallel with the provided global data - // passed as immutable arguments. Returns global data output for each - // computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) override; @@ -143,16 +113,6 @@ class Service : public ServiceInterface { Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override; - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the @@ -190,13 +150,6 @@ class Service : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - // Tests if an expression is a compile-time constant. - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - // Computes the value of a constant expression. - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; @@ -205,54 +158,15 @@ class Service : public ServiceInterface { Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; - // Returns the program shape of the computation associated with the given - // handle. - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - // Retrieves the statistics of a computation. - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Retrieves the statistics of a computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) override; - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - Status SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - // Creates a unique channel handle that can be used for Send/Recv // instructions. Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Returns the ComputationTracker of the current service instance. - // Only used in unit tests to access user computations from client. - const ComputationTracker& computation_tracker() { - return computation_tracker_; - } - // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } Backend* mutable_backend() { return execute_backend_.get(); } @@ -263,8 +177,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, @@ -280,9 +193,6 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); - // Assert that host- and device-shapes are in a consistent state. - Status ValidateEntryComputationLayout(HloModule* module); - protected: friend class LocalExecutable; @@ -305,23 +215,13 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. // // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator = nullptr); - - // Builds an Executable for the given HLO module proto. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -330,26 +230,12 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator); StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); - // Similar to BuildExecutable, but look in the compilation cache for the - // executable first. If the executable is not in the cache, it is built and - // inserted into the cache. - StatusOr> BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator = nullptr); - // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an @@ -372,24 +258,16 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile); - // Convenience function for adding a function to a user computation. - template - Status AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder); - // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given shape_with_layout + // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const; + Status ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that @@ -405,9 +283,6 @@ class Service : public ServiceInterface { ServiceOptions options_; - // Tracks computations built via the API. - ComputationTracker computation_tracker_; - // Tracks channels created via the API. ChannelTracker channel_tracker_; @@ -417,9 +292,6 @@ class Service : public ServiceInterface { // Tracks asynchronously launched executions via the API. ExecutionTracker execution_tracker_; - // Cache containing previously built Executables. - CompilationCache compilation_cache_; - // Backend to compile and execute computations on. std::unique_ptr execute_backend_; diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto deleted file mode 100644 index bb8d1cd2a106ea3e5bb61eee5052bd60c38cd0e2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/session.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This proto file defines messages which store the state of XLA -// computations within the XLA service. A computation is stored as a record -// of the operation requests used to build it. -syntax = "proto3"; - -import "tensorflow/compiler/xla/xla_data.proto"; - -package xla; - -// Describes a single operation request. -message OperationRequest { - ComputationDataHandle output_handle = 1; - Shape output_shape = 2; - - // For operations which call embedded computations such as "Map", these are - // the version(s) that the embedded computation should be called at. A version - // value of a computation is the ComputationDataHandle of the root of the - // computation at the point in time. - // - // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single - // embedded computation so this field will have a single value for those - // operations. - // - // "While" operation takes two; index 0 is the "condition" version and index 1 - // is the "body" version. - repeated int64 embedded_computation_versions = 3; - - // The actual request, which in itself is a tagged union of all possible - // operation request types. - OpRequest request = 4; -} - -// Describes a sequence of operation requests which define an XLA -// computation. -message SessionComputation { - string name = 1; - - // The ComputationHandle used to refer to this computation in the XLA - // service. - ComputationHandle computation_handle = 2; - - // Map from ComputationDataHandle value to operation request. The highest - // ComputationDataHandle value corresponds to the root of the computation. - map requests = 3; -} - -// Describes a group of SessionComputations with an "entry point" computation -// that may refer to the other non-entry (AKA embedded) computations. -// -// This message is used to serialize a computation that has been built via the -// XLA service API, along with its dependencies, for purposes such as -// analysis/replay/file-storage. -message SessionModule { - // The entry computation, which was requested for serialization. This may have - // referred to embedded computations, which are reflected below. - SessionComputation entry = 1; - - // Embedded computations that are transitively referred to by the entry - // computation. - repeated SessionComputation embedded_computations = 2; - - // The arguments passed to the computation. - repeated LiteralProto arguments = 3; - - // The result of the computation. - LiteralProto result = 4; - - // The name of the platform used to run the computation. - string execution_platform = 5; -} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3500978bdd808f0c7684d14a05636d90105aa594..70edf7883f91a0112a9576b639eb0e75b7f471e4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -44,147 +44,18 @@ namespace xla { namespace { -// Return the UnaryOperation proto enum value associated with the given HLO -// opcode. -UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAbs: - return UNOP_ABS; - case HloOpcode::kCeil: - return UNOP_CEIL; - case HloOpcode::kClz: - return UNOP_CLZ; - case HloOpcode::kCos: - return UNOP_COS; - case HloOpcode::kExp: - return UNOP_EXP; - case HloOpcode::kExpm1: - return UNOP_EXPM1; - case HloOpcode::kFloor: - return UNOP_FLOOR; - case HloOpcode::kImag: - return UNOP_IMAG; - case HloOpcode::kIsFinite: - return UNOP_IS_FINITE; - case HloOpcode::kLog: - return UNOP_LOG; - case HloOpcode::kLog1p: - return UNOP_LOG1P; - case HloOpcode::kNot: - return UNOP_NOT; - case HloOpcode::kNegate: - return UNOP_NEGATE; - case HloOpcode::kReal: - return UNOP_REAL; - case HloOpcode::kRoundNearestAfz: - return UNOP_ROUND_NEAREST_AFZ; - case HloOpcode::kSign: - return UNOP_SIGN; - case HloOpcode::kSin: - return UNOP_SIN; - case HloOpcode::kSort: - return UNOP_SORT; - case HloOpcode::kTanh: - return UNOP_TANH; - default: - LOG(FATAL) << "Unhandled opcode for conversion to unary operation: " - << opcode; - } -} - -// Return the BinaryOperation proto enum value associated with the given HLO -// opcode. -BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAtan2: - return BINOP_ATAN2; - case HloOpcode::kComplex: - return BINOP_COMPLEX; - case HloOpcode::kMultiply: - return BINOP_MUL; - case HloOpcode::kAdd: - return BINOP_ADD; - case HloOpcode::kSubtract: - return BINOP_SUB; - case HloOpcode::kDivide: - return BINOP_DIV; - case HloOpcode::kEq: - return BINOP_EQ; - case HloOpcode::kGe: - return BINOP_GE; - case HloOpcode::kGt: - return BINOP_GT; - case HloOpcode::kLe: - return BINOP_LE; - case HloOpcode::kLt: - return BINOP_LT; - case HloOpcode::kNe: - return BINOP_NE; - case HloOpcode::kMaximum: - return BINOP_MAX; - case HloOpcode::kMinimum: - return BINOP_MIN; - case HloOpcode::kPower: - return BINOP_POW; - case HloOpcode::kRemainder: - return BINOP_REM; - case HloOpcode::kOr: - return BINOP_OR; - case HloOpcode::kAnd: - return BINOP_AND; - case HloOpcode::kShiftLeft: - return BINOP_SHIFT_LEFT; - case HloOpcode::kShiftRightArithmetic: - return BINOP_SHIFT_RIGHT_ARITHMETIC; - case HloOpcode::kShiftRightLogical: - return BINOP_SHIFT_RIGHT_LOGICAL; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the TernaryOperation proto enum value associated with the given HLO -// opcode. -TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kClamp: - return TRIOP_CLAMP; - case HloOpcode::kSelect: - return TRIOP_SELECT; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the VariadicOperation proto enum value associated with the given HLO -// opcode. -VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kTuple: - return VAROP_TUPLE; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); - } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s, but got %s.", +Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument("Expected array argument for %s, but got %s.", std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); - } else { - return Status::OK(); } + return Status::OK(); } Status VerifyReducerShape(const ProgramShape& reducer_shape, @@ -198,11 +69,11 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, } const Shape& accumulator_shape = reducer_shape.result(); - if (ShapeUtil::Rank(accumulator_shape) != 0) { + if (!ShapeUtil::IsArray(accumulator_shape) || + ShapeUtil::Rank(accumulator_shape) != 0) { return InvalidArgument( - "Reduction function must have rank 0 (rank %lld reduction function " - "given).", - ShapeUtil::Rank(accumulator_shape)); + "Reduction function must produce a scalar but has shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); } // Check that the accumulator can be passed in as the first argument. @@ -316,88 +187,84 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. - if (opcode == HloOpcode::kCopy) { + // A domain shape is the same as the input one. + if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) { return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); -} + TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation")); -/* static */ StatusOr ShapeInference::InferUnaryOpShape( - UnaryOperation operation, const Shape& arg) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); - switch (operation) { - case UNOP_FLOOR: - case UNOP_CEIL: - if (!ShapeUtil::ElementIsFloating(arg)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + switch (opcode) { + case HloOpcode::kFloor: + case HloOpcode::kCeil: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( "Expected element type in shape to be floating for floor/ceil " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_COS: - case UNOP_SIN: - case UNOP_EXP: - case UNOP_EXPM1: - case UNOP_LOG: - case UNOP_LOG1P: - case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg) && - !ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kCos: + case HloOpcode::kSin: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kTanh: + if (!ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be floating or complex for " "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_REAL: - case UNOP_IMAG: - if (!ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kReal: + case HloOpcode::kImag: + if (!ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be complex for real/imag " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, F32); - case UNOP_ABS: - if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType(shape, F32); + case HloOpcode::kAbs: + if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( - arg, primitive_util::ComplexComponentType(arg.element_type())); + shape, primitive_util::ComplexComponentType(shape.element_type())); } - return arg; - case UNOP_CLZ: - case UNOP_NEGATE: - case UNOP_ROUND_NEAREST_AFZ: - case UNOP_SIGN: - case UNOP_SORT: - return arg; - - case UNOP_NOT: - if (arg.element_type() != PRED && - !primitive_util::IsIntegralType(arg.element_type())) { + return shape; + case HloOpcode::kClz: + case HloOpcode::kNegate: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSign: + return shape; + + case HloOpcode::kNot: + if (shape.element_type() != PRED && + !primitive_util::IsIntegralType(shape.element_type())) { return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; + return shape; - case UNOP_IS_FINITE: - if (!ShapeUtil::ElementIsFloating(arg)) { + case HloOpcode::kIsFinite: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating point for IsFinite " + "Expected element type in shape to be floating " + "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, PRED); + return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - UnaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -414,8 +281,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const Shape* arg_shape = nullptr; PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); + TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; element_type = arg_shape->element_type(); @@ -462,6 +328,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr ShapeInference::InferAfterAllShape( + tensorflow::gtl::ArraySlice arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -472,12 +349,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "Convert does not allow tuples, so cannot convert from %s to %s.", + "Convert does not allow non-arrays, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -494,7 +372,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. @@ -541,7 +420,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (ShapeUtil::IsTuple(operand_shape)) { + if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -680,8 +559,8 @@ Status ValidateDotDimensionNumbers( /* static */ StatusOr ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( @@ -767,8 +646,9 @@ Status ValidateDotDimensionNumbers( } /* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs) { +ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, + const Shape& lhs, + const Shape& rhs) { TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); // The shapes have to be compatible. That is, if some dimension d has a @@ -786,7 +666,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - BinaryOperation_Name(operation).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -796,8 +676,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring @@ -898,18 +777,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -942,10 +818,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN( - Shape indim_broadcast_shape, - InferInDimBroadcastShape(operation, smaller_shape, larger_shape, - broadcast_dimensions)); + TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape( operation, indim_broadcast_shape, larger_shape); @@ -954,51 +829,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), - rhs->shape(), /*broadcast_dimensions=*/{}); + return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), + /*broadcast_dimensions=*/{}); } /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, - broadcast_dimensions); -} - -/* static */ StatusOr ShapeInference::InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - lhs, tensorflow::strings::StrCat("lhs of binary operation ", - BinaryOperation_Name(operation)))); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - rhs, tensorflow::strings::StrCat("rhs of binary operation ", - BinaryOperation_Name(operation)))); - switch (operation) { - case BINOP_MAX: - case BINOP_MIN: - case BINOP_SUB: - case BINOP_ADD: - case BINOP_ATAN2: - case BINOP_POW: - case BINOP_DIV: - case BINOP_REM: - case BINOP_MUL: - case BINOP_SHIFT_LEFT: - case BINOP_SHIFT_RIGHT_ARITHMETIC: - case BINOP_SHIFT_RIGHT_LOGICAL: - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + TF_RETURN_IF_ERROR( + ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", + HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR( + ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", + HloOpcodeString(opcode)))); + switch (opcode) { + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kSubtract: + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kPower: + case HloOpcode::kDivide: + case HloOpcode::kRemainder: + case HloOpcode::kMultiply: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_COMPLEX: { + case HloOpcode::kComplex: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( "Expected element type in shape to be floating for complex compose " @@ -1006,7 +874,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); @@ -1014,8 +882,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return Unimplemented("Complex component type is not implemented."); } } - case BINOP_AND: - case BINOP_OR: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -1023,24 +892,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: - case BINOP_GE: - case BINOP_GT: - case BINOP_LE: - case BINOP_LT: - case BINOP_NE: { + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: { TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - BinaryOperation_Name(operation).c_str(), - lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); } } @@ -1052,23 +921,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); -} - -/* static */ StatusOr ShapeInference::InferTernaryOpShape( - TernaryOperation operation, const Shape& lhs, const Shape& rhs, - const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); - switch (operation) { - case TRIOP_CLAMP: + switch (opcode) { + case HloOpcode::kClamp: return InferClampShape(lhs, rhs, ehs); - case TRIOP_SELECT: + case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); + case HloOpcode::kTupleSelect: + return InferTupleSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", - TernaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1076,6 +941,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operands) { std::vector operand_shapes; + operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } @@ -1085,27 +951,30 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes) { - return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), - operand_shapes); -} - -/* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } - switch (operation) { - case VAROP_TUPLE: { + switch (opcode) { + case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); + result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } return result; } + case HloOpcode::kSort: { + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } else if (operand_shapes.size() == 2) { + return ShapeUtil::MakeTupleShape( + {*operand_shapes[0], *operand_shapes[1]}); + } + return InvalidArgument("Unexpected number of operands for sort"); + } default: return InvalidArgument("Unknown operation %s.", - VariadicOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1120,15 +989,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // All arguments must have the same shape. const Shape* arg_shape = arg_shapes[0]; for (size_t i = 1; i < arg_shapes.size(); ++i) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); + TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map")); if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } - if (!ShapeUtil::IsTuple(*arg_shapes[i]) && - !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; @@ -1211,11 +1077,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm training")); + ExpectArray(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1317,11 +1183,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm inference")); + ExpectArray(operand_shape, "operand of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1464,16 +1330,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64 feature_index) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad")); + ExpectArray(scale_shape, "scale input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - output_grad_shape, "output_grad input of batch norm grad")); + ExpectArray(output_grad_shape, "output_grad input of batch norm grad")); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); @@ -1622,8 +1485,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( @@ -1858,7 +1721,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + ExpectArray(*operand_shape, "operand of cross replica sum")); } if (operand_shapes.size() == 1) { return *operand_shapes[0]; @@ -1900,8 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, operand_shape.element_type())); return InferWindowOutputShape(operand_shape, window, @@ -1914,7 +1776,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter")); + ExpectArray(operand_shape, "operand of select-and-scatter")); // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { @@ -1979,7 +1841,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); }; - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); + TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), @@ -2038,10 +1900,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, tensorflow::gtl::ArraySlice slice_sizes) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape, - "start indices of dynamic slice")); + ExpectArray(start_indices_shape, "start indices of dynamic slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", @@ -2099,11 +1960,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& update_shape, const Shape& start_indices_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice")); + ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - start_indices_shape, "start indices of dynamic update slice")); + ExpectArray(update_shape, "update of dynamic update slice")); + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); VLOG(2) << tensorflow::strings::Printf( "updating slice of shape %s at dynamic start_indices %s with update " @@ -2171,8 +2032,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /*static */ StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reverse")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); } @@ -2302,7 +2162,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { return InvalidArgument("Broadcast with negative dimension size %lld.", @@ -2321,7 +2181,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); @@ -2353,7 +2213,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTransposeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); std::iota(indices.begin(), indices.end(), 0); @@ -2374,9 +2234,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // "degenerate" cases, as with binary elementwise ops. /* static */ StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", @@ -2409,15 +2269,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - bool compatible; - if (ShapeUtil::IsTuple(on_true)) { - // Select only defines the top-level buffer, so if it's a tuple, the two - // input must match exactly. - compatible = ShapeUtil::Compatible(on_true, on_false); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); - } - if (!compatible) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), @@ -2429,7 +2281,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(pred).c_str()); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || - ShapeUtil::Rank(pred) == 0) { + ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. @@ -2443,6 +2295,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } +/* static */ StatusOr ShapeInference::InferTupleSelectShape( + const Shape& pred, const Shape& on_true, const Shape& on_false) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + if (!ShapeUtil::Compatible(on_true, on_false)) { + return InvalidArgument( + "Operands to tuple-select must be the same shape; got %s and %s.", + ShapeUtil::HumanString(on_true).c_str(), + ShapeUtil::HumanString(on_false).c_str()); + } + if (pred.element_type() != PRED) { + return InvalidArgument( + "TupleSelect's pred operand must have PRED element type; got %s.", + ShapeUtil::HumanString(pred).c_str()); + } + if (!ShapeUtil::IsScalar(pred)) { + return InvalidArgument( + "TupleSelect operation with non-scalar predicate: %s.", + ShapeUtil::HumanString(pred).c_str()); + } + return on_true; +} + /* static */ StatusOr ShapeInference::InferCallShape( tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply) { @@ -2575,9 +2450,9 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR( + ExpectArray(gather_indices_shape, "gather indices operand of gather op")); if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 9da2c99b4177f08ece8daabaf2922ddd7e947a1b..1a5684e3c306eef90fd1bfdf4565b0dcde2fbab6 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -46,8 +46,6 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(UnaryOperation operation, - const Shape& arg); static StatusOr InferUnaryOpShape(HloOpcode opcode, const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, @@ -55,9 +53,6 @@ class ShapeInference { // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); @@ -67,9 +62,6 @@ class ShapeInference { // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(TernaryOperation operation, - const Shape& lhs, const Shape& rhs, - const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs); @@ -80,9 +72,6 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes); @@ -227,6 +216,13 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + // Infers the shape produced by a kAfterAll. Trivially this shape is always a + // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes + // and checking operand shapes. This method verifies that the operand shapes + // are all TOKENs. + static StatusOr InferAfterAllShape( + tensorflow::gtl::ArraySlice arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -279,7 +275,7 @@ class ShapeInference { // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); // Helper for inferring the shape of Clamp ops. @@ -290,12 +286,16 @@ class ShapeInference { static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, const Shape& on_false); + // Helper for inferring the shape of TupleSelect ops. + static StatusOr InferTupleSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); // Helper for inferring shapes of binary operations which use degenerate // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). static StatusOr InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs); + HloOpcode operation, const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations @@ -303,8 +303,7 @@ class ShapeInference { // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 0e61994a786b53a295ef9c9c2287b28fbf754d9b..bafe14d6f45f851924c37908d4c93bbff2dac459 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = ShapeInference::InferUnaryOpShape( - UnaryOperation::UNOP_NEGATE, matrix_shape); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } @@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } @@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), - matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, + matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); @@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, - matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + auto inferred_status = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) - .ok()); - // Dimension mismatch ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_64_, vector_32_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_64_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_32_, vector_64_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); - // Dimension mismatch, where one operand is a scalar + // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, const tensorflow::gtl::ArraySlice& bcast) { - return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, - lhs, rhs, bcast); + return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, + bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); @@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = ShapeInference::InferVariadicOpShape( - VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + StatusOr result = + ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); @@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } @@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) { TEST_F(ShapeInferenceTest, InferCompareShapeEq) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) { TEST_F(ShapeInferenceTest, InferCompareShapeGe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) { TEST_F(ShapeInferenceTest, InferCompareShapeGt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) { TEST_F(ShapeInferenceTest, InferCompareShapeLe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) { TEST_F(ShapeInferenceTest, InferCompareShapeLt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) { TEST_F(ShapeInferenceTest, InferCompareShapeNe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {1}); + auto inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {0}); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); - inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {0}); + inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {1}); + inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } @@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {}); + auto inferred_status_error1 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + auto inferred_status_error2 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + auto inferred_status_error3 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("dimension 0 mismatch")); @@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), HasSubstr("dimensions order is wrong")); @@ -1315,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( @@ -1391,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("Expected non-tuple argument")); + HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1690,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for input")) + HasSubstr("Expected array argument for input")) << statusor.status(); } @@ -1704,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for gather indices")) + HasSubstr("Expected array argument for gather indices")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 6bacb37206c3b521ec05a63be73a40fc203f3265..7d7dcac10b65933d1c81b8aca77465932694bfdb 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -123,6 +123,8 @@ ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) } ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { + Deallocate(); + *static_cast(this) = std::move(static_cast(s)); allocator_ = s.allocator_; // Null out s.allocator_ so it doesn't try to free anything in its destructor. @@ -130,7 +132,15 @@ ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { return *this; } -ScopedShapedBuffer::~ScopedShapedBuffer() { +ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); } + +ShapedBuffer ScopedShapedBuffer::release() { + ShapedBuffer shaped_buffer(static_cast(*this)); + buffers_ = ShapeTree(); + return shaped_buffer; +} + +void ScopedShapedBuffer::Deallocate() { // allocator_ will be null if we were moved-from. if (allocator_ == nullptr) { return; @@ -148,10 +158,4 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } } -ShapedBuffer ScopedShapedBuffer::release() { - ShapedBuffer shaped_buffer(static_cast(*this)); - buffers_ = ShapeTree(); - return shaped_buffer; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 25b709523b7cd59a6bac56a478fc9886e1cf0487..905a7e82e621f2bf4588b71be5dbab20f892cafe 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -169,6 +169,8 @@ class ScopedShapedBuffer : public ShapedBuffer { TF_MUST_USE_RESULT ShapedBuffer release(); protected: + void Deallocate(); + DeviceMemoryAllocator* allocator_; }; diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fc243667911651c788e3c1e5f1d39d86170f1ad --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace { + +TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { + TF_ASSERT_OK_AND_ASSIGN(auto platforms, + xla::PlatformUtil::GetSupportedPlatforms()); + ASSERT_FALSE(platforms.empty()); + auto* platform = platforms[0]; + TF_ASSERT_OK_AND_ASSIGN(auto executors, + xla::PlatformUtil::GetStreamExecutors(platform)); + xla::StreamExecutorMemoryAllocator allocator(platform, executors); + const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + const int kDeviceOrdinal = 0; + auto scoped_buffer = tensorflow::MakeUnique( + shape, shape, &allocator, kDeviceOrdinal); + std::unique_ptr buffer = std::move(scoped_buffer); + buffer = nullptr; +} + +class TestAllocator : public DeviceMemoryAllocator { + public: + TestAllocator() + : DeviceMemoryAllocator(PlatformUtil::GetDefaultPlatform().ValueOrDie()) { + } + + ~TestAllocator() override { + if (!allocations_.empty()) { + ADD_FAILURE() << "Some allocations not freed!"; + } + } + + // Pull in two-arg overload of Allocate. + using DeviceMemoryAllocator::Allocate; + + StatusOr Allocate(int device_ordinal, uint64 size, + bool /*retry_on_failure*/) override { + // By contract, we must return null if size == 0. + if (size == 0) { + return OwningDeviceMemory(); + } + void* buf = malloc(size); + allocations_.insert({device_ordinal, buf}); + return OwningDeviceMemory(se::DeviceMemoryBase(buf, size), device_ordinal, + this); + } + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { + if (mem.is_null()) { + return Status::OK(); + } + + auto it = allocations_.find({device_ordinal, mem.opaque()}); + if (it == allocations_.end()) { + ADD_FAILURE() << "Allocation not found (double free?)"; + } else { + free(mem.opaque()); + allocations_.erase(it); + } + return Status::OK(); + } + + bool AllowsAsynchronousDeallocation() const override { return false; } + + private: + std::set> allocations_; +}; + +TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { + Shape s = ShapeUtil::MakeShape(F32, {1}); + TestAllocator allocator; + ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + sb1.set_buffer( + allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), + /*index=*/{}); + + ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + sb2.set_buffer( + allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), + /*index=*/{}); + + sb1 = std::move(sb2); + + // TestAllocator's destructor checks that all memory was freed. +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index c4d01562c4e32225ebb984d8fcd93ec3fa86e403..4c5038a009ba5da4172129980014913f3f4418f4 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -22,8 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" + +using ::tensorflow::strings::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -36,8 +40,73 @@ TransferManager::GetPlatformTransferManagers() { return r; } +StatusOr> TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer) { + StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, + [&](StatusOr> arg) { + ret = std::move(arg); + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +Status TransferManager::TransferLiteralToDevice( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR( + TransferLiteralToDeviceAsync(substream, literal, device_buffer)); + return substream->BlockHostUntilDone(); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + TransferArrayFromDevice(substream, shape, source, + [&](StatusOr> arg) { + ret = std::move(arg); + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest)); + return substream->BlockHostUntilDone(); +} + +Status TransferManager::TransferArrayToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -51,28 +120,32 @@ Status TransferManager::TransferArrayToDevice( dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); - return TransferLiteralToDevice(executor, literal, shaped_buffer); + return TransferLiteralToDevice(stream, literal, shaped_buffer); } -StatusOr> TransferManager::TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, - const se::DeviceMemoryBase& source) { - TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) - << "Shape " << ShapeUtil::HumanString(shape) - << " has a differently shaped representation on-device: " - << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); +void TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, + std::function>)> done) { + if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { + auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), + " has a differently shaped representation on-device: ", + ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); + return done(FailedPrecondition("%s", error.c_str())); + } if (source.size() < GetByteSizeRequirement(shape)) { - return FailedPrecondition( - "Allocation on device not large enough for array: " - "%lld < %lld", - source.size(), GetByteSizeRequirement(shape)); + return done( + FailedPrecondition("Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(executor, shaped_buffer); + return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( @@ -108,10 +181,14 @@ StatusOr> TransferManager::TransferArrayFromDevice( } Status TransferManager::WriteTupleIndexTables( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables for " << device_buffer; + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); + return stream->BlockHostUntilDone(); +} - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); +Status TransferManager::WriteTupleIndexTablesAsync( + se::Stream* stream, const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables for " << device_buffer; return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), @@ -129,7 +206,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteSingleTupleIndexTable(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(stream, elements, device_subshape, &device_memory); } @@ -138,26 +215,20 @@ Status TransferManager::WriteTupleIndexTables( } Status TransferManager::TransferBufferFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - int64 size, void* destination) { + se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, + void* destination) { if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " "%lld < %lld", source.size(), size); } - auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer from device to buffer"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } Status TransferManager::TransferBufferToDevice( - se::StreamExecutor* executor, int64 size, const void* source, + se::Stream* stream, int64 size, const void* source, se::DeviceMemoryBase* destination) { if (destination->size() < size) { return FailedPrecondition( @@ -165,13 +236,7 @@ Status TransferManager::TransferBufferToDevice( "%lld < %lld", destination->size(), size); } - auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of buffer to device"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 43a8092b06fba0e2495bce0ee1a309c85a908273..e384359642a8fe09e0b8516e342a56259912922a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -52,30 +52,65 @@ class TransferManager { return host_shape; } - // Returns a literal containing the data held in the given ShapedBuffer. - // using the provided executor. The optional literal_shape will be the shape - // for the literal. The shape of the ShapedBuffer and - // DeviceShape(literal_shape) must be compatible, but need not have the same - // layout. + // Returns a literal containing the data held in the given ShapedBuffer + // using the provided executor. This operation is performed synchronously + // without waiting for any other operation on a stream to complete. + // + // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; + se::Stream* stream, const ShapedBuffer& device_buffer); + + // Begins transferring a literal containing the data held in the given + // ShapedBuffer using the provided executor. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. 'done' is invoked with the result when + // complete. + // + // device_buffer is copied by reference and must live at least until done() is + // invoked. + virtual void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, - // but need not have the same layout - virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, + // but need not have the same layout. + // + // This operation is performed synchronously without waiting for any other + // operation on a stream to complete. This function should be avoided in favor + // of the asynchronous version below. + virtual Status TransferLiteralToDevice(se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer); + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. + virtual Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, + Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); + void TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + std::function>)> done); + + Status TransferArrayToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, @@ -96,8 +131,10 @@ class TransferManager { // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. - Status WriteTupleIndexTables(se::StreamExecutor* executor, + Status WriteTupleIndexTables(se::Stream* stream, const ShapedBuffer& device_buffer); + Status WriteTupleIndexTablesAsync(se::Stream* stream, + const ShapedBuffer& device_buffer); // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory @@ -144,7 +181,7 @@ class TransferManager { // 'destination' buffer. // // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice(se::StreamExecutor* executor, + virtual Status TransferBufferFromDevice(se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination); @@ -152,15 +189,15 @@ class TransferManager { // destination of the device. // // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice(se::StreamExecutor* executor, - int64 size, const void* source, + virtual Status TransferBufferToDevice(se::Stream* stream, int64 size, + const void* source, se::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index ba16dc640e2d2974eab4fc8b134a6e33c03e3b85..49e1f873192f800056a2272f7d4f698898b0f8a1 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -178,7 +178,6 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - convolution.SetupDerivedInstruction(new_conv.get()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index f73f1227aaf1630a9e7c43bb508732c5518ef929..cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -69,7 +69,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -91,7 +91,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -119,7 +119,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -147,7 +147,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( - {add, sub, mul}, "", entry_computation); + {add, sub, mul}, "entry", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. @@ -205,7 +205,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); const HloComputation* callee = module->GetComputationWithName("callee"); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 8cb654493ca82dc702b2c1e7a4284f4f31d1e5f9..990dfc410ccf6ab84af00f4a16dc783c11985844 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -121,7 +122,6 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } namespace { - // Gather fusion instructions from 'instruction' into 'fusion_instructions'. void GatherFusionInstructions( HloInstruction* instruction, @@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(domain, domain->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { // A kSlice instruction aliases its operand if the backend lowers it to an // in-place implementation. @@ -284,22 +292,29 @@ Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { } Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { - // RecvDone aliases its input (Recv) tuple element {0} to its output. + // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its + // output. The other indices ({} and {1}) define their own buffers. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}), + /*index=*/{}); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}), + /*index=*/{1}); + const PointsToSet& operand_points_to_set = GetPointsToSet(recv_done->operand(0)); - // Recursively copy the points to set of the operand tuple {0}. + // Recursively copy the points to set of the operand tuple {0} to the output + // element {0}. points_to_set.ForEachMutableElement( [this, &points_to_set, &operand_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { - ShapeIndex src_index({0}); - for (auto element : index) { - src_index.push_back(element); + if (index.empty() || index[0] != 0) { + return; } - *buffers = operand_points_to_set.element(src_index); - for (auto& tuple_source : - operand_points_to_set.tuple_sources(src_index)) { + *buffers = operand_points_to_set.element(index); + for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) { points_to_set.add_tuple_source(index, tuple_source); } }); @@ -307,7 +322,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { } Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { - // Send creates a tuple of {aliased operand, U32 context}. + // Send creates a tuple of {aliased operand, U32 context, token}. PointsToSet& points_to_set = CreateEmptyPointsToSet(send); // Creates the points to set for the tuple and its element at {1}. @@ -320,6 +335,10 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { context_buffer->push_back( &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + auto token_buffer = points_to_set.mutable_element(ShapeIndex({2})); + token_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2}))); + // Recursively copy the points to set of the operand to output tuple {0}. const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); operand_points_to_set.ForEachElement( @@ -380,7 +399,7 @@ Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { +Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. Which side is chosen cannot be // determined statically so conservatively set the points-to set to the union @@ -388,9 +407,9 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // // First create a copy of the on_true points-to set (and tuple sources), then // add in elements of the on_false points-to set (tuple sources). - auto on_true = select->operand(1); - auto on_false = select->operand(2); - PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true); const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set; points_to_set.ForEachMutableElement( [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) { @@ -408,7 +427,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // respective element in the points-to set should contain only itself. points_to_set.mutable_element({})->clear(); points_to_set.AddPointedToBuffer( - logical_buffer_analysis_->GetBuffer(select, /*index=*/{}), + logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}), /*index=*/{}); return Status::OK(); } @@ -715,15 +734,22 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } else { + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + fusion_param); + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -781,8 +807,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return param_uses.size() == 1 && param_uses[0].first == callee_root && callee_root->IsElementwiseOnOperand(param_uses[0].second); } - // Check if 'user' is element-wise. - return user->IsElementwise(); + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 1ac713013650d807b15e33565e6d2dec406a5d13..686bb053288fbd6a46ca50a2c65c739354fd2678 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -248,11 +248,12 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; - Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index f558316b05b168a6f100e8ef69adfd9dbc023102..226d0af5d27bb37b08747cb86f0bc4bfa6f3db96 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -318,8 +318,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto send = builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); BuildModuleAndRunAnalysis(builder.Build()); @@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { // RecvDone forwards its operand tuple element at {0} to the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto recv = builder.AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); BuildModuleAndRunAnalysis(builder.Build()); @@ -355,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); - ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}}); } TEST_F(TuplePointsToAnalysisTest, TupleSelect) { @@ -374,7 +376,7 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -403,7 +405,7 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, param0, param1)); + tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1)); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select)); @@ -452,7 +454,7 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -488,7 +490,7 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { auto pred = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -1148,5 +1150,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { call, {})); } +TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) { + Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32}); + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16}); + + auto builder = HloComputation::Builder(TestName() + "_fusion"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "full")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, broadcast_shape, "small")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, param1, {0})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, param0, broadcast)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index 113c2e2bd9f73a2b0c783103d7f2da9534bc97c3..77bdcc9de0d830991208a1db271d009bccaf550e 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -30,10 +30,17 @@ limitations under the License. namespace xla { +TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : + exclude_entry_computation_(exclude_entry_computation) {} + StatusOr TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue worklist; for (auto* computation : module->computations()) { + if (exclude_entry_computation_ && + computation == module->entry_computation()) { + continue; + } for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -78,7 +85,6 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), @@ -108,10 +114,10 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // | // GTE if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { - changed = true; HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); + changed = true; TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); for (HloInstruction* user : element_source->users()) { if (user->opcode() == HloOpcode::kTuple || diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8..750950188312c5077d487f2feef0606f07839432 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -27,13 +27,20 @@ namespace xla { // the module. class TupleSimplifier : public HloPassInterface { public: - TupleSimplifier() {} + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} tensorflow::StringPiece name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index ca9ae91281fce5ee061d066fc3e538dbbc09f6b3..d3635eae81ec7017f9bf6a69250d10716309c9ec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { TF_ASSERT_OK(changed_status.status()); EXPECT_EQ(change_expected, changed_status.ValueOrDie()); } + void Run(HloModule* module, bool change_expected, bool exclude_entry) { + TupleSimplifier simplifier(exclude_entry); + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); } +TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { + // Verify that the root computation can be excluded + auto module = CreateNewModule(); + + HloInstruction* p0; + HloInstruction* p1; + HloComputation* c0; + HloComputation* c1; + HloComputation* entry; + + { + HloComputation::Builder builder(TestName() + "_1"); + p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c0 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_2"); + p1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c1 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_Entry"); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* call0 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); + HloInstruction* call1 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); + HloInstruction* gte3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); + + entry = module->AddEntryComputation(builder.Build()); + } + + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + + EXPECT_THAT(c0->root_instruction(), p0); + EXPECT_THAT(c1->root_instruction(), p1); + EXPECT_THAT(entry->instruction_count(), 9); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 754fd8ef169231827eeb5bfd72aeb596644ca767..d33d5bb8f30c8504aa323d461e5f59709b48e1fc 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc deleted file mode 100644 index 9e62d0acfb98946f1e693fc0310098b4ec99750b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ /dev/null @@ -1,3557 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { -namespace { - -HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { - switch (unop) { - case UNOP_ABS: - return HloOpcode::kAbs; - case UNOP_CEIL: - return HloOpcode::kCeil; - case UNOP_CLZ: - return HloOpcode::kClz; - case UNOP_COS: - return HloOpcode::kCos; - case UNOP_EXP: - return HloOpcode::kExp; - case UNOP_EXPM1: - return HloOpcode::kExpm1; - case UNOP_FLOOR: - return HloOpcode::kFloor; - case UNOP_IMAG: - return HloOpcode::kImag; - case UNOP_IS_FINITE: - return HloOpcode::kIsFinite; - case UNOP_LOG: - return HloOpcode::kLog; - case UNOP_LOG1P: - return HloOpcode::kLog1p; - case UNOP_NOT: - return HloOpcode::kNot; - case UNOP_NEGATE: - return HloOpcode::kNegate; - case UNOP_REAL: - return HloOpcode::kReal; - case UNOP_ROUND_NEAREST_AFZ: - return HloOpcode::kRoundNearestAfz; - case UNOP_SIGN: - return HloOpcode::kSign; - case UNOP_SIN: - return HloOpcode::kSin; - case UNOP_SORT: - return HloOpcode::kSort; - case UNOP_TANH: - return HloOpcode::kTanh; - default: - LOG(FATAL) << "unhandled operation " << unop; - } -} - -HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { - switch (binop) { - case BINOP_ATAN2: - return HloOpcode::kAtan2; - case BINOP_COMPLEX: - return HloOpcode::kComplex; - case BINOP_MUL: - return HloOpcode::kMultiply; - case BINOP_ADD: - return HloOpcode::kAdd; - case BINOP_SUB: - return HloOpcode::kSubtract; - case BINOP_DIV: - return HloOpcode::kDivide; - case BINOP_EQ: - return HloOpcode::kEq; - case BINOP_GE: - return HloOpcode::kGe; - case BINOP_GT: - return HloOpcode::kGt; - case BINOP_LE: - return HloOpcode::kLe; - case BINOP_LT: - return HloOpcode::kLt; - case BINOP_NE: - return HloOpcode::kNe; - case BINOP_MAX: - return HloOpcode::kMaximum; - case BINOP_MIN: - return HloOpcode::kMinimum; - case BINOP_POW: - return HloOpcode::kPower; - case BINOP_REM: - return HloOpcode::kRemainder; - case BINOP_OR: - return HloOpcode::kOr; - case BINOP_AND: - return HloOpcode::kAnd; - case BINOP_SHIFT_LEFT: - return HloOpcode::kShiftLeft; - case BINOP_SHIFT_RIGHT_ARITHMETIC: - return HloOpcode::kShiftRightArithmetic; - case BINOP_SHIFT_RIGHT_LOGICAL: - return HloOpcode::kShiftRightLogical; - default: - LOG(FATAL) << "unhandled operation " << binop; - } -} - -HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { - switch (triop) { - case TRIOP_CLAMP: - return HloOpcode::kClamp; - case TRIOP_SELECT: - return HloOpcode::kSelect; - default: - LOG(FATAL) << "unhandled operation " << triop; - } -} - -HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { - switch (varop) { - case VAROP_TUPLE: - return HloOpcode::kTuple; - default: - LOG(FATAL) << "unhandled operation " << varop; - } -} - -} // namespace - -/* static */ StatusOr> -UserComputation::MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new) { - auto user_computation = - MakeUnique(session_computation.name(), handle); - { - tensorflow::mutex_lock lock(user_computation->mutex_); - user_computation->session_computation_ = session_computation; - user_computation->next_handle_value_ = - std::max_element(session_computation.requests().begin(), - session_computation.requests().end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }) - ->first + - 1; - TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); - } - - return std::move(user_computation); -} - -UserComputation::UserComputation(const string& name, - const ComputationHandle& handle) - : name_(name), next_handle_value_(1) { - *session_computation_.mutable_computation_handle() = handle; - session_computation_.set_name(name); - - VLOG(1) << "New UserComputation \"" << name - << "\", handle: " << handle.handle(); -} - -ComputationDataHandle UserComputation::CreateComputationDataHandle() { - ComputationDataHandle handle; - handle.set_handle(next_handle_value_); - // Handles are used as Version values and *must* be assigned consecutively for - // computation versioning to work. - next_handle_value_++; - return handle; -} - -StatusOr UserComputation::AddParameterInstruction( - const ParameterRequest& parameter_request) { - tensorflow::mutex_lock lock(mutex_); - - int64 parameter_number = parameter_request.parameter(); - if (parameters_.count(parameter_number) != 0) { - return InvalidArgument("parameter %lld already registered", - parameter_number); - } - ComputationDataHandle handle = CreateComputationDataHandle(); - - const Shape& validated_shape = parameter_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_parameter_request() = parameter_request; - - parameters_[parameter_number] = &request; - - VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << parameter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSendInstruction( - const SendRequest& send_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); - - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_send_request() = send_request; - - VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << send_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddRecvInstruction( - const RecvRequest& recv_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = recv_request.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_recv_request() = recv_request; - - VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << recv_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddPadInstruction( - const PadRequest& pad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(pad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookUpRequest(pad_request.padding_value())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( - operand->output_shape(), - padding_value->output_shape(), - pad_request.padding_config())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_pad_request() = pad_request; - - VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << pad_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConstantInstruction( - const ConstantRequest& constant_request) { - const Shape& validated_shape = constant_request.literal().shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - tensorflow::mutex_lock lock(mutex_); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_constant_request() = constant_request; - - VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle(); - return handle; -} - -StatusOr UserComputation::AddGatherInstruction( - const GatherRequest& gather_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, - LookUpRequest(gather_request.input())); - TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, - LookUpRequest(gather_request.gather_indices())); - - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferGatherShape( - input_request->output_shape(), gather_indices_request->output_shape(), - gather_request.dimension_numbers(), - AsInt64Slice(gather_request.window_bounds()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_gather_request() = gather_request; - - VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << gather_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(get_tuple_element_request.operand())); - if (!ShapeUtil::IsTuple(operand->output_shape())) { - return InvalidArgument( - "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(operand->output_shape()).c_str()); - } - Shape element_shape = ShapeUtil::GetTupleElementShape( - operand->output_shape(), get_tuple_element_request.index()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = element_shape; - *request.mutable_request()->mutable_get_tuple_element_request() = - get_tuple_element_request; - - VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << get_tuple_element_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_trace_request() = trace_request; - - VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << trace_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRngInstruction( - const RngRequest& rng_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check the number of parameters per RNG distribution. - switch (rng_request.distribution()) { - case RandomDistribution::RNG_NORMAL: - case RandomDistribution::RNG_UNIFORM: - if (rng_request.parameter_size() != 2) { - return InvalidArgument( - "RNG distribution (%s) expects 2 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; - default: - LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); - } - - // Verify that the parameter indices are valid; - for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookUpRequest(param).status()); - } - const Shape& validated_shape = rng_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_rng_request() = rng_request; - - VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << rng_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, - AsInt64Slice(map_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_map_request() = map_request; - - VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << map_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceShape( - operand->output_shape(), init_value->output_shape(), - AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_request() = reduce_request; - - VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_training_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_training_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_training_request.offset())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormTrainingShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), batch_norm_training_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_training_request() = - batch_norm_training_request; - - VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_training_request.ShortDebugString(); - - return handle; -} - -StatusOr -UserComputation::AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_inference_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_inference_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_inference_request.offset())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_inference_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_inference_request.variance())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBatchNormInferenceShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), mean->output_shape(), - variance->output_shape(), - batch_norm_inference_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_inference_request() = - batch_norm_inference_request; - - VLOG(1) << "AddBatchNormInferenceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << batch_norm_inference_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_grad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_grad_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_grad_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_grad_request.variance())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, - LookUpRequest(batch_norm_grad_request.grad_output())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormGradShape( - operand->output_shape(), scale->output_shape(), mean->output_shape(), - variance->output_shape(), grad_output->output_shape(), - batch_norm_grad_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_grad_request() = - batch_norm_grad_request; - - VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_grad_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_window_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_window_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->output_shape(), init_value->output_shape(), - reduce_window_request.window(), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_window_request() = - reduce_window_request; - - VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_window_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(select_and_scatter_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookUpRequest(select_and_scatter_request.source())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(select_and_scatter_request.init_value())); - - VersionedComputationHandle::Version select_version = - select_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, - select_computation.ComputeProgramShape(select_version)); - VersionedComputationHandle::Version scatter_version = - scatter_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, - scatter_computation.ComputeProgramShape(scatter_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferSelectAndScatterShape( - operand->output_shape(), *select_program_shape, - select_and_scatter_request.window(), source->output_shape(), - init_value->output_shape(), *scatter_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(select_version); - request.add_embedded_computation_versions(scatter_version); - *request.mutable_request()->mutable_select_and_scatter_request() = - select_and_scatter_request; - - VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << select_and_scatter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReverseInstruction( - const ReverseRequest& reverse_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reverse_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReverseShape( - operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reverse_request() = reverse_request; - VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reverse_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookUpRequest(while_request.init())); - - VersionedComputationHandle::Version condition_version = - condition_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr condition_program_shape, - condition_computation.ComputeProgramShape(condition_version)); - - VersionedComputationHandle::Version body_version = body_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, - body_computation.ComputeProgramShape(body_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferWhileShape( - *condition_program_shape, *body_program_shape, init->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(condition_version); - request.add_embedded_computation_versions(body_version); - *request.mutable_request()->mutable_while_request() = while_request; - - VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << while_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* pred, - LookUpRequest(conditional_request.predicate())); - TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, - LookUpRequest(conditional_request.true_operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, - LookUpRequest(conditional_request.false_operand())); - - VersionedComputationHandle::Version true_computation_version = - true_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr true_computation_shape, - true_computation.ComputeProgramShape(true_computation_version)); - - VersionedComputationHandle::Version false_computation_version = - false_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr false_computation_shape, - false_computation.ComputeProgramShape(false_computation_version)); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferConditionalShape( - pred->output_shape(), true_operand->output_shape(), - false_operand->output_shape(), - *true_computation_shape, *false_computation_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(true_computation_version); - request.add_embedded_computation_versions(false_computation_version); - *request.mutable_request()->mutable_conditional_request() = - conditional_request; - - VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << conditional_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBroadcastInstruction( - const BroadcastRequest& broadcast_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(broadcast_request.operand())); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBroadcastShape( - operand->output_shape(), - AsInt64Slice(broadcast_request.broadcast_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_broadcast_request() = broadcast_request; - - VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << broadcast_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReshapeInstruction( - const ReshapeRequest& reshape_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reshape_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReshapeShape( - operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), - AsInt64Slice(reshape_request.new_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reshape_request() = reshape_request; - - VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reshape_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTransposeInstruction( - const TransposeRequest& transpose_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(transpose_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferTransposeShape( - operand->output_shape(), - AsInt64Slice(transpose_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_transpose_request() = transpose_request; - - VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << transpose_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSliceInstruction( - const SliceRequest& slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(slice_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferSliceShape( - operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_slice_request() = slice_request; - - VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookUpRequest(dynamic_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferDynamicSliceShape( - operand->output_shape(), start_indices->output_shape(), - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_slice_request() = - dynamic_slice_request; - - VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dynamic_slice_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_update_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookUpRequest(dynamic_update_slice_request.update())); - - TF_ASSIGN_OR_RETURN( - const OperationRequest* start_indices, - LookUpRequest(dynamic_update_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->output_shape(), update->output_shape(), - start_indices->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_update_slice_request() = - dynamic_update_slice_request; - - VLOG(1) << "AddDynamicUpdateSliceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << dynamic_update_slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferConcatOpShape( - operand_shapes, concatenate_request.dimension())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_concatenate_request() = - concatenate_request; - - VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << concatenate_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_convert_request() = convert_request; - - VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBitcastConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_bitcast_convert_request() = - convert_request; - - VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_precision_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferReducePrecisionShape( - operand->output_shape(), reduce_precision_request.exponent_bits(), - reduce_precision_request.mantissa_bits())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_reduce_precision_request() = - reduce_precision_request; - - VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_precision_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvolveInstruction( - const ConvolveRequest& convolve_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(convolve_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(convolve_request.rhs())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( - lhs->output_shape(), rhs->output_shape(), - convolve_request.window(), - convolve_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_convolve_request() = convolve_request; - - VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convolve_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddFftInstruction( - const FftRequest& fft_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(fft_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferFftShape( - operand->output_shape(), fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_fft_request() = fft_request; - - VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << fft_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(cross_replica_sum_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand->output_shape()})); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_cross_replica_sum_request() = - cross_replica_sum_request; - - VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << cross_replica_sum_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddInfeedInstruction( - const InfeedRequest& infeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = infeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Infeed must have a layout"); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_infeed_request() = infeed_request; - - VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << infeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddOutfeedInstruction( - const OutfeedRequest& outfeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = outfeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Outfeed must have a layout"); - } - - // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_outfeed_request() = outfeed_request; - - VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << outfeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_call_request() = call_request; - - VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCustomCallInstruction( - const CustomCallRequest& custom_call_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), - "$")) { - return InvalidArgument( - "Invalid custom_call_target \"%s\": Call targets that start with '$' " - "are reserved for internal use.", - custom_call_request.call_target_name().c_str()); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = custom_call_request.shape(); - *request.mutable_request()->mutable_custom_call_request() = - custom_call_request; - - VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << custom_call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddHostComputeInstruction( - const HostComputeRequest& host_compute_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : host_compute_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = host_compute_request.shape(); - *request.mutable_request()->mutable_host_compute_request() = - host_compute_request; - - VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << host_compute_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDotInstruction( - const DotRequest& dot_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(dot_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(dot_request.rhs())); - - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( - lhs->output_shape(), rhs->output_shape(), - dot_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_dot_request() = dot_request; - - VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dot_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddUnaryInstruction( - const UnaryOpRequest& unary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(unary_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), - operand->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_unary_op_request() = unary_request; - - VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << unary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBinaryInstruction( - const BinaryOpRequest& binary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(binary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(binary_request.rhs())); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferBinaryOpShape( - binary_request.binop(), lhs->output_shape(), rhs->output_shape(), - AsInt64Slice(binary_request.broadcast_dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_binary_op_request() = binary_request; - - VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << binary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTernaryInstruction( - const TernaryOpRequest& ternary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(ternary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(ternary_request.rhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookUpRequest(ternary_request.ehs())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferTernaryOpShape( - ternary_request.triop(), lhs->output_shape(), - rhs->output_shape(), ehs->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_ternary_op_request() = ternary_request; - - VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << ternary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddVariadicInstruction( - const VariadicOpRequest& variadic_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferVariadicOpShape( - variadic_request.varop(), operand_shapes)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_variadic_op_request() = variadic_request; - - VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << variadic_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - return operand->output_shape(); -} - -Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_metadata() = metadata; - return Status::OK(); -} - -Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpSharding (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_sharding() = sharding; - return Status::OK(); -} - -Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { - return InvalidArgument("Invalid handle in SetReturnValue"); - } - - handle_to_return_ = handle; - - VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " - << GetVersionedHandleInternal(); - - return Status::OK(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandle() const { - tensorflow::mutex_lock lock(mutex_); - return GetVersionedHandleInternal(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - - if (handle_to_return_.handle() > 0) { - // A specific handle has been requested for the result of the computation. - versioned_handle.version = handle_to_return_.handle(); - } else { - // A version value is simply the most recently assigned - // ComputationDataHandle value, ie the handle value of the root of the - // computation. - versioned_handle.version = next_handle_value_ - 1; - } - - return versioned_handle; -} - -VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const { - tensorflow::mutex_lock lock(mutex_); - - // The version at which an operation was added is simply the handle value of - // the ComputationDataHandle. - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - versioned_handle.version = operation.handle(); - return versioned_handle; -} - -VersionedComputationHandle::Version UserComputation::version() const { - return GetVersionedHandle().version; -} - -namespace { - -// Returns true if the operation type corresponding to the given opcase can be -// the root of the computation. -bool CanBeRoot(const OpRequest::OpCase& op_case) { - switch (op_case) { - case OpRequest::kTraceRequest: - case OpRequest::kSendRequest: - case OpRequest::kOutfeedRequest: - return false; - default: - return true; - } -} - -// Returns a pointer to the operation with the given data handle value in the -// given SessionComputation. -StatusOr LookUpRequest( - int64 handle_value, const SessionComputation& session_computation) { - if (session_computation.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation.requests().at(handle_value); -} - -// Returns the OperationRequest corresponding to the root (result) of the -// session computation. -StatusOr GetRoot( - VersionedComputationHandle::Version version, - const SessionComputation& session_computation) { - TF_RET_CHECK(version > 0); - // Not all instructions can be roots. Walk backwards from the operation - // indicated by this version until a valid root is found. - const OperationRequest* root_request = nullptr; - while (version > 0) { - TF_ASSIGN_OR_RETURN(root_request, - LookUpRequest(version, session_computation)); - if (CanBeRoot(root_request->request().op_case())) { - break; - } - version--; - } - if (version == 0) { - return InternalError("Computation contains no root operation"); - } - return root_request; -} - -} // namespace - -StatusOr> -UserComputation::ComputeProgramShape( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - if (program_shape_ == nullptr || program_shape_version_ != version) { - // ProgramShape has not been computed yet, or is for different - // version. Compute it now. - TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); - - auto program_shape = MakeUnique(); - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - int64 param_no = parameter_request.parameter(); - // Parameters may be out of order so expand ProgramShape parameters - // until it is at least large enough to hold the current parameter - // number. - while (program_shape->parameters_size() <= param_no) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); - } - *program_shape->mutable_parameters(param_no) = request.output_shape(); - *program_shape->mutable_parameter_names(param_no) = - parameter_request.name(); - } - } - - // The root determines the output shape. - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version, session_computation_)); - *program_shape->mutable_result() = root_request->output_shape(); - if (ShapeUtil::IsOpaque(program_shape->result())) { - return Unimplemented("Computation results cannot be opaque"); - } - - program_shape_ = std::move(program_shape); - program_shape_version_ = version; - } - - return program_shape_; -} - -namespace { - -// A visitor which checks whether an operation is pure functional meaning that -// it doesn't depend on any parameter with an index higher then num_parameters. -// The visitor walks the computation starting at a given operation and sets -// is_functional to false iff a parameter or RNG operation is encountered. -void PureFunctionalVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - int64 num_parameters, std::set* visited, - bool* is_functional) { - if (visited->count(handle.handle()) != 0 || !*is_functional) { - return; - } - - const OperationRequest& request = - session_computation.requests().at(handle.handle()); - switch (request.request().op_case()) { - case OpRequest::kRngRequest: - *is_functional = false; - break; - - case OpRequest::kConstantRequest: - break; - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - PureFunctionalVisitor(session_computation, - get_tuple_element_request.operand(), num_parameters, - visited, is_functional); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - PureFunctionalVisitor(session_computation, slice_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.update(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - PureFunctionalVisitor(session_computation, convolve_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, convolve_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - PureFunctionalVisitor(session_computation, fft_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_functional = false; - break; - } - - case OpRequest::kInfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kOutfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kHostComputeRequest: { - *is_functional = false; - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself, - // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_functional=false in other similar - // cases since we're already relying on IsConstant to return true. - *is_functional = false; - break; - } - - case OpRequest::kCustomCallRequest: { - *is_functional = false; - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - PureFunctionalVisitor(session_computation, dot_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, dot_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kSendRequest: { - *is_functional = false; - break; - } - - case OpRequest::kRecvRequest: { - *is_functional = false; - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - PureFunctionalVisitor(session_computation, reduce_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, reduce_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - PureFunctionalVisitor(session_computation, - reduce_window_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - reduce_window_request.init_value(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.source(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the select and scatter - // computations themselves. - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - PureFunctionalVisitor(session_computation, broadcast_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - PureFunctionalVisitor(session_computation, reshape_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - PureFunctionalVisitor(session_computation, reverse_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - PureFunctionalVisitor(session_computation, pad_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, pad_request.padding_value(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - if (parameter_request.parameter() >= num_parameters) { - *is_functional = false; - } - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - PureFunctionalVisitor(session_computation, while_request.init(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the condition and body - // computations themselves. - *is_functional = false; - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - PureFunctionalVisitor(session_computation, - conditional_request.predicate(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.true_operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.false_operand(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the true and false computations - // themselves. - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - PureFunctionalVisitor(session_computation, transpose_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - PureFunctionalVisitor(session_computation, unary_op_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.offset(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.scale(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.offset(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.mean(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.variance(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.variance(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.grad_output(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - PureFunctionalVisitor(session_computation, binary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, binary_op_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kGatherRequest: { - PureFunctionalVisitor(session_computation, - request.request().gather_request().input(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - request.request().gather_request().gather_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - if (!*is_functional) { - VLOG(1) << "Non-functional: " << request.request().DebugString(); - } - visited->insert(handle.handle()); -} - -} // namespace - -StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, - int64 num_parameters) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the handle is valid. - auto operation_status = LookUpRequest(handle); - if (!operation_status.ok()) { - return operation_status.status(); - } - - bool is_constant = true; - std::set visited; - PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, - &is_constant); - - return is_constant; -} - -std::vector -UserComputation::GetEmbeddedComputations( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(1) - << "GetEmbeddedComputations(" << name() << " " - << VersionedComputationHandle{session_computation_.computation_handle(), - version} - << ")"; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - std::vector computations; - std::vector sorted_handles; - for (const auto& handle_request : session_computation_.requests()) { - sorted_handles.push_back(handle_request.first); - } - std::sort(sorted_handles.begin(), sorted_handles.end()); - for (int64 handle : sorted_handles) { - const auto& handle_request = session_computation_.requests().find(handle); - CHECK(handle_request != session_computation_.requests().end()); - int64 handle_value = handle_request->first; - if (handle_value <= version) { - const OperationRequest& request = handle_request->second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const CallRequest& call_request = request.request().call_request(); - const VersionedComputationHandle versioned_handle = { - call_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kMapRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const MapRequest& map_request = request.request().map_request(); - const VersionedComputationHandle versioned_handle = { - map_request.to_apply(), request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceRequest& reduce_request = - request.request().reduce_request(); - const VersionedComputationHandle versioned_handle = { - reduce_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceWindowRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - const VersionedComputationHandle versioned_handle = { - reduce_window_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - const VersionedComputationHandle select_versioned_handle = { - select_and_scatter_request.select(), - request.embedded_computation_versions(0)}; - computations.push_back(select_versioned_handle); - const VersionedComputationHandle scatter_versioned_handle = { - select_and_scatter_request.scatter(), - request.embedded_computation_versions(1)}; - computations.push_back(scatter_versioned_handle); - break; - } - - case OpRequest::kWhileRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const WhileRequest& while_request = request.request().while_request(); - const VersionedComputationHandle condition_versioned_handle = { - while_request.condition(), - request.embedded_computation_versions(0)}; - computations.push_back(condition_versioned_handle); - const VersionedComputationHandle body_versioned_handle = { - while_request.body(), request.embedded_computation_versions(1)}; - computations.push_back(body_versioned_handle); - break; - } - - case OpRequest::kConditionalRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - const VersionedComputationHandle true_computation_versioned_handle = { - conditional_request.true_computation(), - request.embedded_computation_versions(0)}; - computations.push_back(true_computation_versioned_handle); - const VersionedComputationHandle false_computation_versioned_handle = - {conditional_request.false_computation(), - request.embedded_computation_versions(1)}; - computations.push_back(false_computation_versioned_handle); - break; - } - - default: - // No embedded computation. - break; - } - } - } - VLOG(2) << "Embedded computations: " - << tensorflow::str_util::Join( - computations, ", ", - [](string* out, const VersionedComputationHandle& h) { - out->append(h.ToString()); - }); - return computations; -} - -StatusOr -UserComputation::LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const { - tensorflow::mutex_lock lock(mutex_); - return LookUpRequest(handle); -} - -tensorflow::gtl::optional UserComputation::ParameterMetadata( - int parameter_number) const { - tensorflow::mutex_lock lock(mutex_); - auto it = parameters_.find(parameter_number); - if (it == parameters_.end()) { - return tensorflow::gtl::nullopt; - } - OperationRequest* op = it->second; - return &op->request().metadata(); -} - -Status UserComputation::RemapEmbeddedComputations( - const std::map& old_to_new) { - auto update = [&old_to_new](ComputationHandle* to_update) -> Status { - int64 old = to_update->handle(); - auto it = old_to_new.find(old); - if (it == old_to_new.end()) { - string mapping = tensorflow::str_util::Join( - old_to_new, ", ", - [](string* out, std::pair element) { - tensorflow::strings::Appendf(out, "%lld:%lld", element.first, - element.second.handle()); - }); - return NotFound( - "could not find referenced (old) computation handle in mapping: " - "%lld; mapping: {%s}", - old, mapping.c_str()); - } - VLOG(2) << "remapping " << old << " to " << it->second.handle(); - *to_update = it->second; - return Status::OK(); - }; - TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); - for (auto& handle_request : *session_computation_.mutable_requests()) { - OperationRequest& request = handle_request.second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - CallRequest* call_request = - request.mutable_request()->mutable_call_request(); - TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); - break; - } - case OpRequest::kMapRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - MapRequest* map_request = - request.mutable_request()->mutable_map_request(); - TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceRequest* reduce_request = - request.mutable_request()->mutable_reduce_request(); - TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceWindowRequest* reduce_window_request = - request.mutable_request()->mutable_reduce_window_request(); - TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); - break; - } - case OpRequest::kSelectAndScatterRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - SelectAndScatterRequest* select_and_scatter_request = - request.mutable_request()->mutable_select_and_scatter_request(); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_select())); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_scatter())); - break; - } - case OpRequest::kWhileRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - WhileRequest* while_request = - request.mutable_request()->mutable_while_request(); - TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); - TF_RETURN_IF_ERROR(update(while_request->mutable_body())); - break; - } - case OpRequest::kConditionalRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - ConditionalRequest* conditional_request = - request.mutable_request()->mutable_conditional_request(); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_true_computation())); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_false_computation())); - break; - } - default: - // No embedded computation. - TF_RET_CHECK(0 == request.embedded_computation_versions_size()); - break; - } - } - return Status::OK(); -} - -SessionComputation UserComputation::CloneSessionComputation( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - SessionComputation result = session_computation_; - // Erase all the requests that exceed the version specified. - // There's no lower_bound method on tensorflow::protobuf::Map so we iterate - // all the elements. - auto it = result.mutable_requests()->begin(); - while (it != result.mutable_requests()->end()) { - if (it->first > version) { - it = result.mutable_requests()->erase(it); - } else { - ++it; - } - } - return result; -} - -StatusOr UserComputation::LookUpRequest( - const ComputationDataHandle& handle) const { - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation_.requests().at(handle_value); -} - -Status UserComputation::CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const { - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - // Determine number of parameter inputs at the given version. - std::map parameter_requests; - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - // Duplicate parameters should be checked when parameter requests are - // added. - TF_RET_CHECK(0 == - parameter_requests.count(parameter_request.parameter())); - parameter_requests[parameter_request.parameter()] = ¶meter_request; - } - } - - for (int64 i = 0; i < parameter_requests.size(); ++i) { - auto it = parameter_requests.find(i); - if (it == parameter_requests.end()) { - return FailedPrecondition( - "computation %s does not have all its parameters populated " - "sequentially, missing parameter %lld", - name_.c_str(), i); - } - } - - return Status::OK(); -} - -namespace { - -// Helper class which builds an HLO computation from a SessionComputation. To -// construct the HLO computation, the SessionComputation graph is walked in -// DFS order lowering each OperationRequest to an HLO instruction. -class ComputationLowerer { - public: - static StatusOr> Lower( - const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) { - ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver), debug_options, - include_unreachable_instructions); - return lowerer.Lower(); - } - - private: - ComputationLowerer(const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) - : hlo_builder_(computation_name), - session_computation_(session_computation), - version_(version), - hlo_resolver_(std::move(hlo_resolver)), - debug_options_(debug_options), - include_unreachable_instructions_(include_unreachable_instructions) {} - - // Build an HLO computation from the SessionComputation at the given - // version. - StatusOr> Lower(); - - private: - // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. - void TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit); - - // DFS visitor of the UserComputation operations which lowers the operations - // to HLO instructions. - void Visit(const ComputationDataHandle& handle, - std::unordered_map* instructions); - - // Resolves a ComputationHandle and Version to a previously lowered - // HloComputation using the hlo_resolver_ function. - HloComputation* ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version); - - // This function takes an input value which is being implicitly broadcast into - // an output shape and figures out the right kBroadcast instruction(s) - // necessary to replicate the implicit broadcast semantics explicitly. - HloInstruction* ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape); - - HloComputation::Builder hlo_builder_; - const SessionComputation& session_computation_; - const VersionedComputationHandle::Version version_; - const UserComputation::HloComputationResolver hlo_resolver_; - const DebugOptions& debug_options_; - const bool include_unreachable_instructions_; -}; - -// Calls 'apply' on each operand of 'request'. -static void ForEachOperand( - const OperationRequest& request, - const std::function& apply) { - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - for (const ComputationDataHandle& param : rng_request.parameter()) { - apply(param); - } - break; - } - - case OpRequest::kConstantRequest: - break; - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - apply(get_tuple_element_request.operand()); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - apply(slice_request.operand()); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - apply(dynamic_slice_request.operand()); - apply(dynamic_slice_request.start_indices()); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - apply(dynamic_update_slice_request.operand()); - apply(dynamic_update_slice_request.update()); - apply(dynamic_update_slice_request.start_indices()); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - apply(convolve_request.lhs()); - apply(convolve_request.rhs()); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - apply(fft_request.operand()); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - - apply(batch_norm_training_request.operand()); - apply(batch_norm_training_request.scale()); - apply(batch_norm_training_request.offset()); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - - apply(batch_norm_inference_request.operand()); - apply(batch_norm_inference_request.scale()); - apply(batch_norm_inference_request.offset()); - apply(batch_norm_inference_request.mean()); - apply(batch_norm_inference_request.variance()); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - apply(batch_norm_grad_request.operand()); - apply(batch_norm_grad_request.scale()); - apply(batch_norm_grad_request.mean()); - apply(batch_norm_grad_request.variance()); - apply(batch_norm_grad_request.grad_output()); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - apply(cross_replica_sum_request.operand()); - break; - } - - case OpRequest::kInfeedRequest: - break; - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - apply(outfeed_request.operand()); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - apply(reduce_request.operand()); - apply(reduce_request.init_value()); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - apply(reduce_window_request.operand()); - apply(reduce_window_request.init_value()); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - apply(select_and_scatter_request.operand()); - apply(select_and_scatter_request.source()); - apply(select_and_scatter_request.init_value()); - - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - apply(broadcast_request.operand()); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - apply(reshape_request.operand()); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - apply(transpose_request.operand()); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - apply(reverse_request.operand()); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - apply(pad_request.operand()); - apply(pad_request.padding_value()); - break; - } - - case OpRequest::kRecvRequest: - case OpRequest::kParameterRequest: - break; - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - apply(while_request.init()); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - apply(conditional_request.predicate()); - apply(conditional_request.true_operand()); - apply(conditional_request.false_operand()); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - apply(ternary_op_request.lhs()); - apply(ternary_op_request.rhs()); - apply(ternary_op_request.ehs()); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - for (const ComputationDataHandle& operand : cc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& hc_request = - request.request().host_compute_request(); - for (const ComputationDataHandle& operand : hc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - apply(dot_request.rhs()); - apply(dot_request.lhs()); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - apply(unary_op_request.operand()); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - apply(binary_op_request.rhs()); - apply(binary_op_request.lhs()); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - apply(reduce_precision_request.operand()); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - apply(trace_request.operand()); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - apply(send_request.operand()); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - apply(gather_request.input()); - apply(gather_request.gather_indices()); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } -} - -void ComputationLowerer::TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit) { - // Stack containing {handle, enter} pairs. The 'enter' value describes whether - // we are entering or leaving 'handle'. - std::stack> work; - work.push({root, true}); - while (!work.empty()) { - ComputationDataHandle handle; - bool enter; - std::tie(handle, enter) = work.top(); - work.pop(); - - if (enter) { - // We are entering 'handle'. The first time we enter 'handle', we add it - // to 'visited' with a nullptr value. If 'handle' is already in 'visited', - // we do not visit it again. This algorithm only uses the presence of - // a handle in 'visited', but we use a map so we can use the same data - // structure to store the HloInstruction outputs. - if (visited->emplace(handle.handle(), nullptr).second) { - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - // Push the corresponding 'leave' action onto the stack, followed by - // the operands. - work.push({handle, false}); - ForEachOperand(request, [&work](const ComputationDataHandle& child) { - work.push({child, true}); - }); - } - } else { - // We are leaving 'handle'. We have visited the operands of 'handle', and - // now can visit the 'handle' itself. - visit(handle); - } - } -} - -StatusOr> ComputationLowerer::Lower() { - // Map from ComputationDataHandle to HLO instruction. Serves as a record of - // which operations have been visited as well as a cache for looking up - // ComputationDataHandles as HloInstructions. - std::unordered_map instructions; - - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version_, session_computation_)); - - auto visit = [&](const ComputationDataHandle& handle) { - Visit(handle, &instructions); - }; - TraversePostorder(root_request->output_handle(), &instructions, visit); - HloInstruction* hlo_root = - instructions.at(root_request->output_handle().handle()); - - if (include_unreachable_instructions_) { - // Iterate through all computation data handles, and visit any unvisited - // operations. - for (int64 request_num = 1; request_num <= version_; ++request_num) { - TF_ASSIGN_OR_RETURN(const OperationRequest* request, - LookUpRequest(request_num, session_computation_)); - TraversePostorder(request->output_handle(), &instructions, visit); - } - } - - return hlo_builder_.Build(hlo_root); -} - -HloComputation* ComputationLowerer::ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version) { - const VersionedComputationHandle checked_handle = {handle, version}; - return hlo_resolver_(checked_handle); -} - -HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape) { - auto fadd = [this](std::unique_ptr x) { - return hlo_builder_.AddInstruction(std::move(x)); - }; - return fadd( - HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); -} - -void ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::unordered_map* instructions) { - CHECK_LE(handle.handle(), version_); - CHECK(instructions->at(handle.handle()) == nullptr); - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - auto add_instruction = [&](std::unique_ptr instruction) { - HloInstruction* hlo_instruction = - hlo_builder_.AddInstruction(std::move(instruction)); - hlo_instruction->set_metadata(request.request().metadata()); - if (request.request().has_sharding()) { - OpSharding op_sharding = request.request().sharding(); - hlo_instruction->set_sharding( - HloSharding::FromProto(op_sharding).ValueOrDie()); - } - return hlo_instruction; - }; - auto lookup_instruction = [&](const ComputationDataHandle& handle) { - return instructions->at(handle.handle()); - }; - HloInstruction* hlo_instruction; - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - std::vector parameters; - for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(lookup_instruction(param)); - } - hlo_instruction = add_instruction(HloInstruction::CreateRng( - request.output_shape(), rng_request.distribution(), parameters)); - break; - } - - case OpRequest::kConstantRequest: { - const ConstantRequest& constant_request = - request.request().constant_request(); - hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal::CreateFromProto(constant_request.literal()) - .ConsumeValueOrDie())); - break; - } - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - HloInstruction* operand = - lookup_instruction(get_tuple_element_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, get_tuple_element_request.index())); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = lookup_instruction(slice_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSlice( - request.output_shape(), operand, - AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_slice_request.operand()); - HloInstruction* start_indices = - lookup_instruction(dynamic_slice_request.start_indices()); - - hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_update_slice_request.operand()); - HloInstruction* update = - lookup_instruction(dynamic_update_slice_request.update()); - HloInstruction* start_indices = - lookup_instruction(dynamic_update_slice_request.start_indices()); - hlo_instruction = - add_instruction(HloInstruction::CreateDynamicUpdateSlice( - request.output_shape(), operand, update, start_indices)); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( - request.output_shape(), operands, concatenate_request.dimension())); - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); - HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - HloInstruction* operand = lookup_instruction(fft_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateFft( - request.output_shape(), operand, fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - HloInstruction* lhs = lookup_instruction(dot_request.lhs()); - HloInstruction* rhs = lookup_instruction(dot_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateDot( - request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - HloInstruction* operand = - lookup_instruction(cross_replica_sum_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), {operand})); - break; - } - - case OpRequest::kInfeedRequest: { - const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = add_instruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); - break; - } - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - HloInstruction* operand = lookup_instruction(outfeed_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( - outfeed_request.shape(), operand, outfeed_request.outfeed_config())); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - std::vector operands; - for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version map_version = - request.embedded_computation_versions(0); - HloComputation* map_computation = - ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = add_instruction(HloInstruction::CreateMap( - request.output_shape(), operands, map_computation)); - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = lookup_instruction(reduce_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_version = - request.embedded_computation_versions(0); - HloComputation* reduce_computation = - ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - HloInstruction* operand = - lookup_instruction(reduce_window_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_window_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_window_version = - request.embedded_computation_versions(0); - HloComputation* reduce_window_computation = ResolveComputation( - reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - HloInstruction* operand = - lookup_instruction(select_and_scatter_request.operand()); - HloInstruction* source = - lookup_instruction(select_and_scatter_request.source()); - HloInstruction* init_value = - lookup_instruction(select_and_scatter_request.init_value()); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version select_version = - request.embedded_computation_versions(0); - VersionedComputationHandle::Version scatter_version = - request.embedded_computation_versions(1); - HloComputation* select_computation = ResolveComputation( - select_and_scatter_request.select(), select_version); - HloComputation* scatter_computation = ResolveComputation( - select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_training_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_training_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_training_request.offset()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( - request.output_shape(), operand, scale, offset, - batch_norm_training_request.epsilon(), - batch_norm_training_request.feature_index())); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_inference_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_inference_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_inference_request.offset()); - HloInstruction* mean = - lookup_instruction(batch_norm_inference_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_inference_request.variance()); - - hlo_instruction = - add_instruction(HloInstruction::CreateBatchNormInference( - request.output_shape(), operand, scale, offset, mean, variance, - batch_norm_inference_request.epsilon(), - batch_norm_inference_request.feature_index())); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - HloInstruction* operand = - lookup_instruction(batch_norm_grad_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_grad_request.scale()); - HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_grad_request.variance()); - HloInstruction* grad_output = - lookup_instruction(batch_norm_grad_request.grad_output()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( - request.output_shape(), operand, scale, mean, variance, grad_output, - batch_norm_grad_request.epsilon(), - batch_norm_grad_request.feature_index())); - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - HloInstruction* operand = lookup_instruction(broadcast_request.operand()); - std::vector broadcast_dimensions; - // The client-level broadcast instruction just appends dimensions on the - // left (adds lowest numbered dimensions). The HLO broadcast op is more - // flexible and can add new dimensions anywhere. The broadcast_dimensions - // maps operand dimensions to dimensions in the broadcast output, so - // to append dimensions on the left the broadcast_dimensions should just - // be the n highest dimension numbers of the output shape where n is - // the number of input dimensions. - broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { - broadcast_dimensions.push_back(i + - ShapeUtil::Rank(request.output_shape()) - - ShapeUtil::Rank(operand->shape())); - } - hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - HloInstruction* operand = lookup_instruction(reshape_request.operand()); - HloInstruction* transposed; - if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { - transposed = operand; - } else { - transposed = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); - } - hlo_instruction = add_instruction( - HloInstruction::CreateReshape(request.output_shape(), transposed)); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - HloInstruction* operand = lookup_instruction(transpose_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - HloInstruction* operand = lookup_instruction(reverse_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = lookup_instruction(pad_request.operand()); - HloInstruction* padding_value = - lookup_instruction(pad_request.padding_value()); - hlo_instruction = add_instruction(HloInstruction::CreatePad( - request.output_shape(), operand, padding_value, - pad_request.padding_config())); - break; - } - - case OpRequest::kRecvRequest: { - const RecvRequest& recv_request = request.request().recv_request(); - HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( - request.output_shape(), recv_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - hlo_instruction = add_instruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateConvert(request.output_shape(), operand)); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( - request.output_shape(), operand)); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version condition_version = - request.embedded_computation_versions(0); - HloComputation* condition = - ResolveComputation(while_request.condition(), condition_version); - VersionedComputationHandle::Version body_version = - request.embedded_computation_versions(1); - HloComputation* body = - ResolveComputation(while_request.body(), body_version); - HloInstruction* init = lookup_instruction(while_request.init()); - hlo_instruction = add_instruction(HloInstruction::CreateWhile( - request.output_shape(), condition, body, init)); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version true_computation_version = - request.embedded_computation_versions(0); - HloComputation* true_computation = ResolveComputation( - conditional_request.true_computation(), true_computation_version); - VersionedComputationHandle::Version false_computation_version = - request.embedded_computation_versions(1); - HloComputation* false_computation = ResolveComputation( - conditional_request.false_computation(), false_computation_version); - HloInstruction* predicate = - lookup_instruction(conditional_request.predicate()); - HloInstruction* true_operand = - lookup_instruction(conditional_request.true_operand()); - HloInstruction* false_operand = - lookup_instruction(conditional_request.false_operand()); - hlo_instruction = add_instruction(HloInstruction::CreateConditional( - request.output_shape(), predicate, true_operand, true_computation, - false_operand, false_computation)); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); - HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); - auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - !ShapeUtil::IsTuple(request.output_shape())) { - if (!ShapeUtil::IsTuple(lhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(rhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(ehs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { - ehs = - ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); - } - } - - hlo_instruction = add_instruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - auto hlo_opcode = - VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = add_instruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - std::vector operands; - for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(lookup_instruction(handle)); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version call_version = - request.embedded_computation_versions(0); - HloComputation* call_computation = - ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = add_instruction(HloInstruction::CreateCall( - request.output_shape(), operands, call_computation)); - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - std::vector operands; - for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& host_compute_request = - request.request().host_compute_request(); - std::vector operands; - for (const ComputationDataHandle& operand : - host_compute_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - auto output_shape = host_compute_request.shape(); - auto channel_name = host_compute_request.channel_name(); - auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); - hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( - output_shape, operands, channel_name, cost_estimate_ns)); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - HloInstruction* operand = lookup_instruction(unary_op_request.operand()); - auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = add_instruction(HloInstruction::CreateUnary( - request.output_shape(), hlo_opcode, operand)); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); - auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0 && - ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { - // Emit a broadcast instruction to perform the "broadcast in dimension" - // operation. - HloInstruction* operand_to_broadcast = - ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs - : rhs; - CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), - binary_op_request.broadcast_dimensions().size()); - - // Construct the bounds of the shape of the kBroadcast instruction - // responsible for the in-dimension broadcast. - std::vector output_dimensions; - for (int64 size : request.output_shape().dimensions()) { - output_dimensions.push_back(size); - } - for (int64 operand_dim = 0; - operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); - ++operand_dim) { - int64 output_dim = - binary_op_request.broadcast_dimensions()[operand_dim]; - output_dimensions[output_dim] = - operand_to_broadcast->shape().dimensions(operand_dim); - } - - Shape broadcast_shape = ShapeUtil::MakeShape( - operand_to_broadcast->shape().element_type(), output_dimensions); - - // The broadcast semantics of a client-level binary op broadcast is - // identical to the HLO broadcast semantics so the broadcast_dimensions - // field can just be passed to the instruction builder. - HloInstruction* broadcasted_operand = - add_instruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand_to_broadcast, - AsInt64Slice(binary_op_request.broadcast_dimensions()))); - - lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; - rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; - } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - } - hlo_instruction = add_instruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - HloInstruction* operand = - lookup_instruction(reduce_precision_request.operand()); - auto exponent_bits = reduce_precision_request.exponent_bits(); - auto mantissa_bits = reduce_precision_request.mantissa_bits(); - hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( - request.output_shape(), operand, exponent_bits, mantissa_bits)); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - HloInstruction* operand = lookup_instruction(trace_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - HloInstruction* operand = lookup_instruction(send_request.operand()); - HloInstruction* send = add_instruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - HloInstruction* input_operand = - lookup_instruction(gather_request.input()); - HloInstruction* gather_indices_operand = - lookup_instruction(gather_request.gather_indices()); - std::vector window_bounds; - c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); - hlo_instruction = add_instruction(HloInstruction::CreateGather( - request.output_shape(), input_operand, gather_indices_operand, - gather_request.dimension_numbers(), window_bounds)); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - (*instructions)[handle.handle()] = hlo_instruction; -} // NOLINT(readability/fn_size) - -} // namespace - -StatusOr> UserComputation::BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), debug_options, - include_unreachable_instructions)); - - return std::move(hlo_computation); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h deleted file mode 100644 index 5544c868fe905c1ca7e6cab32738440add2e3b4f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.h +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// A UserComputation is the built-up computation that users create via the -// XLA Service interface. -// -// The XLA service adds instructions to a user computation via this -// interface. The state of the computation is stored as a SessionComputation -// proto which holds a record of all operation-building requests received by the -// XLA service. -// -// UserComputations are lowered to HloComputations which are passed to the high -// level compiler interface. -class UserComputation { - public: - // Factory used when restoring a computation from serialized session - // computation (computation snapshot) data. Remaps any references to - // computation handle via the old_to_new mapping. - // - // An error will occur if the old_to_new mapping cannot resolve a reference to - // a computation that is present in session_computation. - static StatusOr> MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new); - - // Creates an empty computation with the given name and computation handle. - explicit UserComputation(const string& name, const ComputationHandle& handle); - - // Enqueues a parameter-retrieving instruction onto this user computation. - // Returns an error status if the parameter number is already registered with - // different values. - StatusOr AddParameterInstruction( - const ParameterRequest& parameter_request); - - // Enqueues a pad instruction onto this user computation. - StatusOr AddPadInstruction( - const PadRequest& pad_request); - - // Enqueues a tracing instruction onto this user computation. - // Returns an error status if the operand cannot be resolved. - Status AddTraceInstruction(const TraceRequest& trace_request); - - // Enqueues a random number generation instruction onto this user computation. - StatusOr AddRngInstruction( - const RngRequest& rng_request); - - // Enqueues a unary instruction onto this user computation. - // Returns an error status if the operand index is out of bounds. - StatusOr AddUnaryInstruction( - const UnaryOpRequest& unary_request); - - // Enqueues a batch norm training instruction onto this user computation. - StatusOr AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request); - - // Enqueues a batch norm inference instruction onto this user computation. - StatusOr AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request); - - // Enqueues a batch norm grad instruction onto this user computation. - StatusOr AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request); - - // Enqueues a binary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddBinaryInstruction( - const BinaryOpRequest& binary_request); - - // Enqueues a ternary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddTernaryInstruction( - const TernaryOpRequest& ternary_request); - - // Enqueues a variadic instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddVariadicInstruction( - const VariadicOpRequest& variadic_request); - - // Enqueues a constant instruction onto this user computation. - StatusOr AddConstantInstruction( - const ConstantRequest& constant_request); - - // Enqueues a get tuple element instruction onto this user computation. - StatusOr AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request); - - // Enqueues a map instruction onto this user computation. - StatusOr AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation); - - // Enqueues a reduce-precision instruction onto this user computation. - StatusOr AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request); - - // Enqueues a convolution instruction onto this user computation. - StatusOr AddConvolveInstruction( - const ConvolveRequest& convolve_request); - - // Enqueues an FFT instruction onto this user computation. - StatusOr AddFftInstruction( - const FftRequest& fft_request); - - // Enqueues a cross replica sum instruction onto this user computation. - StatusOr AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request); - - // Enqueues an infeed instruction onto this user computation. - StatusOr AddInfeedInstruction( - const InfeedRequest& infeed_request); - - // Enqueues an outfeed instruction onto this user computation. - StatusOr AddOutfeedInstruction( - const OutfeedRequest& outfeed_request); - - // Enqueues a host compute instruction onto this user computation. - StatusOr AddHostComputeInstruction( - const HostComputeRequest& host_compute_request); - - // Enqueues a call instruction onto this user computation. - StatusOr AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation); - - // Enqueues a custom call instruction onto this user computation. - StatusOr AddCustomCallInstruction( - const CustomCallRequest& custom_call_request); - - // Enqueues a dot instruction onto this user computation. - StatusOr AddDotInstruction( - const DotRequest& dot_request); - - // Enqueues a broadcast instruction onto this user computation. - StatusOr AddBroadcastInstruction( - const BroadcastRequest& broadcast_request); - - // Enqueues a reshape instruction onto this user computation. - StatusOr AddReshapeInstruction( - const ReshapeRequest& reshape_request); - - // Enqueues a transpose instruction onto this user computation. - StatusOr AddTransposeInstruction( - const TransposeRequest& transpose_request); - - // Enqueues a slice instruction onto this user computation. - StatusOr AddSliceInstruction( - const SliceRequest& slice_request); - - // Enqueues a dynamic slice instruction onto this user computation. - StatusOr AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request); - - // Enqueues a dynamic update slice instruction onto this user computation. - StatusOr AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request); - - // Enqueues a concatenate instruction onto this user computation. - StatusOr AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request); - - // Enqueues a convert instruction onto this user computation. - StatusOr AddConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a bitcast element instruction onto this user computation. - StatusOr AddBitcastConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a reduce instruction onto this user computation. - StatusOr AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation); - - // Enqueues a windowed reduce instruction onto this user computation. - StatusOr AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation); - - // Enqueues a select-and-scatter instruction onto this user - // computation. - StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation); - - // Enqueues a reverse instruction onto this user computation. - StatusOr AddReverseInstruction( - const ReverseRequest& reverse_request); - - // Enqueues a while instruction onto this user computation. - StatusOr AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation); - - // Enqueues a conditional instruction on this user computation. - StatusOr AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation); - - // Enqueues a Send instruction onto this user computation. - StatusOr AddSendInstruction( - const SendRequest& send_request); - - // Enqueues a Recv instruction onto this user computation. - StatusOr AddRecvInstruction( - const RecvRequest& recv_request); - - // Enqueues a Gather instruction onto this user computation. - StatusOr AddGatherInstruction( - const GatherRequest& gather_request); - - // Returns the user-provided name of this user computation, which is provided - // via the XLA computation-building API. - const string& name() const { return name_; } - - // Subsequent executions of this computation will compute the value - // represented by handle, rather than the last expression enqueued - // on the computation. - Status SetReturnValue(const ComputationDataHandle& handle); - - // Return a versioned handle for this computation. - VersionedComputationHandle GetVersionedHandle() const; - - // Return a versioned handle for this computation with a version equal to the - // point at which given operation was added to the computation. - VersionedComputationHandle GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const; - - // Return a version value representing the current state of the - // computation. - VersionedComputationHandle::Version version() const; - - // Computes and returns the program shape for the user computation -- gathers - // parameters and result type into a single proto. A shared_ptr is used - // because the returned pointer refers to an internally cached value which may - // be discarded by the UserComputation object. This avoid unnecessary copies. - // - // If the parameter space is not dense (i.e. there are holes in the parameter - // numbers provided) then an error status is returned. - StatusOr> ComputeProgramShape( - VersionedComputationHandle::Version version) const; - - // Returns true if the given data handle does not depend on any parameter with - // index higher then num_parameters. That is, the value can be computed at - // compile time if we know the first num_parameters arguments. - StatusOr IsConstant(const ComputationDataHandle& handle, - int64 num_parameters); - - // Returns the output shape of the operation indicated by the given handle. - StatusOr GetShape(const ComputationDataHandle& handle); - - // Sets metadata on the Hlo instruction referenced by the given handle. - Status SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata); - - // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding); - - // Builds a HLO computation from the UserComputation. The parameter "resolver" - // is a function which returns a pointer to the HloComputation corresponding - // to the given ComputationHandle at the given version. The resolver is used - // for operations, such as map, which call other computations and need a - // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unreachable_instructions is true, then - // instructions which are not reachable from the root are lowered into - // HloInstructions. - using HloComputationResolver = - std::function; - StatusOr> BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions = true) const; - - // Return a vector containing the embedded computations used by this - // UserComputation. Only embedded computations which are called directly by - // this UserComputation are included. That is, the transitive closure of - // embedded computations is not included. - std::vector GetEmbeddedComputations( - VersionedComputationHandle::Version version) const; - - // Returns the number of OperationRequest objects in this UserComputation. - // The 'version' of a computation is identical to the number of - // OperationRequests in the UserComputation. - int64 request_count(VersionedComputationHandle::Version version) const { - return version; - } - - // Returns a copy of the internal session state for this computation -- this - // is useful for serializing the guts of a user computation, though references - // to other handles (e.g. referred-to computations) must be handled with care - // in the serialization / de-serialization process. - SessionComputation CloneSessionComputation( - VersionedComputationHandle::Version version) const; - - // Warning: typically we don't want to look up computation data handles until - // the computation is finished being built, for consistency purposes. We - // expose this routine for error reporting purposes so that we can provide - // more meaningful error messages from the XLA service layer. - // - // Returns the operation request that the handle comes from. - StatusOr LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const; - - // Retrieves the parameter metadata for the given parameter number. - // - // If the parameter number is invalid for this computation, nullopt is - // returned. When the return value has_value(), nullptr will never be - // the held value. - tensorflow::gtl::optional ParameterMetadata( - int parameter_number) const; - - private: - // Warning: dangerous mutating operation that doesn't respect versioning. - // This is only used at initialization time when constructing from a - // SessionComputation a la MakeWithRemapping. - // - // Remaps references to old computations (with handle values in the keys of - // old_to_new) to the computation handle given in the values. This is useful - // when loading computations from snapshots, to finish initialization, before - // the user computation is released into the wild. - Status RemapEmbeddedComputations( - const std::map& old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle. - StatusOr LookUpRequest( - const ComputationDataHandle& handle) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Creates a new ComputationDataHandle with the next available handle value. - ComputationDataHandle CreateComputationDataHandle() - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Checks whether the parameter numbers of the parameter operations are - // contiguous starting from zero. Returns appropriate error status if not. - Status CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - VersionedComputationHandle GetVersionedHandleInternal() const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Name of the computation. - string name_; - - mutable tensorflow::mutex mutex_; - - // State of the computation as a record of all operation-building requests. - SessionComputation session_computation_ GUARDED_BY(mutex_); - - // Mapping from parameter number to operation request containing the - // respective ParameterRequest. - std::map parameters_ GUARDED_BY(mutex_); - - // The next ComputationDataHandle value to assign. Handle values are assigned - // sequentially. - int64 next_handle_value_ GUARDED_BY(mutex_); - - // If handle_to_return_.has_handle() then an Execution of this Computation - // will compute the value represented by handle_to_return_, otherwise it will - // compute the value of (next_handle_value_ - 1). - ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); - - // Memoized ProgramShape and its version. A shared_ptr is used because - // references to this object are returned by ComputeProgramShape. - mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; - mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc deleted file mode 100644 index 2fa163953f638c0038e9f6bb11ce2a3742e0558c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using UserComputationTest = ::testing::Test; - -TEST_F(UserComputationTest, SimpleComputation) { - const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); - const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); - - // Build a simple three operation computatation: - // - // %constant = Constant({123, 42}) - // %param = Param(0) - // %outfeed = Outfeed(%constant) - // - // Build the computation at two different versions and check invariants. - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest constant_request; - *constant_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); - - ParameterRequest param_request; - *param_request.mutable_shape() = kScalarShape; - param_request.set_parameter(0); - param_request.set_name("param0"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); - OpMetadata metadata; - metadata.set_op_name("meta"); - TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); - - OutfeedRequest outfeed_request; - *outfeed_request.mutable_operand() = constant_handle; - *outfeed_request.mutable_shape() = kVectorShape; - outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, - computation.AddOutfeedInstruction(outfeed_request)); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - { - // Test the computation at the latest version. In this case, the most - // recently added operation is an outfeed. However, the outfeed is not the - // root because outfeeds cannot be the root of a computation. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Program shape should have a single scalar parameter and scalar - // result. The outfeed instruction should not affect the program shape. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(latest_version.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - DebugOptions())); - // There should be one HloInstruction per UserComputation operation. - EXPECT_EQ(3, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - - { - // Test the computation at the version right after the parameter instruction - // is added. - VersionedComputationHandle version_at_param = - computation.GetVersionedHandleAtOperation(param_handle); - - // Program shape should have a single scalar parameter, and scalar result. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(version_at_param.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // There should be two instructions, one for the constant and one for the - // parameter. The outfeed instruction should not be included. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(version_at_param.version, hlo_resolver, - DebugOptions())); - EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - { - // Test the computation at the latest version, but lowered with - // include_unreachable_instructions set to false. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, DebugOptions(), - /*include_unreachable_instructions=*/false)); - // There is only one reachable instruction, the parameter. - EXPECT_EQ(1, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), - "meta"); - } -} - -TEST_F(UserComputationTest, EliminateScalarBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with scalar broadcast. - // - // %a = Constant({123, 42}) - // %b = Constant(1) - // %add = Add(%a, %b) - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest a_request; - *a_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); - - ConstantRequest b_request; - *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - // The binary operation has implicit scalar broadcast, should be converted - // to an explicit broadcast intruction and a binary instruction. - EXPECT_EQ(4, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - LOG(INFO) << hlo_computation->root_instruction()->ToString(); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with degenerate broadcast. - // - // %a = Param({1, 2, 3}); - // %b = Param({1, 2, 1}); - // %add = Add(%a, %b, {}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - const int64 kDevice = 7; - OpSharding sharding; - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(kDevice); - - TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // b a - // | | - // reshape | - // | | - // broadcast | - // \ / - // add - EXPECT_EQ(5, hlo_computation->instruction_count()); - ASSERT_THAT( - hlo_computation->root_instruction(), - op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); - - const HloInstruction* broadcast = - hlo_computation->root_instruction()->operand(1); - EXPECT_TRUE(broadcast->has_sharding()); - - const HloInstruction* reshape = broadcast->operand(0); - EXPECT_TRUE(reshape->has_sharding()); -} - -TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with in-dim broadcast and degenerate broadcast. - // - // %a = Param({2, 3}); - // %b = Param({2, 1, 4}); - // %add = Add(%a, %b, {0, 1}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - add.add_broadcast_dimensions(0); - add.add_broadcast_dimensions(1); - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // b a - // | | - // broadcast reshape - // | | - // | broadcast - // \ / - // add - EXPECT_EQ(6, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h deleted file mode 100644 index 5732a56caffa31dde52dff5c2775f9fde0cacfbd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// A data structure encapsulating a ComputationHandle and version value of that -// computation. This object is used to unambiguously refer to a particular -// computation in the service. -struct VersionedComputationHandle { - // A version value unambiguously specifying the state of the computation at a - // particular point in time as it is being built. This value is the - // ComputationDataHandle of the current root instruction. - using Version = int64; - - ComputationHandle handle; - Version version; - - string ToString() const; - bool operator==(const VersionedComputationHandle& other) const { - return (handle.handle() == other.handle.handle()) && - (version == other.version); - } - bool operator<(const VersionedComputationHandle& other) const { - return ((handle.handle() < other.handle.handle()) || - ((handle.handle() == other.handle.handle()) && - (version < other.version))); - } -}; - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 0d2288d8ea6ebb0ac4ac9468a211b161438fc5f1..393e75803888d8a642881c4d525b170d1e1180ba 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -55,7 +55,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -95,7 +95,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -136,7 +136,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -184,7 +184,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 321fdeb1ea313d2bc00b0210b422f36915f41453..09ddcffb22c2184262adf87d570870ec000c0e6f 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -static StatusOr TryHoistingInvariantInstructionsFromWhileBody( +StatusOr +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -161,12 +165,16 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -243,6 +251,9 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { @@ -270,6 +281,14 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b0003c48cfacb9d28e7c8259ac0927d66..8e6cc8787576e4f041229da5cf8dd2b09194eb2a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 799340fda905fb7d40b19b4cb79bb0fcb5629fd3..23519e445ea8a5f578a54708f38059feef3280c0 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -247,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -257,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); - + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, WhileLoopInvariantCodeMotion{}.Run(&module())); - EXPECT_FALSE(simplified_loop); + ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Outfeed())); @@ -286,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -296,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); HloInstruction* bitcast_inst = builder.AddInstruction( HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); @@ -438,5 +458,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { EXPECT_FALSE(simplified_loop); } +const char* const kConstantHoistingTestCase = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2]{0}) parameter(0) + p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 + const = f32[2]{0} constant({3, 4}) + add.0 = f32[2]{0} add(p_body.1, const) + ROOT root = (f32[2]{0}) tuple(add.0) +} + +condition { + p_cond = (f32[2]{0}) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2]{0} constant({1, 2}) + while_init = (f32[2]{0}) tuple(const_0) + ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = module().GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + + // We expect the while body to be the equivalent of: + // + // wide.body { + // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0) + // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0 + // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1) + // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0 + // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7) + // tuple.3 = (f32[2]{0}) tuple(add.1) + // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0 + // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8, + // get-tuple-element.9) + // } + + auto wide_param_1 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0); + auto tuple_1 = op::Tuple(get_tuple_element_1); + auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0); + auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1); + auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7); + auto tuple_3 = op::Tuple(add_1); + auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0); + auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1); + auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9); + + EXPECT_THAT(while_body->root_instruction(), tuple_4); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963..3c8304921661a486f283ea8c0009db16a81531a4 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))), + token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); auto* recv = while_body->AddInstruction( - HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), + HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -208,8 +211,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); + while_body->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index ed20b36292a7f24385603627d74fc72ba6b3b724..473eab2ea84eb8faf745cbe299bc80bcc1b62a35 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while = containing_computation->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); - TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( - while_instr, TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + + // We want to get rid of the old while instruction even if it has side + // effecting operations so we do a manual HloComputation::RemoveInstruction + // instead of relying on HloComputation::ReplaceInstruction. + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 322d27b88cae60cb051f5fafdde70e2aafedbc1e..e67636d80f4b682fe1335eae535fb86105ac082b 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -38,17 +38,21 @@ class WhileUtil { }; // Replaces `while_instr` with a new while instruction that is equivalent to - // `while_instr`, except that it has all of the HLO instructions in + // `while_instr` except that it has all of the HLO instructions in // `instructions` as live-in, loop invariant values. These new live in values // are represented as new elements appended to the parameter of the while // loop, which must be of tuple shape. GetTupleElement instructions computing // each new live in value is returned in the `while_body_live_in_values` // vector. // - // Precondition: `while_instr` must have a tuple shaped state. + // Deletes `while_instr` after replacing it. // - // Every instruction in `instructions` must be contained in the computation - // that contains `while_instr`. + // Preconditions: + // + // `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 974bc542a34d0af6d41ed29f36df87f4c164a360..2ccb919acf9c4e7c59a1ebaf36f42a6781068b5e 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace { @@ -49,7 +50,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -150,7 +151,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* while_body = module->GetComputationWithName("body"); @@ -163,5 +164,49 @@ ENTRY main { ASSERT_EQ(gte_list.size(), 1); EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); } + +TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { + const char* const hlo_string = R"( +HloModule WhileWithSideEffects + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) + ROOT condition = pred[] get-tuple-element(infeed), index=0 +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + to_make_live_in = f32[100] parameter(1) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* main = module->GetComputationWithName("main"); + HloInstruction* while_instr = main->root_instruction(); + HloInstruction* to_make_live_in = main->parameter_instruction(1); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{to_make_live_in})); + + auto is_while = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }; + EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index aa40b5cb264803097f52966d6f61f1f41b6b3017..44b0ec5cd4c1d406467007fcc530e919d602c438 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -32,11 +32,11 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { if (instruction->HasSideEffect() || - ShapeUtil::IsTuple(instruction->shape())) { + !ShapeUtil::IsArray(instruction->shape())) { continue; } if (comp->IsRemovable(instruction) && - ShapeUtil::HasZeroElements(instruction->shape())) { + ShapeUtil::IsZeroElementArray(instruction->shape())) { TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant( Literal::CreateFromShape(instruction->shape())))); diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index f5331280ee9f252aa5717baab88f2c203be5c372..c6bd013a1aa59fe99f8f80197f04eb1e8a97cbb7 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -67,7 +67,9 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) { } TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) { - builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0)); + auto token = builder_.AddInstruction(HloInstruction::CreateAfterAll({})); + builder_.AddInstruction( + HloInstruction::CreateSend(zero_sized_param_, token, 0)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 141347a792c23a2c542d7b564ab76c118409865d..14c35e7b84f07bebac33a9753ac26a8ee1418f1e 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,41 +47,22 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; - - virtual Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) = 0; - virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) = 0; - virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; - virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; - virtual Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; - virtual Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) = 0; - virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; - virtual Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) = 0; @@ -91,31 +72,9 @@ class ServiceInterface { virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - // Methods used by ComputationBuilder. - virtual Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) = 0; - - virtual Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; - - virtual Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) = 0; - virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) = 0; - // Methods used by Computation. - virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; - // Methods used by GlobalData. virtual Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index ffaa40c2d673a2365342371ed8dab59565d1d08f..4aacc87b78e2c271829cdf397cd69bfb490125b8 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -42,36 +42,23 @@ namespace internal { template struct ShapeTreeNode { // Data corresponding to this node. - T data; + std::pair data; - // Children of this node. - std::vector> children; + // Children of this node, as indices into the container's nodes_ array. + std::vector children; - ShapeTreeNode() = default; - explicit ShapeTreeNode(const T& data) : data(data) {} + // Tells whether this is a leaf node. + bool is_leaf = true; - ShapeTreeNode(const ShapeTreeNode& other) - : data(other.data), children(other.children.size()) { - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - - ShapeTreeNode& operator=(const ShapeTreeNode& other) { - if (this != &other) { - data = other.data; - children.resize(other.children.size()); - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - return *this; - } + explicit ShapeTreeNode(ShapeIndex index) + : ShapeTreeNode(std::move(index), T()) {} + ShapeTreeNode(ShapeIndex index, T data) + : data(std::move(index), std::move(data)) {} }; } // namespace internal -template +template class ShapeTreeIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a @@ -95,10 +82,9 @@ class ShapeTreeIterator; // before its ShapeTree goes away. template class ShapeTree { - friend class ShapeTreeIterator; - friend class ShapeTreeIterator; - public: + using Node = internal::ShapeTreeNode; + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -110,35 +96,17 @@ class ShapeTree { // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); + explicit ShapeTree(const std::shared_ptr& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); - - ShapeTree(const ShapeTree& other) { *this = other; } - ShapeTree(ShapeTree&&) = default; - - ShapeTree& operator=(const ShapeTree& other) { - root_ = other.root_; - - // Fix up internal pointer if necessary. - if (other.shape_storage_) { - CHECK_EQ(other.shape_, other.shape_storage_.get()); - shape_storage_.reset(new Shape(*other.shape_)); - shape_ = shape_storage_.get(); - } else { - shape_ = other.shape_; - } - - return *this; - } - - ShapeTree& operator=(ShapeTree&& other) = default; + ShapeTree(const std::shared_ptr& shape, const T& init_value); // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). - const T& element(const ShapeIndex& index) const; - T* mutable_element(const ShapeIndex& index); + const T& element(ShapeIndexView index) const; + T* mutable_element(ShapeIndexView index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } @@ -157,67 +125,72 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index)->children.empty(); - } + bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; } + + ShapeTree(const ShapeTree&) = default; + ShapeTree& operator=(const ShapeTree&) = default; + ShapeTree(ShapeTree&&) = default; + ShapeTree& operator=(ShapeTree&& other) = default; - // iterator implements a forward_iterator with value_type = - // std::pair - using iterator = ShapeTreeIterator; - using const_iterator = ShapeTreeIterator; + // iterator implements a bidirectional_iterator with + // value_type = std::pair. + // + // The iteration order is guaranteed to be a pre-order walk of the ShapeTree. + using iterator = + ShapeTreeIterator, typename std::vector::iterator, + std::pair>; + using const_iterator = + ShapeTreeIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; // begin/end for iterating over all nodes. iterator begin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } iterator end() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } // rbegin/rend for iterating over all nodes in reverse. - iterator rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - iterator rend() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); } - const_iterator rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - const_iterator rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -227,22 +200,32 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - iterator leaf_rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } + reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } + const_reverse_iterator leaf_rbegin() const { + return const_reverse_iterator(leaf_end()); } - iterator leaf_rend() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_reverse_iterator leaf_rend() const { + return const_reverse_iterator(leaf_begin()); } - const_iterator leaf_rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/true); + + // Returns an iterator pointing to the given ShapeIndex. + // REQUIRES: index must exist in the ShapeTree. + iterator find(ShapeIndexView index) { + Node* element = Lookup(index); + return iterator(&nodes_, typename std::vector::iterator(element), + /*iterate_leaves_only=*/false); } - const_iterator leaf_rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_iterator find(ShapeIndexView index) const { + Node* element = Lookup(index); + return iterator(&nodes_, + typename std::vector::const_iterator(element), + /*iterate_leaves_only=*/false); } + // Returns the number of leaf nodes in the tree. + int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -282,8 +265,6 @@ class ShapeTree { bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: - using Node = internal::ShapeTreeNode; - // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); @@ -292,136 +273,55 @@ class ShapeTree { // default-constructed data values. void InitChildren(const Shape& shape, Node* node); + // Returns the number of subshapes, including interior nodes, in shape. + int64 CountSubshapes(const Shape& shape); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template - static Status ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index); + static Status ForEachHelper(const Fn& func, const std::vector& nodes); template - static Status ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index); + static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. - Node* Lookup(const ShapeIndex& index); - const Node* Lookup(const ShapeIndex& index) const; + Node* Lookup(ShapeIndexView index); + const Node* Lookup(ShapeIndexView index) const; - // The root node, which contains all other nodes. - Node root_; + // The nodes in this shape tree. + std::vector nodes_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. - std::unique_ptr shape_storage_; + std::shared_ptr shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either // shape_storage_.get() or the Shape pointer passed to our constructor. const Shape* shape_; }; -// Internal iterator that performs a pre-order walk. This is copyable, but -// contains a vector so isn't cheap to copy. This also means post-increment is -// expensive. The iterator value_type is equivalent to a std::pair, similar to std::map. The non-const iterator's T& type can be mutated -// in-place. -template -class ShapeTreeIterator : public std::iterator> { +// Internal iterator that performs a pre-order walk. This is cheap to copy. +// The iterator value_type is equivalent to a +// std::pair&, similar to std::map. +template +class ShapeTreeIterator + : public std::iterator { public: - using value_type = - typename std::conditional, - std::pair>::type; - using NodeType = - typename std::conditional::Node, - typename ShapeTree::Node>::type; - - // Construct an iterator pointing at node. Node must either be the tree root - // or nullptr (which is equivalent to end() and should not be dereferenced or - // incremented). If iterate_leaves_only is true, the iterator will not include - // interior tree nodes, only leaves. If reverse is true, the iterator will - // visit nodes in the reverse of pre-order traversal. - ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) - : node_(node), - iterate_leaves_only_(iterate_leaves_only), - reverse_(reverse) { - if (node_) { - if (reverse_) { - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - } else { - if (!node_->children.empty() && iterate_leaves_only) { - ++*this; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node, + bool iterate_leaves_only) + : nodes_(nodes), + node_(std::move(node)), + iterate_leaves_only_(iterate_leaves_only) { + while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { + ++node_; } } - ShapeTreeIterator(const ShapeTreeIterator& other) - : node_(other.node_), - stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_), - reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { - CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - if (reverse_) { - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second - 1; - stack_.pop_back(); - if (next_child_index < 0) { - if (!iterate_leaves_only_) { - // All children are visited, yield . - return *this; - } - } else { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - return *this; - } - } - } else { - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - // Otherwise we are currently at a leaf. Walk back up until a node - // contains a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - } + ++node_; + while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { + ++node_; } - // We've walked off the end of the tree. Set node_ to nullptr to signify - // end(). - node_ = nullptr; - current_.reset(); return *this; } ShapeTreeIterator operator++(int) { @@ -429,52 +329,62 @@ class ShapeTreeIterator : public std::iterator nodes_->begin() && !node_->is_leaf) { + --node_; + } + return *this; + } + ShapeTreeIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + bool operator==(const ShapeTreeIterator& other) const { return node_ == other.node_; } bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - value_type& operator*() { return UpdateCurrent(); } - value_type* operator->() { return &UpdateCurrent(); } + ValueType& operator*() { return node_->data; } + ValueType* operator->() { return &node_->data; } private: - // Updates the current_ member to reflect the current state. - value_type& UpdateCurrent() { - ShapeIndex index; - for (auto& node_and_index : stack_) { - index.push_back(node_and_index.second); - } - current_ = ::xla::MakeUnique(index, node_->data); - return *current_; - } - - // The node to which this iterator is pointing. This is the source of truth in - // the iterator - the stack only exists to facilitate walking back from - // children to parents. - NodeType* node_; - // Stack of {node, child-index} pairs of the path taken from the root to get - // to node_. This allows us to backtrack and know where to go next. - std::vector> stack_; + ContainerType* nodes_; + IteratorType node_; // True if we should not include interior nodes in our walk. - bool iterate_leaves_only_; - // True if we should yield the reverse of the pre-order traversal. - bool reverse_; - // Placeholder for the current value. Ideally this wouldn't exist and would - // just be an rvalue, but operator -> needs to return a pointer to something. - // We cannot just use a plain old value_type as it contains a reference so - // cannot be default-constructed. - std::unique_ptr current_; + const bool iterate_leaves_only_; }; +template +int64 ShapeTree::CountSubshapes(const Shape& shape) { + int64 current_count = 1; + if (ShapeUtil::IsTuple(shape)) { + int64 count = ShapeUtil::TupleElementCount(shape); + for (int i = 0; i < count; ++i) { + current_count += CountSubshapes(shape.tuple_shapes(i)); + } + } + return current_count; +} + template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node(init_value)); - InitChildren(shape.tuple_shapes(i), init_value, - node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + node->is_leaf = false; + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index, init_value); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); } } } @@ -482,83 +392,110 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node()); - InitChildren(shape.tuple_shapes(i), node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + node->is_leaf = false; + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index); + InitChildren(shape.tuple_shapes(i), &nodes_.back()); } } } template ShapeTree::ShapeTree(Shape shape) - : root_(), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template -ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { - InitChildren(*shape_, &root_); +ShapeTree::ShapeTree(const std::shared_ptr& shape) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) - : root_(init_value), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, init_value, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) - : root_(init_value), shape_(shape) { - InitChildren(*shape_, init_value, &root_); + : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const std::shared_ptr& shape, + const T& init_value) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template -const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index)->data; +const T& ShapeTree::element(ShapeIndexView index) const { + return Lookup(index)->data.second; } template -T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index)->data; +T* ShapeTree::mutable_element(ShapeIndexView index) { + return &Lookup(index)->data.second; } template -internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { - Node* node = &root_; +internal::ShapeTreeNode* ShapeTree::Lookup(ShapeIndexView index) { + Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); CHECK_LT(i, node->children.size()); - node = node->children[i].get(); + node = &nodes_[node->children[i]]; } return node; } template const internal::ShapeTreeNode* ShapeTree::Lookup( - const ShapeIndex& index) const { + ShapeIndexView index) const { return const_cast(this)->Lookup(index); } /* static */ template template -Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, node.data)); - for (int64 i = 0; i < node.children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); - index->pop_back(); +Status ShapeTree::ForEachHelper(const Fn& func, + const std::vector& nodes) { + for (const auto& node : nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, node.data.second)); } return Status::OK(); } @@ -566,14 +503,10 @@ Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, /* static */ template template -Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, &node->data)); - for (int64 i = 0; i < node->children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, node->children[i].get(), index)); - index->pop_back(); +Status ShapeTree::ForEachMutableHelper(const Fn& func, + std::vector* nodes) { + for (auto& node : *nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second)); } return Status::OK(); } @@ -581,40 +514,36 @@ Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, root_, &index); + return ForEachHelper(func, nodes_); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper(func, &root_, &index); + return ForEachMutableHelper(func, &nodes_); } template template void ShapeTree::ForEachElement(const Fn& func) const { - ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, - root_, &index) + nodes_) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { - ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, - &root_, &index) + &nodes_) .IgnoreError(); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4b6ab772811f4a6c6ffc1d10befc7122f883b8f9..51de82e95746281ed6e587b545dc933b48ce1ad4 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -115,6 +116,11 @@ TEST_F(ShapeTreeTest, InitValueConstructor) { TestInitValueConstructor(nested_tuple_shape_, 10); } +TEST_F(ShapeTreeTest, EmptyTupleMustHaveNoLeaves) { + ShapeTree shape_tree{ShapeUtil::MakeTupleShape({})}; + EXPECT_EQ(0, shape_tree.leaf_count()); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; @@ -421,8 +427,8 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - t.begin()->second = 78; - EXPECT_EQ(78, t.begin()->second); + (*t.begin()).second = 78; + EXPECT_EQ(78, (*t.begin()).second); i = 0; for (auto& index_to_data : t) { if (i == 0) { @@ -434,14 +440,14 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - EXPECT_EQ(78, t.begin()->second); - EXPECT_EQ(98, std::next(t.begin())->second); + EXPECT_EQ(78, (*t.begin()).second); + EXPECT_EQ(98, (*std::next(t.begin())).second); } TEST_F(ShapeTreeTest, IterateOrder) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t) { + for (auto index_to_data : t) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{{}, @@ -479,7 +485,7 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t.leaves()) { + for (auto index_to_data : t.leaves()) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{ @@ -502,5 +508,106 @@ TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { })); } +void BM_Construct(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(shape); + } +} + +void BM_ConstructUnowned(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(&shape); + } +} + +void BM_Copy(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = shape_tree; + tensorflow::testing::DoNotOptimize(copy); + } +} + +void BM_Move(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = std::move(shape_tree); + shape_tree = std::move(copy); + } +} + +void BM_ForEach(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + shape_tree.ForEachMutableElement([](const ShapeIndex& index, int* data) { + tensorflow::testing::DoNotOptimize(index); + }); + } +} + +void BM_Iterate(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + for (auto& iter : shape_tree) { + tensorflow::testing::DoNotOptimize(iter.second); + } + } +} + +BENCHMARK(BM_Construct)->ArgPair(2, 8); +BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); +BENCHMARK(BM_Copy)->ArgPair(2, 8); +BENCHMARK(BM_Move)->ArgPair(2, 8); +BENCHMARK(BM_ForEach)->ArgPair(2, 8); +BENCHMARK(BM_Iterate)->ArgPair(2, 8); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7a897f6f8f99e65285e1be0757a55f703fc81c72..56d24423c428d32c1c65ed7a47aab9691a846559 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -42,17 +43,35 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + string ShapeIndex::ToString() const { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } string ShapeIndexView::ToString() const { - return tensorflow::strings::StrCat( - "{", - tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), - ","), - "}"); + return StrCat("{", + tensorflow::str_util::Join( + tensorflow::gtl::make_range(begin_, end_), ","), + "}"); +} + +bool ShapeIndexView::operator==(const ShapeIndexView& other) const { + if (size() != other.size()) { + return false; + } + for (auto it = begin(), other_it = other.begin(); it != end(); + ++it, ++other_it) { + if (*it != *other_it) { + return false; + } + } + return true; +} + +bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { + return !(*this == other); } std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { @@ -67,18 +86,34 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { namespace { +// Returns whether the given primitive type corresponds to an array shape. +bool IsArrayPrimitiveType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { - return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), +bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, + bool ignore_fp_precision) { + if ((ignore_fp_precision && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + + if (ShapeUtil::IsTuple(lhs)) { + return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts); + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); }); - } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { - return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); + } else if (!ShapeUtil::IsArray(lhs)) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return true; } if (compare_layouts) { @@ -108,10 +143,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; } - if (!ShapeUtil::SameElementType(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } return true; } @@ -144,7 +175,8 @@ StatusOr MakeShapeWithLayoutInternal( } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/false); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -153,9 +185,21 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/true); + if (!equal && VLOG_IS_ON(3)) { + VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " + << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) - << "Tuples do not have a rank, shape: " << shape; + CHECK(ShapeUtil::IsArray(shape)) + << "Non-arrays do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -182,8 +226,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); return result; @@ -206,8 +249,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, int64 max_sparse_elements) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -240,6 +282,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( tensorflow::gtl::ArraySlice shapes) { Shape result; result.set_element_type(TUPLE); + result.mutable_tuple_shapes()->Reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -254,6 +297,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return result; } +/* static */ Shape ShapeUtil::MakeTokenShape() { + Shape result; + result.set_element_type(TOKEN); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); @@ -277,7 +327,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + if (!IsArray(shape)) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -303,6 +353,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case TUPLE: case OPAQUE: + case TOKEN: return false; default: @@ -318,6 +369,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return IsArrayPrimitiveType(shape.element_type()); +} + /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), IsTuple); @@ -328,7 +383,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); + return IsEmptyTuple(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -344,6 +399,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.tuple_shapes(index); } +/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) { + int64 n = 0; + ForEachSubshape(shape, [&](const Shape& literal_subshape, + const ShapeIndex& index) { ++n; }); + return n; +} + /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); @@ -371,39 +433,33 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); + CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies()); } -/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) { - return ElementsIn(shape) == 0; +/* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { + CHECK(IsArray(shape) || IsTuple(shape)); + if (IsArray(shape)) { + return ElementsIn(shape); + } + int64 count = 0; + for (const Shape& element_shape : shape.tuple_shapes()) { + count += ElementsInRecursive(element_shape); + } + return count; +} + +/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { return shape.element_type() == F32 && Rank(shape) == 0; } -/* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { - string text = "("; - const char* prefix = ""; - for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); - prefix = ", "; - } - text += ")"; - return text; - } else { - return tensorflow::strings::StrCat( - tensorflow::str_util::Lowercase( - PrimitiveType_Name(shape.element_type())), - "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); - } -} - namespace { // Class to memoize the computation of @@ -453,48 +509,56 @@ StatusOr StringToPrimitiveType(const string& name) { } // namespace -/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { +/* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, - HumanStringWithLayout(elem_shape)); + StrAppend(&text, prefix, HumanString(elem_shape)); prefix = ", "; } text += ")"; return text; - } else { - string result = tensorflow::strings::StrCat( - LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", - shape.dimensions(i)); + } + return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", + tensorflow::str_util::Join(shape.dimensions(), ","), "]"); +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (IsTuple(shape)) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + StrAppend(&text, prefix, HumanStringWithLayout(elem_shape)); + prefix = ", "; } - result += "]"; - if (!IsScalar(shape) && !IsOpaque(shape)) { - if (LayoutUtil::HasLayout(shape)) { - tensorflow::strings::StrAppend(&result, - LayoutUtil::HumanString(shape.layout())); - } + text += ")"; + return text; + } + string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + for (int i = 0; i < shape.dimensions().size(); i++) { + StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + } + result += "]"; + if (!IsScalar(shape) && IsArray(shape)) { + if (LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } - return result; } + return result; } /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { std::vector parameters; for (auto& shape : program_shape.parameters()) { const int i = parameters.size(); - parameters.push_back( - tensorflow::strings::StrCat(i < program_shape.parameter_names_size() - ? program_shape.parameter_names(i) - : "(unknown)", - ": ", HumanString(shape))); + parameters.push_back(StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); } - return tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", - HumanString(program_shape.result())); + return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); } namespace { @@ -528,12 +592,11 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding // amount from our StringPiece type. + static LazyRE2 shape_pattern = { + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume( - &s_consumable, - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", - &element_type_string, &dimensions_string, &format_string, - &layout_string)) { + if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, + &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); auto string_to_int64 = [&s](const string& input) -> StatusOr { @@ -564,14 +627,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the primitive element type. TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || - primitive_type == OPAQUE) { + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (format_string.empty() && layout_string.empty()) { + if (primitive_type == OPAQUE) { + result = ShapeUtil::MakeOpaqueShape(); + } else if (primitive_type == TOKEN) { + result = ShapeUtil::MakeTokenShape(); + } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); } else if (format_string == "sparse") { @@ -616,43 +682,37 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); - } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); + return CompareShapes(lhs, rhs, /*compare_layouts=*/false, + /*ignore_fp_precision=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + CompatibleIgnoringElementType(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return CompatibleIgnoringElementType(lhs, rhs); - } - return false; } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -674,10 +734,6 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { switch (primitive_type) { case PRED: return sizeof(int8); - case TUPLE: - LOG(FATAL) << "tuples have no definitive size"; - case OPAQUE: - LOG(FATAL) << "opaque have no definitive size"; case S8: return sizeof(int8); case S16: @@ -704,6 +760,13 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(double); case C64: return sizeof(complex64); + case TOKEN: + // Tokens require no space. + return 0; + case TUPLE: + case OPAQUE: + LOG(FATAL) << PrimitiveType_Name(primitive_type) + << " primitive type has no definitive size"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } @@ -712,28 +775,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); + } else if (IsArray(shape)) { + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; + } else if (shape.element_type() == TOKEN) { + return 0; } - int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } - return byte_size; + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " primitive type has no definitive size"; } /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_EQ(TUPLE, shape.element_type()); CHECK_GT(pointer_size, 0); return pointer_size * shape.tuple_shapes_size(); } /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(ShapeUtil::IsArray(shape)); + CHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -758,13 +825,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(LayoutUtil::IsSparseArray(shape)); + CHECK(LayoutUtil::IsSparseArray(shape)); return LayoutUtil::MaxSparseElements(shape.layout()) * ShapeUtil::Rank(shape) * sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -780,10 +851,24 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.tuple_shapes_size() > 0) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + + // Tokens and opaques can should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.dimensions_size() != 0) { + return InvalidArgument( + "shape has %s element type, but has dimensions field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + if (shape.has_layout()) { + return InvalidArgument( + "shape has %s element type, but has layout field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + return Status::OK(); } + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " @@ -800,6 +885,60 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } } + TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); + return Status::OK(); +} + +/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { + VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); + + if (!IsArray(shape)) { + return Status::OK(); + } + + int64 shape_size = [&shape]() { + int64 shape_size; + if (LayoutUtil::IsSparseArray(shape)) { + shape_size = LayoutUtil::MaxSparseElements(shape.layout()); + if (shape_size < 0) { + return shape_size; + } + shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape)); + if (shape_size < 0) { + return shape_size; + } + shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64)); + if (shape_size < 0) { + return shape_size; + } + } + + shape_size = 1; + + // This is intentionally unconditional: even if the shape is sparse, we want + // to verify the densified version has a reasonable size. + if (shape.dimensions().empty()) { + return shape_size; + } + + for (int64 dim : shape.dimensions()) { + shape_size = MultiplyWithoutOverflow(shape_size, dim); + if (shape_size < 0) { + return shape_size; + } + } + shape_size = MultiplyWithoutOverflow( + shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + + return shape_size; + }(); + + if (shape_size < 0) { + return InvalidArgument("Shape %s size may overflow int64.", + ShapeUtil::HumanString(shape).c_str()); + } + + VLOG(3) << "Shape size is valid: " << shape_size; return Status::OK(); } @@ -848,6 +987,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return *return_shape; } +/* static */ StatusOr ShapeUtil::TryGetSubshape( + const Shape& shape, ShapeIndexView index) { + const Shape* return_shape = &shape; + for (auto i : index) { + if (!IsTuple(*return_shape) || i < 0 || + i >= return_shape->tuple_shapes_size()) { + return InvalidArgument( + "Shape index %s not a valid subshape index for tuple with shape %s", + index.ToString().c_str(), shape.DebugString().c_str()); + } + return_shape = &return_shape->tuple_shapes(i); + } + return return_shape; +} + /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape, ShapeIndexView index) { Shape* return_shape = shape; @@ -863,64 +1017,30 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !IsTuple(GetSubshape(shape, index)); } -/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { - std::vector dimension_sizes; - std::vector degenerate_dimensions; - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.dimensions(i) == 1) { - degenerate_dimensions.push_back(i); - } else { - dimension_sizes.push_back(shape.dimensions(i)); - } - } - - // Construct minor_to_major of stripped shape. The order of the non-degenerate - // dimensions should be preserved from the original shape. First, create - // vector of the non-degenerate dimensions from the original minor_to_major - // array. - std::vector minor_to_major; - for (int64 i : shape.layout().minor_to_major()) { - if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), - i) == degenerate_dimensions.end()) { - minor_to_major.push_back(i); +/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + int64 count = 0; + ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + ++count; } - } + }); + return count; +} - // The dimensions in minor_to_major need to be renumbered to account for the - // degenerate dimensions which have removed. Decrement each dimension number - // once for each degenerate dimension which has a smaller number. - for (int i = 0; i < minor_to_major.size(); ++i) { - int adjustment = 0; - for (int64 dim : degenerate_dimensions) { - if (minor_to_major[i] > dim) { - adjustment++; - } +/* static */ std::vector ShapeUtil::GetLeafShapes( + const Shape& shape) { + std::vector leaves; + ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + leaves.emplace_back(index, sub_shape); } - minor_to_major[i] -= adjustment; - } - - { - std::vector dims(minor_to_major.size()); - std::iota(dims.begin(), dims.end(), 0); - DCHECK(minor_to_major.size() == dims.size() && - std::is_permutation(minor_to_major.begin(), minor_to_major.end(), - dims.begin())); - } - Shape stripped_shape; - if (LayoutUtil::IsDenseArray(shape)) { - stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, - minor_to_major); - } else if (LayoutUtil::IsSparseArray(shape)) { - stripped_shape = - MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, - shape.layout().max_sparse_elements()); - } else { - stripped_shape = MakeShape(shape.element_type(), dimension_sizes); - } + }); + return leaves; +} - VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); - VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); - return stripped_shape; +/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { + CHECK(ShapeUtil::IsArray(shape)); + return ArrayContains(AsInt64Slice(shape.dimensions()), 1); } namespace { @@ -1028,6 +1148,9 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { + CHECK(IsArray(shape_pre)); + CHECK(IsArray(shape_post)); + auto nil = std::make_tuple(false, std::vector(), std::vector()); std::vector deleted_indices; @@ -1085,6 +1208,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -1138,8 +1264,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(LayoutUtil::HasLayout(input_shape) && - LayoutUtil::HasLayout(output_shape)); + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + CHECK(LayoutUtil::HasLayout(input_shape)); + CHECK(LayoutUtil::HasLayout(output_shape)); if (!SameElementType(input_shape, output_shape)) { return false; @@ -1301,6 +1429,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + int64 input_rank = Rank(input_shape); int64 output_rank = Rank(output_shape); @@ -1435,6 +1566,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { + CHECK(IsArray(shape)); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); @@ -1456,6 +1588,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { + CHECK(IsArray(shape)); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 82c75f85d838f94cb040e56d59d0e012af5e0db0..5ae04451d32bd733dce55c4a56f5ebc1882d9fbd 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -61,6 +62,8 @@ class ShapeIndex { public: ShapeIndex() = default; ShapeIndex(std::initializer_list init) : indices_(init) {} + template + ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} bool empty() const { return indices_.empty(); } size_t size() const { return indices_.size(); } @@ -131,6 +134,10 @@ class ShapeIndexView { ++new_begin; return ShapeIndexView(new_begin, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } + + bool operator==(const ShapeIndexView& other) const; + bool operator!=(const ShapeIndexView& other) const; string ToString() const; @@ -150,29 +157,40 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + // Data structure which describes the coordinates and the shape, of a tuple + // shaped sub-shape. + struct IndexedShape { + IndexedShape() = default; + IndexedShape(ShapeIndex index, Shape shape) + : index(std::move(index)), shape(std::move(shape)) {} + ShapeIndex index; + Shape shape; + }; + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: !IsTuple(shape) + // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); - // Returns true if 'shape' has zero elements. - static bool HasZeroElements(const Shape& shape); + // As ElementsIn(), but recurses through tuples. + static int64 ElementsInRecursive(const Shape& shape); + + // Returns true if 'shape' is an array with zero elements. + static bool IsZeroElementArray(const Shape& shape); // Returns the number of bytes required for an allocation of shape. The // |pointer_size| parameter is used for calculating the size of tuple // shapes. This includes only the size of the top-level buffer. For example, a // tuple is stored as an array of pointers to other buffers. In this case, // this method only returns the size of the pointer array. - // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && - // !ShapeUtil::IsOpaque(shape) static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); // Returns the number of bytes used to store the primitive_type. // - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + // Precondition: ShapeUtil::IsArray(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -231,7 +249,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that they have the same element type + // point types; otherwise, checks that that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -262,6 +280,9 @@ class ShapeUtil { // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); + // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. + static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns the rank (number of dimensions) of the given shape. // Precondition: !IsTuple(shape) static int64 Rank(const Shape& shape); @@ -279,10 +300,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + return IsArray(shape) && Rank(shape) == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + return IsArray(shape) && TrueRank(shape) == 0; } static bool IsScalarF32(const Shape& shape); @@ -311,13 +332,17 @@ class ShapeUtil { // into a custom operation. static Shape MakeOpaqueShape(); + // Creates a token shape. Values of this shape are used for ordering + // side-effecting operations. + static Shape MakeTokenShape(); + // Appends a shape to the given tuple. static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); - // Returns an empty tuple shape. Can be used to indicate side-effects. + // Returns an empty tuple shape. Can be used as a sentinel Shape value. static Shape MakeNil() { return MakeTupleShape({}); } // Checks whether the shape is initialized. @@ -410,11 +435,15 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } + // Returns whether the shape is an token value used for ordering + // side-effecting operations. + static bool IsToken(const Shape& shape) { + return shape.element_type() == TOKEN; + } + // Returns whether the shape is an array. Note that scalars are considered // arrays. - static bool IsArray(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape); - } + static bool IsArray(const Shape& shape); // Returns whether the shape is a tuple with at least one element which is // also a tuple. @@ -423,7 +452,7 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is an empty tuple, or is an array with no elements. + // Returns true if shape is the nil shape (an empty tuple). static bool IsNil(const Shape& shape); // Returns the number of elements in the given tuple shape. @@ -434,6 +463,9 @@ class ShapeUtil { // Precondition: IsTuple(shape) && TupleElementCount(shape) > index static const Shape& GetTupleElementShape(const Shape& shape, int64 index); + // Returns the number of elements, recursively, in the given shape. + static int64 SubshapeCount(const Shape& shape); + // Slices tuple elements in the range [start, limit) and returns a new tuple // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); @@ -453,14 +485,24 @@ class ShapeUtil { static bool IndexIsValid(const Shape& shape, ShapeIndexView index); // GetSubshape and GetMutableSubshape return a particular nested Shape within - // the given Shape argument. + // the given Shape argument. The non-Try variants check fail if index is + // invalid. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); + static StatusOr TryGetSubshape(const Shape& shape, + ShapeIndexView index); static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); // Returns whether the given index in the given shape is a leaf element of the // shape. static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); + // Returns the number of leaves in the shape. + static int64 GetLeafCount(const Shape& shape); + + // Retrieves all the leaf shapes and their indexes, in the order walked by + // the ForEachSubshape() API. + static std::vector GetLeafShapes(const Shape& shape); + // Calls the given visitor function for each subshape of the given shape. // Subshapes are visited in DFS pre-order starting with the entire shape // (index {}). @@ -483,25 +525,9 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); - // Removes all degenerate dimensions (size one) from the given shape. The - // stripped minor_to_major preserves the relative ordering of non-degenerate - // dimensions. The stripped shape has the property that the underlying - // representation (bits in memory) for the stripped shape is the same as the - // original shape modulo padding. Examples: - // - // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} - // stripped shape: F32 [2], minor_to_major = {0} - // - // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} - // stripped shape: F32 [6, 5], minor_to_major = {1, 0} - // - // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} - // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} - // - // input shape: F32 [1, 1], minor_to_major = {0, 1} - // stripped shape: F32 [], minor_to_major = {} - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) - static Shape StripDegenerateDimensions(const Shape& shape); + // Returns true if `shape` (which must be an array) with degenerate dimensions + // (dimensions with bound 1). + static bool HasDegenerateDimensions(const Shape& shape); // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i] @@ -626,6 +652,28 @@ class ShapeUtil { .IgnoreError(); } + // These convenience wrappers don't take `base`, `count` and `incr` + // explicitly, but iterate over every element in `shape` instead. + + template + static Status ForEachIndexWithStatus(const Shape& shape, + const FnType& visitor_function) { + std::vector base(shape.dimensions_size()); + std::vector incr(shape.dimensions_size(), 1); + return ForEachIndexWithStatus(shape, base, + /*count=*/AsInt64Slice(shape.dimensions()), + incr, visitor_function); + } + + template + static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + // A parallel version of ForEachIndex(WithStatus). This can only be used if // the visitor_function is thread-safe and the order of iteration does not // matter. @@ -654,6 +702,10 @@ class ShapeUtil { static size_t Hash(const Shape& shape); private: + // Validates the shape size is sane. This makes sure it's safe to do + // calculations in int64 without overflowing. + static Status ValidateShapeSize(const Shape& shape); + // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); @@ -665,7 +717,7 @@ class ShapeUtil { tensorflow::gtl::ArraySlice incr, const FnType& visitor_function, bool parallel = false) { - if (ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index f7675e97da7b061bde063e5093256c2288f99c98..b6f30af381dd8d24ff28fdf7f729d6cb3df46ec9 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { } TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2]), f32[3])"; + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeShape(F32, {3}), }); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) @@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, ParseInvalidShapeString) { string shape_strings[] = { "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", @@ -153,6 +172,41 @@ TEST(ShapeUtilTest, CompatibleIdenticalShapes) { ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2)); } +TEST(ShapeUtilTest, TokenCompatibility) { + EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Compatible( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}))); +} + +TEST(ShapeUtilTest, TokensEqualShapes) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}))); + EXPECT_FALSE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {1, 0})}))); +} + TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); auto layout_1 = shape_1.mutable_layout(); @@ -188,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, EqualIgnoringFpPrecision) { + EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); +} + +TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); @@ -295,6 +367,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); + + EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); + EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { @@ -307,6 +382,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) { EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); } +TEST(ShapeUtilTest, NilShape) { + EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); +} + TEST(ShapeUtilTest, NestedTuple) { EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({}))); EXPECT_FALSE(ShapeUtil::IsNestedTuple( @@ -337,25 +422,30 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } -TEST(ShapeUtilTest, HasZeroElements) { - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {}))); - EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17}))); +TEST(ShapeUtilTest, IsZeroElementArray) { + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); + EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17}))); + + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})}))); } TEST(ShapeUtilTest, SameDimensions) { @@ -449,19 +539,21 @@ TEST(ShapeUtilTest, IsLeafIndex) { TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(nested_tuple)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); @@ -470,8 +562,10 @@ TEST(ShapeUtilTest, HumanString) { EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", - ShapeUtil::HumanStringWithLayout(nested_tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + ShapeUtil::HumanStringWithLayout(nested_tuple)); ProgramShape prog = ShapeUtil::MakeProgramShape( {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); @@ -481,8 +575,9 @@ TEST(ShapeUtilTest, HumanString) { "(unknown): u32[1,2], " "(unknown): s32[3,4], " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); prog.add_parameter_names("arg0"); @@ -497,8 +592,10 @@ TEST(ShapeUtilTest, HumanString) { "matrix: u32[1,2], " "matrix2: s32[3,4], " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); } @@ -713,14 +810,15 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } -TEST(ShapeUtilTest, StripDegenerateDimensions) { - EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShape(F32, {3, 1, 2})), - ShapeUtil::MakeShape(F32, {3, 2}))); - EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); +TEST(ShapeUtilTest, HasDegenerateDimensions) { + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 2}))); + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 1}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5}))); } TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index 0e1387c93938fa520562fcd63ac107a82b089a51..a32e2ad9851b0b5644f7e6f0f9ead6c438934c07 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -12,297 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -// StatusOr is the union of a Status object and a T object. StatusOr models -// the concept of an object that is either a value, or an error Status -// explaining why such a value is not present. To this end, StatusOr does not -// allow its Status value to be Status::OK. -// -// The primary use-case for StatusOr is as the return value of a -// function which may fail. -// -// Example client usage for a StatusOr, where T is not a pointer: -// -// StatusOr result = DoBigCalculationThatCouldFail(); -// if (result.ok()) { -// float answer = result.ValueOrDie(); -// printf("Big calculation yielded: %f", answer); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr: -// -// StatusOr result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr>: -// -// StatusOr> result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo = std::move(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example factory implementation returning StatusOr: -// -// StatusOr FooFactory::MakeNewFoo(int arg) { -// if (arg <= 0) { -// return tensorflow::InvalidArgument("Arg must be positive"); -// } else { -// return new Foo(arg); -// } -// } -// -// Note that the assignment operators require that destroying the currently -// stored value cannot invalidate the argument; in other words, the argument -// cannot be an alias for the current value, or anything owned by the current -// value. #ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #define TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/statusor_internals.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { -#if defined(__clang__) -// Only clang supports warn_unused_result as a type annotation. -template -class TF_MUST_USE_RESULT StatusOr; -#endif - -template -class StatusOr : private internal_statusor::StatusOrData, - private internal_statusor::TraitsBase< - std::is_copy_constructible::value, - std::is_move_constructible::value> { - template - friend class StatusOr; - - typedef internal_statusor::StatusOrData Base; - - public: - typedef T element_type; - - // Constructs a new StatusOr with Status::UNKNOWN status. This is marked - // 'explicit' to try to catch cases like 'return {};', where people think - // StatusOr> will be initialized with an empty vector, - // instead of a Status::UNKNOWN status. - explicit StatusOr(); - - // StatusOr will be copy constructible/assignable if T is copy - // constructible. - StatusOr(const StatusOr&) = default; - StatusOr& operator=(const StatusOr&) = default; - - // StatusOr will be move constructible/assignable if T is move - // constructible. - StatusOr(StatusOr&&) = default; - StatusOr& operator=(StatusOr&&) = default; - - // Conversion copy/move constructor, T must be convertible from U. - template ::value>::type* = nullptr> - StatusOr(const StatusOr& other); - template ::value>::type* = nullptr> - StatusOr(StatusOr&& other); - - // Conversion copy/move assignment operator, T must be convertible from U. - template ::value>::type* = nullptr> - StatusOr& operator=(const StatusOr& other); - template ::value>::type* = nullptr> - StatusOr& operator=(StatusOr&& other); - - // Constructs a new StatusOr with the given value. After calling this - // constructor, calls to ValueOrDie() will succeed, and calls to status() will - // return OK. - // - // NOTE: Not explicit - we want to use StatusOr as a return type - // so it is convenient and sensible to be able to do 'return T()' - // when the return type is StatusOr. - // - // REQUIRES: T is copy constructible. - StatusOr(const T& value); - - // Constructs a new StatusOr with the given non-ok status. After calling - // this constructor, calls to ValueOrDie() will CHECK-fail. - // - // NOTE: Not explicit - we want to use StatusOr as a return - // value, so it is convenient and sensible to be able to do 'return - // Status()' when the return type is StatusOr. - // - // REQUIRES: !status.ok(). This requirement is DCHECKed. - // In optimized builds, passing Status::OK() here will have the effect - // of passing tensorflow::error::INTERNAL as a fallback. - StatusOr(const Status& status); - StatusOr& operator=(const Status& status); - - // TODO(b/62186997): Add operator=(T) overloads. - - // Similar to the `const T&` overload. - // - // REQUIRES: T is move constructible. - StatusOr(T&& value); - - // RValue versions of the operations declared above. - StatusOr(Status&& status); - StatusOr& operator=(Status&& status); - - // Returns this->status().ok() - bool ok() const { return this->status_.ok(); } - - // Returns a reference to our status. If this contains a T, then - // returns Status::OK(). - const Status& status() const &; - Status status() &&; - - // Returns a reference to our current value, or CHECK-fails if !this->ok(). - // - // Note: for value types that are cheap to copy, prefer simple code: - // - // T value = statusor.ValueOrDie(); - // - // Otherwise, if the value type is expensive to copy, but can be left - // in the StatusOr, simply assign to a reference: - // - // T& value = statusor.ValueOrDie(); // or `const T&` - // - // Otherwise, if the value type supports an efficient move, it can be - // used as follows: - // - // T value = std::move(statusor).ValueOrDie(); - // - // The std::move on statusor instead of on the whole expression enables - // warnings about possible uses of the statusor object after the move. - // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 - // See go/ref-qualifiers for more details on such overloads. - const T& ValueOrDie() const &; - T& ValueOrDie() &; - const T&& ValueOrDie() const &&; - T&& ValueOrDie() &&; - - T ConsumeValueOrDie() { return std::move(ValueOrDie()); } - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; -}; - -//////////////////////////////////////////////////////////////////////////////// -// Implementation details for StatusOr - -template -StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} - -template -StatusOr::StatusOr(const T& value) : Base(value) {} - -template -StatusOr::StatusOr(const Status& status) : Base(status) {} - -template -StatusOr& StatusOr::operator=(const Status& status) { - this->Assign(status); - return *this; -} - -template -StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} - -template -StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} - -template -StatusOr& StatusOr::operator=(Status&& status) { - this->Assign(std::move(status)); - return *this; -} - -template -template ::value>::type*> -inline StatusOr::StatusOr(const StatusOr& other) - : Base(static_cast::Base&>(other)) {} - -template -template ::value>::type*> -inline StatusOr& StatusOr::operator=(const StatusOr& other) { - if (other.ok()) - this->Assign(other.ValueOrDie()); - else - this->Assign(other.status()); - return *this; -} - -template -template ::value>::type*> -inline StatusOr::StatusOr(StatusOr&& other) - : Base(static_cast::Base&&>(other)) {} - -template -template ::value>::type*> -inline StatusOr& StatusOr::operator=(StatusOr&& other) { - if (other.ok()) { - this->Assign(std::move(other).ValueOrDie()); - } else { - this->Assign(std::move(other).status()); - } - return *this; -} - -template -const Status& StatusOr::status() const & { - return this->status_; -} -template -Status StatusOr::status() && { - return ok() ? Status::OK() : std::move(this->status_); -} - -template -const T& StatusOr::ValueOrDie() const & { - this->EnsureOk(); - return this->data_; -} - -template -T& StatusOr::ValueOrDie() & { - this->EnsureOk(); - return this->data_; -} - -template -const T&& StatusOr::ValueOrDie() const && { - this->EnsureOk(); - return std::move(this->data_); -} - -template -T&& StatusOr::ValueOrDie() && { - this->EnsureOk(); - return std::move(this->data_); -} - +// Use steam_executor's StatusOr so we don't duplicate code. template -void StatusOr::IgnoreError() const { - // no-op -} +using StatusOr = ::stream_executor::port::StatusOr; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7a528a22473c435e8afe3895a37889334bf4f1ab..02f6fc3a27152afb6085494887f0777c23030263 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -90,11 +90,9 @@ cc_library( "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -119,11 +117,11 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -140,8 +138,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -621,6 +619,7 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", + size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], backends = [ "cpu", @@ -628,7 +627,6 @@ xla_test( ], shard_count = 48, tags = [ - "enormous", "manual", "notap", ], @@ -699,8 +697,9 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -778,30 +777,42 @@ xla_test( ], ) +CONVOLUTION_TEST_DEPS = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", +] + xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], + deps = CONVOLUTION_TEST_DEPS, +) + +xla_test( + name = "convolution_test_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + deps = CONVOLUTION_TEST_DEPS, ) xla_test( @@ -875,6 +886,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", @@ -1185,9 +1197,25 @@ xla_test( ], deps = [ ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1222,6 +1250,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1497,6 +1526,30 @@ xla_test( ], ) +xla_test( + name = "cross_replica_sum_test", + srcs = ["cross_replica_sum_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], @@ -1730,6 +1783,7 @@ xla_test( "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1936,6 +1990,7 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", ], ) @@ -1987,6 +2042,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 36a706496918ac8c15780473019e2a8d098ffa22..3bdf98544affca11fd825e28d20f4903188fe920 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -51,16 +51,16 @@ class ArrayElementwiseOpTestParamCount XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Neg(a); + auto a = ConstantR1(&builder, {}); + Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - builder.Neg(a); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); @@ -68,10 +68,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 1, 324, - std::numeric_limits::min(), - std::numeric_limits::max()}); - builder.Neg(a); + auto a = ConstantR1(&builder, + {-1, 0, 1, 324, std::numeric_limits::min(), + std::numeric_limits::max()}); + Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -84,17 +84,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Neg(a); + auto a = ConstantR1(&builder, {}); + Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - builder.Neg(a); + auto a = ConstantR1( + &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, @@ -103,16 +103,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({ - -1, - 1, - 0, - 0x12345678, - static_cast(0xffffffff12345678l), - static_cast(0x8000000000000000LL), - static_cast(0x8000000000000001LL), - }); - builder.Neg(a); + auto a = + ConstantR1(&builder, { + -1, + 1, + 0, + 0x12345678, + static_cast(0xffffffff12345678l), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + }); + Neg(a); LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); ComputeAndCompareR1(&builder, @@ -130,8 +131,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.IsFinite(a); + auto a = ConstantR1(&builder, {}); + IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -141,21 +142,21 @@ static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { XlaBuilder builder(TestName()); - builder.IsFinite(builder.ConstantR0(NAN)); + IsFinite(ConstantR0(&builder, NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + IsFinite(ConstantR0(&builder, kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - builder.IsFinite(builder.ConstantR0(inf)); + IsFinite(ConstantR0(&builder, inf)); ComputeAndCompareR0(&builder, false, {}); - builder.IsFinite(builder.ConstantR0(-inf)); + IsFinite(ConstantR0(&builder, -inf)); ComputeAndCompareR0(&builder, false, {}); - builder.IsFinite(builder.ConstantR0(0.0f)); + IsFinite(ConstantR0(&builder, 0.0f)); ComputeAndCompareR0(&builder, true, {}); } @@ -163,9 +164,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto a = builder.ConstantR1( - {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - builder.IsFinite(a); + auto a = ConstantR1(&builder, + {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); + IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); @@ -173,9 +174,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); @@ -183,20 +184,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); - auto b = builder.ConstantR1( - {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - builder.Add(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = ConstantR1( + &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -205,9 +206,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -225,7 +226,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 1}; std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); - auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -239,11 +240,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 1, 0x8000000000000000LL}; std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); - auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - b.Add(lhs_param, rhs_param); + Add(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -265,7 +266,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0, -1}; std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); - auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -278,11 +279,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); - auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - auto sub = b.Sub(lhs_param, rhs_param); + Sub(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -305,23 +306,23 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { std::unique_ptr a_literal = Literal::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a_constant = builder.ConstantR1(a_values); - auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); + auto a_constant = ConstantR1(&builder, a_values); + auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); std::unique_ptr b_literal = Literal::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); - auto b_param = builder.ConstantR1(b_values); + auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + auto b_param = ConstantR1(&builder, b_values); - auto sum1 = builder.Add(a_constant, b_constant); - auto sum2 = builder.Add(a_constant, b_param); - auto sum3 = builder.Add(a_param, b_constant); - auto sum4 = builder.Add(a_param, b_param); + auto sum1 = Add(a_constant, b_constant); + auto sum2 = Add(a_constant, b_param); + auto sum3 = Add(a_param, b_constant); + auto sum4 = Add(a_param, b_param); - auto sum = builder.Add(sum1, sum2); - sum = builder.Add(sum, sum3); - sum = builder.Add(sum, sum4); + auto sum = Add(sum1, sum2); + sum = Add(sum, sum3); + sum = Add(sum, sum4); std::vector expected; for (int64 i = 0; i < count; ++i) { @@ -334,9 +335,9 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); @@ -344,38 +345,38 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); - auto b = builder.ConstantR1({-1, 2, 1, -1}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {-1, 0, 2, 1000000000}); + auto b = ConstantR1(&builder, {-1, 2, 1, -1}); + Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); - auto b = builder.ConstantR1( - {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = ConstantR1( + &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -384,18 +385,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); + Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); @@ -403,9 +404,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -442,7 +443,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); @@ -454,7 +455,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Div(dividend, builder.ConstantR1(divisors)); + Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } @@ -467,7 +468,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); @@ -479,7 +480,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Rem(dividend, builder.ConstantR1(divisors)); + Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } @@ -513,7 +514,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); @@ -524,7 +525,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Div(dividend, builder.ConstantR1(divisors)); + Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } @@ -537,7 +538,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); @@ -548,7 +549,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Rem(dividend, builder.ConstantR1(divisors)); + Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } @@ -556,11 +557,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); - auto b = builder.ConstantR1( - {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - builder.Div(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = ConstantR1(&builder, + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); @@ -568,20 +569,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); - auto b = builder.ConstantR1( - {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - builder.Rem(a, b); + auto a = ConstantR1( + &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); + auto b = ConstantR1( + &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); + Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -590,20 +591,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Rem(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); - auto b = builder.ConstantR1( - {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - builder.Rem(a, b); + auto a = ConstantR1( + &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); + auto b = ConstantR1( + &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); + Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -612,9 +613,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); @@ -622,9 +623,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -648,18 +649,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(a_data); - auto b = builder.ConstantR1(b_data); - builder.Mul(a, b); + auto a = ConstantR1(&builder, a_data); + auto b = ConstantR1(&builder, b_data); + Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -679,20 +680,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(a_data); - auto b = builder.ConstantR1(b_data); - builder.Mul(a, b); + auto a = ConstantR1(&builder, a_data); + auto b = ConstantR1(&builder, b_data); + Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); - auto b = builder.ConstantR1( - {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - builder.Mul(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = ConstantR1(&builder, + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -701,27 +702,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false, true, true}); - auto b = builder.ConstantR1({false, true, false, true}); - builder.And(a, b); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, false}, {true, true}}); - auto b = builder.ConstantR2({{false, true}, {false, true}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -729,27 +730,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, -1, -8}); - auto b = builder.ConstantR1({5, -7, 12}); - builder.And(a, b); + auto a = ConstantR1(&builder, {0, -1, -8}); + auto b = ConstantR1(&builder, {5, -7, 12}); + And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); - auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{0, -5}, {-1, 5}}); + auto b = ConstantR2(&builder, {{1, -6}, {4, 5}}); + And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -757,27 +758,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 1, 8}); - auto b = builder.ConstantR1({5, 7, 12}); - builder.And(a, b); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 12}); + And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 1}, {3, 8}}); - auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{0, 1}, {3, 8}}); + auto b = ConstantR2(&builder, {{1, 0}, {7, 6}}); + And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -785,27 +786,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false, true, true}); - auto b = builder.ConstantR1({false, true, false, true}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, false}, {true, true}}); - auto b = builder.ConstantR2({{false, true}, {false, true}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -813,27 +814,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, -1, 8}); - auto b = builder.ConstantR1({5, -7, 4}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {0, -1, 8}); + auto b = ConstantR1(&builder, {5, -7, 4}); + Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, -1}, {8, 8}}); - auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); + Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -841,27 +842,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 1, 8}); - auto b = builder.ConstantR1({5, 7, 4}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 4}); + Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 1}, {8, 8}}); - auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); + Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -869,25 +870,108 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {false, true, true, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + Xor(a, b); + + Array2D expected_array({{false, true}, {true, false}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, -1, 8}); + auto b = ConstantR1(&builder, {5, -7, 4}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {5, 6, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); + Xor(a, b); + + Array2D expected_array({{5, 6}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 4}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {5, 6, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); + Xor(a, b); + + Array2D expected_array({{5, 6}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, true, true, false}); - builder.Not(a); + auto a = ConstantR1(&builder, {false, true, true, false}); + Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, true}, {true, false}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{false, true}, {true, false}}); + Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -895,24 +979,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 1}); - builder.Not(a); + auto a = ConstantR1(&builder, {-1, 0, 1}); + Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{-1, 0}, {1, 8}}); + Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -920,24 +1004,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 4294967295}); - builder.Not(a); + auto a = ConstantR1(&builder, {0, 4294967295}); + Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{0, 4294967295}, {1, 4294967294}}); + Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -945,19 +1029,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x12345678), - static_cast(0xF0001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); - builder.ShiftLeft(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x12345678), static_cast(0xF0001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, -1}); + ShiftLeft(a, b); ComputeAndCompareR1(&builder, {static_cast(0x23456780), 0x00100000, 0x4, @@ -967,11 +1051,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); - builder.ShiftRightArithmetic(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x92345678), static_cast(0x10001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, -1}); + ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, @@ -982,11 +1066,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); - builder.ShiftRightLogical(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x92345678), static_cast(0x10001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, -1}); + ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -994,10 +1078,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); - builder.ShiftLeft(a, b); + auto a = ConstantR1(&builder, + {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u}); + ShiftLeft(a, b); ComputeAndCompareR1( &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); @@ -1005,10 +1089,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); - builder.ShiftRightArithmetic(a, b); + auto a = ConstantR1(&builder, + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u}); + ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); @@ -1016,10 +1100,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); - builder.ShiftRightLogical(a, b); + auto a = ConstantR1(&builder, + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u}); + ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -1028,18 +1112,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1047,9 +1131,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Ge(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } @@ -1057,9 +1141,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Gt(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } @@ -1067,9 +1151,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Le(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } @@ -1077,9 +1161,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Lt(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -1088,9 +1172,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Eq(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1099,9 +1184,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1109,26 +1194,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, - {1.0f, 25.5f}, - {2.25f, -3.0f}, - {NAN, 0.0f}, - {1.0f, 6.0f}}); - auto rhs = builder.ConstantR1({{0.0f, 10.0f}, - {1.0f, 5.0f}, - {2.25f, -3.0f}, - {10.0f, 0.0f}, - {1.0f, NAN}}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1138,17 +1223,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, - {1.0f, 25.5f}, - {2.25f, -3.0f}, - {NAN, 0.0f}, - {1.0f, 6.0f}}); - auto rhs = builder.ConstantR1({{0.0f, 10.0f}, - {1.0f, 5.0f}, - {2.25f, -3.0f}, - {10.0f, 0.0f}, - {1.0f, NAN}}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1158,9 +1243,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN}); + Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1169,9 +1254,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Ne(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1181,9 +1267,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Ge(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1193,9 +1280,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Gt(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1206,9 +1294,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Le(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1218,9 +1307,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Lt(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1230,9 +1320,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1242,9 +1332,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1253,9 +1343,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Ge(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1264,9 +1354,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Gt(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1276,9 +1366,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Le(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1287,9 +1377,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Lt(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1300,10 +1390,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = - builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); + ConstantR1(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = - builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - builder.Pow(lhs, rhs); + ConstantR1(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); + Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1312,9 +1402,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); - auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - builder.Pow(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f}); + auto rhs = ConstantR1(&builder, {0.5f, 0.6f, -0.6f, -0.6f}); + Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); @@ -1322,9 +1412,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Pow(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1340,10 +1430,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); - auto sum = b.ConstantR0(0.0f); - auto param = b.Parameter(0, param_literal->shape(), "param"); + auto sum = ConstantR0(&b, 0.0f); + auto param = Parameter(&b, 0, param_literal->shape(), "param"); for (float exponent : exponents) { - sum = b.Add(sum, b.Pow(param, b.ConstantR0(exponent))); + sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } std::vector expected; @@ -1370,9 +1460,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Pow(b.Exp(param0), param1); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Pow(Exp(param0), param1); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1395,9 +1485,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Log(b.Pow(param0, param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Log(Pow(param0, param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1420,9 +1510,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Mul(b.Exp(param0), b.Exp(param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1445,9 +1535,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Div(param0, b.Exp(param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Div(param0, Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1476,10 +1566,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::unique_ptr literal2 = Literal::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(b.Div(param0, param1), param2); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(Div(param0, param1), param2); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1509,10 +1599,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(param0, b.Div(param1, param2)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(param0, Div(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1542,10 +1632,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(param0, b.Pow(param1, param2)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1580,11 +1670,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::unique_ptr data3 = client_->TransferToServer(*literal3).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - auto param3 = b.Parameter(3, literal3->shape(), "param2"); - b.Div(b.Div(param0, param1), b.Div(param2, param3)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1604,8 +1694,8 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } - auto x = builder.ConstantR1(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR1(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1630,8 +1720,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); - auto x = builder.ConstantR4FromArray4D(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR4FromArray4D(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -1641,8 +1731,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); - auto x = builder.ConstantR4FromArray4D(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR4FromArray4D(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -1650,9 +1740,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, error_spec_); @@ -1660,18 +1750,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, error_spec_); @@ -1680,9 +1770,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, error_spec_); @@ -1690,18 +1780,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, error_spec_); @@ -1711,11 +1801,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); - auto y = builder.ConstantR1( - {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = ConstantR1( + &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + Max(x, y); std::vector expected = {min, max, 0, -1, 0, 0, 0, 1, 1, 10, max, max, max}; @@ -1726,11 +1816,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); - auto y = builder.ConstantR1( - {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); - builder.Min(x, y); + auto x = ConstantR1( + &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = ConstantR1( + &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + Min(x, y); std::vector expected = {min, min, min, -10, -1, -1, 0, 0, 0, 1, 0, max, min}; @@ -1740,9 +1830,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); - auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); - builder.Max(x, y); + auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); + auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); + Max(x, y); std::vector expected = {0, 1, 1, 1, 10, max, max, max}; ComputeAndCompareR1(&builder, expected, {}); @@ -1751,9 +1841,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); - auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); - builder.Min(x, y); + auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); + auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); + Min(x, y); std::vector expected = {0, 0, 0, 1, 1, 0, 234234, max}; ComputeAndCompareR1(&builder, expected, {}); @@ -1761,11 +1851,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = builder.ConstantR1( - {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; @@ -1774,9 +1864,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XlaBuilder builder(TestName()); - auto u = builder.ConstantR1({3.5}); - auto v = builder.ConstantR1({}); - builder.Max(u, v); + auto u = ConstantR1(&builder, {3.5}); + auto v = ConstantR1(&builder, {}); + Max(u, v); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1784,9 +1874,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { XlaBuilder builder(TestName()); - auto u = builder.ConstantR1({3.5}); - auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); - builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); + auto u = ConstantR1(&builder, {3.5}); + auto v = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } @@ -1794,10 +1884,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - builder.Max(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1805,9 +1895,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({}); - auto m = builder.ConstantR2({{}, {}}); - builder.Max(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {}); + auto m = ConstantR2(&builder, {{}, {}}); + Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1815,10 +1905,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XlaBuilder builder(TestName()); - auto scalar = builder.ConstantR0(2); + auto scalar = ConstantR0(&builder, 2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); - auto array = builder.ConstantR3FromArray3D(a_3d); - builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + auto array = ConstantR3FromArray3D(&builder, a_3d); + Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); ComputeAndCompareR3(&builder, expected, {}); @@ -1826,10 +1916,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XlaBuilder builder(TestName()); - auto scalar = builder.ConstantR0(2); + auto scalar = ConstantR0(&builder, 2); Array3D a_3d(2, 0, 3); - auto array = builder.ConstantR3FromArray3D(a_3d); - builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + auto array = ConstantR3FromArray3D(&builder, a_3d); + Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected(2, 0, 3); ComputeAndCompareR3(&builder, expected, {}); @@ -1837,10 +1927,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XlaBuilder builder(TestName()); - auto m = - builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); - auto v = builder.ConstantR1({-10.2f, 16.4f}); - builder.Min(m, v, /*broadcast_dimensions=*/{0}); + auto m = ConstantR2(&builder, + {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); + auto v = ConstantR1(&builder, {-10.2f, 16.4f}); + Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1848,9 +1938,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XlaBuilder builder(TestName()); - auto m = builder.ConstantR2({{}, {}}); - auto v = builder.ConstantR1({-10.2f, 16.4f}); - builder.Min(m, v, /*broadcast_dimensions=*/{0}); + auto m = ConstantR2(&builder, {{}, {}}); + auto v = ConstantR1(&builder, {-10.2f, 16.4f}); + Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1859,11 +1949,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XlaBuilder builder(TestName()); auto array2d = - builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); - auto array4d = builder.ConstantR4FromArray4D( - {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, - {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); - builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + auto array4d = ConstantR4FromArray4D( + &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, + {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); + Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected( {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, @@ -1874,10 +1964,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XlaBuilder builder(TestName()); auto array2d = - builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); - auto array4d = builder.ConstantR4FromArray4D(arg); - builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + auto array4d = ConstantR4FromArray4D(&builder, arg); + Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected(2, 2, 0, 3); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -1885,9 +1975,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); - builder.Min(x, y); + auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + Min(x, y); std::vector expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; ComputeAndCompareR1(&builder, expected, {}); @@ -1895,9 +1985,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); - builder.Max(x, y); + auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + Max(x, y); std::vector expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; ComputeAndCompareR1(&builder, expected, {}); @@ -1905,19 +1995,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); - auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - builder.Rem(a, b); + auto a = ConstantR1(&builder, {-3, 26, 2, -1, 1}); + auto b = ConstantR1(&builder, {10, 5, 1, 10, -10}); + Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XlaBuilder builder(TestName()); - auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); - auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); - auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - builder.Clamp(minimum, argument, maximum); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); @@ -1925,10 +2016,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); - auto minimum = builder.ConstantR0(0.0f); - auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto maximum = builder.ConstantR0(5.0f); - builder.Clamp(minimum, argument, maximum); + auto minimum = ConstantR0(&builder, 0.0f); + auto argument = ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto maximum = ConstantR0(&builder, 5.0f); + Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); @@ -1936,16 +2027,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0.0f); - auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); - auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto max_scalar = builder.ConstantR0(3.0f); - auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + auto min_scalar = ConstantR0(&builder, 0.0f); + auto min_vector = + ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto arg_vector = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto max_scalar = ConstantR0(&builder, 3.0f); + auto max_vector = + ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); @@ -1953,52 +2047,52 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { XlaBuilder builder(TestName()); - auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); - auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); - auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - builder.Clamp(min_vector, arg_vector, max_vector); + auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0, -5}); + auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4, 10}); + auto max_vector = ConstantR1(&builder, {3, 0, 25, 5, 123, -1}); + Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0); - auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); - auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); - auto max_scalar = builder.ConstantR0(3); - auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + auto min_scalar = ConstantR0(&builder, 0); + auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0}); + auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4}); + auto max_scalar = ConstantR0(&builder, 3); + auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { XlaBuilder builder(TestName()); - auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); - auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); - auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - builder.Clamp(min_vector, arg_vector, max_vector); + auto min_vector = ConstantR1(&builder, {1, 2, 1, 2, 0, ~0u - 4}); + auto arg_vector = ConstantR1(&builder, {2, 10, 5, 1, 4, 10}); + auto max_vector = ConstantR1(&builder, {3, 5, 25, 5, 123, ~0u}); + Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0); - auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); - auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); - auto max_scalar = builder.ConstantR0(3); - auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + auto min_scalar = ConstantR0(&builder, 0); + auto min_vector = ConstantR1(&builder, {1, 0, 1, 2, 0}); + auto arg_vector = ConstantR1(&builder, {2, 10, 0, 1, 4}); + auto max_scalar = ConstantR0(&builder, 3); + auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } @@ -2016,9 +2110,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -2038,9 +2132,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -2055,9 +2149,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); - auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Add(a, p); + auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); + auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); @@ -2065,8 +2159,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - builder.Cos(a); + auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); + Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); @@ -2074,8 +2168,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - builder.Sin(a); + auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); + Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); @@ -2083,9 +2177,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); - auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - builder.Atan2(a, b); + auto a = ConstantR1(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); + auto b = ConstantR1(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); + Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2095,8 +2189,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - builder.Tanh(a); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f}); + Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2118,8 +2212,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Tanh(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Tanh(input); ComputeAndCompareR1( &builder, @@ -2164,8 +2258,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Exp(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Exp(input); std::vector expected_result; int64 input_size = input_literal->shape().dimensions(0); @@ -2202,8 +2296,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Log(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Log(input); std::vector expected_result; int64 input_size = input_literal->shape().dimensions(0); @@ -2218,9 +2312,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); - builder.Clz(a); + auto a = ConstantR1( + &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); + Clz(a); ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } @@ -2228,8 +2322,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { XlaBuilder builder(TestName()); auto a = - builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); - builder.Clz(a); + ConstantR1(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + Clz(a); ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); } @@ -2241,12 +2335,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // c---------------------/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); - auto add = builder.Add(a, b); - builder.Add(add, c); + auto add = Add(a, b); + Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2259,12 +2353,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // a---------------------/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); - auto add = builder.Add(b, c); - builder.Add(a, add); + auto add = Add(b, c); + Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2276,12 +2370,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // b ----- (neg) ----/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); - auto neg_a = builder.Neg(a); - auto neg_b = builder.Neg(b); - builder.Add(neg_a, neg_b); + auto neg_a = Neg(a); + auto neg_b = Neg(b); + Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2297,14 +2391,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // d -----/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); - auto d = builder.ConstantR1({-19.0f, 10.0f, -40.0f, 20.2f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); + auto d = ConstantR1(&builder, {-19.0f, 10.0f, -40.0f, 20.2f}); - auto add_ab = builder.Add(a, b); - auto add_cd = builder.Add(c, d); - builder.Add(add_ab, add_cd); + auto add_ab = Add(a, b); + auto add_cd = Add(c, d); + Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); @@ -2312,11 +2406,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2326,10 +2420,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto scalar = builder.ConstantR0(3.0f); - builder.Add(scalar, a); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = ConstantR0(&builder, 3.0f); + Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2338,10 +2432,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto scalar = builder.ConstantR0(3.0f); - builder.Add(a, scalar); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = ConstantR0(&builder, 3.0f); + Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2351,13 +2445,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); + auto v = ConstantR1(&builder, {20.0f, 40.0f, 60.0f}); // clang-format off - auto m = builder.ConstantR2({ + auto m = ConstantR2(&builder, { {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2366,14 +2460,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // Test broadcasting in Eq comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({42, 73}); - auto m = builder.ConstantR2({{42, 73}, {42, 52}}); + auto v = ConstantR1(&builder, {42, 73}); + auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); // This test exercises both possible broadcast dimensions for a vector/matrix // comparison. - auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); - auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); - auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); + auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1}); + auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); + Tuple(&builder, {cmp_dim_0, cmp_dim_1}); auto expected = Literal::MakeTuple( {Literal::CreateR2({{true, true}, {true, false}}).get(), @@ -2384,9 +2478,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({42, 73}); - auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {42, 73}); + auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); + Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2398,9 +2492,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2412,9 +2506,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2426,9 +2520,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Le(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2440,9 +2534,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2455,9 +2549,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. XlaBuilder builder(TestName()); - auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); - auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + auto m = + ConstantR2(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); + auto v = ConstantR1(&builder, {2.0f, 4.0f, 6.0f}); + Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2468,10 +2563,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - builder.Add(m, md); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = ConstantR2(&builder, {{10.0f, 20.0f, 30.0f}}); + Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2483,10 +2578,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - builder.Add(m, md); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = ConstantR2(&builder, {{10.0f}, {20.0f}}); + Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2501,9 +2596,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. - auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); - auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - builder.Add(a, b); + auto a = ConstantR2(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}}); + auto b = ConstantR2(&builder, {{1.0f, 2.0f, 3.0f}}); + Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2515,9 +2610,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f}); - auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - builder.Add(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {20.0f, 40.0f}); + auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); + Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2526,9 +2621,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f}); - auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - builder.Add(v, m, /*broadcast_dimensions=*/{0}); + auto v = ConstantR1(&builder, {20.0f, 40.0f}); + auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); + Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2538,12 +2633,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); - auto a = builder.ConstantR3FromArray3D(a_3d); + auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); - auto b = builder.ConstantR3FromArray3D(b_3d); - builder.Add(a, b); + auto b = ConstantR3FromArray3D(&builder, b_3d); + Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2565,9 +2660,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { {11.0f, 12.0f}}, }); // clang-format on - auto a = builder.ConstantR3FromArray3D(a_3d); - auto v = builder.ConstantR1({10.0f, 20.0f}); - builder.Add(a, v, /*broadcast_dimensions=*/{2}); + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto v = ConstantR1(&builder, {10.0f, 20.0f}); + Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2589,9 +2684,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { {11.0f, 12.0f}}, }); // clang-format on - auto a = builder.ConstantR3FromArray3D(a_3d); - auto v = builder.ConstantR1({10.0f, 20.0f}); - builder.Add(a, v, /*broadcast_dimensions=*/{0}); + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto v = ConstantR1(&builder, {10.0f, 20.0f}); + Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2619,12 +2714,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {9.0f, 10.0f}, {11.0f, 12.0f}}, }); - auto a = builder.ConstantR3FromArray3D(a_3d); - auto m = builder.ConstantR2({ + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto m = ConstantR2(&builder, { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2644,12 +2739,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); - auto a = builder.ConstantR3FromArray3D(a_3d); + auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); - auto b = builder.ConstantR3FromArray3D(b_3d); + auto b = ConstantR3FromArray3D(&builder, b_3d); - builder.Gt(a, b); + Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2684,9 +2779,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { } } - auto a = builder.ConstantR4FromArray4D(*operand_a_4d); - auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - builder.Add(a, b); + auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); + auto b = ConstantR4FromArray4D(&builder, *operand_b_4d); + Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2712,9 +2807,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { } } - auto a = builder.ConstantR4FromArray4D(*operand_a_4d); - auto b = builder.ConstantR1(operand_b_1d); - builder.Add(a, b, {1}); + auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); + auto b = ConstantR1(&builder, operand_b_1d); + Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2732,9 +2827,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = builder.ConstantLiteral(*a_literal); - auto b = builder.ConstantR1(r1); - builder.Add(a, b, {1}); + auto a = ConstantLiteral(&builder, *a_literal); + auto b = ConstantR1(&builder, r1); + Add(a, b, {1}); for (int i0 = 0; i0 < d0; ++i0) { for (int i1 = 0; i1 < d1; ++i1) { @@ -2752,22 +2847,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); - auto x = builder.Parameter(0, shape, "x"); - builder.Add(x, x); + auto x = Parameter(&builder, 0, shape, "x"); + Add(x, x); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( - "Expected non-opaque argument for lhs of binary operation")); + "Expected array argument for lhs of binary operation")); } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2776,11 +2871,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b, /*broadcast_dimensions=*/{1, 0}); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); @@ -2797,10 +2892,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, y_literal->shape(), "y"); - auto slice = builder.Slice(x, {1}, {2}, {1}); - builder.Sub(slice, y); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto slice = Slice(x, {1}, {2}, {1}); + Sub(slice, y); ComputeAndCompareR1(&builder, {-2, -3}, {x_data.get(), y_data.get()}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index fcd9ff55e393f64476ddd4754e0fa74427f1cb51..8d15b7841bc7298cd6865d8689cc496c0459e4b9 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -29,10 +29,10 @@ class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { XlaBuilder builder("ax_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - builder.Mul(alpha, x); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -42,11 +42,11 @@ TEST_F(AxpySimpleTest, AxTenValues) { XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1({}); - auto y = builder.ConstantR1({}); - auto ax = builder.Mul(alpha, x); - builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1(&builder, {}); + auto y = ConstantR1(&builder, {}); + auto ax = Mul(alpha, x); + Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -54,13 +54,13 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { TEST_F(AxpySimpleTest, AxpyTenValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto y = builder.ConstantR1( - {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); - auto ax = builder.Mul(alpha, x); - builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = ConstantR1( + &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = Mul(alpha, x); + Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index 22c3394e6f34bd018ffaaaa4d9d68339673c3764..8c227df7f04e79ccc332062d0889d282c0f5e40f 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -35,10 +35,10 @@ class BadRngShapeValidationTest : public ClientLibraryTestBase {}; TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0.0); - auto one = builder.ConstantR0(1.0); + auto zero = ConstantR0(&builder, 0.0); + auto one = ConstantR0(&builder, 1.0); Shape default_constructed; - builder.RngUniform(zero, one, default_constructed); + RngUniform(zero, one, default_constructed); StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); @@ -49,13 +49,13 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0.0); - auto one = builder.ConstantR0(1.0); + auto zero = ConstantR0(&builder, 0.0); + auto one = ConstantR0(&builder, 1.0); Shape sans_layout; sans_layout.set_element_type(F32); sans_layout.add_dimensions(1); - builder.RngUniform(zero, one, sans_layout); + RngUniform(zero, one, sans_layout); StatusOr computation = builder.Build(); ASSERT_TRUE(computation.ok()); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index f3dac75a44b948c4b45b80b93e7462073010979e..217673c8cbc212958fe79b67546f28b0be091803 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -101,9 +102,9 @@ INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XlaBuilder builder("subtract_in_z_one_sample"); - auto x = builder.ConstantLiteral(input_literal_); - auto y = builder.ConstantR1({3.14, 4.25}); - builder.Sub(x, y, /*broadcast_dimensions=*/{1}); + auto x = ConstantLiteral(&builder, input_literal_); + auto y = ConstantR1(&builder, {3.14, 4.25}); + Sub(x, y, /*broadcast_dimensions=*/{1}); Array4D expected(kSamples, kZ, kY, kX); Array2D pz({ @@ -117,8 +118,8 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XlaBuilder builder("square_tesseract_elementwise"); - auto x = builder.ConstantLiteral(input_literal_); - builder.SquareF32(x); + auto x = ConstantLiteral(&builder, input_literal_); + Square(x); using tensorflow::MathUtil; @@ -134,11 +135,10 @@ XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XLA_TEST_P(BatchNormalizationTest, SumToZ) { XlaBuilder builder("sum_to_z"); - auto input_activations = builder.ConstantLiteral(input_literal_); + auto input_activations = ConstantLiteral(&builder, input_literal_); XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all but the Z dimension. - builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, - {0, 2, 3}); + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {6, 12.6}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -146,13 +146,13 @@ XLA_TEST_P(BatchNormalizationTest, SumToZ) { XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XlaBuilder builder("square_and_reduce"); - auto input_activations = builder.ConstantLiteral(input_literal_); - auto set_means = builder.ConstantR1({2.f, 4.2f}); - auto activation_deviations = builder.Sub(input_activations, set_means, - /*broadcast_dimensions=*/{1}); + auto input_activations = ConstantLiteral(&builder, input_literal_); + auto set_means = ConstantR1(&builder, {2.f, 4.2f}); + auto activation_deviations = Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto dev_squares = builder.SquareF32(activation_deviations); - builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); + auto dev_squares = Square(activation_deviations); + Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {18, 0.06}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -160,8 +160,8 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XlaBuilder builder("variance_to_stddev"); - auto variance = builder.ConstantR1({6.f, .02f}); - builder.SqrtF32(variance); + auto variance = ConstantR1(&builder, {6.f, .02f}); + Sqrt(variance); std::vector expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -172,50 +172,50 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { XlaBuilder builder("batch_normalize_per_spec"); auto input_activations = - CheckShape(&builder, builder.ConstantLiteral(input_literal_), + CheckShape(&builder, ConstantLiteral(&builder, input_literal_), ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); - auto gamma = builder.ConstantR1({1.0, 1.0}); - auto beta = builder.ConstantR1({0.0, 0.0}); + auto gamma = ConstantR1(&builder, {1.0, 1.0}); + auto beta = ConstantR1(&builder, {0.0, 0.0}); XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); auto sum = CheckShape( &builder, - builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0, 2, 3}), + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); - auto count = builder.ConstantR0(ShapeUtil::ElementsIn(input_shape) / - ShapeUtil::ElementsIn(sum_shape)); - auto set_means = builder.Div(sum, count); + auto count = + ConstantR0(&builder, ShapeUtil::ElementsIn(input_shape) / + ShapeUtil::ElementsIn(sum_shape)); + auto set_means = Div(sum, count); const float kEpsilon = 1e-9f; - auto epsilon = builder.ConstantR0(kEpsilon); - auto epsilon2 = builder.ConstantR1({kEpsilon, kEpsilon}); - auto activation_deviations = builder.Sub(input_activations, set_means, - /*broadcast_dimensions=*/{1}); - auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = CheckShape( - &builder, - builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0, 2, 3}), - TwoElementVectorF32); - auto variance = builder.Div(sum_of_squares, count); - auto standard_deviation = builder.SqrtF32(variance); + auto epsilon = ConstantR0(&builder, kEpsilon); + auto epsilon2 = ConstantR1(&builder, {kEpsilon, kEpsilon}); + auto activation_deviations = Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + auto dev_squares = Square(activation_deviations); + auto sum_of_squares = + CheckShape(&builder, + Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto variance = Div(sum_of_squares, count); + auto standard_deviation = Sqrt(variance); auto standard_deviation_above_epsilon = - CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + CheckShape(&builder, Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); - auto gt_eps = builder.Select(standard_deviation_above_epsilon, - standard_deviation, epsilon2); - auto normalization_factors = builder.ReciprocalF32(gt_eps); + auto gt_eps = + Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); + auto normalization_factors = Reciprocal(gt_eps); auto normalized_input_activations = - builder.Mul(activation_deviations, normalization_factors, - /*broadcast_dimensions=*/{1}); - /* auto output_activations = */ builder.Add( - builder.Mul(normalized_input_activations, gamma, - /*broadcast_dimensions=*/{1}), - beta, /*broadcast_dimensions=*/{1}); + Mul(activation_deviations, normalization_factors, + /*broadcast_dimensions=*/{1}); + /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma, + /*broadcast_dimensions=*/{1}), + beta, /*broadcast_dimensions=*/{1}); Array4D expected(kSamples, kZ, kY, kX); Array2D pz({ @@ -232,15 +232,15 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { const int kFeatureIndex = 3; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + auto operand = ConstantR4FromArray4D( + &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); - auto scale = builder.ConstantR1({2.0f, 3.0f}); + auto scale = ConstantR1(&builder, {2.0f, 3.0f}); - auto offset = builder.ConstantR1({1.0f, 2.0f}); + auto offset = ConstantR1(&builder, {1.0f, 2.0f}); - builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, @@ -252,19 +252,20 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } -XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { +XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( + auto operand = ConstantR4FromArray4D( + &builder, {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - auto scale = builder.ConstantR1({2.0f, 3.0f}); + auto scale = ConstantR1(&builder, {2.0f, 3.0f}); - auto offset = builder.ConstantR1({1.0f, 2.0f}); + auto offset = ConstantR1(&builder, {1.0f, 2.0f}); - builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, @@ -294,8 +295,8 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/2, "offset", &builder, &h2); - builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); + BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) @@ -327,8 +328,8 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { /*parameter_number=*/2, "offset", &builder, &h2); // var = 125, mean = 15, epsilon = -100 - builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); + BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) @@ -346,19 +347,20 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { XlaBuilder builder(TestName()); auto operand = - builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); + ConstantR4FromArray4D(&builder, Array4D(2, 2, 2, 1, 0.0f)); - auto scale = builder.ConstantR1({1.0f, 1.0f}); + auto scale = ConstantR1(&builder, {1.0f, 1.0f}); - auto mean = builder.ConstantR1({0.0f, 0.0f}); + auto mean = ConstantR1(&builder, {0.0f, 0.0f}); - auto var = builder.ConstantR1({1.0f, 1.0f}); + auto var = ConstantR1(&builder, {1.0f, 1.0f}); - auto grad_output = builder.ConstantR4FromArray4D( + auto grad_output = ConstantR4FromArray4D( + &builder, {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); + BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, @@ -518,11 +520,11 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = Literal::CreateR4FromArray4D(input_array); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); auto scale_activations = - builder.Parameter(1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal->shape(), "offset"); auto offset_activations = - builder.Parameter(2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal->shape(), "scale"); auto expected = Literal::MakeTuple({expected_normalized.get(), Literal::CreateR1(mean).get(), @@ -535,8 +537,8 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { std::unique_ptr offset_data = client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); - builder.BatchNormTraining(input_activations, scale_activations, - offset_activations, epsilon, feature_index); + BatchNormTraining(input_activations, scale_activations, offset_activations, + epsilon, feature_index); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor @@ -618,14 +620,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = Literal::CreateR4FromArray4D(input_array); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); auto scale_activations = - builder.Parameter(1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal->shape(), "offset"); auto offset_activations = - builder.Parameter(2, offset_literal->shape(), "scale"); - auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal->shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); auto variance_activations = - builder.Parameter(4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal->shape(), "variance"); Array4D expected = normalized; @@ -640,9 +642,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { std::unique_ptr variance_data = client_->TransferToServer(*var_literal).ConsumeValueOrDie(); - builder.BatchNormInference(input_activations, scale_activations, - offset_activations, mean_activations, - variance_activations, epsilon, feature_index); + BatchNormInference(input_activations, scale_activations, offset_activations, + mean_activations, variance_activations, epsilon, + feature_index); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor @@ -807,12 +809,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = Literal::CreateR4FromArray4D(grad_output_array); - auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); - auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); - auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); - auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); + auto input_parameter = + Parameter(&builder, 0, input_literal->shape(), "input"); + auto scale_parameter = + Parameter(&builder, 1, scale_literal->shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); auto grad_output_parameter = - builder.Parameter(4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -825,9 +829,8 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { std::unique_ptr grad_output_data = client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); - builder.BatchNormGrad(input_parameter, scale_parameter, mean_parameter, - var_parameter, grad_output_parameter, epsilon, - feature_index); + BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, + grad_output_parameter, epsilon, feature_index); auto expected = Literal::MakeTuple({expected_grad_activation.get(), diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index ca337e78840e77377719636cd4cf33af2578210d..f40d03bea79de2a78814a0ad9f6cae6098d1449b 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -51,9 +51,9 @@ class Bfloat16Test : public ClientLibraryTestBase { XLA_TEST_F(Bfloat16Test, ScalarOperation) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(static_cast(2.0f)); - auto y = builder.ConstantR0(static_cast(1.0f)); - builder.Add(x, y); + auto x = ConstantR0(&builder, static_cast(2.0f)); + auto y = ConstantR0(&builder, static_cast(1.0f)); + Add(x, y); ComputeAndCompareR0(&builder, static_cast(3.0f), {}, error_spec_); @@ -61,8 +61,8 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { XLA_TEST_F(Bfloat16Test, LogOperation) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(static_cast(4.0f)); - builder.Log(x); + auto x = ConstantR0(&builder, static_cast(4.0f)); + Log(x); ComputeAndCompareR0(&builder, static_cast(1.387f), {}, error_spec_); @@ -70,7 +70,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { XLA_TEST_F(Bfloat16Test, NegateScalarF16) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(static_cast(2.1f))); + Neg(ConstantR0(&builder, static_cast(2.1f))); ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, error_spec_); @@ -80,20 +80,20 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( + auto operand = ConstantR4FromArray4D( + &builder, {{{{static_cast(1.f)}, {static_cast(2.f)}}, {{static_cast(3.f)}, {static_cast(4.f)}}}, {{{static_cast(5.f)}, {static_cast(6.f)}}, {{static_cast(7.f)}, {static_cast(8.f)}}}}); - auto scale = builder.ConstantR1( - {static_cast(2.0f), static_cast(3.0f)}); + auto scale = ConstantR1( + &builder, {static_cast(2.0f), static_cast(3.0f)}); - auto offset = builder.ConstantR1( - {static_cast(1.0f), static_cast(2.0f)}); + auto offset = ConstantR1( + &builder, {static_cast(1.0f), static_cast(2.0f)}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4( @@ -117,26 +117,27 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( - Array4D(2, 2, 2, 1, static_cast(0.0f))); + auto operand = ConstantR4FromArray4D( + &builder, Array4D(2, 2, 2, 1, static_cast(0.0f))); - auto scale = builder.ConstantR1( - {static_cast(1.0f), static_cast(1.0f)}); + auto scale = ConstantR1( + &builder, {static_cast(1.0f), static_cast(1.0f)}); - auto mean = builder.ConstantR1( - {static_cast(0.0f), static_cast(0.0f)}); + auto mean = ConstantR1( + &builder, {static_cast(0.0f), static_cast(0.0f)}); - auto var = builder.ConstantR1( - {static_cast(1.0f), static_cast(1.0f)}); + auto var = ConstantR1( + &builder, {static_cast(1.0f), static_cast(1.0f)}); - auto grad_output = builder.ConstantR4FromArray4D( + auto grad_output = ConstantR4FromArray4D( + &builder, {{{{static_cast(1.f)}, {static_cast(2.f)}}, {{static_cast(3.f)}, {static_cast(4.f)}}}, {{{static_cast(5.f)}, {static_cast(6.f)}}, {{static_cast(7.f)}, {static_cast(8.f)}}}}); - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); + BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4( diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 48203b1d40ea69ff00a57c2c9e42620739b23d59..20cb989751ad69e2f3cf97c87c43293951f599ab 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -33,9 +33,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -49,9 +49,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -65,9 +65,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -81,9 +81,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -94,11 +94,12 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { TEST_F(BinopScalingTest, R0PlusR2F32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR0(42.0); - auto rhs = builder.ConstantR2({ - {1.0, 2.0}, {3.0, 4.0}, - }); - builder.Add(lhs, rhs); + auto lhs = ConstantR0(&builder, 42.0); + auto rhs = ConstantR2(&builder, { + {1.0, 2.0}, + {3.0, 4.0}, + }); + Add(lhs, rhs); Array2D expected(2, 2); expected(0, 0) = 42.0 + 1.0; @@ -129,9 +130,9 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { }); // clang-format on - auto lhs = builder.ConstantR4FromArray4D(lhs_array); - auto rhs = builder.ConstantR0(42); - builder.Add(lhs, rhs); + auto lhs = ConstantR4FromArray4D(&builder, lhs_array); + auto rhs = ConstantR0(&builder, 42); + Add(lhs, rhs); ComputeAndCompareR4(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index bff60f25ec8f15d372d251ac313200301a04f20f..d531e8fa82e47f7bcd278f10da2c205e44db0ac1 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -43,8 +43,8 @@ class BitcastConvertTest : public ClientLibraryTestBase { TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.BitcastConvertType(a, S32); + auto a = ConstantR1(&builder, {42, 64}); + BitcastConvertType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -52,8 +52,8 @@ TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + BitcastConvertType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -62,10 +62,10 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { XlaBuilder builder(TestName()); auto a = - builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, - static_cast(0xBF800000), 0x3F000000, - static_cast(0xBF000000)}); - builder.BitcastConvertType(a, F32); + ConstantR1(&builder, {0, static_cast(0x80000000), + 0x3F800000, static_cast(0xBF800000), + 0x3F000000, static_cast(0xBF000000)}); + BitcastConvertType(a, F32); std::vector expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f}; ComputeAndCompareR1(&builder, expected, {}); @@ -73,8 +73,8 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {}); + BitcastConvertType(a, F32); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}); @@ -82,8 +82,8 @@ XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.6, 64.4}); - builder.BitcastConvertType(a, S32); + auto a = ConstantR1(&builder, {42.6, 64.4}); + BitcastConvertType(a, S32); std::vector expected = {0x422a6666, 0x4280cccd}; ComputeAndCompareR1(&builder, expected, {}); @@ -91,9 +91,9 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { TEST_F(BitcastConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {std::numeric_limits::min(), std::numeric_limits::max()}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::max()}); + BitcastConvertType(a, F32); std::vector expected = {-0.0f, NAN}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0, 0)); @@ -102,10 +102,10 @@ TEST_F(BitcastConvertTest, ConvertS32Extremes) { TEST_F(BitcastConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); - b->BitcastConvertType(param, S32); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); + BitcastConvertType(param, S32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {0x42280000, 0x42800000}; ComputeAndCompareR1(&builder, expected, {}); @@ -114,10 +114,10 @@ TEST_F(BitcastConvertTest, ConvertMapToS32) { TEST_F(BitcastConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); - b->BitcastConvertType(param, F32); - auto a = builder.ConstantR1({0x42280000, 0x42800000}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); + BitcastConvertType(param, F32); + auto a = ConstantR1(&builder, {0x42280000, 0x42800000}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -130,9 +130,9 @@ TEST_F(BitcastConvertTest, ConvertMapToF32) { // the new convert should have the same element type as the old convert. TEST_F(BitcastConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); - auto input = builder.ConstantR1({0x42280000}); - auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); - builder.BitcastConvertType(reshape, F32); + auto input = ConstantR1(&builder, {0x42280000}); + auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + BitcastConvertType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}); } diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 34c86e007beea1cbac04641bdbdab62dc567f13e..91aba9a8de3f1fe098e8bc8cc9d5378fa67b8385 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -37,17 +37,17 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { XlaBuilder* builder) { switch (op) { case HloOpcode::kMinimum: { - return builder->Min(lhs, rhs); + return Min(lhs, rhs); } case HloOpcode::kMaximum: { - return builder->Max(lhs, rhs); + return Max(lhs, rhs); } case HloOpcode::kMultiply: { - return builder->Mul(lhs, rhs); + return Mul(lhs, rhs); } default: { // Default to Add - return builder->Add(lhs, rhs); + return Add(lhs, rhs); } } } @@ -104,13 +104,13 @@ using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(1.5), {}); + Broadcast(ConstantR0(&b, 1.5), {}); ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {2, 3}); + Broadcast(ConstantR0(&b, 2.25), {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } @@ -122,7 +122,7 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", /*builder=*/&b, /*data_handle=*/&src); - b.Broadcast(src, {2, 3}); + Broadcast(src, {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {param_data.get()}, ErrorSpec(0.0001)); @@ -130,21 +130,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {2, 0}); + Broadcast(ConstantR0(&b, 2.25), {2, 0}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {0, 2}); + Broadcast(ConstantR0(&b, 2.25), {0, 2}); Array2D expected(0, 2); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({1, 2, 3}), {2}); + Broadcast(ConstantR1(&b, {1, 2, 3}), {2}); Array2D expected(2, 3); expected(0, 0) = 1; @@ -156,6 +156,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {1}); + + Array2D expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {0}); + + Array2D expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 1; + expected(1, 0) = 2; + expected(1, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + + Array3D expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 1.0; + expected(1, 0, 1) = 2.0; + expected(0, 1, 0) = 5.0; + expected(1, 1, 0) = 6.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + + Array3D expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 5.0; + expected(1, 0, 1) = 6.0; + expected(0, 1, 0) = 1.0; + expected(1, 1, 0) = 2.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {3, 2}), {1}); + + Array2D expected(3, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + expected(2, 0) = 1; + expected(2, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + // Tests implicit broadcasting of PREDs. XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaBuilder b(TestName()); @@ -172,7 +252,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaOp x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); - b.And(x, y, /*broadcast_dimensions=*/{1, 2}); + And(x, y, /*broadcast_dimensions=*/{1, 2}); Array3D expected(2, 2, 1); expected(0, 0, 0) = false; @@ -185,7 +265,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({}), {2}); + Broadcast(ConstantR1(&b, {}), {2}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); @@ -193,7 +273,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({1, 2, 3}), {0}); + Broadcast(ConstantR1(&b, {1, 2, 3}), {0}); Array2D expected(0, 3); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); @@ -209,10 +289,10 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { // dimensions. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 5.0}}), - b.ConstantLiteral(*Literal::CreateR3( - {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), - /*broadcast_dimensions=*/{1, 2}); + Add(ConstantR2(&b, {{1.0, 5.0}}), + ConstantLiteral(&b, *Literal::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); auto expected = Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, @@ -260,9 +340,10 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, &r3_implicit_array, 1.0, 0.2, 56789); - auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); - auto r3_parameter = builder.Parameter(1, r3_shape, "input"); - XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + auto r3_implicit_parameter = + Parameter(&builder, 0, r3_implicit_shape, "input"); + auto r3_parameter = Parameter(&builder, 1, r3_shape, "input"); + BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); @@ -306,7 +387,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); - b.Add(r3h, r1h); + Add(r3h, r1h); auto expected = Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); @@ -317,10 +398,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); @@ -330,10 +411,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); @@ -343,10 +424,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}, {3, 4}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); @@ -356,10 +437,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = + ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); @@ -370,10 +452,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = - b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); @@ -383,10 +465,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); @@ -509,14 +591,14 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); auto r2_implicit_parameter1 = - builder.Parameter(0, r2_implicit_shape1, "input0"); - auto r2_parameter = builder.Parameter(1, r2_shape, "input1"); + Parameter(&builder, 0, r2_implicit_shape1, "input0"); + auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1"); auto r2_implicit_parameter2 = - builder.Parameter(2, r2_implicit_shape2, "input2"); + Parameter(&builder, 2, r2_implicit_shape2, "input2"); XlaOp op1 = BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); - XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); Array2D expected_array(spec.output_bounds[0], spec.output_bounds[1]); @@ -544,9 +626,9 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); - auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); - b.Add(r2, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + Add(r2, r1); auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); @@ -555,9 +637,9 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); - auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); - b.Add(r2, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + Add(r2, r1); auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); @@ -566,10 +648,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1, {0}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1, {0}); auto expected = Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); @@ -579,10 +661,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r1, r3, {1}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r1, r3, {1}); auto expected = Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); @@ -592,10 +674,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r1, r3, {2}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r1, r3, {2}); auto expected = Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); @@ -605,17 +687,17 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XlaBuilder b(TestName()); - auto r1_0 = b.ConstantR1({1000, 2000}); - auto r1_1 = b.ConstantR1({100, 200}); - auto r1_2 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + auto r1_0 = ConstantR1(&b, {1000, 2000}); + auto r1_1 = ConstantR1(&b, {100, 200}); + auto r1_2 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { - r3 = b.Add(r1_0, r3, {0}); - r3 = b.Add(r3, r1_1, {1}); - r3 = b.Add(r1_2, r3, {2}); + r3 = Add(r1_0, r3, {0}); + r3 = Add(r3, r1_1, {1}); + r3 = Add(r1_2, r3, {2}); } - r3 = b.Mul(r3, b.ConstantR0(-2)); + r3 = Mul(r3, ConstantR0(&b, -2)); auto expected = Literal::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, @@ -626,17 +708,17 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { XlaBuilder b(TestName()); - auto r1_0 = b.ConstantR1({1000, 2000}); - auto r1_1 = b.ConstantR1({100, 200}); - auto r1_2 = b.ConstantR1({10, 20}); - auto r0 = b.ConstantR0(3); - auto r3 = b.Broadcast(r0, {2, 2, 2}); + auto r1_0 = ConstantR1(&b, {1000, 2000}); + auto r1_1 = ConstantR1(&b, {100, 200}); + auto r1_2 = ConstantR1(&b, {10, 20}); + auto r0 = ConstantR0(&b, 3); + auto r3 = Broadcast(r0, {2, 2, 2}); for (int i = 0; i < 3; ++i) { - r3 = b.Add(r1_0, r3, {0}); - r3 = b.Add(r3, r1_1, {1}); - r3 = b.Add(r1_2, r3, {2}); + r3 = Add(r1_0, r3, {0}); + r3 = Add(r3, r1_1, {1}); + r3 = Add(r1_2, r3, {2}); } - r3 = b.Mul(r3, b.ConstantR0(-1)); + r3 = Mul(r3, ConstantR0(&b, -1)); auto expected = Literal::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, @@ -650,10 +732,10 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // results in a shape incompatible with the lhs [2, 3, 1]. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), - b.ConstantLiteral(*Literal::CreateR3( - {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), - /*broadcast_dimensions=*/{1, 2}); + Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), + ConstantLiteral(&b, *Literal::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); @@ -665,26 +747,26 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 2.0}}), - b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + Add(ConstantR2(&b, {{1.0, 2.0}}), + ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 2.0}}), - b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + Add(ConstantR2(&b, {{1.0, 2.0}}), + ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 5fd33b50c94356839bbed58acd43b7d0286f4a7e..bc64a19ce22072152216a7c150fbd16480d261fb 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -34,7 +34,7 @@ class CallOpTest : public ClientLibraryTestBase { protected: XlaComputation CreateR0F32IdentityComputation() { XlaBuilder builder("Identity"); - builder.Parameter(0, r0f32_, "x"); + Parameter(&builder, 0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -42,9 +42,9 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR1S0F32AdditionComputation() { XlaBuilder builder("Addition"); - auto x = builder.Parameter(0, r1s0f32_, "x"); - auto y = builder.Parameter(1, r1s0f32_, "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, r1s0f32_, "x"); + auto y = Parameter(&builder, 1, r1s0f32_, "y"); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -52,9 +52,9 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR1S2F32AdditionComputation() { XlaBuilder builder("Addition"); - auto x = builder.Parameter(0, r1s2f32_, "x"); - auto y = builder.Parameter(1, r1s2f32_, "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, r1s2f32_, "x"); + auto y = Parameter(&builder, 1, r1s2f32_, "y"); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -62,7 +62,7 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR0F32TupleComputation() { XlaBuilder builder("Tuple"); - builder.Tuple({builder.Parameter(0, r0f32_, "x")}); + Tuple(&builder, {Parameter(&builder, 0, r0f32_, "x")}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -76,8 +76,8 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); - builder.Call(callee, {constant}); + auto constant = ConstantLiteral(&builder, *Literal::CreateR0(42.0)); + Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); } @@ -85,9 +85,9 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = builder.ConstantLiteral(*Literal::CreateR1({})); - auto y = builder.ConstantLiteral(*Literal::CreateR1({})); - builder.Call(callee, {x, y}); + auto x = ConstantLiteral(&builder, *Literal::CreateR1({})); + auto y = ConstantLiteral(&builder, *Literal::CreateR1({})); + Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); } @@ -95,9 +95,9 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); - auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); - auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); - builder.Call(callee, {x, y}); + auto x = ConstantLiteral(&builder, *Literal::CreateR1({1.0f, 2.0f})); + auto y = ConstantLiteral(&builder, *Literal::CreateR1({2.0f, 3.0f})); + Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); } @@ -105,26 +105,26 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { XlaBuilder builder("inner"); { - auto x = builder.Parameter(0, r0f32_, "x"); - builder.Add(x, builder.ConstantR0(1.0)); + auto x = Parameter(&builder, 0, r0f32_, "x"); + Add(x, ConstantR0(&builder, 1.0)); } TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build()); XlaBuilder builder2("outer"); { - auto x = builder2.Parameter(0, r0f32_, "x"); - x = builder2.Call(inner, {x}); - x = builder2.Call(inner, {x}); - x = builder2.Call(inner, {x}); + auto x = Parameter(&builder2, 0, r0f32_, "x"); + x = Call(&builder2, inner, {x}); + x = Call(&builder2, inner, {x}); + x = Call(&builder2, inner, {x}); } TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build()); XlaBuilder builder3("outermost"); { - auto x = builder3.Parameter(0, r0f32_, "x"); - x = builder3.Call(outer, {x}); - x = builder3.Call(outer, {x}); - x = builder3.Call(outer, {x}); + auto x = Parameter(&builder3, 0, r0f32_, "x"); + x = Call(&builder3, outer, {x}); + x = Call(&builder3, outer, {x}); + x = Call(&builder3, outer, {x}); } TF_ASSERT_OK_AND_ASSIGN( @@ -138,7 +138,7 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaComputation callee = CreateR0F32TupleComputation(); auto elem = Literal::CreateR0(42.0); auto tuple = Literal::MakeTuple({elem.get()}); - builder.Call(callee, {builder.ConstantLiteral(*elem)}); + Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); } diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 660ff0cad5666219a4a7cb1eedbed03f06e651ba..1ad57c075b22c7730ffd8d1beeab60c9d5dc7458 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,9 +38,9 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = Literal::CreateR1({1.1f, 2.2f}); - auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + Add(p0, p1); auto param0_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -77,9 +77,9 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { XlaBuilder builder("add_two_params"); - auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); - auto add = builder.Mul(p0, p1); + auto p0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); + auto p1 = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {4}), "param1"); + Mul(p0, p1); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index bf8ed4d9fb0bc61b86ef0b5872711a122a3d416b..dafd6ebabbe6edafc1c926677b3ea00e775be010 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -486,11 +487,11 @@ ClientLibraryTestBase::ComputeValueAndReference( XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaBuilder builder("relu"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto z_value = builder.Parameter(0, shape, "z_value"); + auto z_value = Parameter(&builder, 0, shape, "z_value"); auto zero = use_bfloat16_ - ? builder.ConstantR0(static_cast(0.0f)) - : builder.ConstantR0(0.0f); - builder.Max(z_value, zero); + ? ConstantR0(&builder, static_cast(0.0f)) + : ConstantR0(&builder, 0.0f); + Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -499,9 +500,9 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto x = builder.Parameter(0, shape, "x"); - auto y = builder.Parameter(1, shape, "y"); - builder.Max(x, y); + auto x = Parameter(&builder, 0, shape, "x"); + auto y = Parameter(&builder, 1, shape, "y"); + Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -510,13 +511,13 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto activation = builder.Parameter(0, shape, "activation"); - auto backprop = builder.Parameter(1, shape, "backprop"); + auto activation = Parameter(&builder, 0, shape, "activation"); + auto backprop = Parameter(&builder, 1, shape, "backprop"); auto zero = use_bfloat16_ - ? builder.ConstantR0(static_cast(0.0f)) - : builder.ConstantR0(0.0f); - auto activation_gtz = builder.Gt(activation, zero); - builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); + ? ConstantR0(&builder, static_cast(0.0f)) + : ConstantR0(&builder, 0.0f); + auto activation_gtz = Gt(activation, zero); + Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -559,8 +560,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); + return ConstantLiteral( + builder, use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); } std::unique_ptr @@ -588,7 +589,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( client_->TransferToServer(*param_literal, device_handle) .ConsumeValueOrDie(); *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); + Parameter(builder, parameter_number, param_literal->shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 0499fec5898a42affa0e0a712dee10187355c13e..5361ae6783c4c103cf923ffbda066165545c39a1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -373,6 +373,13 @@ class ClientLibraryTestBase : public ::testing::Test { // The float type used in this test, BF16 or F32 according to use_bfloat16. PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + Client* client_; Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -390,13 +397,6 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the reference client. Returns two literals in the order of (expected, - // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; @@ -545,7 +545,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -559,7 +559,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -573,7 +573,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -587,7 +587,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 08671cf62445826649b5c97003f998ae98a59d97..831b863998f1cab31d37aa4474be45d8531075ac 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -43,8 +43,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& execute_layout : layouts) { for (const std::vector& transfer_layout : layouts) { - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})); + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; @@ -72,8 +72,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XlaBuilder b(TestName()); - b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})}); + Tuple(&b, {ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})}); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); @@ -117,8 +117,8 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); - b.Add(b.Parameter(0, shape, "param_0"), - b.ConstantR2({{1, 2}, {3, 4}})); + Add(Parameter(&b, 0, shape, "param_0"), + ConstantR2(&b, {{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); // We can't really test parallel execution on CPU since all of the cores in a diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 50a006964869b3e5dce431d441f7cd81af9df910..eb211dd8ff376fb0da03b3e68be1d849970d96fd 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -77,7 +77,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { // TODO(b/74197823): Disabled because there is no cache in the new design. XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(42.0)); + Neg(ConstantR0(&builder, 42.0)); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); @@ -99,7 +99,7 @@ XLA_TEST_F(CompilationCacheTest, .ConsumeValueOrDie(); XlaBuilder builder(TestName()); - builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param")); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, @@ -115,16 +115,16 @@ XLA_TEST_F(CompilationCacheTest, // TODO(b/74197823): Disabled because there is no cache in the new design. XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { XlaBuilder builder_neg(TestName() + "_neg"); - builder_neg.Neg(builder_neg.ConstantR0(42.0)); + Neg(ConstantR0(&builder_neg, 42.0)); XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie(); XlaBuilder builder_exp(TestName() + "_exp"); - builder_exp.Exp(builder_exp.ConstantR0(1.0)); + Exp(ConstantR0(&builder_exp, 1.0)); XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie(); XlaBuilder builder_add(TestName() + "_add"); - builder_add.Add(builder_add.ConstantR0(2.0), - builder_add.ConstantR0(3.0)); + Add(ConstantR0(&builder_add, 2.0), + ConstantR0(&builder_add, 3.0)); XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation_neg, {}, -42.0, @@ -154,7 +154,7 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR2F32(computation, {colmaj_handle.get()}, diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ba22530f1cfee56337f862c25122d399dbf0f1e4..1a396b090c615dbd829964bd68ebda74df29c71e 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.ConstantR0(42); + auto computation = ConstantR0(&b, 42); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + Add(ConstantR0(&b, 42.5f), ConstantR0(&b, 1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); + RngUniform(ConstantR0(&b, 1.1f), ConstantR0(&b, 2.1f), + ShapeUtil::MakeShape(F32, {})); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Add(ConstantR0(&b, 1.0f), + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = - b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + Add(ConstantR0(&b, 2.5f), ConstantR0(&b, 1.5f)); + auto not_constant_a = Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1"); auto constant_9 = - b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + Mul(ConstantR0(&b, 2.0f), ConstantR0(&b, 4.5f)); + auto not_constant_b = Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = Add(constant_4, constant_9); + Add(not_constant_b, Add(constant_13, not_constant_a)); EXPECT_TRUE(IsConstant(constant_13, &b)); @@ -201,7 +201,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + Add(ConstantR1(&b, {1, 2}), ConstantR1(&b, {3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, @@ -216,7 +216,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + auto computation = Div(ConstantR0(&b, 15), ConstantR0(&b, 3)); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, @@ -237,8 +237,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, ComputeConstantLiteral( client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); std::unique_ptr expected_literal = diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf..1161b560b7b0756556911812666c6f4fe9179f72 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -39,7 +39,7 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { XlaBuilder builder(TestName()); - builder.ConcatInDim({}, 0); + ConcatInDim(&builder, {}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -49,8 +49,8 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - builder.ConcatInDim({a}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + ConcatInDim(&builder, {a}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -58,8 +58,8 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.ConcatInDim({a}, 0); + auto a = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -69,9 +69,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(42.0); - auto b = builder.ConstantR0(64.0); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR0(&builder, 42.0); + auto b = ConstantR0(&builder, 64.0); + ConcatInDim(&builder, {a, b}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -80,9 +80,9 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -90,9 +90,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -100,9 +100,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - auto b = builder.ConstantR1({}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + auto b = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -110,9 +110,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - auto b = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + auto b = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -130,9 +130,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(lhs); - auto b = builder.ConstantR1(rhs); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, lhs); + auto b = ConstantR1(&builder, rhs); + ConcatInDim(&builder, {a, b}, 0); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -140,9 +140,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.ConcatInDim({a, b}, dim); + auto a = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + auto b = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + ConcatInDim(&builder, {a, b}, dim); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, ErrorSpec(0.0001)); @@ -153,9 +153,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 0); Array2D expected({ {0}, @@ -168,9 +168,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); Array2D expected({ {0, 64}, @@ -181,9 +181,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { XLA_TEST_F(ConcatTest, Concat2x0With2x5) { XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); } @@ -192,9 +192,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); Array2D expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -206,9 +206,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { XLA_TEST_F(ConcatTest, Concat3x2With0x2) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + ConcatInDim(&builder, {a, b}, 0); ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); } @@ -217,9 +217,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 0); Array2D expected({ {0, 1}, @@ -236,9 +236,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); - auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); - builder.ConcatInDim({a, b}, 2); + auto a = ConstantR3FromArray3D(&builder, Array3D(3, 0, 2)); + auto b = ConstantR3FromArray3D(&builder, Array3D(3, 0, 1)); + ConcatInDim(&builder, {a, b}, 2); ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, ErrorSpec(0.0001)); } @@ -257,9 +257,9 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { {{7}}, {{8}}, }); - auto a = builder.ConstantR3FromArray3D(a_array); - auto b = builder.ConstantR3FromArray3D(b_array); - builder.ConcatInDim({a, b}, 2); + auto a = ConstantR3FromArray3D(&builder, a_array); + auto b = ConstantR3FromArray3D(&builder, b_array); + ConcatInDim(&builder, {a, b}, 2); Array3D expected({ {{0, 1, 6}}, @@ -271,10 +271,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b, c}, 0); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b, c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -300,10 +300,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { {{7}}, {{11}}, }); - auto a = builder.ConstantR3FromArray3D(a_array); - auto b = builder.ConstantR3FromArray3D(b_array); - auto c = builder.ConstantR3FromArray3D(c_array); - builder.ConcatInDim({a, b, c}, 2); + auto a = ConstantR3FromArray3D(&builder, a_array); + auto b = ConstantR3FromArray3D(&builder, b_array); + auto c = ConstantR3FromArray3D(&builder, c_array); + ConcatInDim(&builder, {a, b, c}, 2); Array3D expected({ {{0, 1, 2, 3}}, @@ -315,11 +315,11 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); // concatenated = (a concat b) concat c - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -327,11 +327,11 @@ XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); // concatenated = a concat (b concat c) - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -346,9 +346,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 0); Array2D expected(2, 1024); for (int i = 0; i < 1024; ++i) { @@ -367,9 +367,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 1); Array2D expected(1, 2048); for (int i = 0; i < 1024; ++i) { @@ -392,9 +392,9 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 1); Array2D expected(64, 66); for (int i0 = 0; i0 < 64; ++i0) { @@ -410,22 +410,37 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); - auto x = builder.Parameter(0, r1f32, "x"); - auto y = builder.Parameter(1, opaque_shape, "y"); - builder.ConcatInDim({x, y}, 0); + auto x = Parameter(&builder, 0, r1f32, "x"); + auto y = Parameter(&builder, 1, opaque_shape, "y"); + ConcatInDim(&builder, {x, y}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), - HasSubstr("Expected non-opaque argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); +} + +// Show that we can't concatenate with tokens. +XLA_TEST_F(ConcatTest, CannotConcatTokens) { + XlaBuilder builder(TestName()); + auto token_shape = ShapeUtil::MakeTokenShape(); + auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); + auto x = Parameter(&builder, 0, r1f32, "x"); + auto y = Parameter(&builder, 1, token_shape, "y"); + ConcatInDim(&builder, {x, y}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_THAT( + computation_status.status().ToString(), + HasSubstr("Expected array argument for operand of concatenation")); } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { XlaBuilder builder(TestName()); - auto p0 = builder.ConstantR1({true}); - auto p1 = builder.ConstantR1({false}); - auto p2 = builder.ConstantR1({true}); - builder.ConcatInDim({p0, p1, p2}, 0); + auto p0 = ConstantR1(&builder, {true}); + auto p1 = ConstantR1(&builder, {false}); + auto p2 = ConstantR1(&builder, {true}); + ConcatInDim(&builder, {p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); @@ -433,11 +448,11 @@ XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { XlaBuilder builder(TestName()); - auto a0 = builder.ConstantR1({1}); - auto a1 = builder.ConstantR1({2, 3}); - auto a2 = builder.ConstantR1({4, 5, 6}); - auto a3 = builder.ConstantR1({7, 8, 9, 10}); - builder.ConcatInDim({a0, a1, a2, a3}, 0); + auto a0 = ConstantR1(&builder, {1}); + auto a1 = ConstantR1(&builder, {2, 3}); + auto a2 = ConstantR1(&builder, {4, 5, 6}); + auto a3 = ConstantR1(&builder, {7, 8, 9, 10}); + ConcatInDim(&builder, {a0, a1, a2, a3}, 0); std::vector expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -472,7 +487,7 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - builder.ConcatInDim({h0, h1}, 2); + ConcatInDim(&builder, {h0, h1}, 2); ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } @@ -499,9 +514,9 @@ TEST_P(ConcatR2BinaryTest, DoIt) { rhs.FillUnique(1000); XlaBuilder builder(TestName()); - auto a0 = builder.ConstantR2FromArray2D(lhs); - auto a1 = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a0, a1}, spec.concat_dimension); + auto a0 = ConstantR2FromArray2D(&builder, lhs); + auto a1 = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a0, a1}, spec.concat_dimension); std::unique_ptr> expected = ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension); @@ -525,13 +540,13 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, f32_scalar, "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto mul = builder.Mul(x, y); - auto add1 = builder.Add(mul, builder.ConstantR1({1.f, 2.f})); - auto add2 = builder.Add(mul, builder.ConstantR1({3.f, 4.f})); - auto add3 = builder.Add(mul, builder.ConstantR1({5.f, 6.f})); - builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0); + auto x = Parameter(&builder, 0, f32_scalar, "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto mul = Mul(x, y); + auto add1 = Add(mul, ConstantR1(&builder, {1.f, 2.f})); + auto add2 = Add(mul, ConstantR1(&builder, {3.f, 4.f})); + auto add3 = Add(mul, ConstantR1(&builder, {5.f, 6.f})); + ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0); ComputeAndCompareR1(&builder, {7., 8., 9., 10., 11., 12.}, {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); @@ -549,13 +564,13 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto z = builder.Parameter(2, f32_scalar, "z"); - auto bcast = builder.Broadcast(y, {5}); - auto bcast2 = builder.Broadcast(z, {3}); - auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0); - builder.ConcatInDim({concat, bcast2}, /*dimension=*/0); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto z = Parameter(&builder, 2, f32_scalar, "z"); + auto bcast = Broadcast(y, {5}); + auto bcast2 = Broadcast(z, {3}); + auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0); + ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0); ComputeAndCompareR1( &builder, @@ -577,13 +592,13 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto z = builder.Parameter(2, f32_scalar, "y"); - auto y_bcast = builder.Broadcast(y, {1, 5, 7}); - auto z_bcast = builder.Broadcast(z, {4, 1, 7}); - auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0); - builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto z = Parameter(&builder, 2, f32_scalar, "y"); + auto y_bcast = Broadcast(y, {1, 5, 7}); + auto z_bcast = Broadcast(z, {4, 1, 7}); + auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0); + ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1); Array3D y_bcast3d(1, 5, 7, 1.5f); Array3D z_bcast3d(4, 1, 7, 5.5f); auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 7ff6706935740c7d76ee5cd03eae292386760397..ee3c83039bfc13f6ad78111d92ba0f8387a3ade3 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -26,8 +26,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { protected: XlaComputation CreateR0ConstantComputation(float value) { XlaBuilder builder("Constant"); - builder.Parameter(0, empty_tuple_, "tuple"); - builder.ConstantR0(value); + Parameter(&builder, 0, empty_tuple_, "tuple"); + ConstantR0(&builder, value); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -35,7 +35,7 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateR0IdentityComputation() { XlaBuilder builder("Identity"); - builder.Parameter(0, r0f32_, "x"); + Parameter(&builder, 0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -43,8 +43,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateCeilComputation(const Shape& shape) { XlaBuilder builder("Ceil"); - auto param = builder.Parameter(0, shape, "param"); - builder.Ceil(param); + auto param = Parameter(&builder, 0, shape, "param"); + Ceil(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -60,8 +60,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateFloorComputation(const Shape& shape) { XlaBuilder builder("Floor"); - auto param = builder.Parameter(0, shape, "param"); - builder.Floor(param); + auto param = Parameter(&builder, 0, shape, "param"); + Floor(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -78,12 +78,12 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleCeilComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - auto x_ceil = builder.Ceil(x); - auto y_ceil = builder.Ceil(y); - builder.Tuple({x_ceil, y_ceil}); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + auto x_ceil = Ceil(x); + auto y_ceil = Ceil(y); + Tuple(&builder, {x_ceil, y_ceil}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -100,12 +100,12 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleFloorComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - auto x_floor = builder.Floor(x); - auto y_floor = builder.Floor(y); - builder.Tuple({x_floor, y_floor}); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + auto x_floor = Floor(x); + auto y_floor = Floor(y); + Tuple(&builder, {x_floor, y_floor}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -122,10 +122,10 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleAddComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - builder.Add(x, y); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -142,10 +142,10 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleSubComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - builder.Sub(x, y); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + Sub(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -172,12 +172,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({}); + auto pred = ConstantR0(&builder, true); + auto operands = Tuple(&builder, {}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); - builder.Conditional(pred, operands, true_computation, operands, - false_computation); + Conditional(pred, operands, true_computation, operands, false_computation); ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); } @@ -185,11 +184,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) { // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); auto identity = CreateR0IdentityComputation(); - builder.Conditional(pred, operand1, identity, operand2, identity); + Conditional(pred, operand1, identity, operand2, identity); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -198,11 +197,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, - CreateR0FloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -211,10 +210,10 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand, CreateR0CeilComputation(), operand, - CreateR0FloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operand = ConstantR0(&builder, 12.6f); + Conditional(pred, operand, CreateR0CeilComputation(), operand, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -223,11 +222,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); - builder.Conditional(pred, operand1, floor, operand2, floor); + Conditional(pred, operand1, floor, operand2, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -236,10 +235,10 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand = builder.ConstantR0(12.6f); + auto pred = ConstantR0(&builder, false); + auto operand = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); - builder.Conditional(pred, operand, floor, operand, floor); + Conditional(pred, operand, floor, operand, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -248,11 +247,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2, - CreateR0FloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Conditional(pred, operand1, CreateR0FloorComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -261,19 +260,19 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); XlaBuilder inner_builder(TestName() + ".inner_conditional"); - auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); - auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); - auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); - inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(), - false_operand, CreateR0FloorComputation()); + auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0"); + auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1"); + auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2"); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Call(inner_builder_result.ConsumeValueOrDie(), - {pred, operand1, operand2}); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Call(&builder, inner_builder_result.ConsumeValueOrDie(), + {pred, operand1, operand2}); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -282,12 +281,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); } @@ -296,12 +295,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); } @@ -310,12 +309,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR1({24.0f, 56.0f}); - auto operand2 = builder.ConstantR1({10.0f, 11.0f}); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR1TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); } @@ -324,12 +323,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR1({24.0f, 56.0f}); - auto operand2 = builder.ConstantR1({10.0f, 11.0f}); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR1TupleSubComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); } @@ -337,11 +336,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operands = builder.Tuple( - {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); - builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, - CreateR0TupleFloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operands = Tuple(&builder, {ConstantR0(&builder, 12.2f), + ConstantR0(&builder, 25.6f)}); + Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, + CreateR0TupleFloorComputation()); ComputeAndCompareTuple( &builder, @@ -353,11 +352,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), - builder.ConstantR1({25.6f, 29.2f})}); - builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, - CreateR1TupleFloorComputation()); + auto pred = ConstantR0(&builder, true); + auto operands = + Tuple(&builder, {ConstantR1(&builder, {12.2f, 15.8f}), + ConstantR1(&builder, {25.6f, 29.2f})}); + Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, + CreateR1TupleFloorComputation()); ComputeAndCompareTuple( &builder, @@ -371,31 +371,31 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { XlaBuilder true_builder(TestName() + ".true"); { - true_builder.Parameter(0, empty_tuple_, "tuple"); - auto true_pred = true_builder.ConstantR0(true); - auto true_scalar = true_builder.ConstantR0(12.2f); - auto true_array = true_builder.ConstantR1({12.8f, 14.6f}); - true_builder.Tuple({true_pred, true_scalar, true_array}); + Parameter(&true_builder, 0, empty_tuple_, "tuple"); + auto true_pred = ConstantR0(&true_builder, true); + auto true_scalar = ConstantR0(&true_builder, 12.2f); + auto true_array = ConstantR1(&true_builder, {12.8f, 14.6f}); + Tuple(&true_builder, {true_pred, true_scalar, true_array}); } auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); XlaBuilder false_builder(TestName() + ".false"); { - false_builder.Parameter(0, empty_tuple_, "tuple"); - auto false_pred = false_builder.ConstantR0(false); - auto false_scalar = false_builder.ConstantR0(25.6f); - auto false_array = false_builder.ConstantR1({26.4f, 32.6f}); - false_builder.Tuple({false_pred, false_scalar, false_array}); + Parameter(&false_builder, 0, empty_tuple_, "tuple"); + auto false_pred = ConstantR0(&false_builder, false); + auto false_scalar = ConstantR0(&false_builder, 25.6f); + auto false_array = ConstantR1(&false_builder, {26.4f, 32.6f}); + Tuple(&false_builder, {false_pred, false_scalar, false_array}); } auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({}); - builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), - operands, false_builder_result.ConsumeValueOrDie()); + auto pred = ConstantR0(&builder, true); + auto operands = Tuple(&builder, {}); + Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, + false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, @@ -409,36 +409,37 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { XlaBuilder true_builder(TestName() + ".true"); { - true_builder.Parameter(0, empty_tuple_, "tuple"); - auto true_constant1 = true_builder.ConstantR0(12.2f); - auto true_constant2 = true_builder.ConstantR1({12.8f, 14.6f}); - auto true_constant3 = true_builder.ConstantR1({25.4f, 29.8f}); - auto true_constant4 = true_builder.ConstantR0(35.6f); - true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}), - true_builder.Tuple({true_constant3, true_constant4})}); + Parameter(&true_builder, 0, empty_tuple_, "tuple"); + auto true_constant1 = ConstantR0(&true_builder, 12.2f); + auto true_constant2 = ConstantR1(&true_builder, {12.8f, 14.6f}); + auto true_constant3 = ConstantR1(&true_builder, {25.4f, 29.8f}); + auto true_constant4 = ConstantR0(&true_builder, 35.6f); + Tuple(&true_builder, + {Tuple(&true_builder, {true_constant1, true_constant2}), + Tuple(&true_builder, {true_constant3, true_constant4})}); } auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); XlaBuilder false_builder(TestName() + ".false"); { - false_builder.Parameter(0, empty_tuple_, "tuple"); - auto false_constant1 = false_builder.ConstantR0(46.6f); - auto false_constant2 = false_builder.ConstantR1({54.4f, 58.4f}); - auto false_constant3 = false_builder.ConstantR1({62.1f, 67.4f}); - auto false_constant4 = false_builder.ConstantR0(9.3f); - false_builder.Tuple( - {false_builder.Tuple({false_constant1, false_constant2}), - false_builder.Tuple({false_constant3, false_constant4})}); + Parameter(&false_builder, 0, empty_tuple_, "tuple"); + auto false_constant1 = ConstantR0(&false_builder, 46.6f); + auto false_constant2 = ConstantR1(&false_builder, {54.4f, 58.4f}); + auto false_constant3 = ConstantR1(&false_builder, {62.1f, 67.4f}); + auto false_constant4 = ConstantR0(&false_builder, 9.3f); + Tuple(&false_builder, + {Tuple(&false_builder, {false_constant1, false_constant2}), + Tuple(&false_builder, {false_constant3, false_constant4})}); } auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operands = builder.Tuple({}); - builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), - operands, false_builder_result.ConsumeValueOrDie()); + auto pred = ConstantR0(&builder, false); + auto operands = Tuple(&builder, {}); + Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, + false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, @@ -464,8 +465,8 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); - builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, - CreateR0FloorComputation()); + Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0( &builder, 57.0f, @@ -484,8 +485,8 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { &builder, &operand1); auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", &builder, &operand2); - builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2, - CreateR1FloorComputation()); + Conditional(pred, operand1, CreateR1CeilComputation(), operand2, + CreateR1FloorComputation()); ComputeAndCompareR1( &builder, {10.0f, 11.0f}, @@ -499,27 +500,25 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0CeilComputation(), false_operand, - CreateR0FloorComputation()); + auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0"); + auto pred_cond = GetTupleElement(param0, 0); + auto true_operand = GetTupleElement(param0, 1); + auto false_operand = GetTupleElement(param0, 2); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); } auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred1 = builder.ConstantR0(true); - auto pred2 = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(1.1f); - auto operand2 = builder.ConstantR0(12.2f); - auto operand3 = builder.ConstantR0(43.3f); - auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); - builder.Conditional(pred1, tuple_operand, - inner_builder_result.ConsumeValueOrDie(), operand3, - CreateR0IdentityComputation()); + auto pred1 = ConstantR0(&builder, true); + auto pred2 = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 1.1f); + auto operand2 = ConstantR0(&builder, 12.2f); + auto operand3 = ConstantR0(&builder, 43.3f); + auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(), + operand3, CreateR0IdentityComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -529,23 +528,22 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0CeilComputation(), false_operand, - CreateR0FloorComputation()); + auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0"); + auto pred_cond = GetTupleElement(param0, 0); + auto true_operand = GetTupleElement(param0, 1); + auto false_operand = GetTupleElement(param0, 2); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); } auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred2 = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(1.1f); - auto operand2 = builder.ConstantR0(12.2f); - auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); - builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); + auto pred2 = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 1.1f); + auto operand2 = ConstantR0(&builder, 12.2f); + auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -553,12 +551,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR0TupleSubComputation()); auto result = builder.Build(); EXPECT_FALSE(result.ok()); @@ -572,40 +570,40 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { XlaComputation swapper; { XlaBuilder builder(TestName() + ".swapper"); - auto param0 = builder.Parameter(0, tuple_shape, "sp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - builder.Tuple({y, x}); + auto param0 = Parameter(&builder, 0, tuple_shape, "sp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + Tuple(&builder, {y, x}); swapper = builder.Build().ConsumeValueOrDie(); } XlaComputation forwarder; { XlaBuilder builder(TestName() + ".forwarder"); - auto param0 = builder.Parameter(0, tuple_shape, "fp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - builder.Tuple({x, y}); + auto param0 = Parameter(&builder, 0, tuple_shape, "fp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + Tuple(&builder, {x, y}); forwarder = builder.Build().ConsumeValueOrDie(); } XlaComputation main; { XlaBuilder builder(TestName() + ".main"); - auto param0 = builder.Parameter(0, tuple_shape, "mp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - auto lt_pred = builder.Lt(x, y); - auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); - auto ge_pred = builder.Ge(x, y); - builder.Conditional(ge_pred, res, swapper, res, forwarder); + auto param0 = Parameter(&builder, 0, tuple_shape, "mp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + auto lt_pred = Lt(x, y); + auto res = Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = Ge(x, y); + Conditional(ge_pred, res, swapper, res, forwarder); main = builder.Build().ConsumeValueOrDie(); } auto test_swap = [&](float a, float b) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(a); - auto y = builder.ConstantR0(b); - auto tuple_operand = builder.Tuple({x, y}); - builder.Call(main, {tuple_operand}); + auto x = ConstantR0(&builder, a); + auto y = ConstantR0(&builder, b); + auto tuple_operand = Tuple(&builder, {x, y}); + Call(&builder, main, {tuple_operand}); ComputeAndCompareTuple( &builder, diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 916ffadbc798ec0dd016f45b0bc4c36233455ee7..cc5d3b11767457444d4c199943e689f082d5b199 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,7 +40,7 @@ class ConstantsTest : public ClientLibraryTestBase { TEST_F(ConstantsTest, ZeroCellF32) { XlaBuilder builder(TestName()); - builder.ConstantR1({}); + ConstantR1(&builder, {}); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -48,7 +49,7 @@ TEST_F(ConstantsTest, OneCellF32) { std::vector constant = {2.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } @@ -57,7 +58,7 @@ TEST_F(ConstantsTest, OneCellS32) { std::vector constant = {2}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}); } @@ -66,7 +67,7 @@ TEST_F(ConstantsTest, OneCellU32) { std::vector constant = {2}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}); } @@ -75,7 +76,7 @@ TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } @@ -85,14 +86,14 @@ TEST_F(ConstantsTest, SixteenCells) { 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_0x2) { XlaBuilder builder(TestName()); - builder.ConstantR2FromArray2D(Array2D(0, 2)); + ConstantR2FromArray2D(&builder, Array2D(0, 2)); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } @@ -102,15 +103,15 @@ TEST_F(ConstantsTest, Small_2x2) { MakeLinspaceArray2D(100.0, 200.0, 2, 2); XlaBuilder builder(TestName()); - builder.ConstantR2FromArray2D(*constant); + ConstantR2FromArray2D(&builder, *constant); ComputeAndCompareR2(&builder, *constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - auto constant = builder.ConstantLiteral( - *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); + ConstantLiteral( + &builder, *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -125,8 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = - builder.ConstantLiteral(*Literal::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, *Literal::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -145,13 +145,13 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { { XlaBuilder builder(TestName()); - builder.ConstantLiteral(*input_literal); + ConstantLiteral(&builder, *input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } { XlaBuilder builder(TestName()); - builder.ConstantR4FromArray4D(input_array); + ConstantR4FromArray4D(&builder, input_array); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } } @@ -159,9 +159,9 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - builder.ConstantLiteral( - *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, *Literal::MakeTuple( + {Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); @@ -172,5 +172,13 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } +TEST_F(ConstantsTest, Token) { + XlaBuilder builder(TestName()); + ConstantLiteral(&builder, *Literal::CreateToken()); + // TODO(b/80000000): tokens cannot be returned from computations. + Tuple(&builder, {}); + TF_ASSERT_OK(Execute(&builder, {}).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 4ef0a77884c90b9fe32f96d3361fa3d80bde623b..292942a49e2f0c4b077dc71c9d0e730909689e3a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -45,8 +45,8 @@ class ConvertTest : public ClientLibraryTestBase { TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -54,8 +54,8 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -63,8 +63,8 @@ TEST_F(ConvertTest, ConvertR1F32ToR1F32) { TEST_F(ConvertTest, ConvertR1S32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -72,8 +72,8 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { TEST_F(ConvertTest, ConvertR1PREDToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, true}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, S32); std::vector expected = {1, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); @@ -81,8 +81,8 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, true}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, F32); std::vector expected = {1., 0., 1.}; ComputeAndCompareR1(&builder, expected, {}); @@ -90,8 +90,8 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {}); + ConvertElementType(a, F32); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -99,8 +99,8 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { TEST_F(ConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.6, 64.4}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {42.6, 64.4}); + ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -146,11 +146,11 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000010000000000LL), }; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, F32); + ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -165,11 +165,11 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, F32); + ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -183,11 +183,11 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, U32); + ConvertElementType(arg_param, U32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -200,11 +200,11 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -217,11 +217,11 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -249,16 +249,16 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { -1.99f, -2.0f, -2.01f, - 0x1.FFFFFEp+62F, - 0x1.FFFFFCp+62F, - -0x1.FFFFFEp+62F, - -0x1.FFFFFCp+62F}; + 9223371487098961920.f, + 9223370937343148032.f, + -9223371487098961920.f, + -9223370937343148032.f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -269,8 +269,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, F32); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -278,8 +278,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, S32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -287,8 +287,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, U32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, U32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -296,8 +296,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32.0f, 64.0f}); - builder.ConvertElementType(a, F64); + auto a = ConstantR1(&builder, {32.0f, 64.0f}); + ConvertElementType(a, F64); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -305,8 +305,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32.0, 64.0}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {32.0, 64.0}); + ConvertElementType(a, F32); std::vector expected = {32.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -314,9 +314,9 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { TEST_F(ConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {std::numeric_limits::min(), std::numeric_limits::max()}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::max()}); + ConvertElementType(a, F32); std::vector expected = { static_cast(std::numeric_limits::min()), @@ -327,10 +327,10 @@ TEST_F(ConvertTest, ConvertS32Extremes) { TEST_F(ConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); - b->ConvertElementType(param, S32); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); + ConvertElementType(param, S32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -339,10 +339,10 @@ TEST_F(ConvertTest, ConvertMapToS32) { TEST_F(ConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); - b->ConvertElementType(param, F32); - auto a = builder.ConstantR1({42, 64}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); + ConvertElementType(param, F32); + auto a = ConstantR1(&builder, {42, 64}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -355,9 +355,9 @@ TEST_F(ConvertTest, ConvertMapToF32) { // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); - auto input = builder.ConstantR1({42}); - auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); - builder.ConvertElementType(reshape, F32); + auto input = ConstantR1(&builder, {42}); + auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + ConvertElementType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } @@ -394,10 +394,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { client_->TransferToServer(*Literal::CreateR1(input))); XlaBuilder builder(TestName()); - builder.ConvertElementType( - builder.Parameter( - 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), - "param"), + ConvertElementType( + Parameter(&builder, 0, + ShapeUtil::MakeShape(F16, {static_cast(input.size())}), + "param"), F32); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); @@ -414,10 +414,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { client_->TransferToServer(*Literal::CreateR1(input))); XlaBuilder builder(TestName()); - builder.ConvertElementType( - builder.Parameter( - 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), - "param"), + ConvertElementType( + Parameter(&builder, 0, + ShapeUtil::MakeShape(F32, {static_cast(input.size())}), + "param"), F16); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); @@ -426,28 +426,28 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { XLA_TEST_F(ConvertTest, ConvertC64ToC64) { XlaBuilder builder(TestName()); std::vector x = {{42.0f, 64.0f}}; - builder.ConvertElementType(builder.ConstantR1(x), C64); + ConvertElementType(ConstantR1(&builder, x), C64); ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConvertTest, ConvertS64S64) { XlaBuilder builder(TestName()); std::vector x = {{-42, 64}}; - builder.ConvertElementType(builder.ConstantR1(x), S64); + ConvertElementType(ConstantR1(&builder, x), S64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64U64) { XlaBuilder builder(TestName()); std::vector x = {{42, 64}}; - builder.ConvertElementType(builder.ConstantR1(x), U64); + ConvertElementType(ConstantR1(&builder, x), U64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64S64) { XlaBuilder builder(TestName()); std::vector unsigned_x = {{42, UINT64_MAX}}; - builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + ConvertElementType(ConstantR1(&builder, unsigned_x), S64); std::vector signed_x = {{42, -1}}; ComputeAndCompareR1(&builder, signed_x, {}); } @@ -455,11 +455,31 @@ XLA_TEST_F(ConvertTest, ConvertU64S64) { XLA_TEST_F(ConvertTest, ConvertS64U64) { XlaBuilder builder(TestName()); std::vector signed_x = {{42, -1, INT64_MIN}}; - builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + ConvertElementType(ConstantR1(&builder, signed_x), U64); std::vector unsigned_x = { {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; ComputeAndCompareR1(&builder, unsigned_x, {}); } +XLA_TEST_F(ConvertTest, ConvertBF16F32) { + XlaBuilder builder(TestName()); + + std::vector all_bfloats(1 << 16); + for (int i = 0; i < all_bfloats.size(); ++i) { + all_bfloats[i].value = i; + } + + std::vector expected(all_bfloats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = (1U << 16) * i; + } + + // Exhaustively test all bf16 to f32 conversions. + xla::XlaOp all_bfloats_bf16 = ConstantR1(&builder, all_bfloats); + xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32); + BitcastConvertType(all_bfloats_f32, U32); + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index b5a42e305987df030c15d089f5877f73bb61de1b..7605ebf4c0eacd7f44e867e23dbc27c6c1bc3e93 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -97,10 +97,10 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, .ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(*input_array); + auto input = ConstantR4FromArray4D(&builder, *input_array); auto weight = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); - auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); + auto conv1 = Conv(input, weight, {1, 1}, Padding::kValid); ConvolutionDimensionNumbers dim_nums = XlaBuilder::CreateDefaultConvDimensionNumbers(); @@ -117,8 +117,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, dim_nums.set_kernel_input_feature_dimension( dim_nums.kernel_output_feature_dimension()); dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim); - builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, - dim_nums); + ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums); auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array, {1, 1}, Padding::kValid); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 947959beb144e1509a77ad2f94b8493de46ba6f2..0f6d54d042dd6af6d82e1eea93a66c2e9be53639 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2); + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4); + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4); #endif }; @@ -89,9 +89,9 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { ASSERT_EQ(2, arhs->height()); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR4FromArray4D(*alhs); - auto rhs = builder.ConstantR4FromArray4D(*arhs); - builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + auto lhs = ConstantR4FromArray4D(&builder, *alhs); + auto rhs = ConstantR4FromArray4D(&builder, *arhs); + Conv(lhs, rhs, {1, 1}, Padding::kValid); ComputeAndCompare(&builder, {}, error_spec_); } @@ -109,9 +109,9 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -140,9 +140,9 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -174,9 +174,9 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -210,9 +210,9 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, @@ -238,9 +238,9 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1}, Padding::kValid); } Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); @@ -268,10 +268,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { { Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -304,10 +304,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -335,10 +335,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -369,10 +369,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { { Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -408,8 +408,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 3D convolution. ConvolutionDimensionNumbers dnums; @@ -429,8 +429,7 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { dnums.set_kernel_input_feature_dimension(3); dnums.set_kernel_output_feature_dimension(4); - builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); @@ -475,8 +474,8 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; @@ -493,8 +492,7 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(3); - builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); @@ -541,8 +539,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); ConvolutionDimensionNumbers dnums; dnums.set_input_feature_dimension(0); @@ -551,7 +549,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, dnums.set_kernel_output_feature_dimension(1); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); - builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); + ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); Array2D param0(4, 29); param0.FillUnique(); @@ -599,8 +597,8 @@ class Convolve1D1WindowTestBase Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 1D convolution. ConvolutionDimensionNumbers dnums; @@ -614,8 +612,7 @@ class Convolve1D1WindowTestBase dnums.set_kernel_input_feature_dimension(1); dnums.set_kernel_output_feature_dimension(2); - builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), @@ -726,9 +723,9 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -754,9 +751,9 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillIota(0); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index fea850dc135e33fe098aa755c6fdd93319cd2837..c31d033bb0f0e52d40251c4d7b64d52f42d29dc6 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -55,12 +55,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { XlaBuilder builder(TestName()); const Array4D input_array(1, 1, 1, 1, {2}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {3}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(1, 1, 1, 1, {6}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -70,12 +70,12 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { XlaBuilder builder(TestName()); const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {2}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(5, 1, 1, 1, {2, 4, 6, 8, 10}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -86,12 +86,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { Array4D input_array(2, 1, 3, 4); input_array.FillWithMultiples(1); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {2.3}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 1, 3, 4); expected.FillWithMultiples(2.3); @@ -102,12 +102,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 1, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 3, 1, 1, {12, 34, 56}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -117,12 +117,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 2, {1, 2}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {12}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -132,12 +132,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {12, 23}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -147,12 +147,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 1, {12, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -162,12 +162,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 1, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {13, 24}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -177,12 +177,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 2, {1000, 100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {1234}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -194,13 +194,13 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { Array4D input_array( 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1 - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array( 2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected( 2, 2, 2, 2, @@ -213,12 +213,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {10, 30}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -228,12 +228,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 3, {10, 30, 50}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -243,12 +243,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 1, {123}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -258,12 +258,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {123, 345}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -273,12 +273,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {2, 2}, Padding::kValid); + Conv(input, filter, {2, 2}, Padding::kValid); Array4D expected(1, 1, 2, 2, {10, 30, 70, 90}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -288,12 +288,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 1, {1}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {10, 20, 30}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 1, {20}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -303,12 +303,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 3, {123, 1230, 12300}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -318,15 +318,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 0, 100, 0, // row 1 10, 0, 1}); // row 2 - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 2, 2, {104, 230, 2300, 10400}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -336,12 +336,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 2, 1, 1, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 2, {13, 24}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -351,12 +351,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 2, {7, 13, 17, 23}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 2, {216, 276, 396, 456}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -366,12 +366,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {7, 13}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {33, 53}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -383,15 +383,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { std::vector input_data(64); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(1, 1, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(128); std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0); std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0); const Array4D filter_array(2, 1, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {2016, 4032}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -403,14 +403,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { std::vector input_data(16 * 1 * 1 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(16, 1, 1, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; @@ -432,14 +432,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * ky * kx); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, ky, kx, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data(bs); for (int i = 0; i < bs; ++i) { @@ -463,14 +463,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * ky * kx); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, ky, kx, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 23, @@ -492,14 +492,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 8 * 8); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224, @@ -515,7 +515,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { std::vector input_data(2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(1, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -527,9 +527,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {14240, 30496}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -541,7 +541,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { std::vector input_data(2 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(2, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -553,9 +553,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 2, 1, 1, {14240, 30496, 38816, 87840}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -567,7 +567,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { std::vector input_data(32 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(32, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -579,9 +579,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 14240, 30496, 38816, 87840, 63392, 145184, 87968, @@ -613,9 +613,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } } - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(16, 16, 1, 1); for (int i0 = 0; i0 < 16; ++i0) { @@ -635,9 +635,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { Array4D input_array(1, 1, 4, 6, input_data); Array4D filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -654,9 +654,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -677,9 +677,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { 200, 20, 2, // 300, 30, 3, // 400, 40, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -699,9 +699,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, -1}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -718,9 +718,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, 2}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -737,9 +737,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {2, -1}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -756,9 +756,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {3, 2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, @@ -781,9 +781,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-3, -2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, @@ -821,9 +821,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -854,9 +854,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -887,9 +887,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -920,9 +920,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -954,9 +954,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -970,12 +970,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 2 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 1.0); Array4D filter_array(1, 2, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -995,7 +995,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests padding sizes that don't correspond either to SAME or VALID padding. - builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); std::vector expected_data = { 0, 0, 0, 0, 0, 0, 0, // @@ -1014,12 +1014,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1039,7 +1039,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests padding sizes that don't correspond either to SAME or VALID padding. - builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); std::vector expected_data = { 0, 0, 0, 0, 0, 0, 0, 0, // @@ -1058,12 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1083,7 +1083,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests zero padding sizes. This can use matmul for computation. - builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); std::vector expected_data = { 2, 4, 6, // @@ -1099,12 +1099,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { std::vector input_data(1 * 2 * 3 * 2); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 2, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 2 * 3); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 2, 3, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1124,7 +1124,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests zero padding sizes. This can use matmul for computation. - builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); std::vector expected_data = { 12, 15, 18, // @@ -1148,14 +1148,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 2, /*values=*/{5, 6})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 0}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 2, /*values=*/{5, 6})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 0}}); ComputeAndCompareR4(&builder, {{{{5, 16, 27}}}}, {}, error_spec_); } @@ -1167,16 +1167,16 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 1, /*values=*/{1})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvGeneralDilated(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 3}}, - /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvGeneralDilated(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); } @@ -1187,14 +1187,14 @@ XLA_TEST_F(ConvolutionVariantsTest, XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 1, /*values=*/{1})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 1}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 1}}); ComputeAndCompareR4(&builder, {{{{10}}}}, {}, error_spec_); } @@ -1208,14 +1208,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 2, /*values=*/{1, 10})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 2}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 2, /*values=*/{1, 10})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 2}}); ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); } @@ -1229,17 +1229,17 @@ XLA_TEST_F(ConvolutionVariantsTest, // weight gradients: 24,130,240 // // This pattern will be fused to backward convolution with padding=(1,2). - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 2}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 2}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); } @@ -1255,17 +1255,17 @@ XLA_TEST_F(ConvolutionVariantsTest, // This pattern will be fused to backward convolution with padding=(2,1). // Note: both (2,1) and (2,0) are valid padding for the backward convolution // because the stride is 2. - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {2, 0}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 0}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); } @@ -1282,17 +1282,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { // because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN // supports even padding only -- using (2,1) would need extra effort of // canonicalization. - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } @@ -1300,14 +1300,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR3FromArray3D( - Array3D(1, 1, 1, /*value=*/1)); + auto gradients = ConstantR3FromArray3D( + &builder, Array3D(1, 1, 1, /*value=*/1)); auto weights = - builder.ConstantR3FromArray3D(Array3D({{{1, 10, 100}}})); - auto mirrored_weights = builder.Rev(weights, {2}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1}, - /*padding=*/{{1, 1}}); + ConstantR3FromArray3D(&builder, Array3D({{{1, 10, 100}}})); + auto mirrored_weights = Rev(weights, {2}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1}, + /*padding=*/{{1, 1}}); ComputeAndCompareR3(&builder, {{{10}}}, {}, error_spec_); } @@ -1315,17 +1315,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XlaBuilder builder(TestName()); auto activations = - builder.ConstantR3FromArray3D(Array3D({{{1, 2, 3, 4}}})); + ConstantR3FromArray3D(&builder, Array3D({{{1, 2, 3, 4}}})); auto gradients = - builder.ConstantR3FromArray3D(Array3D({{{100, 10, 1}}})); + ConstantR3FromArray3D(&builder, Array3D({{{100, 10, 1}}})); auto forward_conv = - builder.ConvGeneralDilated(activations, gradients, - /*window_strides=*/{1}, - /*padding=*/{{2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, - XlaBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/1)); - builder.Transpose(forward_conv, {0, 1, 2}); + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); + Transpose(forward_conv, {0, 1, 2}); ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); } @@ -1336,21 +1336,21 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = builder.ConstantLiteral(*gradients_literal); + auto gradients = ConstantLiteral(&builder, *gradients_literal); auto weights_flat = Literal::CreateR1({1, 10, 100}); auto weights_literal = weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = builder.ConstantLiteral(*weights_literal); + auto weights = ConstantLiteral(&builder, *weights_literal); auto expected_flat = Literal::CreateR1({10}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1, 1}, - /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); + auto mirrored_weights = Rev(weights, {2, 3, 4}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); } @@ -1360,25 +1360,25 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = builder.ConstantLiteral(*activations_literal); + auto activations = ConstantLiteral(&builder, *activations_literal); auto gradients_flat = Literal::CreateR1({100, 10, 1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = builder.ConstantLiteral(*gradients_literal); + auto gradients = ConstantLiteral(&builder, *gradients_literal); auto expected_flat = Literal::CreateR1({13, 24, 130}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1, 1}, - /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/3)); - builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/3)); + Transpose(forward_conv, {0, 1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 2b3390ca98cb2922410d451c06811aa9d4ff8c0b..fef42885e516fa8c8f87756d7a953fe5f37a630f 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -248,7 +248,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { auto empty = Literal::CreateFromShape(in_shape); XlaBuilder builder(TestName()); - auto param0 = builder.Parameter(0, in_shape, "input"); + Parameter(&builder, 0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b151187c4b8f01c5b46ccadf27d2e22a7c902e98 --- /dev/null +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class TrivialCrossReplicaSumTest : public HloTestBase {}; + +// Currently the CPU and GPU backends only support CrossReplicaSum with one +// replica. But we can at least check this. + +XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal = Literal::CreateR1({1, 2, 3}); + EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); +} + +XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] parameter(1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ( + *Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); +} + +// On the GPU backend, constants get special handling. Someone might pass a +// constant to CRS to e.g. count the number of replicas -- we need to make sure +// it works. +XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] constant({10, 20}) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index b43d5c9ff5d75ee0e1b3c9ceb2bc295e631ac107..d1516a28b0bb3857d9aee0922a252e25a8f9d2d5 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" @@ -135,8 +136,8 @@ class CustomCallClientAPITest : public ClientLibraryTestBase {}; // are reserved for internal use. XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { XlaBuilder builder(TestName()); - builder.CustomCall("$illegal", /*operands=*/{}, - ShapeUtil::MakeShape(F32, {1})); + CustomCall(&builder, "$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); StatusOr> result = Execute(&builder, /*arguments=*/{}); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index bfe688e20d182d581c3e3b545ac2289413deef7c..d4b3aac85bff283515088f6e61c9d2bad11f60d3 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -48,7 +48,7 @@ class DeallocationTest : public ClientLibraryTestBase { TEST_F(DeallocationTest, DeallocateScalar) { XlaBuilder builder(TestName()); - builder.ConstantR0(42.0); + ConstantR0(&builder, 42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); // A result can be transferred an arbitrary number of times. Add an extra @@ -66,7 +66,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { TEST_F(DeallocationTest, DeallocateVector) { XlaBuilder builder(TestName()); - builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -79,7 +79,7 @@ TEST_F(DeallocationTest, DeallocateVector) { TEST_F(DeallocationTest, DeallocateEmptyVector) { XlaBuilder builder(TestName()); - builder.ConstantR1({}); + ConstantR1(&builder, {}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -92,8 +92,8 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { XLA_TEST_F(DeallocationTest, DeallocateTuple) { XlaBuilder builder(TestName()); - builder.Tuple({builder.ConstantR0(42.0), - builder.ConstantR1({1.0, 2.0, 3.0})}); + Tuple(&builder, {ConstantR0(&builder, 42.0), + ConstantR1(&builder, {1.0, 2.0, 3.0})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -106,9 +106,10 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { XlaBuilder builder(TestName()); - auto element = builder.ConstantR0(42.0); - auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); - builder.Tuple({element, inner_tuple, element}); + auto element = ConstantR0(&builder, 42.0); + auto inner_tuple = + Tuple(&builder, {ConstantR0(&builder, 42.0), element}); + Tuple(&builder, {element, inner_tuple, element}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -122,9 +123,9 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { XlaBuilder builder(TestName()); auto inner_tuple = - builder.Tuple({builder.ConstantR0(42.0), - builder.ConstantR1({1.0, 2.0, 3.0})}); - builder.Tuple({inner_tuple, builder.ConstantR1({0.123, 0.456})}); + Tuple(&builder, {ConstantR0(&builder, 42.0), + ConstantR1(&builder, {1.0, 2.0, 3.0})}); + Tuple(&builder, {inner_tuple, ConstantR1(&builder, {0.123, 0.456})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 12789fe66530fe03eb33316eda652336f29971ab..acba67491d25007ab774530fd7ca236a4363b6f0 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -54,9 +54,9 @@ class DeconstructTupleTest : public ClientLibraryTestBase { TEST_F(DeconstructTupleTest, DeconstructTuple) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -73,9 +73,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status1 = client_->DeconstructTuple(*global_data); @@ -103,9 +103,9 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2, const2, const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2, const2, const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -129,9 +129,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2, const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2, const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -159,7 +159,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XlaBuilder builder(TestName()); - builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -174,8 +174,8 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); - builder.Tuple({p}); + auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); + Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); auto result_status = client_->DeconstructTuple(*global_data); @@ -186,9 +186,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({builder.Tuple({const1, const2}), const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {Tuple(&builder, {const1, const2}), const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 085a5105aca1c173a7cbc211aebbeb5b254b0753..810947ab01b69b10b6ae60c551bd7aba10a6313d 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -30,7 +30,7 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); XlaOp z = x; for (int i = 0; i < kDepth; ++i) { - z = b.Add(z, y); + z = Add(z, y); } ComputeAndCompareR0(&b, /*expected=*/kDepth + 3, {x_data.get(), y_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0fd846cef8095a857dd7b2c12d8afdf409e2bd66..cf2e645d472efab9ca649dbde6602fd4f205d924 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -70,9 +70,9 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), Literal::CreateR2({{5, 6}, {7, 8}}).get()}), "arg0", &builder, ¶m); - auto lhs = builder.GetTupleElement(param, 0); - auto rhs = builder.GetTupleElement(param, 1); - builder.Dot(lhs, rhs); + auto lhs = GetTupleElement(param, 0); + auto rhs = GetTupleElement(param, 1); + Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, *Literal::CreateR2({{19, 22}, {43, 50}}), @@ -87,9 +87,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(0.0), {}, this->error_spec_); @@ -102,9 +102,9 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); - auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, {{3.0f, 4.0f}}); + auto rhs = ConstantFromArray(&builder, {3.0f, 4.0f}); + Dot(lhs, rhs); this->template ComputeAndCompareR1(&builder, {static_cast(25.0f)}, {}, this->error_spec_); @@ -113,9 +113,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR1({static_cast(2.0f)}); - auto rhs = builder.ConstantR1({static_cast(3.0f)}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR1(&builder, {static_cast(2.0f)}); + auto rhs = ConstantR1(&builder, {static_cast(3.0f)}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(6.0f), {}, this->error_spec_); @@ -124,9 +124,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); - auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantFromArray(&builder, {1.0f, 2.5f, 42.0f}); + auto rhs = ConstantFromArray(&builder, {11.0f, -1.0f, 0.5f}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(29.5f), {}, this->error_spec_); @@ -139,9 +139,9 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(0, 0), {}, this->error_spec_); @@ -150,10 +150,10 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D( - {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + auto rhs = ConstantR2FromArray2D( + &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(0, 3), {}, this->error_spec_); @@ -162,10 +162,10 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D( - {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D( + &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(3, 0), {}, this->error_spec_); @@ -174,9 +174,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + Dot(lhs, rhs); this->template ComputeAndCompareR2( &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); @@ -186,11 +186,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); auto param1 = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); - auto exp0 = builder.Exp(param0); - auto result = builder.Dot(exp0, param1); + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); + auto exp0 = Exp(param0); + Dot(exp0, param1); auto lhs_handle = this->client_ @@ -231,9 +231,8 @@ class SquareMatrixDot : public DotOperationTest { .ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); Array2D expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); ComputeAndCompareR2(&builder, expected, @@ -316,26 +315,26 @@ void ParametricDotTest::TestImpl() { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.m, param.k}, - MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), - "dot_lhs"), - builder.Parameter(1, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.k, param.n}, - MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), - "dot_rhs")); + auto result = + Dot(Parameter(&builder, 0, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.k}, + MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), + "dot_lhs"), + Parameter(&builder, 1, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.k, param.n}, + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), + "dot_rhs")); if (param.has_addend) { - result = builder.Add( - result, builder.Parameter( - 2, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.m, param.n}, - MinorToMajorForIsRowMajor(param.addend_row_major)), - "addend")); + result = + Add(result, + Parameter(&builder, 2, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.n}, + MinorToMajorForIsRowMajor(param.addend_row_major)), + "addend")); } std::unique_ptr> expected; @@ -492,9 +491,8 @@ class NonsquareMatrixDot : public DotOperationTest { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); @@ -524,9 +522,8 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); Array2D expected({{30.0, -2.0}}); @@ -538,11 +535,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); - auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); - auto matrix12 = builder.Dot(matrix1, matrix2); - auto matrix21 = builder.Dot(matrix2, matrix1); - builder.Add(matrix12, matrix21); + auto matrix1 = + ConstantR2FromArray2D(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto matrix2 = + ConstantR2FromArray2D(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto matrix12 = Dot(matrix1, matrix2); + auto matrix21 = Dot(matrix2, matrix1); + Add(matrix12, matrix21); Array2D expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); this->template ComputeAndCompareR2(&builder, expected, {}, @@ -559,29 +558,29 @@ TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto x = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); - auto y = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "y"); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); - auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); - auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); + auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); + auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); // Slice batches into individual matrices and multiply them. std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); - x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); - auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); - y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); + auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); + x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2}); + auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); + y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2}); - auto out = builder.Dot(x_slice, y_slice); - out = builder.Reshape(out, {0, 1}, {1, 2, 2}); + auto out = Dot(x_slice, y_slice); + out = Reshape(out, {0, 1}, {1, 2, 2}); out_slices.push_back(out); } - auto out_flat = builder.ConcatInDim(out_slices, 0); - builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); + auto out_flat = ConcatInDim(&builder, out_slices, 0); + Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ ->TransferToServer(*Literal::CreateR4FromArray4D( @@ -616,9 +615,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { XlaBuilder builder(this->TestName()); auto x = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); auto y = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -626,7 +625,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); - auto out = builder.DotGeneral(x, y, dnums); + DotGeneral(x, y, dnums); auto x_data = this->client_ @@ -678,19 +677,21 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { XlaBuilder builder(this->TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto lhs_arg = builder.Parameter( - 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), + auto lhs_arg = Parameter( + &builder, 0, + ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), "lhs"); - auto rhs_arg = builder.Parameter( - 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), + auto rhs_arg = Parameter( + &builder, 1, + ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), "rhs"); if (transpose_lhs) { - lhs_arg = builder.Transpose(lhs_arg, {1, 0}); + lhs_arg = Transpose(lhs_arg, {1, 0}); } if (transpose_rhs) { - rhs_arg = builder.Transpose(rhs_arg, {1, 0}); + rhs_arg = Transpose(rhs_arg, {1, 0}); } - auto result = builder.Dot(lhs_arg, rhs_arg); + Dot(lhs_arg, rhs_arg); Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " @@ -713,15 +714,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); XlaBuilder builder(this->TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), - "rhs_arg_0"); - auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), - "rhs_arg_1"); - auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), - "rhs_arg_2"); - auto result = builder.Dot( - lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_arg_0 = Parameter( + &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); + auto rhs_arg_1 = Parameter( + &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1"); + auto rhs_arg_2 = Parameter( + &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2"); + Dot(lhs_constant, + ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); std::unique_ptr> arg_0_value_array( new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -761,15 +762,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {2.0f, 1.0f}})); XlaBuilder builder(this->TestName()); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), - "lhs_arg_0"); - auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 3}), - "lhs_arg_1"); - auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType({2, 1}), - "lhs_arg_2"); - auto result = builder.Dot( - builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto lhs_arg_0 = Parameter( + &builder, 0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); + auto lhs_arg_1 = Parameter( + &builder, 1, ShapeUtil::MakeShapeWithType({2, 3}), "lhs_arg_1"); + auto lhs_arg_2 = Parameter( + &builder, 2, ShapeUtil::MakeShapeWithType({2, 1}), "lhs_arg_2"); + Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), + rhs_constant); std::unique_ptr> arg_0_value_array( new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -811,16 +812,15 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{96.0, 105.0, 114.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -839,25 +839,23 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{105.0}, {105.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + + DotOfGatherOptimizationWithConstRHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -870,25 +868,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{105.0, 105.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -901,25 +895,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{96.0}, {105.0}, {114.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -937,25 +927,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{126.0, 129.0, 132.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -973,25 +959,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{129.0}, {129.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1001,25 +983,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{56.0, 168.0, 91.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1029,19 +1007,41 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{168.0}, {168.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { + XlaBuilder builder(TestName()); + + Array2D lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); + + Array2D rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + DotGeneral(lhs_constant, rhs_constant, dot_dnums); + + Array2D expected({ + {26.f, 30.f}, + {38.f, 44.f}, + }); + + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index bfb83faf5222b8ca5ceceebf7f2f976ec803245e..f3c258a4d4c446c465320ac16ef7c72e299a51a8 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR1Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + void TestR1OOB() { + // Slice at dimension boundaries, but with out of bounds indices. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7}); } template @@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR2Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR2OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } template @@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR3Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR3OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, - {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); + {2, 1, 2}, {{{5, 6}}, {{11, 12}}}); } template @@ -138,8 +138,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -164,8 +164,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -190,8 +190,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } @@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void TestWrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestOOB() { + // // Slice at dimension boundaries, but with out of bounds indices. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {10, 1, 2, 3, 4, 5, 8, 9}); + {0, 1, 2, 3, 4, 8, 9, 10}); // R2 Shape: [3, 3] RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); } template @@ -367,9 +367,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_value); - auto update = builder.ConstantLiteral(update_value); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_value); + auto update = ConstantLiteral(&builder, update_value); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); } @@ -398,9 +398,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -429,9 +429,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -460,9 +460,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { Array3D input_values(kSeq, kBatch, kDim); Array3D update_values(size, kBatch, kDim); Array3D expected_values(kSeq, kBatch, kDim); + index = std::min(std::max(0, index), kSeq - size); input_values.FillIota(static_cast(0)); T value = static_cast(10); update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when - // the update wraps. According to documentation, the results are technically - // implementation specific where the update is out of bounds, and hence - // we don't really know what to pass into ComputeAndCompareR3. + // the indices are out of bounds. expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = value++; + expected_values(index + i, j, k) = value++; } } } @@ -509,8 +508,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); - auto starts = builder.ConstantR1({index, 0, 0}); - builder.DynamicUpdateSlice(input, update, starts); + auto starts = ConstantR1(&builder, {index, 0, 0}); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, @@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) { - TestWrap(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } @@ -701,14 +698,14 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = builder.ConstantLiteral(*input_literal); + auto input = ConstantLiteral(&builder, *input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); auto start_indices = - builder.Parameter(0, start_indices_shape, "start_indices"); + Parameter(&builder, 0, start_indices_shape, "start_indices"); // Add DynamicSlice op to the computatation. - builder.DynamicSlice(input, start_indices, {1, 1, 1, 1}); + DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. @@ -719,8 +716,10 @@ void BM_DynamicSlice(int num_iters) { .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, buffer)); + stream.get(), *start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index a6ba6db5d3bf86de91f6fda022c46afee01281c2..ddc6a7db18760bf951023f0a684d78739f3e869d 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -34,7 +34,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); - b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); + Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build()); ExecutionProfile execution_profile; diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 0a37e4d423620122f2e109343a86a964f46d778f..74cf8b213e0a03394c84008e7a2919e1a5bf1af2 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -54,7 +54,7 @@ class ExhaustiveF32ElementwiseOpTest TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; @@ -79,8 +79,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #endif ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); }, - std::log, known_incorrect_range); + [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, + known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { @@ -95,14 +95,14 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { #endif ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); }, - std::exp, known_incorrect_range); + [](XlaBuilder* builder, const XlaOp& input) { Exp(input); }, std::exp, + known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); }, - std::tanh, /*known_incorrect_range=*/{0, 0}); + [](XlaBuilder* builder, const XlaOp& input) { Tanh(input); }, std::tanh, + /*known_incorrect_range=*/{0, 0}); } std::vector> CreateExhaustiveParameters() { diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 71eb914a8e5eaef2e38b9e6e7d45b8a10ce1bd7a..30dc639f117b9871238f0bf1628502cf8bef2e0c 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -42,12 +42,12 @@ class FloorCeilTest : public ClientLibraryTestBase { LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") << "}"; XlaBuilder builder(TestName()); - auto c = builder.ConstantR1(input); + auto c = ConstantR1(&builder, input); if (f == kCeil) { - builder.Ceil(c); + Ceil(c); } else { ASSERT_EQ(kFloor, f); - builder.Floor(c); + Floor(c); } ComputeAndCompareR1(&builder, expected, /*arguments=*/{}); } @@ -55,12 +55,12 @@ class FloorCeilTest : public ClientLibraryTestBase { void TestR0F32(float input, float expected, Function f) { LOG(INFO) << "input: " << expected; XlaBuilder builder(TestName()); - auto c = builder.ConstantR0(input); + auto c = ConstantR0(&builder, input); if (f == kCeil) { - builder.Ceil(c); + Ceil(c); } else { ASSERT_EQ(kFloor, f); - builder.Floor(c); + Floor(c); } ComputeAndCompareR0(&builder, expected, /*arguments=*/{}); } diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 73f029b59bc56aa6c3e86200a49fcae0fd177101..0254ae1baaa864b38c3b217a5c2026d34b7f7d12 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -28,11 +28,11 @@ class FmaxSimpleTest : public ClientLibraryTestBase {}; TEST_F(FmaxSimpleTest, FmaxTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = builder.ConstantR1( - {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index e6f79b5ac55dddfbb213a36cadbee53bc9443d9d..f7f9a87413ee3cae50b3aa6518293827d40837ca 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -557,8 +558,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { *ExecuteAndTransfer(std::move(hlo_module), {}))); } -// TODO(b/64070202): Investigate failure. -XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { +XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -765,6 +765,39 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } +// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. +XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { + const string hlo_text = R"( +HloModule Cluster + +fusion_c { + fusion.arg = f32[2,2]{1,0} parameter(0) + bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg) + tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0) + ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0) +} + +ENTRY main { + arg = f32[2,2]{1,0} parameter(0) + ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c +} +)"; + + std::unique_ptr operand = + Literal::CreateR2({{0., 0.}, {1., 0.}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + test_runner_.Execute(std::move(module), {operand.get()}, + /*run_hlo_passes=*/false)); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + *result)); +} + void BM_ParallelFusion(int num_iters) { // Simple element-wise computation to benchmark parallel task partitioning. tensorflow::testing::StopTiming(); @@ -793,14 +826,14 @@ void BM_ParallelFusion(int num_iters) { // Create computation. XlaBuilder builder("ParallelFusion"); Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); - auto param0 = builder.Parameter(0, shape0, "param0"); + auto param0 = Parameter(&builder, 0, shape0, "param0"); Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); - auto param1 = builder.Parameter(1, shape1, "param1"); + auto param1 = Parameter(&builder, 1, shape1, "param1"); Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); - auto param2 = builder.Parameter(2, shape2, "param2"); + auto param2 = Parameter(&builder, 2, shape2, "param2"); - auto x = builder.Mul(param0, param1); - auto y = builder.Add(x, param2); + auto x = Mul(param0, param1); + Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 4854c649c15f2ab89bd3b343abd248be6e227c60..b8404826b161b9edbbd260d73c175cce935ace91 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" // NB! TODO(b/74360564): These tests do not test out of bounds behavior since // that hasn't been specced yet. @@ -41,7 +42,7 @@ class GatherOperationTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; @@ -598,14 +599,14 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); - auto operand = builder.Parameter(0, operand_shape, "operand"); - auto indices = builder.Parameter(1, indices_shape, "indices"); + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; dim_numbers.add_output_window_dims(1); dim_numbers.add_elided_window_dims(0); dim_numbers.add_gather_dims_to_operand_dims(0); dim_numbers.set_index_vector_dim(1); - builder.Gather(operand, indices, dim_numbers, {1, 3}); + Gather(operand, indices, dim_numbers, {1, 3}); std::vector expected = {}; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, @@ -629,8 +630,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, + *result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 76bf47845ca045b4eede9a3b47ae5c2ce93ce577..fd8511884907ae500d8256c3250fe779f8eba83a 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase { static const int kNumElements = 4; }; -using UnaryBuildFuncTy = - std::function; +using UnaryBuildFuncTy = std::function; struct UnaryOpTestParam { std::function compute_func; @@ -62,7 +61,7 @@ XLA_TEST_P(UnaryOpTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); } @@ -79,18 +78,17 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, ::testing::Values( - UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &XlaBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} + UnaryOpTestParam{[](half x) { return abs(x); }, &Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &Log}, + UnaryOpTestParam{[](half x) { return -x; }, &Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh} )); @@ -118,19 +116,18 @@ XLA_TEST_P(UnaryPredTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}); } INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ - [](half x) { return isfinite(x); }, - &XlaBuilder::IsFinite})); + [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = std::function)>; +using BinaryBuildFuncTy = + std::function)>; struct BinaryOpTestParam { std::function compute_func; @@ -159,7 +156,7 @@ XLA_TEST_P(BinaryOpTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, error_spec_); @@ -173,22 +170,15 @@ half atan2_imp(half x, half y) { INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( - BinaryOpTestParam{[](half x, half y) { return x + y; }, - &XlaBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &XlaBuilder::Atan2}, - BinaryOpTestParam{[](half x, half y) { return x / y; }, - &XlaBuilder::Div}, - BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &XlaBuilder::Max}, - BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &XlaBuilder::Min}, - BinaryOpTestParam{[](half x, half y) { return x * y; }, - &XlaBuilder::Mul}, - BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &XlaBuilder::Pow}, - BinaryOpTestParam{[](half x, half y) { return x - y; }, - &XlaBuilder::Sub} + &Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub} )); @@ -221,27 +211,22 @@ XLA_TEST_P(BinaryPredTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); } INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, - ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &XlaBuilder::Eq}, - BinaryPredTestParam{[](half x, half y) { return x != y; }, - &XlaBuilder::Ne}, - BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &XlaBuilder::Ge}, - BinaryPredTestParam{[](half x, half y) { return x > y; }, - &XlaBuilder::Gt}, - BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &XlaBuilder::Le}, - BinaryPredTestParam{[](half x, half y) { return x < y; }, - &XlaBuilder::Lt} - - )); + ::testing::Values( + BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt} + + )); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index cf971dd61b71ad329b20b0bb7c16166126562681..4d82442f7e3630c115eff1f17544e2b892c5e7eb 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -30,9 +30,9 @@ class HloMetadataTest : public LocalClientTestBase { } void BuildAddComputation(XlaBuilder* builder) { - auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder->Add(x, y); + auto x = Parameter(builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); } OpMetadata metadata_; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 12598579c7032e954c4a4875ab8e6475b112f5ae..242cc5db11ff2bdf69209df7537216573d8afbf3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -94,18 +94,14 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - HloModuleConfig config; - auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_max_kernel_unroll_factor(1); - config.set_debug_options(debug_options); - - return MakeUnique(name, VersionedComputationHandle(), config); + return MakeUnique(name, GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 9539ae06801628baedaea69024b7760ebefa6e3a..9009d67cea6840235d63724ef76d777c8f693d33 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,15 @@ namespace xla { // // For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { + public: + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + static std::unique_ptr CreateNewModule( + const string& name = TestName()); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -80,19 +89,18 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override {} - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); - // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. static DebugOptions GetDebugOptionsForTest(); + // Gets an HloModuleConfig with options appropriate for tests. + static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return config; + } + // Executes the given module and return the result as a Literal. StatusOr> Execute( std::unique_ptr module, @@ -177,13 +185,9 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT( - param_no, - module->mutable_host_entry_computation_layout()->parameter_count()); - module->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(param_no) - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -191,10 +195,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -202,10 +203,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->Clear(); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index da4cf4ae0c31bc194cd2ec9b845df36afbde69b0..ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(); + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule() { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } -void HloVerifiedTestBase::ParseAndVerifyModule( - tensorflow::StringPiece hlo_text) { +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + +void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); - VerifyModule(); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); + VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40..5b28c01c369fa1ae1c7941f5c8139882c4dbed08 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,7 +44,8 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. @@ -52,11 +53,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - void VerifyModule(); + static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 2f46ee0be216d7dabf1c476d3cfb7d528f8ab6a4..082bc34136e004795ce300c66591758f47c665fe 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index f21f83992ffb7c07dff31c68a7e9e3f7944bf512..9191be9fd905ab2e0c661042b042c8233d39e4a1 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -38,9 +38,9 @@ class LocalClientAllocationTest : public LocalClientTestBase { XLA_TEST_F(LocalClientAllocationTest, AddVectors) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = ConstantR1(&builder, {0.0f, 1.0f, 2.0f}); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); @@ -74,9 +74,9 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { // Run a computation on every device on the system. Verify that allocation // occurs on the proper device. XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = ConstantR1(&builder, {0.0f, 1.0f, 2.0f}); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index a366afe8262e1f537b225e395bba9cb2fc22683a..70612e7c49d2815096cc54fd6ae796148249b4db 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -37,8 +37,8 @@ using xla::string; xla::XlaComputation Doubler() { xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto x = builder.Parameter(0, r0f32, "x"); - builder.Mul(x, builder.ConstantR0(2.0)); + auto x = xla::Parameter(&builder, 0, r0f32, "x"); + xla::Mul(x, xla::ConstantR0(&builder, 2.0)); return std::move(builder.Build().ValueOrDie()); } @@ -51,10 +51,10 @@ int main(int argc, char** argv) { xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); - auto opaque_param = builder.Parameter(0, opaque_shape, "x"); + auto opaque_param = Parameter(&builder, 0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(), {sum}); + auto sum = CustomCall(&builder, "SumStructElements", {opaque_param}, r0f32); + Call(&builder, Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 96858c00d6bbe59b673a34e7d5ca261756709596..2c6393794ef1b1558f5e651b5cb7bfa2afa961de 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -54,7 +54,7 @@ class LocalClientExecuteTest : public LocalClientTestBase { XLA_TEST_F(LocalClientExecuteTest, Constant) { XlaBuilder builder(TestName()); - auto y = builder.ConstantR0(123.0f); + ConstantR0(&builder, 123.0f); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); @@ -64,9 +64,9 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { XLA_TEST_F(LocalClientExecuteTest, AddScalars) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.ConstantR0(123.0f); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = ConstantR0(&builder, 123.0f); + Add(x, y); auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); ScopedShapedBuffer result = @@ -77,9 +77,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x"); - auto y = builder.ConstantR1({}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x"); + auto y = ConstantR1(&builder, {}); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); ScopedShapedBuffer result = @@ -90,9 +90,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { XLA_TEST_F(LocalClientExecuteTest, AddVectors) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); @@ -104,9 +104,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); @@ -122,9 +122,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. @@ -155,9 +155,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( @@ -192,9 +192,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { XLA_TEST_F(LocalClientExecuteTest, TupleResult) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Tuple({x, y, x}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Tuple(&builder, {x, y, x}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( @@ -209,21 +209,20 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - auto inner_tuple = builder.Tuple({x, y, x}); - builder.Tuple({inner_tuple, x}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + auto inner_tuple = Tuple(&builder, {x, y, x}); + Tuple(&builder, {inner_tuple, x}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( @@ -238,25 +237,22 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {0, 1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { // Verify setting the result layout of a computation with a tuple output. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Tuple({x, y}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -273,10 +269,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { options, DefaultExecutableRunOptions()); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -291,15 +287,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { // Computation adds the respective array and vector elements from each tuple // argument and returns the results as a tuple. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, tuple_shape0, "x"); - auto y = builder.Parameter(1, tuple_shape1, "y"); - auto x_0 = builder.GetTupleElement(x, 0); - auto x_1 = builder.GetTupleElement(x, 1); - auto y_0 = builder.GetTupleElement(y, 0); - auto y_1 = builder.GetTupleElement(y, 1); - auto array_sum = builder.Add(x_0, y_1); - auto vector_diff = builder.Sub(x_1, y_0); - builder.Tuple({array_sum, vector_diff}); + auto x = Parameter(&builder, 0, tuple_shape0, "x"); + auto y = Parameter(&builder, 1, tuple_shape1, "y"); + auto x_0 = GetTupleElement(x, 0); + auto x_1 = GetTupleElement(x, 1); + auto y_0 = GetTupleElement(y, 0); + auto y_1 = GetTupleElement(y, 1); + auto array_sum = Add(x_0, y_1); + auto vector_diff = Sub(x_1, y_0); + Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_literal = Literal::MakeTuple( @@ -319,11 +315,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -338,15 +333,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { // Computation negates the array element and sums the two vector elements in // the nested tuple. The resulting array and vector are returned as a tuple. XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, nested_tuple_shape, "param"); - auto inner_tuple = builder.GetTupleElement(param, 0); - auto inner_array = builder.GetTupleElement(inner_tuple, 0); - auto inner_vector = builder.GetTupleElement(inner_tuple, 1); - auto outer_vector = builder.GetTupleElement(param, 1); - - auto negate_array = builder.Neg(inner_array); - auto vector_sum = builder.Add(inner_vector, outer_vector); - builder.Tuple({negate_array, vector_sum}); + auto param = Parameter(&builder, 0, nested_tuple_shape, "param"); + auto inner_tuple = GetTupleElement(param, 0); + auto inner_array = GetTupleElement(inner_tuple, 0); + auto inner_vector = GetTupleElement(inner_tuple, 1); + auto outer_vector = GetTupleElement(param, 1); + + auto negate_array = Neg(inner_array); + auto vector_sum = Add(inner_vector, outer_vector); + Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); auto arg_literal = Literal::MakeTuple( @@ -360,10 +355,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -376,10 +371,10 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { ShapeUtil::MakeTupleShape({array_shape, array_shape}); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); - auto element_0 = builder.GetTupleElement(param, 0); - auto element_1 = builder.GetTupleElement(param, 1); - builder.Tuple({builder.Neg(element_0), builder.Add(element_1, element_1)}); + auto param = Parameter(&builder, 0, tuple_shape, "param"); + auto element_0 = GetTupleElement(param, 0); + auto element_1 = GetTupleElement(param, 1); + Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); auto arg_literal = Literal::MakeTuple( @@ -389,18 +384,17 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, + LiteralSlice(*result_0_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, + LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); - LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, + LiteralSlice(*result_1_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, + LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -420,16 +414,15 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); + auto param = Parameter(&builder, 0, tuple_shape, "param"); // Add each element's tuple index value to every element. std::vector result_elements; for (int i = 0; i < kElementCount; ++i) { - auto element = builder.GetTupleElement(param, i); - result_elements.push_back( - builder.Add(element, builder.ConstantR0(i))); + auto element = GetTupleElement(param, i); + result_elements.push_back(Add(element, ConstantR0(&builder, i))); } - builder.Tuple(result_elements); + Tuple(&builder, result_elements); auto computation = builder.Build().ConsumeValueOrDie(); // Feed in a tuple where each two-element vector element is {tuple_index, @@ -447,8 +440,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), - error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -465,22 +457,22 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); + auto param = Parameter(&builder, 0, tuple_shape, "param"); // The computation increments each leaf value by an amount equal to the leaf's // ordinal position in a traversal of the tuple. std::vector result_elements; for (int i = 0; i < kFanout; ++i) { - auto outer_element = builder.GetTupleElement(param, i); + auto outer_element = GetTupleElement(param, i); std::vector inner_result_elements; for (int j = 0; j < kFanout; ++j) { - auto inner_element = builder.GetTupleElement(outer_element, j); - inner_result_elements.push_back(builder.Add( - inner_element, builder.ConstantR0(i * kFanout + j))); + auto inner_element = GetTupleElement(outer_element, j); + inner_result_elements.push_back( + Add(inner_element, ConstantR0(&builder, i * kFanout + j))); } - result_elements.push_back(builder.Tuple(inner_result_elements)); + result_elements.push_back(Tuple(&builder, inner_result_elements)); } - builder.Tuple(result_elements); + Tuple(&builder, result_elements); auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. @@ -520,14 +512,14 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { } XlaBuilder builder(TestName()); - auto element = builder.Parameter(0, shape, "param"); + auto element = Parameter(&builder, 0, shape, "param"); for (int i = 0; i < kTupleDepth; ++i) { - element = builder.GetTupleElement(element, 0); + element = GetTupleElement(element, 0); } - auto output = builder.Add(element, builder.ConstantR0(42.0)); + auto output = Add(element, ConstantR0(&builder, 42.0)); for (int i = 0; i < kTupleDepth; ++i) { - output = builder.Tuple({output}); + output = Tuple(&builder, {output}); } auto computation = builder.Build().ConsumeValueOrDie(); @@ -547,16 +539,16 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } - LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralSlice(*result_literal, index)); + LiteralTestUtil::ExpectR0Equal(165.0, + LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { // Test passing in an invalid number of arguments. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y"); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); @@ -571,8 +563,8 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { // Test passing in an argument with the wrong shape. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - builder.Neg(x); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); @@ -588,8 +580,8 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { // Test passing in an invalid result layout parameter. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - builder.Neg(x); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); @@ -611,7 +603,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { // Try to run a trivial computation on every device on the system. If a // specific device is not supported, check that the right error is returned. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { if (!local_client_->device_ordinal_supported(d)) { @@ -638,7 +630,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { // Try running computations on devices with device ordinal values which do not // exist. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); auto execute_status = @@ -655,7 +647,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // Run a computation on a specific stream on each device on the system. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { @@ -691,7 +683,7 @@ XLA_TEST_F(LocalClientExecuteTest, wrong_stream.Init(); XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_stream(&wrong_stream)); @@ -708,7 +700,7 @@ XLA_TEST_F(LocalClientExecuteTest, TestAllocator allocator(wrong_platform); XlaBuilder builder(TestName()); - auto y = builder.ConstantR0(123.0f); + ConstantR0(&builder, 123.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), @@ -721,7 +713,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { // Try to run a computation on a stream that has not been initialized. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); se::StreamExecutor* executor = @@ -744,26 +736,26 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); + Select(ConstantR0(&builder, false), tuple12, tuple21); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); + LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, + LiteralSlice(*tuple_literal, {0})); + LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, + LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); Shape argument_layout = ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); @@ -779,6 +771,10 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); + ASSERT_IS_OK(local_client_->mutable_backend() + ->BorrowStream(0) + .ValueOrDie() + ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); @@ -848,15 +844,40 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { Literal::CreateR0(123456789000LL).get()})); } +XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { + XlaBuilder builder(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto in = Infeed(&builder, shape); + auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); + Add(in, constant); + + std::unique_ptr result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + result = ShapedBufferToLiteral(ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), /*arguments=*/{})); + })); + + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *Literal::CreateR1({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); + + // Join the thread. + thread.reset(); + + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); +} + // TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. // 2017-10-18. XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); - auto in = builder.Infeed(shape); - auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); - auto sum = builder.Add(in, constant); - builder.Outfeed(sum, shape, /*outfeed_config=*/""); + auto in = Infeed(&builder, shape); + auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); + auto sum = Add(in, constant); + Outfeed(sum, shape, /*outfeed_config=*/""); std::unique_ptr thread( tensorflow::Env::Default()->StartThread( @@ -891,8 +912,8 @@ void BM_LocalClientOverhead(int num_iters) { // Use a tiny add operation as the computation. XlaBuilder builder("Add"); auto shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto x = builder.Parameter(0, shape, "x"); - builder.Add(x, x); + auto x = Parameter(&builder, 0, shape, "x"); + Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); auto buffer = @@ -900,8 +921,10 @@ void BM_LocalClientOverhead(int num_iters) { ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer)); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, + buffer)); const int kWarmups = 2; @@ -911,11 +934,8 @@ void BM_LocalClientOverhead(int num_iters) { std::unique_ptr executable = executable_status.ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); - stream.Init(); - ExecutableRunOptions run_options; - run_options.set_allocator(&allocator).set_stream(&stream); + run_options.set_allocator(&allocator).set_stream(stream.get()); for (int i = 0; i < kWarmups; ++i) { auto result = executable->Run({&buffer}, run_options); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 88797a7d0a7d0567b3a380c5fb1ad0c0ee875587..c31ba0e713a45d18b60bfdb9a47545cf34220333 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -189,7 +189,19 @@ StatusOr LocalClientTestBase::ExecuteLocally( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, local_client_->Compile(computation, argument_layouts, build_options)); - return executable->Run(arguments, run_options); + TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options)); + + auto device_ordinal = + build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal(); + auto* stream = run_options.stream(); + if (!stream) { + stream = local_client_->mutable_backend() + ->BorrowStream(device_ordinal) + .ValueOrDie() + .get(); + } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + return std::move(ret); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index c0c02e584c2348f64a9d7d0800038f5ca67a2171..cdf70ee4185be2ecd9dcb2d21fbd98c2ab6cc0ad 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -30,8 +30,8 @@ class LogTest : public ClientLibraryTestBase {}; XLA_TEST_F(LogTest, LogZeroValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); - builder.Log(x); + auto x = ConstantR3FromArray3D(&builder, Array3D(3, 0, 0)); + Log(x); ComputeAndCompareR3(&builder, Array3D(3, 0, 0), {}, ErrorSpec(0.0001)); @@ -42,8 +42,8 @@ TEST_F(LogTest, LogTenValues) { 5.0, 6.0, -7.0, -8.0, 9.0}; XlaBuilder builder(TestName()); - auto x = builder.ConstantR1(input); - builder.Log(x); + auto x = ConstantR1(&builder, input); + Log(x); std::vector expected; expected.reserve(input.size()); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 7df45bebebdd3eb2e71f27d831a8e2ac9e3b5f7c..1b3bc9d5040e1382f534e00ea2679ebbd48ceb59 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -52,9 +52,9 @@ class MapTest : public ClientLibraryTestBase { // 1.0f ---------/ XlaComputation CreateAdderToOne() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = mapped_builder.ConstantR0(1.0); - mapped_builder.Add(x, one); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&mapped_builder, 1.0); + Add(x, one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -62,9 +62,9 @@ class MapTest : public ClientLibraryTestBase { XlaComputation CreateMax() { XlaBuilder b(TestName()); - auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - b.Max(lhs, rhs); + auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Max(lhs, rhs); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -75,8 +75,8 @@ class MapTest : public ClientLibraryTestBase { template XlaComputation CreateScalarOne() { XlaBuilder mapped_builder("scalar_one"); - (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - mapped_builder.ConstantR0(1); + (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + ConstantR0(&mapped_builder, 1); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -89,9 +89,9 @@ class MapTest : public ClientLibraryTestBase { // 2.0f ---------/ XlaComputation CreateMulByTwo() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto two = mapped_builder.ConstantR0(2.0); - mapped_builder.Mul(x, two); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto two = ConstantR0(&mapped_builder, 2.0); + Mul(x, two); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -107,10 +107,10 @@ class MapTest : public ClientLibraryTestBase { // 1.0f ---------/ XlaComputation CreateAdderToOneTimesItself() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = mapped_builder.ConstantR0(1.0); - auto adder_to_one = mapped_builder.Add(x, one); - mapped_builder.Mul(x, adder_to_one); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&mapped_builder, 1.0); + auto adder_to_one = Add(x, one); + Mul(x, adder_to_one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -125,10 +125,10 @@ class MapTest : public ClientLibraryTestBase { XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation, float n) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto map = builder.Map({x}, embedded_computation, {}); - auto constant_n = builder.ConstantR0(n); - builder.Add(map, constant_n); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto map = Map(&builder, {x}, embedded_computation, {}); + auto constant_n = ConstantR0(&builder, n); + Add(map, constant_n); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -138,9 +138,9 @@ class MapTest : public ClientLibraryTestBase { // defined by (x, y) -> x > y. XlaComputation CreateGt() { XlaBuilder b("Gt"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - b.Gt(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Gt(x, y); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -155,11 +155,11 @@ class MapTest : public ClientLibraryTestBase { // z {R0F32} ---------------/ XlaComputation CreateTernaryAdder() { XlaBuilder mapped_builder("TernaryAdder"); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); - auto xy = mapped_builder.Add(x, y); - mapped_builder.Add(xy, z); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z"); + auto xy = Add(x, y); + Add(xy, z); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -173,8 +173,8 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -187,8 +187,8 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -202,8 +202,8 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -216,8 +216,8 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateScalarOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -229,8 +229,8 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateScalarOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -243,8 +243,8 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOneTimesItself(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -259,9 +259,9 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - builder.Map({map1}, CreateMulByTwo(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); + Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -276,9 +276,9 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - builder.Map({map1}, CreateMulByTwo(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); + Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -292,8 +292,8 @@ TEST_F(MapTest, MapEachElemPlusOneR2) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0, 1}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -319,10 +319,10 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed3 = CreateMapPlusN(embed1, 4.0); XlaBuilder embed4_builder("embed4"); - auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); - auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); - auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); - embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); + auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x"); + auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {}); + auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {}); + Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); auto embed4 = embed4_status.ConsumeValueOrDie(); @@ -330,11 +330,11 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed5 = CreateMapPlusN(embed2, 6.0); XlaBuilder builder(TestName()); - auto constant_42 = builder.ConstantR0(42.0); - auto constant_7 = builder.ConstantR0(7.0); - auto map_42 = builder.Map({constant_42}, embed5, {}); - auto map_7 = builder.Map({constant_7}, embed4, {}); - builder.Add(map_42, map_7); + auto constant_42 = ConstantR0(&builder, 42.0); + auto constant_7 = ConstantR0(&builder, 7.0); + auto map_42 = Map(&builder, {constant_42}, embed5, {}); + auto map_7 = Map(&builder, {constant_7}, embed4, {}); + Add(map_42, map_7); ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); } @@ -351,9 +351,10 @@ TEST_F(MapTest, MapBinaryAdder) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), + {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -374,10 +375,10 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), - {0, 1}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1}); Array2D expected(2, 2); expected(0, 0) = 11; @@ -400,10 +401,10 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), - {0, 1, 2}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -425,10 +426,10 @@ TEST_F(MapTest, MapTernaryAdder) { std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -440,7 +441,8 @@ TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. XlaBuilder b(TestName()); auto gt = CreateGt(); - b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt, {0}); + Map(&b, {ConstantR1(&b, {1, 20}), ConstantR1(&b, {10, 2})}, gt, + {0}); ComputeAndCompareR1(&b, {false, true}, {}); } @@ -449,15 +451,15 @@ TEST_F(MapTest, NestedBinaryMap) { { // max_with_square(x) = do max(x, x^2) via a map. XlaBuilder b("max_with_square"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - b.Map({x, b.Mul(x, x)}, CreateMax(), {}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + Map(&b, {x, Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } XlaBuilder b(TestName()); - auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); - b.Map({input}, max_with_square, {0}); + auto input = ConstantR1(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); + Map(&b, {input}, max_with_square, {0}); ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); } @@ -468,9 +470,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); - sub_builder->Add(x, y); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y"); + Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = @@ -482,16 +484,15 @@ TEST_F(MapTest, MapOperantionWithBuildError) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, error_add, {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT( - computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all @@ -507,9 +508,9 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - sub_builder->Pow(x, y); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y"); + Pow(x, y); auto power = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = Literal::CreateR0(2.0f); @@ -519,9 +520,9 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, power, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, {param0_data.get(), param1_data.get()}, @@ -534,9 +535,9 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - sub_builder->Sub(y, x); // note that this is y - x, not x - y + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y"); + Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = Literal::CreateR0(2.0f); @@ -546,9 +547,9 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, sub_opposite, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); @@ -560,16 +561,16 @@ TEST_F(MapTestWithFullOpt, MapSquare) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - sub_builder->Mul(x, x); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + Mul(x, x); auto square = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = Literal::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param0}, square, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 27fd36e06acdc589f3a84ad561164e4a33b93506..17b1807f44a457786906afc15d8d410f6cf2d4cd 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -56,11 +56,11 @@ TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { using T = TypeParam; XlaBuilder builder("exp_2x2"); - auto data = builder.ConstantR2FromArray2D({ - {1.0f, 0.0f}, // row 0 - {-1.0f, 0.5f}, // row 1 - }); - builder.Exp(data); + auto data = ConstantR2FromArray2D(&builder, { + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 + }); + Exp(data); std::unique_ptr expected = Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 @@ -76,20 +76,20 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { // add_half(x) = x + 0.5 XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); - auto half = builder.ConstantR0(static_cast(0.5)); - builder.Add(x_value, half); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({}), "x_value"); + auto half = ConstantR0(&builder, static_cast(0.5)); + Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); } XlaBuilder builder("map_2x2"); - auto data = builder.ConstantR2FromArray2D({ - {1.0f, 0.0f}, // row 0 - {-1.0f, 0.5f}, // row 1 - }); - auto map = builder.Map({data}, add_half, {0, 1}); + auto data = ConstantR2FromArray2D(&builder, { + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 + }); + Map(&builder, {data}, add_half, {0, 1}); std::unique_ptr expected = Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 @@ -100,15 +100,15 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { using T = TypeParam; XlaBuilder builder("max_2x2"); - auto lhs = builder.ConstantR2FromArray2D({ - {7.0f, 2.0f}, // row 0 - {3.0f, -4.0f}, // row 1 - }); - auto rhs = builder.ConstantR2FromArray2D({ - {5.0f, 6.0f}, // row 0 - {1.0f, -8.0f}, // row 1 - }); - auto max = builder.Max(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, { + {7.0f, 2.0f}, // row 0 + {3.0f, -4.0f}, // row 1 + }); + auto rhs = ConstantR2FromArray2D(&builder, { + {5.0f, 6.0f}, // row 0 + {1.0f, -8.0f}, // row 1 + }); + Max(lhs, rhs); std::unique_ptr expected = Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 @@ -137,9 +137,9 @@ class TestLinspaceMaxParametric XlaBuilder builder( tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - auto max = builder.Max(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Max(lhs, rhs); Array2D expected(rows, cols); for (int row = 0; row < rows; ++row) { @@ -208,23 +208,23 @@ class MatOpsDotAddTest rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); auto lhs_mat_arg = lhs_arg; if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); + lhs_mat_arg = Transpose(lhs_mat_arg, {1, 0}); } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); + auto rhs_arg = Parameter(&builder, 1, rhs_shape, "rhs"); + auto result = Dot(lhs_mat_arg, rhs_arg); Array2D expected; if (add_lhs) { - result = builder.Add(result, lhs_arg); + result = Add(result, lhs_arg); if (transpose) { expected = Array2D({{47.0f, 52.0f}, {71.0f, 78.0f}}); } else { expected = Array2D({{35.0f, 39.0f}, {81.0f, 89.0f}}); } } else { - result = builder.Add(result, rhs_arg); + result = Add(result, rhs_arg); if (transpose) { expected = Array2D({{56.0f, 61.0f}, {80.0f, 87.0f}}); } else { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 0791a71aacf7614286fe964623a3172a174d4722..e576f000ef23e761d6fa818457eec2144d4bcb00 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -33,9 +33,10 @@ class SliceTest : public ClientLibraryTestBase {}; XLA_TEST_F(SliceTest, Slice2D) { XlaBuilder builder("slice_2d"); - auto original = builder.ConstantR2( + auto original = ConstantR2( + &builder, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); - builder.Slice(original, {2, 1}, {4, 3}, {1, 1}); + Slice(original, {2, 1}, {4, 3}, {1, 1}); Array2D expected({{8.0f, 9.0f}, {11.0f, 12.0f}}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -45,8 +46,8 @@ XLA_TEST_F(SliceTest, Slice3D) { XlaBuilder builder("slice_3d"); Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); - auto original = builder.ConstantR3FromArray3D(array_3d); - builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, array_3d); + Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); Array3D expected_3d({{{2.0f}}, {{6.0f}}}); ComputeAndCompareR3(&builder, expected_3d, {}, ErrorSpec(0.000001)); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 39f9bbaa92592e94352f0e6b9d4534a39d65c6f9..6597748c8d1f45391799dbe384a5afc0284de2dd 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -204,10 +204,10 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { Literal::CreateR0(1.0)), Literal::MakeTupleOwned(Literal::CreateR0(3.0), Literal::CreateR0(4))); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); + *Literal::MakeTupleOwned(Literal::CreateR0(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -215,27 +215,27 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloModule m fused_computation { - p = f32[] parameter(0) - multiply = f32[] multiply(p, p) - less-than = pred[] less-than(p, multiply) - ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) } ENTRY PredFloatMOF { - p0 = f32[] parameter(0) - fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation - gte0 = pred[] get-tuple-element(fusion), index=0 - gte1 = f32[] get-tuple-element(fusion), index=1 - const = f32[] constant(0) - ROOT select = f32[] select(gte0, gte1, const) + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) })"; auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR0(2.0); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateR0(4.0))); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -266,10 +266,291 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR1({1.0, 2.0, 3.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); +} + +const char* const kScalarOps = R"( + HloModule m + + Add { + lhsadd = f32[] parameter(0) + rhsadd = f32[] parameter(1) + ROOT add = f32[] add(lhsadd, rhsadd) + } + + Max { + lhsmax = f32[] parameter(0) + rhsmax = f32[] parameter(1) + ROOT max = f32[] maximum(lhsmax, rhsmax) + } +)"; + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned(Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(1.17549e-38) + r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max + r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add + ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) + tuple(p0, r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) + tuple(r1, mul, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR2({{25, 36}, {49, 64}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={} + mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1) + ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) + tuple(r1, mul, mul2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR1({14, 22}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR3( + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + init1 = f32[] parameter(1) + init2 = f32[] parameter(2) + r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add + r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + i = f32[] parameter(1) + j = f32[] parameter(2) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = Literal::CreateR0(5); + auto init2 = Literal::CreateR0(6); + std::unique_ptr result = ExecuteNoHloPasses( + std::move(module), {param.get(), init1.get(), init2.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR2({{167, 172}, {176, 180}}), + Literal::CreateR2({{6, 6}, {6, 8}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { + p0 = f16[2,2,2]{2,1,0} parameter(0) + convert = f32[2,2,2]{2,1,0} convert(p0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) + tuple(r1, r2, p0) + } + + ENTRY reduce { + p = f16[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); + *Literal::MakeTupleOwned( + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}}), + Literal::CreateR3({{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index ce295b832d79e4f00656f2893c2ba1162693dd73..2e5081bbcb64ea9416c5a9731dba43891ecceedf 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - b.Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - b.Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,16 +123,16 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - b.Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); - b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); + Pad(AddParam(Array4D(2, 0, 3, 2), &b), + AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); } @@ -147,8 +147,8 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), - r4_padding_on_dim0_dim1_); + Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(1.5); @@ -166,8 +166,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - r4_padding_on_dim0_dim1_); + Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); expected->Fill(pad_value); @@ -208,8 +208,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(AddParam(*input, &b), - AddParam(*Literal::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,8 +254,8 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(AddParam(*input, &b), - AddParam(*Literal::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -275,8 +275,8 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { }); input->FillWithYX(input_xy); - b.Pad(AddParam(*input, &b), b.ConstantR0(35), - r4_padding_on_dim0_dim1_); + Pad(AddParam(*input, &b), ConstantR0(&b, 35), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(35); @@ -294,16 +294,16 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { // Since bool is currently not well supported, use Broadcast operation to // create the operand for Pad. - auto input = b.Broadcast(b.ConstantR0(true), {1, 1, 3, 2}); + auto input = Broadcast(ConstantR0(&b, true), {1, 1, 3, 2}); auto padded = - b.Pad(input, b.ConstantR0(false), r4_padding_on_dim0_dim1_); + Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. auto zeros = MakeUnique>(2, 3, 3, 2); auto ones = MakeUnique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); - b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); + Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(0); @@ -329,7 +329,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -351,7 +351,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -376,7 +376,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -403,7 +403,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -430,7 +430,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -446,12 +446,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 838f1b4e2f0f0e0871ec717bdeefcbbc653397e3..2620063aa492902a705690d28d8124d16184d635 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); ComputeAndCompareR0(&builder, 3.14159f, {param0_data.get()}, ErrorSpec(0.0001f)); @@ -58,7 +58,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -71,7 +71,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); ComputeAndCompareR1(&builder, {3.14f, -100.25f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -84,8 +84,9 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter( - 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), "param0"); + Parameter(&builder, 0, + ShapeUtil::MakeShape(U8, {static_cast(str.size())}), + "param0"); ComputeAndCompareR1U8(&builder, str, {param0_data.get()}); } @@ -97,7 +98,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); ComputeAndCompareR2(&builder, Array2D(3, 0), {param0_data.get()}, ErrorSpec(0.01f)); @@ -110,7 +111,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); Array2D expected_array( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); @@ -124,25 +125,25 @@ XLA_TEST_F(ParamsTest, TwoParameters) { std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); // Use both parameters // // {1, 2} + {10, 20} = {11, 22} - auto sum = builder.Add(param0, param1); - sum = builder.Add(param0, param1); + auto sum = Add(param0, param1); + sum = Add(param0, param1); // Use only the second parameter again, to show that it can be used // twice and to make the computation asymmetric in the two // parameters to test that the parameters are not swapped. // // {11, 22} * {10, 20} = {110, 440} - auto prod = builder.Mul(sum, param1); + Mul(sum, param1); ComputeAndCompareR1(&builder, {110, 440}, {param0_data.get(), param1_data.get()}, @@ -157,7 +158,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { client_->TransferToServer(*literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); + Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); ASSERT_NE(computation_status.status(), Status::OK()); @@ -169,12 +170,12 @@ XLA_TEST_F(ParamsTest, UnusedParameter) { std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + Parameter(&builder, 0, literal0->shape(), "param0"); std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + Parameter(&builder, 1, literal1->shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -194,14 +195,14 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); - auto param2 = builder.Parameter(2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); // This add is unused. - builder.Add(param1, param2); + Add(param1, param2); - builder.Neg(param0); + Neg(param0); ComputeAndCompareR1( &builder, {-1, -2}, @@ -215,7 +216,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector init_value = {{0, 1}}; init_value.resize(size); - XlaOp sum_handle = builder.ConstantR1(init_value); + XlaOp sum_handle = ConstantR1(&builder, init_value); std::vector sum = {{0, 1}}; sum.resize(size); @@ -233,8 +234,8 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); - sum_handle = builder.Add(sum_handle, param); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + sum_handle = Add(sum_handle, param); } std::vector param_data; @@ -260,7 +261,7 @@ XLA_TEST_F(ParamsTest, XlaBuilder builder(TestName()); std::vector> param_data_owner; - XlaOp sum_handle = builder.ConstantR0(0.0f); + XlaOp sum_handle = ConstantR0(&builder, 0.0f); float target = 0.0; constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { @@ -268,8 +269,8 @@ XLA_TEST_F(ParamsTest, std::unique_ptr literal = Literal::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); - sum_handle = builder.Add(sum_handle, param); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + sum_handle = Add(sum_handle, param); } std::vector param_data; @@ -291,7 +292,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( XlaBuilder builder(TestName()); std::vector> param_data_owner; - XlaOp sum_handle = builder.ConstantR1({0, 0}); + XlaOp sum_handle = ConstantR1(&builder, {0, 0}); int32 target = 0; constexpr int kParamCount = 3000; std::vector params; @@ -300,17 +301,17 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); params.push_back(param); - sum_handle = builder.Add(sum_handle, param); + sum_handle = Add(sum_handle, param); } std::vector outputs; for (int i = 0; i < kParamCount; ++i) { - outputs.push_back(builder.Add(params[i], sum_handle)); + outputs.push_back(Add(params[i], sum_handle)); } - builder.Tuple(outputs); + Tuple(&builder, outputs); std::vector param_data; param_data.reserve(param_data_owner.size()); @@ -356,7 +357,7 @@ XLA_TEST_F(ParamsTest, std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); params.push_back(param); parameter_shapes.push_back(literal->shape()); } @@ -367,11 +368,11 @@ XLA_TEST_F(ParamsTest, param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); XlaOp bool_param = - builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); params.push_back(bool_param); parameter_shapes.push_back(bool_literal->shape()); - auto init = builder.Tuple(params); + auto init = Tuple(&builder, params); // Create a computation for the condition: while(bool_param). Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); @@ -379,8 +380,8 @@ XLA_TEST_F(ParamsTest, { XlaBuilder builder("condition"); auto condition_parameter = - builder.Parameter(0, while_shape, "condition_parameter"); - builder.GetTupleElement(condition_parameter, kParamCount); + Parameter(&builder, 0, while_shape, "condition_parameter"); + GetTupleElement(condition_parameter, kParamCount); condition = builder.Build().ConsumeValueOrDie(); } @@ -389,27 +390,27 @@ XLA_TEST_F(ParamsTest, XlaComputation body; { XlaBuilder builder("body"); - auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + auto body_parameter = Parameter(&builder, 0, while_shape, "body_parameter"); std::vector updates; for (int i = 0; i < kParamCount; ++i) { - auto add = builder.Add(builder.GetTupleElement(body_parameter, i), - builder.ConstantR1({1, 1})); + auto add = Add(GetTupleElement(body_parameter, i), + ConstantR1(&builder, {1, 1})); updates.push_back(add); } // Add bool parameter. - updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + updates.push_back(GetTupleElement(body_parameter, kParamCount)); - builder.Tuple(updates); + Tuple(&builder, updates); body = builder.Build().ConsumeValueOrDie(); } - auto loop = builder.While(condition, body, init); + auto loop = While(condition, body, init); std::vector outputs; for (int i = 0; i < kParamCount; ++i) { - outputs.push_back(builder.GetTupleElement(loop, i)); + outputs.push_back(GetTupleElement(loop, i)); } - builder.Tuple(outputs); + Tuple(&builder, outputs); std::vector param_data; param_data.reserve(param_data_owner.size()); @@ -433,10 +434,10 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3}); - auto input = builder.Parameter(0, tuple_shape, "input"); - auto lhs = builder.GetTupleElement(input, 0); - auto rhs = builder.GetTupleElement(input, 1); - builder.Add(lhs, rhs); + auto input = Parameter(&builder, 0, tuple_shape, "input"); + auto lhs = GetTupleElement(input, 0); + auto rhs = GetTupleElement(input, 1); + Add(lhs, rhs); std::unique_ptr data = client_ @@ -457,7 +458,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - builder.Parameter(0, literal->shape(), "input"); + Parameter(&builder, 0, literal->shape(), "input"); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -469,7 +470,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - builder.Parameter(0, literal->shape(), "input"); + Parameter(&builder, 0, literal->shape(), "input"); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -478,7 +479,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { std::unique_ptr literal = Literal::CreateR2({ - {1, 3}, {2, 4}, + {1, 3}, + {2, 4}, }); const Shape original = literal->shape(); { @@ -494,9 +496,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { } // Use the original shape in building the computation. XlaBuilder builder(TestName()); - auto input = builder.Parameter(0, original, "input"); + auto input = Parameter(&builder, 0, original, "input"); // Use the slice operator to get an off-diagonal element. - builder.Slice(input, {0, 1}, {1, 2}, {1, 1}); + Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 77159efb26f3b7dd4918f24305f7269a2d6ff647..5c351b2d113709105244de4aafa49d7cc535ced1 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -29,64 +29,63 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare( - bool lhs, bool rhs, bool expected, - XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(bool lhs, bool rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; TEST_F(PredTest, ConstantR0PredTrue) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(true); + ConstantR0(&builder, true); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, ConstantR0PredFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(false); + ConstantR0(&builder, false); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &XlaBuilder::Eq); + TestCompare(true, false, false, &Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &XlaBuilder::Ne); + TestCompare(true, false, true, &Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &XlaBuilder::Le); + TestCompare(true, false, false, &Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &XlaBuilder::Lt); + TestCompare(true, false, false, &Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &XlaBuilder::Ge); + TestCompare(true, false, true, &Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &XlaBuilder::Gt); + TestCompare(true, false, true, &Gt); } TEST_F(PredTest, ConstantR1Pred) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, false, true}); + ConstantR1(&builder, {true, false, false, true}); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{false, true, true}, {true, false, false}}); + ConstantR2(&builder, {{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { { 011 }, { 100 } @@ -96,44 +95,44 @@ TEST_F(PredTest, ConstantR2Pred) { TEST_F(PredTest, AnyR1True) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {true, false}); + Any(a); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR1False) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {false, false}); + Any(a); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR1VacuouslyFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {}); + Any(a); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR2True) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({ - {false, false, false}, - {false, false, false}, - {false, false, true}, - }); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR2(&builder, { + {false, false, false}, + {false, false, false}, + {false, false, true}, + }); + Any(a); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR2False) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({ - {false, false, false}, - {false, false, false}, - {false, false, false}, - }); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR2(&builder, { + {false, false, false}, + {false, false, false}, + {false, false, false}, + }); + Any(a); ComputeAndCompareR0(&builder, false, {}); } diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 1a2de6937c3e134852a730f62f7b56417cf49b28..8e163e885d0d6315341c213577a3beb0180b679a 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -53,8 +53,8 @@ template std::unique_ptr PrngTest::UniformTest( T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { XlaBuilder builder(TestName()); - builder.RngUniform( - builder.ConstantR0(a), builder.ConstantR0(b), + RngUniform( + ConstantR0(&builder, a), ConstantR0(&builder, b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); SetSeed(seed); @@ -141,9 +141,9 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, int32 sample_size = range_size * expected_count; XlaBuilder builder(TestName()); - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(range_size), - ShapeUtil::MakeShape(S32, {sample_size})); + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, range_size), + ShapeUtil::MakeShape(S32, {sample_size})); SetSeed(seed); auto actual = @@ -184,9 +184,10 @@ XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. auto build_sum_rng = [this](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); - auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input"); - b->Add(x, b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), - ShapeUtil::MakeShape(F32, {}))); + auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input"); + Add(x, + RngUniform(ConstantR0(b.get(), 0), ConstantR0(b.get(), 1), + ShapeUtil::MakeShape(F32, {}))); return b->BuildAndNoteError(); }; @@ -196,9 +197,9 @@ XLA_TEST_F(PrngTest, MapUsingRng) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); - builder.Map({param0}, fn, {0}); + Map(&builder, {param0}, fn, {0}); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -226,9 +227,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { // Build a U[0,1) computation. auto build_computation = [this]() { XlaBuilder builder(TestName()); - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(1), - ShapeUtil::MakeShape(F32, {10})); + RngUniform(ConstantR0(&builder, 0), ConstantR0(&builder, 1), + ShapeUtil::MakeShape(F32, {10})); return builder.Build(); }; @@ -282,8 +282,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); - builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), - ShapeUtil::MakeShape(F32, {10})); + RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), + ShapeUtil::MakeShape(F32, {10})); SetSeed(42); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); @@ -294,9 +294,9 @@ XLA_TEST_F(PrngTest, RngUniformCrash) { XlaBuilder builder(TestName()); // This used to crash XLA during LLVM IR generation for CPUs. - auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(1000 * 1000), - ShapeUtil::MakeShape(S32, {})); + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 1000 * 1000), + ShapeUtil::MakeShape(S32, {})); SetSeed(0); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index f95e75648343aa88bd7c39de4ee9f387f2b60506..526a38e8d1dbed9cdd4a31bfbec49bc5c6bb174b 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -31,8 +31,8 @@ class QueryInferredShapeTest : public ClientLibraryTestBase {}; TEST_F(QueryInferredShapeTest, OnePlusOneShape) { XlaBuilder builder("one_plus_one"); - auto one = builder.ConstantR0(1.0); - auto result = builder.Add(one, one); + auto one = ConstantR0(&builder, 1.0); + auto result = Add(one, one); StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index c0a2c0ca4cb8414e0771a541b9f963f9aedc8376..9052b188ed09a715b6ad7c3a40dc853d02cdd70c 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -73,7 +73,7 @@ ENTRY reduce.1 { } )"; - return tools::Parse(hlo_string); + return ParseHloString(hlo_string); } // TODO(b/72454718): XLA:GPU does not support executing code compiled without diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index b311785449f1774c3bc1e4d7ad35c2866e3b4061..4c1aa121067eed465c6128ea7a34e0284f7af43e 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -233,9 +233,9 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { std::unique_ptr a_literal = Literal::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); - builder.ReducePrecision(a, exponent_bits, mantissa_bits); + ReducePrecision(a, exponent_bits, mantissa_bits); ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); } @@ -256,15 +256,15 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // Abs doesn't affect resolution. - auto abs = builder.Abs(a); + auto abs = Abs(a); // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the // reduce-precision operation showed up in the correct place in the // graph. - builder.Log(abs); + Log(abs); // Insert precision-reduction after the Abs(x) operation, rounding that // result to exactly 1.0f. @@ -285,11 +285,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass after operation fusion, suffixing kAbs operations. This // should not see into the fusion nodes and thus should not affect the @@ -311,11 +311,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass after operation fusion, suffixing kFusion operations. auto reduce_precision_pass = execution_options_.mutable_debug_options() @@ -335,11 +335,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass suffixing fusion nodes containing kCos operations. This // should have no effect. @@ -360,11 +360,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass suffixing fusion nodes containing kAbs operations. This // should see the kAbs operation within the above fusion node. diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index d671d40456a276a44b462f390c95aa4af301263a..c9f57cbb16729627a5e9ad3d49438295a286989e 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -89,9 +89,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); std::vector input_data(element_count); for (int64 i = 0; i < element_count; ++i) { @@ -118,20 +118,20 @@ class ReduceTest : public ClientLibraryTestBase { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); - auto input_par = builder.Parameter(0, input_shape, "input"); + auto input_par = Parameter(&builder, 0, input_shape, "input"); auto pred_values = - builder.Eq(input_par, builder.ConstantR1(element_count, 1)); + Eq(input_par, ConstantR1(&builder, element_count, 1)); XlaOp init_value; XlaComputation reduce; if (and_reduce) { - init_value = builder.ConstantR0(true); + init_value = ConstantR0(&builder, true); reduce = CreateScalarAndComputation(&builder); } else { - init_value = builder.ConstantR0(false); + init_value = ConstantR0(&builder, false); reduce = CreateScalarOrComputation(&builder); } - builder.Reduce(pred_values, init_value, reduce, - /*dimensions_to_reduce=*/{0}); + Reduce(pred_values, init_value, reduce, + /*dimensions_to_reduce=*/{0}); std::unique_ptr input_literal = Literal::CreateR1(input_data); std::unique_ptr input_global_data = @@ -156,21 +156,21 @@ class ReduceTest : public ClientLibraryTestBase { int64 major = 0) { XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto input_pred = builder.Eq(input, builder.ConstantR0(1)); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto input_pred = Eq(input, ConstantR0(&builder, 1)); XlaOp init_value; XlaComputation reduce_op; if (and_reduce) { - init_value = builder.ConstantR0(true); + init_value = ConstantR0(&builder, true); reduce_op = CreateScalarAndComputation(&builder); } else { - init_value = builder.ConstantR0(false); + init_value = ConstantR0(&builder, false); reduce_op = CreateScalarOrComputation(&builder); } - builder.Reduce(input_pred, init_value, reduce_op, - /*dimensions_to_reduce=*/{0}); + Reduce(input_pred, init_value, reduce_op, + /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(0, 1); @@ -202,9 +202,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -230,9 +230,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -287,10 +287,10 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduction_function = reduction_function_generator(&builder); const Shape input_shape = ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(initial_value); - builder.Reduce(input, zero, reduction_function, - /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, initial_value); + Reduce(input, zero, reduction_function, + /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillUnique(initial_value); @@ -442,10 +442,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); - builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Log(input); + Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -473,11 +473,11 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); - auto transpose = builder.Transpose(log_, {1, 0}); - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Log(input); + auto transpose = Transpose(log_, {1, 0}); + Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -505,10 +505,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); - XlaOp input = builder.Parameter(0, input_shape, "input"); - XlaOp zero = builder.ConstantR0(0.0); - XlaOp transpose = builder.Transpose(input, /*permutation=*/{1, 0, 2}); - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + XlaOp input = Parameter(&builder, 0, input_shape, "input"); + XlaOp zero = ConstantR0(&builder, 0.0); + XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); + Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, MakeFakeLiteral(input_shape)); @@ -522,11 +522,11 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Tanh(input); - auto reshape = builder.Reshape(log_, {rows, cols}); - builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Tanh(input); + auto reshape = Reshape(log_, {rows, cols}); + Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); @@ -568,9 +568,9 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) { XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); - auto scalar = builder.ConstantR0(42.0); - auto broadcasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broadcasted, builder.ConstantR0(0.0f), add, {0, 1}); + auto scalar = ConstantR0(&builder, 42.0); + auto broadcasted = Broadcast(scalar, {500, 500}); + Reduce(broadcasted, ConstantR0(&builder, 0.0f), add, {0, 1}); float expected = 42.0f * static_cast(500 * 500); ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -580,9 +580,9 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); - auto scalar = builder.ConstantR0(42.0); - auto broadcasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broadcasted, builder.ConstantR0(0.0f), max, {0, 1}); + auto scalar = ConstantR0(&builder, 42.0); + auto broadcasted = Broadcast(scalar, {500, 500}); + Reduce(broadcasted, ConstantR0(&builder, 0.0f), max, {0, 1}); float expected = 42.0f; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -595,8 +595,8 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = Literal::CreateR2FromArray2D(input); - builder.Reduce(builder.ConstantLiteral(*input_literal), - builder.ConstantR0(FLT_MIN), max, {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), + ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( [&](int64, int64, float* v) { input_max = std::max(input_max, *v); }); @@ -610,8 +610,8 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = Literal::CreateR2FromArray2D(input); - builder.Reduce(builder.ConstantLiteral(*input_literal), - builder.ConstantR0(FLT_MAX), min, {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), + ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; input.Each( @@ -625,10 +625,9 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto min = CreateScalarMinComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); auto initial_value = - builder.ConstantR0(std::numeric_limits::max()); + ConstantR0(&builder, std::numeric_limits::max()); - builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, min, - {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -638,19 +637,18 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto max = CreateScalarMaxComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); auto initial_value = - builder.ConstantR0(std::numeric_limits::min()); + ConstantR0(&builder, std::numeric_limits::min()); - builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, max, - {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); std::vector expected = {6.f, 15.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -659,9 +657,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); ComputeAndCompareR0(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4)); } @@ -669,9 +667,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); std::vector expected = {5.f, 7.f, 9.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -679,9 +677,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1, 2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); std::vector expected = {21.f, 21.f, 21.f, 21.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -689,9 +687,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); std::vector expected = {20.f, 28.f, 36.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -699,9 +697,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1, 2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); float expected = 21.0f * 4.0; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -709,9 +707,9 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); // clang-format off Array2D expected({ @@ -724,9 +722,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); // clang-format off Array2D expected({ @@ -741,9 +739,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); // clang-format off Array2D expected({ @@ -827,10 +825,10 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { client_->TransferToServer(*input_literal).ConsumeValueOrDie(); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto sum = builder.Reduce(input_activations, builder.ConstantR0(0.0f), - add, GetParam().reduce_dims); + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, + GetParam().reduce_dims); auto expected = ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims, @@ -871,14 +869,14 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); - auto a = builder.ConstantR0(2.0f); - auto a2 = builder.Abs(a); + auto a = ConstantR0(&builder, 2.0f); + auto a2 = Abs(a); std::unique_ptr b_literal = Literal::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = builder.Parameter(0, b_literal->shape(), "b"); - auto max = builder.Reduce(b, a2, max_f32, {0}); + auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); } @@ -900,13 +898,13 @@ class ReduceInitializerTest : public ReduceTest { XlaComputation max_fn = CreateScalarMaxComputation( primitive_util::NativeToPrimitiveType(), &builder); - auto init = builder.ConstantR0(initializer); + auto init = ConstantR0(&builder, initializer); std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = Literal::CreateR1(input_arr); auto input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, - max_fn, {0}); + Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, + max_fn, {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -939,15 +937,15 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { XLA_TEST_F(ReduceTest, ReduceIdentity) { XlaBuilder builder(TestName()); Shape single_float = ShapeUtil::MakeShape(F32, {}); - builder.Parameter(0, single_float, "lhs-unused"); - builder.Parameter(1, single_float, "rhs-used"); + Parameter(&builder, 0, single_float, "lhs-unused"); + Parameter(&builder, 1, single_float, "rhs-used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); - builder.Reduce(builder.Parameter(0, operand_shape, "operand"), - builder.Parameter(1, single_float, "init"), - computation_status.ValueOrDie(), {0}); + Reduce(Parameter(&builder, 0, operand_shape, "operand"), + Parameter(&builder, 1, single_float, "init"), + computation_status.ValueOrDie(), {0}); float operand[] = {42.0f}; float init = 58.5f; diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index ee02f09625ed8947d5d5c3b1d9a4cde6d83e3d5c..741974480c6a862a7794aa6257f131a5893e963d 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -72,9 +72,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, Padding padding) { auto init = CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarAddComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } void ReduceWindowMax(const XlaOp& input, @@ -82,9 +82,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_strides, Padding padding) { auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarMaxComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + ReduceWindow(input, init, + CreateScalarMaxComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } void ReduceWindowMin(const XlaOp& input, @@ -92,9 +92,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_strides, Padding padding) { auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarMinComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } XlaBuilder builder_; @@ -106,10 +106,10 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); - builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(FloatType(), &builder_), - /*window_dimensions=*/{1, 2}, - /*window_strides=*/{1}, Padding::kValid); + ReduceWindow(input, init_value, + CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{1, 2}, + /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) << builder_.first_error(); ASSERT_THAT(builder_.first_error().error_message(), @@ -122,10 +122,9 @@ TEST_P(ReduceWindowTest, R0ReduceWindow) { CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); const auto init = CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarAddComputation(FloatType(), &builder_), - /*window_dimensions=*/{}, - /*window_strides=*/{}, Padding::kSame); + ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{}, + /*window_strides=*/{}, Padding::kSame); ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, ErrorSpec(0.00001)); } @@ -306,13 +305,13 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Padding padding = Padding::kValid; const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), - CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); + auto lhs = Parameter(b.get(), 0, scalar, "lhs"); + auto rhs = Parameter(b.get(), 1, scalar, "rhs"); + Min(Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow( + ReduceWindow( input, CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), reduce_fn, @@ -356,12 +355,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 1.0f; - }; - TF_EXPECT_OK(arg_literal->Populate(generator)); - + auto arg_literal = MakeUnique(shape); + arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); Padding padding = Padding::kValid; @@ -371,13 +366,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); - auto out_generator = - [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 27.0f; - }; - TF_EXPECT_OK(expected->Populate(out_generator)); - + auto expected = MakeUnique(result_shape); + expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -551,7 +541,7 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - XlaOp input = builder_.Broadcast( + XlaOp input = Broadcast( CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; @@ -636,7 +626,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -977,11 +967,11 @@ TEST_P(R3ReduceWindowTest, Add) { &b, ¶meter); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); auto expected = ReferenceUtil::ReduceWindow3DAdd( /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, @@ -1118,7 +1108,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -1315,7 +1305,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 36d763b0f7f4267ede076c0b25cfaf9654e96e0d..bebd814fa8b863428750dc12a93d1ef5ad7e6685 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -39,8 +39,8 @@ class ReplayTest : public ClientLibraryTestBase {}; TEST_F(ReplayTest, TwoPlusTwoReplay) { // Make 2+2 computation. XlaBuilder builder(TestName()); - auto two = builder.ConstantR0(2); - builder.Add(two, two); + auto two = ConstantR0(&builder, 2); + Add(two, two); XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. @@ -70,9 +70,9 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Make computation. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "y"); + Add(x, y); XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. @@ -111,13 +111,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { // As above, but with map(+2) over some constant array. XlaBuilder plus_two_builder("plus two"); auto input = - plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); - plus_two_builder.Add(input, plus_two_builder.ConstantR0(2)); + Parameter(&plus_two_builder, 0, ShapeUtil::MakeShape(S32, {}), "input"); + Add(input, ConstantR0(&plus_two_builder, 2)); XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); XlaBuilder mapper_builder(TestName()); - auto original = mapper_builder.ConstantR1({1, 2, 3}); - mapper_builder.Map({original}, plus_two, {0}); + auto original = ConstantR1(&mapper_builder, {1, 2, 3}); + Map(&mapper_builder, {original}, plus_two, {0}); XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index da1b588ec41cef711412367e89b2a9b1029bca71..5812fe442b25da1b7e34494d00fe8025d29b2802 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -44,11 +44,11 @@ using ReshapeMotionTest = ClientLibraryTestBase; TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{2, 3, 5}, {7, 11, 13}}); - auto b = builder.ConstantR2({{17, 19}, {23, 29}, {31, 37}}); - auto c = builder.Reshape(a, {6}); - auto d = builder.Reshape(b, {6}); - auto e = builder.Mul(c, d); + auto a = ConstantR2(&builder, {{2, 3, 5}, {7, 11, 13}}); + auto b = ConstantR2(&builder, {{17, 19}, {23, 29}, {31, 37}}); + auto c = Reshape(a, {6}); + auto d = Reshape(b, {6}); + Mul(c, d); ComputeAndCompareR1(&builder, {34, 57, 115, 203, 341, 481}, {}); } diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index a4580cd71d46ad0a0186eddd51291f9c322b6f49..d3d6c3c7d703161e433740acbbd58d51ba1434af 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -59,7 +59,7 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -72,7 +72,7 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = Literal::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -85,7 +85,7 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = Literal::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -101,8 +101,8 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{}); + auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(1.0f); @@ -117,34 +117,28 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); - auto a = builder.Neg(parameter); - builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + auto a = Neg(parameter); + Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = Literal::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { +XLA_TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D input_array(0, 3); auto input_literal = Literal::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = @@ -152,23 +146,20 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { +XLA_TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D input_array(3, 0); auto input_literal = Literal::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -181,7 +172,7 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -194,25 +185,21 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Splits an empty vector into an empty matrix. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { +XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, - /*new_sizes=*/{2, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); auto expected_literal = Literal::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -226,27 +213,23 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, - /*new_sizes=*/{2, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); auto expected_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { +XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); auto expected_literal = Literal::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -260,8 +243,8 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = Literal::CreateFromArray(*expected); @@ -277,8 +260,8 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = Literal::CreateFromArray(*expected); @@ -286,18 +269,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 0x4 array with XlaBuilder::Transpose. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { +XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Transpose(parameter, {1, 0}); + Transpose(parameter, {1, 0}); auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -311,7 +290,7 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Transpose(parameter, {1, 0}); + Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = Literal::CreateFromArray(*expected); @@ -319,36 +298,29 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -363,8 +335,8 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = Literal::CreateFromArray(*expected); @@ -372,18 +344,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -398,8 +366,8 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = Literal::CreateFromArray(expected); @@ -424,8 +392,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); auto expected_literal = Literal::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); @@ -439,8 +407,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); auto expected_literal = Literal::CreateR2({{10, 11, 12}, {15, 16, 17}, {20, 21, 22}, @@ -459,8 +427,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); auto expected_literal = Literal::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); @@ -474,8 +442,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); auto expected_literal = Literal::CreateR2({{10, 20, 30}, {40, 11, 21}, {31, 41, 12}, @@ -494,8 +462,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); @@ -527,7 +495,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, @@ -552,8 +520,8 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); auto expected_literal = Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); @@ -575,7 +543,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); - b.Reshape(parameter, dimensions, {}); + Reshape(parameter, dimensions, {}); auto expected_literal = Literal::CreateR0(83.0f); ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, @@ -589,7 +557,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); - b.Reshape(parameter, {}, {}); + Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), ::testing::HasSubstr("not a permutation of the operand dimensions")); @@ -601,7 +569,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); - b.Reshape(parameter, {1}, {}); + Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } @@ -637,7 +605,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, @@ -671,7 +639,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off auto expected_literal = Literal::CreateR4({ @@ -698,7 +666,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off auto expected_literal = Literal::CreateR4({ @@ -728,7 +696,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); @@ -750,7 +718,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); @@ -773,8 +741,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, - /*new_sizes=*/{5, 60}); + Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -800,8 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, - /*new_sizes=*/{7, 2, 3, 5}); + Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, + /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; @@ -833,8 +801,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{1, 2, 3, 4}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } @@ -848,8 +816,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, - /*new_sizes=*/{2, 4, 3, 1}); + Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, + /*new_sizes=*/{2, 4, 3, 1}); // clang-format off auto expected_2x4x3x1 = Literal::CreateR4( @@ -882,8 +850,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -911,8 +879,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -940,8 +908,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -970,8 +938,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -999,8 +967,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index e7bd142dc9ddefbd8bebfb77d72218d662645c31..662bc42224851ac19c690129f525953e6d410a55 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -87,7 +87,7 @@ TEST_P(FloatReverseTest, Reverses) { XlaBuilder builder(TestName()); auto a = AddParam(*input_literal, &builder); - builder.Rev(a, spec.reversal); + Rev(a, spec.reversal); std::unique_ptr expected = input_literal->CloneToUnique(); std::vector output_indices(spec.input_dims.size()); @@ -127,7 +127,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { }}); // clang-format on - b.Rev(b.ConstantR4FromArray4D(input), {0, 3}); + Rev(ConstantR4FromArray4D(&b, input), {0, 3}); // clang-format off Array4D expected({{ @@ -163,7 +163,7 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { }); // clang-format on - b.Rev(b.ConstantR4FromArray4D(input), {0, 1}); + Rev(ConstantR4FromArray4D(&b, input), {0, 1}); // clang-format off Array4D expected({ diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 308d3fc78a51e63c0e3db8c0cda18caf11f665bd..3afd8c8fc88a3879cc524c2d1680e8b176b55f81 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -44,74 +44,75 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare( - NativeT lhs, NativeT rhs, bool expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { XlaBuilder builder(TestName()); - builder.ConstantR0(2.1f); + ConstantR0(&builder, 2.1f); ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(2.1f)); + Neg(ConstantR0(&builder, 2.1f)); ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(2)); + Neg(ConstantR0(&builder, 2)); ComputeAndCompareR0(&builder, -2, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + Add(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); + Add(ConstantR0(&builder, 2), ConstantR0(&builder, 5)); ComputeAndCompareR0(&builder, 7, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + Add(ConstantR0(&builder, 35), ConstantR0(&builder, 57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + Add(ConstantR0(&builder, 35), ConstantR0(&builder, 57)); ComputeAndCompareR0(&builder, 92, {}); } @@ -120,7 +121,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { XlaBuilder builder(TestName()); const uint64 a = static_cast(1) << 63; const uint64 b = a + 1; - builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + Add(ConstantR0(&builder, a), ConstantR0(&builder, b)); ComputeAndCompareR0(&builder, a + b, {}); } @@ -129,37 +130,36 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { XlaBuilder builder(TestName()); const int64 a = static_cast(1) << 62; const int64 b = a - 1; - builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + Add(ConstantR0(&builder, a), ConstantR0(&builder, b)); ComputeAndCompareR0(&builder, a + b, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(0.25), - builder.ConstantR0(3.5)); + Add(ConstantR0(&builder, 0.25), ConstantR0(&builder, 3.5)); ComputeAndCompareR0(&builder, 3.75, {}); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + Sub(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { XlaBuilder builder(TestName()); - builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); + Sub(ConstantR0(&builder, 2), ConstantR0(&builder, 5)); ComputeAndCompareR0(&builder, -3, {}); } XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { XlaBuilder builder(TestName()); - auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); - builder.ConvertElementType(a, F32); + auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a"); + ConvertElementType(a, F32); int64 value = 3LL << 35; std::unique_ptr a_literal = Literal::CreateR0(value); @@ -171,9 +171,8 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { XlaBuilder builder(TestName()); - builder.Mul(builder.Mul(builder.ConstantR0(2.1f), - builder.ConstantR0(5.5f)), - builder.ConstantR0(0.5f)); + Mul(Mul(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)), + ConstantR0(&builder, 0.5f)); ComputeAndCompareR0(&builder, 5.775f, {}, error_spec_); } @@ -190,7 +189,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { for (int32 x : data) { for (int32 y : data) { XlaBuilder builder(TestName()); - builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + Mul(ConstantR0(&builder, x), ConstantR0(&builder, y)); // Signed integer overflow is undefined behavior in C++. Convert the input // integers to unsigned, perform the multiplication unsigned, and convert @@ -209,7 +208,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { for (uint32 x : data) { for (uint32 y : data) { XlaBuilder builder(TestName()); - builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + Mul(ConstantR0(&builder, x), ConstantR0(&builder, y)); uint32 expected = x * y; ComputeAndCompareR0(&builder, expected, {}); @@ -219,9 +218,8 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XlaBuilder builder(TestName()); - builder.Mul( - builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), - builder.ConstantR0(1)); + Mul(Mul(ConstantR0(&builder, 2), ConstantR0(&builder, 5)), + ConstantR0(&builder, 1)); ComputeAndCompareR0(&builder, 10, {}); } @@ -239,10 +237,10 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { std::unique_ptr c_data = client_->TransferToServer(*c_literal).ConsumeValueOrDie(); - XlaOp a = builder.Parameter(0, a_literal->shape(), "a"); - XlaOp b = builder.Parameter(1, b_literal->shape(), "b"); - XlaOp c = builder.Parameter(2, c_literal->shape(), "c"); - builder.Mul(builder.Mul(a, b), c); + XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, {a_data.get(), b_data.get(), c_data.get()}, @@ -251,14 +249,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); + Div(ConstantR0(&builder, 5.0f), ConstantR0(&builder, 2.5f)); ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); + Rem(ConstantR0(&builder, 2.5f), ConstantR0(&builder, 5.0f)); ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } @@ -281,8 +279,8 @@ class DivS32Test : public ClientLibraryTestBase, XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { DivS32Params p = GetParam(); XlaBuilder builder(TestName()); - builder.Div(builder.ConstantR0(p.dividend), - builder.ConstantR0(p.divisor)); + Div(ConstantR0(&builder, p.dividend), + ConstantR0(&builder, p.divisor)); ComputeAndCompareR0(&builder, p.quotient, {}); } @@ -290,8 +288,8 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { DivS32Params p = GetParam(); XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(p.dividend), - builder.ConstantR0(p.divisor)); + Rem(ConstantR0(&builder, p.dividend), + ConstantR0(&builder, p.divisor)); ComputeAndCompareR0(&builder, p.remainder, {}); } @@ -305,7 +303,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR0(&builder, p.quotient, {dividendd.get(), divisord.get()}); @@ -320,7 +318,7 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR0(&builder, p.remainder, {dividendd.get(), divisord.get()}); @@ -367,10 +365,10 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { XlaBuilder builder(TestName()); XlaOp dividend = - builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend"); XlaOp divisor = - builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); - builder.Div(dividend, divisor); + Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor"); + Div(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); } @@ -408,10 +406,10 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { XlaBuilder builder(TestName()); XlaOp dividend = - builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend"); XlaOp divisor = - builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); - builder.Rem(dividend, divisor); + Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor"); + Rem(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); } @@ -439,8 +437,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); - builder.Rem(x, builder.ConstantR0(80000)); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); + Rem(x, ConstantR0(&builder, 80000)); std::unique_ptr literal = Literal::CreateR0(87919); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); @@ -451,15 +449,15 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { XlaBuilder builder(TestName()); // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). - builder.Div(builder.ConstantR0(0xFFFFFFFE), - builder.ConstantR0(2)); + Div(ConstantR0(&builder, 0xFFFFFFFE), + ConstantR0(&builder, 2)); ComputeAndCompareR0(&builder, 0x7FFFFFFF, {}); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); + Rem(ConstantR0(&builder, 11), ConstantR0(&builder, 3)); ComputeAndCompareR0(&builder, 2, {}); } @@ -468,7 +466,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x && y, {}); } @@ -479,7 +477,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x & y, {}); } @@ -490,7 +488,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x & y, {}); } @@ -501,7 +499,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x || y, {}); } @@ -512,7 +510,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x | y, {}); } @@ -523,7 +521,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x | y, {}); } @@ -533,7 +531,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, !x, {}); } @@ -542,7 +540,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) { XLA_TEST_F(ScalarComputationsTest, NotS32) { for (int32 x : {-1, 0, 1}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, ~x, {}); } @@ -551,7 +549,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) { XLA_TEST_F(ScalarComputationsTest, NotU32) { for (uint32 x : {0, 1, 2}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, ~x, {}); } @@ -559,18 +557,18 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) { XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { XlaBuilder builder(TestName()); - builder.Select(builder.ConstantR0(true), // The predicate. - builder.ConstantR0(123.0f), // The value on true. - builder.ConstantR0(42.0f)); // The value on false. + Select(ConstantR0(&builder, true), // The predicate. + ConstantR0(&builder, 123.0f), // The value on true. + ConstantR0(&builder, 42.0f)); // The value on false. ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { XlaBuilder builder(TestName()); - builder.Select(builder.ConstantR0(false), // The predicate. - builder.ConstantR0(123.0f), // The value on true. - builder.ConstantR0(42.0f)); // The value on false. + Select(ConstantR0(&builder, false), // The predicate. + ConstantR0(&builder, 123.0f), // The value on true. + ConstantR0(&builder, 42.0f)); // The value on false. ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } @@ -579,313 +577,311 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { // templatized comparison tests. XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { XlaBuilder builder(TestName()); - builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); + Gt(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 1.0f)); ComputeAndCompareR0(&builder, true, {}); } // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Eq); + TestCompare(3, 3, true, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, &XlaBuilder::Lt); + std::numeric_limits::max(), true, &Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Ge); + TestCompare(3, 3, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); - TestCompare(5, 5, false, &XlaBuilder::Gt); - TestCompare(5, 1, true, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); + TestCompare(5, 5, false, &Gt); + TestCompare(5, 1, true, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); - TestCompare(0, std::numeric_limits::max(), true, - &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); + TestCompare(0, std::numeric_limits::max(), true, &Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); + TestCompare(2.0, 1.3, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); + TestCompare(2.0, 1.3, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); + TestCompare(2.0, 1.9, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); + TestCompare(3.5, 3.5, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); + TestCompare(1.0, 5.2, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &XlaBuilder::Le); + TestCompare(2.0, 1.2, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); + TestCompare(9.0, 7.2, false, &Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); + TestCompare(-0.0, 0.0, false, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); + TestCompare(0.0, INFINITY, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); + TestCompare(-0.0, 0.0, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); + TestCompare(0.0, INFINITY, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { XlaBuilder builder(TestName()); - builder.Exp(builder.ConstantR0(2.0f)); + Exp(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { XlaBuilder builder("log"); - builder.Log(builder.ConstantR0(2.0f)); + Log(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { XlaBuilder builder(TestName()); - builder.Tanh(builder.ConstantR0(2.0f)); + Tanh(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { XlaBuilder builder(TestName()); - builder.Tanh(builder.ConstantR0(2.0)); + Tanh(ConstantR0(&builder, 2.0)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { XlaBuilder builder(TestName()); - builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); + Pow(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 3.0f)); ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, 5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 3, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(2), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, 2), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 2, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(-5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, -5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, -1, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 3, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(2), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 2), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 2, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(0), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 0), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 1, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(5.0f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, 5.0f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(2.5f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, 2.5f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(-5.0f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, -5.0f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 2.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &XlaBuilder::Min); + TestMinMax(10, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &XlaBuilder::Min); + TestMinMax(-100, 3, -100, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &XlaBuilder::Max); + TestMinMax(10, 3, 10, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &XlaBuilder::Max); + TestMinMax(-100, 3, 3, &Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &XlaBuilder::Min); + TestMinMax(large, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &XlaBuilder::Min); + TestMinMax(0, 5, 0, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &XlaBuilder::Max); + TestMinMax(large, 3, large, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &XlaBuilder::Max); + TestMinMax(0, 5, 5, &Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); + TestMinMax(NAN, 3.1f, NAN, &Min); + TestMinMax(-3.1f, NAN, NAN, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); + TestMinMax(NAN, 3.1f, NAN, &Max); + TestMinMax(-3.1f, NAN, NAN, &Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. XlaBuilder b(TestName()); - b.Div( - b.Sub(b.Mul(b.ConstantR0(1), - b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), - b.Add(b.ConstantR0(7), b.ConstantR0(0)))), - b.ConstantR0(4)), - b.ConstantR0(20)); + Div(Sub(Mul(ConstantR0(&b, 1), + Mul(Sub(ConstantR0(&b, 3), ConstantR0(&b, 1)), + Add(ConstantR0(&b, 7), ConstantR0(&b, 0)))), + ConstantR0(&b, 4)), + ConstantR0(&b, 20)); ComputeAndCompareR0(&b, 0.5, {}, error_spec_); } @@ -893,30 +889,18 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. XlaBuilder b(TestName()); - b.Sub(b.Mul(b.ConstantR0(1), - b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), - b.Add(b.ConstantR0(7), b.ConstantR0(0)))), - b.ConstantR0(4)); + Sub(Mul(ConstantR0(&b, 1), + Mul(Sub(ConstantR0(&b, 3), ConstantR0(&b, 1)), + Add(ConstantR0(&b, 7), ConstantR0(&b, 0)))), + ConstantR0(&b, 4)); ComputeAndCompareR0(&b, 10, {}); } -XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - XlaBuilder builder(TestName()); - Literal zero_literal = Literal::Zero(PrimitiveType::F32); - - std::unique_ptr zero_data = - client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - - XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero"); - builder.SqrtF32(zero); - - ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); -} XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); - builder.Round(builder.ConstantR0(1.4f)); + Round(ConstantR0(&builder, 1.4f)); ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 7015e5a6a31f506d30c2629d7735482cf354455a..0a173fbbbd5cb5e5005728331561008b8b29af26 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -73,16 +73,16 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { auto operand_shape = GetParam().operand_shape; Array o(operand_shape); o.FillRandom(1.5f); - auto operand = builder_.ConstantFromArray(o); + auto operand = ConstantFromArray(&builder_, o); auto source_shape = GetParam().source_shape; Array s(source_shape); s.FillRandom(12.0f); - auto source = builder_.ConstantFromArray(s); + auto source = ConstantFromArray(&builder_, s); - builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, - GetParam().window_strides, GetParam().padding_type, - source, builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, + GetParam().window_strides, GetParam().padding_type, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5)); } @@ -197,110 +197,110 @@ INSTANTIATE_TEST_CASE_P( // Test for F32 1D array, with a zero-element input. XLA_TEST_F(SelectAndScatterTest, R1S0F32) { - const auto operand = builder_.ConstantR1({}); - const auto source = builder_.ConstantR1({}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + const auto operand = ConstantR1(&builder_, {}); + const auto source = ConstantR1(&builder_, {}); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR1(&builder_, {}, {}, ErrorSpec(1e-7)); } // Test for F32 1D array, when windows do not overlap. XLA_TEST_F(SelectAndScatterTest, R1F32) { const auto operand = - builder_.ConstantR1({1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); - const auto source = builder_.ConstantR1({34.f, 42.f}); + ConstantR1(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); + const auto source = ConstantR1(&builder_, {34.f, 42.f}); const std::vector expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } // Test for S32 1D array, when windows do not overlap and the init value is 1. XLA_TEST_F(SelectAndScatterTest, R1S32) { - const auto operand = builder_.ConstantR1({-1, 0, 6, 4, -4, 10}); - const auto source = builder_.ConstantR1({-10, 20}); + const auto operand = ConstantR1(&builder_, {-1, 0, 6, 4, -4, 10}); + const auto source = ConstantR1(&builder_, {-10, 20}); const std::vector expected = {1, 1, -9, 1, 1, 21}; - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(1), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 1), add_s32_); ComputeAndCompareR1(&builder_, expected, {}); } // Test for S32 1D array, when windows overlap with each other. XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { - const auto operand = builder_.ConstantR1({1, 9, 3, 7, 5, 6}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 9, 3, 7, 5, 6}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 76, 0, 72, 0, 0}; - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR1(&builder_, expected, {}); } // Test for S32 2D array, when windows do not overlap. XLA_TEST_F(SelectAndScatterTest, R2S32) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); - const auto source = builder_.ConstantR2({{2, 6}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); + const auto source = ConstantR2(&builder_, {{2, 6}}); Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{2, 3}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for tie breaking rule in ge_f32_. When a tie is present, the operand // that has the lower lexicographical order (smaller index) should be chosen. XLA_TEST_F(SelectAndScatterTest, R2F32Tie) { - const auto operand = builder_.ConstantR2( - {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); - const auto source = builder_.ConstantR2( - {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + const auto operand = ConstantR2( + &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); + const auto source = ConstantR2( + &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); Array2D expected( {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, - /*window_strides=*/{1, 1}, Padding::kSame, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); } // Similar to SelectAndScatterTest.R2S32 but the input is transposed. XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { - const auto operand = builder_.ConstantR2( - {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); + const auto operand = ConstantR2( + &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); const auto reshape = - builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); - const auto source = builder_.ConstantR2({{2, 6}}); + Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); + const auto source = ConstantR2(&builder_, {{2, 6}}); Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); - builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{2, 3}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for S32 2D array, when windows overlap with each other. XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); - const auto source = builder_.ConstantR2({{2, 6, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = ConstantR2(&builder_, {{2, 6, 4}}); Array2D expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{1, 1}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for S32 2D array, when the padding is Padding::kSAME. XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); - const auto source = builder_.ConstantR2({{2, 6, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = ConstantR2(&builder_, {{2, 6, 4}}); Array2D expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{2, 2}, Padding::kSame, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{2, 2}, Padding::kSame, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } @@ -308,25 +308,26 @@ XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { // with each other. XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); const auto source = - builder_.ConstantR2({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); + ConstantR2(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); Array2D expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{1, 1}, Padding::kSame, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) { - const auto operand = builder_.ConstantR2( - {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); - const auto source = builder_.ConstantR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + const auto operand = ConstantR2( + &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); + const auto source = + ConstantR2(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}}); Array2D expected( {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{1, 1}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); } @@ -342,16 +343,16 @@ TEST_F(SelectAndScatterTest, R4F32Valid) { {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}}; Array4D o(4, 6, 15, 220); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 6, 15, 220); e.FillWithPZ(pze); Array4D s(2, 2, 15, 220); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -367,16 +368,16 @@ TEST_F(SelectAndScatterTest, R4F32Overlap) { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; Array4D o(4, 5, 17, 128); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 5, 17, 128); e.FillWithPZ(pze); Array4D s(2, 2, 17, 128); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -392,16 +393,16 @@ TEST_F(SelectAndScatterTest, R4F32OverlapSmall) { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; Array4D o(4, 5, 1, 1); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 5, 1, 1); e.FillWithPZ(pze); Array4D s(2, 2, 1, 1); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -414,39 +415,39 @@ TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) { Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; Array4D o(4, 6, 4, 4); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D s(2, 2, 4, 4); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1}, {2, 3, 1, 1}, false); ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); } XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { - const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 0, 0, 53, 0, 0, 0}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(0), max_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, 0), max_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { - const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const float max_float = std::numeric_limits::max(); const std::vector expected = {max_float, max_float, max_float, 19, max_float, max_float, max_float}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(max_float), min_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, max_float), min_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 72707f224446c7585d1d90ac6681a7b38c41d5f1..59409ab26e1c19a8271318c18e19caa7b8ddc3b7 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -35,50 +35,52 @@ class SelectTest : public ClientLibraryTestBase { TEST_F(SelectTest, SelectScalarF32True) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR0(123.0f); - auto on_false = builder.ConstantR0(42.0f); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR0(&builder, 123.0f); + auto on_false = ConstantR0(&builder, 42.0f); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } TEST_F(SelectTest, SelectScalarS32True) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR0(-42); - auto on_false = builder.ConstantR0(42); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR0(&builder, -42); + auto on_false = ConstantR0(&builder, 42); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, -42, {}); } TEST_F(SelectTest, SelectScalarF32False) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto on_true = builder.ConstantR0(123.0f); - auto on_false = builder.ConstantR0(42.0f); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, false); + auto on_true = ConstantR0(&builder, 123.0f); + auto on_false = ConstantR0(&builder, 42.0f); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR1({}); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR1(&builder, {}); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR1({false, true, false, true, false}); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR1(&builder, {false, true, false, true, false}); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -88,12 +90,12 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({}); - auto v2 = builder.ConstantR1({}); - auto cmp = builder.Eq(v1, v2); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {}); + auto v2 = ConstantR1(&builder, {}); + auto cmp = Eq(v1, v2); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -102,12 +104,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); - auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); - auto cmp = builder.Eq(v1, v2); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {1, 2, 3, 4, 5}); + auto v2 = ConstantR1(&builder, {9, 2, 9, 4, 9}); + auto cmp = Eq(v1, v2); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -116,12 +120,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); - auto cmp = builder.Gt(v1, v2); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto v2 = ConstantR1(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); + auto cmp = Gt(v1, v2); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, error_spec_); @@ -140,8 +146,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + auto cmp = Gt(v1, v2); + Select(cmp, v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -181,8 +187,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + auto cmp = Gt(v1, v2); + Select(cmp, v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -192,14 +198,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, -1, 2, -2}); - auto s = builder.ConstantR0(0); - auto cmp = builder.Gt(v, s); + auto v = ConstantR1(&builder, {1, -1, 2, -2}); + auto s = ConstantR0(&builder, 0); + auto cmp = Gt(v, s); - auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = - builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, error_spec_); @@ -209,14 +215,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); - auto s = builder.ConstantR0(2.5f); - auto cmp = builder.Gt(v, s); + auto v = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f}); + auto s = ConstantR0(&builder, 2.5f); + auto cmp = Gt(v, s); - auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = - builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, error_spec_); @@ -225,10 +231,10 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(which); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, which); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -236,20 +242,20 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR1({-2.5f, 25.5f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); + auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto on_true = builder.ConstantR1({-2.5f, 25.5f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, false); + auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); + auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 52195db2aa74710b901dd7744a670764a034e96b..3e5c01d6d47cc3f3b7d46ce300fe26c5ec9e63fa 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -42,8 +42,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); Array3D expected{ {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}}; @@ -55,8 +55,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); Array3D expected{ {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}}; @@ -68,8 +68,8 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); Array3D expected{ {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}}; @@ -78,24 +78,24 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + Slice(original, {0, 0}, {0, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}); } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); - builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(0, 20)); + Slice(original, {0, 15}, {0, 20}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 5), {}); } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); - builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(3, 0)); + Slice(original, {1, 0}, {3, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(2, 0), {}); } @@ -109,8 +109,8 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {128, 128}, {256, 256}, {1, 1}); Array2D expected(128, 128); for (int row = 0; row < 128; ++row) { @@ -127,8 +127,8 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { std::iota(values.data(), values.data() + 4096, 0.0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {0, 3072}, {1, 4096}, {1, 1}); Array2D expected(1, 1024); std::iota(expected.data(), expected.data() + 1024, 3072.0); @@ -148,8 +148,8 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); } @@ -160,8 +160,8 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); XlaBuilder builder(TestName()); - auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); + auto original = ConstantR4FromArray4D(&builder, values); + Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } @@ -173,8 +173,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { auto expected_literal = Literal::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); - auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); + auto original = ConstantR4FromArray4D(&builder, values); + Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), &expected_literal->shape()); } @@ -197,11 +197,12 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); + auto literal = Literal::CreateR1(input); XlaBuilder builder(TestName()); - auto original = builder.ConstantR1(input); - builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, - {spec.slice_stride}); + auto original = Parameter(&builder, 0, literal->shape(), "p0"); + Slice(original, {spec.slice_start}, {spec.slice_limit}, + {spec.slice_stride}); // Ditto. tensorflow::gtl::InlinedVector expected; @@ -210,7 +211,9 @@ class SliceR1Test : public ClientLibraryTestBase, expected.push_back(i); } - ComputeAndCompareR1(&builder, expected, {}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); + ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -365,15 +368,18 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); + auto literal = Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(spec.layout)); - builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + auto a = Parameter(&builder, 0, literal->shape(), "p0"); + Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR2(&builder, *expected, {}); + ComputeAndCompareR2(&builder, *expected, {arg.get()}); } INSTANTIATE_TEST_CASE_P( @@ -453,17 +459,16 @@ class SliceR4Test : public ClientLibraryTestBase, void Run(const R4Spec& spec) { Array4D values(spec.input_dims[0], spec.input_dims[1], spec.input_dims[2], spec.input_dims[3]); - values.FillRandom(3.14f); + values.FillIota(3.14159); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); auto literal = Literal::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = builder.Parameter(0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, client_->TransferToServer(*literal)); - builder.Slice(parameter, spec.slice_starts, spec.slice_limits, - spec.slice_strides); + Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } }; diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 810cc25f1b5b1199984a3229909a70f9548c7dd2..20c7c30878a2821915d47bcf9fa1cc53907df9da 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -26,6 +26,7 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal @@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal, template void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } @@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( @@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution generator( @@ -107,7 +112,10 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } - std::unique_ptr literal = Literal::CreateFromShape(shape); + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); @@ -153,6 +161,9 @@ StatusOr> MakeFakeLiteralInternal( })); break; } + // Token requires no data. + case TOKEN: + break; default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); @@ -201,11 +212,13 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1(start_indices); } @@ -260,14 +273,22 @@ StatusOr> CreateLiteralForConstrainedUses( switch (use->opcode()) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr && - !ShapeUtil::Equal(needs_index->shape(), use->shape())) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + if (needs_index != nullptr) { + auto needs_index_shape = needs_index->shape(); + auto use_shape = use->shape(); + if (needs_index->opcode() == HloOpcode::kDynamicSlice) { + needs_index_shape = needs_index->operand(0)->shape(); + } + if (use->opcode() == HloOpcode::kDynamicSlice) { + use_shape = use->operand(0)->shape(); + } + if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } } needs_index = use; break; - case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = use; @@ -321,20 +342,21 @@ StatusOr> MakeConstrainedArgument( } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f483cdebea5c7c8a43e73ab57748a93c97bb78d7..a8689f64981569ceb7c8a712f8ece00c99e8cf2d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -55,16 +55,28 @@ class PseudorandomGenerator { }; // Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random. +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // // Will handle special cases such as making sure that indices used for dynamic // slices are bounded, reduces that call adds use 0 as an init value, etc. +// +// If pseudo_random is true, the generated numbers will be generated +// deterministically in a pseudo random way unless the values are constrated to +// be e.g. init values as above. If pseudo_random is false, the returned values +// will be generated in a faster way that yields less interesting data, e.g. the +// values may all be just the same value. +// +// TODO(b/79942829): Make interesting argument generation fast enough that using +// pseudo_random does not save any noticeable amount of time so that the +// parameter can be removed. StatusOr>> MakeFakeArguments( - HloModule* const module); + HloModule* const module, bool pseudo_random = true); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 59afd28a80c0fbf3df38457cd05961c883769856..8f424ae81f592bfd8accd8decb8fc363f7561c73 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -31,16 +32,16 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { XlaBuilder builder(TestName()); // Make the reduction lambda. Shape single_float = ShapeUtil::MakeShape(F32, {}); - builder.Parameter(0, single_float, "unused"); - builder.Parameter(1, single_float, "used"); + Parameter(&builder, 0, single_float, "unused"); + Parameter(&builder, 1, single_float, "used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); // Make the reduction. Shape pair_float = ShapeUtil::MakeShape(F32, {2}); - builder.Reduce(builder.Parameter(0, pair_float, "operand"), - builder.Parameter(1, single_float, "init"), - computation_status.ValueOrDie(), {0}); + Reduce(Parameter(&builder, 0, pair_float, "operand"), + Parameter(&builder, 1, single_float, "init"), + computation_status.ValueOrDie(), {0}); computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); @@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { TF_ASSERT_OK(MakeFakeArguments(&module).status()); } +XLA_TEST_F(TestUtilsTest, Token) { + auto module = ParseHloString( + R"(HloModule outfeed_module + + ENTRY InfeedToOutfeed { + token = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) + })") + .ValueOrDie(); + TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9008fa48aa7d0158bd2221791be23c128859098 --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateAfterAll({})); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction( + HloInstruction::CreateAfterAll({token0, token0, token1, token2})); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateAfterAll({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { + // Thread a token around a while loop. Token is created and consumed by a + // AfterAll instruction in the while body. + string module_string = R"( +HloModule TokenInWhileLoop + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokenInWhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + + EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_)); +} + +XLA_TEST_F(TokenHloTest, TokenInConditional) { + string module_string = R"( +HloModule TokenInConditional + +%True (param.1: token[]) -> (s32[], token[]) { + %param.1 = token[] parameter(0) + %forty_two = s32[] constant(42) + ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1) +} + +%False (param.2: s32[]) -> (s32[], token[]) { + %param.2 = s32[] parameter(0) + %new_token = token[] after-all() + ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token) +} + +ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { + %param.3 = pred[] parameter(0) + %init_token = token[] after-all() + %seven = s32[] constant(7) + %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False + ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + + { + // True case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = Literal::CreateR0(true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(42, result->Get({})); + } + + { + // False case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = Literal::CreateR0(false); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(7, result->Get({})); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 0063e7ad415e9b6718c164f415ced6fb76cbf44a..86babb58c9d4515935a5904e04e8fea1074a2812 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -41,7 +42,12 @@ class TransferManagerTest : public LocalClientTestBase { TransferManagerTest() : shape_size_fn_([this](const Shape& shape) { return transfer_manager_->GetByteSizeRequirement(shape); - }) {} + }) { + stream_ptr_ = local_client_->mutable_backend() + ->BorrowStream(stream_executor_) + .ValueOrDie(); + stream_ = stream_ptr_.get(); + } ~TransferManagerTest() override = default; @@ -53,6 +59,10 @@ class TransferManagerTest : public LocalClientTestBase { .ValueOrDie(); } + protected: + Backend::StreamPtr stream_ptr_; + se::Stream* stream_; + private: std::function shape_size_fn_; }; @@ -63,11 +73,11 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR0Equal(42, *result); } @@ -79,11 +89,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, *result); @@ -97,11 +107,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal(test_vector, *result); } @@ -113,11 +123,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_EQ(result->GetR1U8AsString(), test_string); } @@ -129,11 +139,11 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); @@ -149,11 +159,11 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); @@ -169,11 +179,11 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -183,11 +193,11 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -203,11 +213,11 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -218,11 +228,11 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -237,14 +247,162 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } +XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { + // "Copy" a token from the device. The token has no physical representation so + // no copying is actually performed, but it shouldn't fail. + // TODO(b/110532604): Add transferring the token to device when this is + // supported. + auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result)); +} + +XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { + const int64 kIterationCount = 5000; + std::unique_ptr literal1 = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-10.0f, 123.0f}).get()}); + std::unique_ptr literal2 = Literal::MakeTuple( + {Literal::CreateR0(456.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), + Literal::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-98.0f, 153.0f}).get()}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + + auto stream1 = stream_; + auto stream2 = stream_->GetOrCreateSubStream(); + + std::unique_ptr result1, result2; + + // Round trip literals through device in multiple streams asynchronously. + for (int i = 0; i < kIterationCount; ++i) { + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + device_buffer1)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + device_buffer2)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result1, + transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result2, + transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); + result1 = std::move(this_result1); + result2 = std::move(this_result2); + } + + EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); +} + +class TransferDeviceToHostBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferDeviceToHostBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + Literal::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +class TransferHostToDeviceBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferHostToDeviceBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + Literal::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +void BM_TransferDeviceToHost(int iters, int num_tuple_elements, + int array_size) { + TransferDeviceToHostBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +void BM_TransferHostToDevice(int iters, int num_tuple_elements, + int array_size) { + TransferHostToDeviceBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +BENCHMARK(BM_TransferHostToDevice) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +BENCHMARK(BM_TransferDeviceToHost) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe1e3da7eca00e128377e6e56af877868aafa836..6ebb4324f8d20ed9f8886d92b0513441685ed19b 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -38,34 +38,35 @@ class TransposeTest : public ClientLibraryTestBase { XLA_TEST_F(TransposeTest, Transpose0x0) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); } XLA_TEST_F(TransposeTest, Transpose0x42) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 42)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(42, 0), {}, error_spec_); } XLA_TEST_F(TransposeTest, Transpose7x0) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(7, 0)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(0, 7), {}, error_spec_); } TEST_F(TransposeTest, Transpose2x2) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2({ - {1.0, 2.0}, {3.0, 4.0}, - }); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2(&builder, { + {1.0, 2.0}, + {3.0, 4.0}, + }); + Transpose(lhs, {1, 0}); Array2D expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); @@ -74,16 +75,18 @@ TEST_F(TransposeTest, Transpose2x2) { XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); - auto result = builder.Transpose(operand, {1, 2, 0}); + auto operand = + ConstantR3FromArray3D(&builder, Array3D(0, 2, 3)); + Transpose(operand, {1, 2, 0}); ComputeAndCompareR3(&builder, Array3D(2, 3, 0), {}); } TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {1, 2, 0}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {1, 2, 0}); Array3D expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}}); @@ -92,8 +95,9 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {2, 1, 0}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {2, 1, 0}); Array3D expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}}); @@ -102,8 +106,9 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {0, 1, 2}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {0, 1, 2}); Array3D expected({{{1, 2, 3}, {4, 5, 6}}}); @@ -116,9 +121,9 @@ TEST_F(TransposeTest, MultiTranspose3x2) { for (int transposes = 0; transposes <= 10; ++transposes) { XlaBuilder builder("Transpose"); - auto computed = builder.ConstantR2FromArray2D(input); + auto computed = ConstantR2FromArray2D(&builder, input); for (int i = 0; i < transposes; ++i) { - computed = builder.Transpose(computed, {1, 0}); + computed = Transpose(computed, {1, 0}); } const Array2D& expected = transposes % 2 == 0 ? input : transposed; ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -130,8 +135,8 @@ TEST_F(TransposeTest, Small_1x1) { auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); XlaBuilder builder("transpose_1x1"); - auto operand = builder.ConstantR2FromArray2D(*aoperand); - builder.Transpose(operand, {1, 0}); + auto operand = ConstantR2FromArray2D(&builder, *aoperand); + Transpose(operand, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*aoperand); ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); @@ -142,8 +147,8 @@ TEST_F(TransposeTest, Small_2x2) { auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); XlaBuilder builder("transpose_2x2"); - auto operand = builder.ConstantR2FromArray2D(*aoperand); - builder.Transpose(operand, {1, 0}); + auto operand = ConstantR2FromArray2D(&builder, *aoperand); + Transpose(operand, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*aoperand); ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); @@ -162,8 +167,8 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { } XlaBuilder builder(TestName()); - auto operand = builder.ConstantR3FromArray3D(aoperand); - builder.Transpose(operand, {0, 2, 1}); + auto operand = ConstantR3FromArray3D(&builder, aoperand); + Transpose(operand, {0, 2, 1}); ComputeAndCompareR3(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 7552224f10dd97c1a0581ef0473edd58f06c28c0..ec11508891d13f8032a1ebec388c756cf6d752c7 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -54,7 +54,7 @@ XLA_TEST_F(TupleTest, TupleConstant) { Literal::CreateR1(constant_vector).get(), Literal::CreateR2(constant_matrix).get()}); - builder.ConstantLiteral(*value); + ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } @@ -68,7 +68,7 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), Literal::CreateR0(constant_scalar2).get()}); - builder.ConstantLiteral(*value); + ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } @@ -82,9 +82,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - builder.Tuple({builder.ConstantR0(constant_scalar), - builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); + Tuple(&builder, {ConstantR0(&builder, constant_scalar), + ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); auto expected = Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), @@ -97,8 +97,8 @@ XLA_TEST_F(TupleTest, TupleCreate) { XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XlaBuilder builder(TestName()); - builder.Tuple( - {builder.ConstantR0(7.0), builder.ConstantR1({})}); + Tuple(&builder, + {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), Literal::CreateR1({}).get()}); @@ -108,7 +108,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { // Tests the creation of an empty tuple. XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); - builder.Tuple({}); + Tuple(&builder, {}); auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -121,9 +121,10 @@ XLA_TEST_F(TupleTest, GetTupleElement) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, error_spec_); } @@ -131,17 +132,18 @@ XLA_TEST_F(TupleTest, GetTupleElement) { // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { XlaBuilder builder(TestName()); - auto tuple_data = builder.Tuple( - {builder.ConstantR1({}), - builder.ConstantR2FromArray2D(Array2D(0, 101))}); - builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, + {ConstantR1(&builder, {}), + ConstantR2FromArray2D(&builder, Array2D(0, 101))}); + GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { XlaBuilder builder(TestName()); - auto value = builder.ConstantR1({4.5f}); - builder.GetTupleElement(value, 1); + auto value = ConstantR1(&builder, {4.5f}); + GetTupleElement(value, 1); auto result_status = builder.Build(); EXPECT_FALSE(result_status.ok()); EXPECT_THAT( @@ -158,14 +160,15 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - auto vector_element = builder.GetTupleElement(tuple_data, 0); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + auto vector_element = GetTupleElement(tuple_data, 0); + auto matrix_element = GetTupleElement(tuple_data, 1); auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); - builder.Add(matrix_element, vector_element, - /*broadcast_dimensions=*/{1}); + Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); Array2D expected({ {2.f, 4.f, 6.f}, // row 0 @@ -185,10 +188,11 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + Tuple(&builder, + {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); auto expected = Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), Literal::CreateR1(constant_vector).get()}); @@ -206,11 +210,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { std::unique_ptr v2_data = CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&b, /*data_handle=*/&v2); - auto v1_gt = b.Gt(v1, v2); // false - auto v2_gt = b.Gt(v2, v1); // true - auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} - auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} - b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto v1_gt = Gt(v1, v2); // false + auto v2_gt = Gt(v2, v1); // true + auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} + auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} + Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = Literal::MakeTuple({Literal::CreateR0(direction).get(), Literal::CreateR0(!direction).get()}); @@ -243,22 +247,23 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0), - builder.GetTupleElement(tuple_data, 1)}); - auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); - auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0); - auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1); - auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1); - auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0); - - auto addvectors = builder.Add(vector_from_01, vector_from_10); - auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); - - builder.Add(addmatrices, addvectors, - /*broadcast_dimensions=*/{1}); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0), + GetTupleElement(tuple_data, 1)}); + auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1), + GetTupleElement(tuple_data, 0)}); + auto vector_from_01 = GetTupleElement(new_tuple01, 0); + auto vector_from_10 = GetTupleElement(new_tuple10, 1); + auto matrix_from_01 = GetTupleElement(new_tuple01, 1); + auto matrix_from_10 = GetTupleElement(new_tuple10, 0); + + auto addvectors = Add(vector_from_01, vector_from_10); + auto addmatrices = Add(matrix_from_01, matrix_from_10); + + Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); Array2D expected({ {4.f, 8.f, 12.f}, // row 0 @@ -273,12 +278,12 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + Select(ConstantR0(&builder, false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -292,22 +297,22 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { // Need to put a select in there to prevent HLO-level optimizations from // optimizing out the tuples. XlaBuilder b("sort_square"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto x2 = b.Mul(x, x); - auto x_smaller_tuple = b.Tuple({x, x2}); - auto x2_smaller_tuple = b.Tuple({x2, x}); - auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); - auto smaller = b.GetTupleElement(sorted, 0); - auto greater = b.GetTupleElement(sorted, 1); - b.Add(greater, b.Mul(b.ConstantR0(100.0f), smaller)); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto x2 = Mul(x, x); + auto x_smaller_tuple = Tuple(&b, {x, x2}); + auto x2_smaller_tuple = Tuple(&b, {x2, x}); + auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); + auto smaller = GetTupleElement(sorted, 0); + auto greater = GetTupleElement(sorted, 1); + Add(greater, Mul(ConstantR0(&b, 100.0f), smaller)); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); tuple_computation = computation_status.ConsumeValueOrDie(); } XlaBuilder b(TestName()); - auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); - b.Map({input}, tuple_computation, {0}); + auto input = ConstantR1(&b, {-1.0f, 1.0f, 2.1f}); + Map(&b, {input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } @@ -317,12 +322,12 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - builder.Select(builder.ConstantR0(true), tuple12, tuple21); + Select(ConstantR0(&builder, true), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -335,14 +340,13 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); - builder.GetTupleElement(select, 0); + auto select = Select(ConstantR0(&builder, false), tuple12, tuple21); + GetTupleElement(select, 0); ComputeAndCompareR1(&builder, vec2, {}, error_spec_); } @@ -371,19 +375,16 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto pred_tuple = builder.Tuple( - {builder.ConstantR0(true), builder.ConstantR0(false)}); - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto pred_tuple = Tuple(&builder, {ConstantR0(&builder, true), + ConstantR0(&builder, false)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - auto select1 = - builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); - auto select2 = - builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); - builder.Add(builder.GetTupleElement(select2, 0), - builder.GetTupleElement(select2, 1)); + auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21); + auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1); + Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1)); ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } @@ -395,12 +396,12 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto c1 = builder.ConstantR1(vec1); - auto c2 = builder.ConstantR1(vec2); - auto tuple12 = builder.Tuple({c1, c2}); - auto tuple21 = builder.Tuple({c2, c1}); + auto c1 = ConstantR1(&builder, vec1); + auto c2 = ConstantR1(&builder, vec2); + auto tuple12 = Tuple(&builder, {c1, c2}); + auto tuple21 = Tuple(&builder, {c2, c1}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + Select(ConstantR0(&builder, false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); @@ -409,9 +410,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { XLA_TEST_F(TupleTest, NestedTuples) { XlaBuilder builder(TestName()); - auto inner_tuple = builder.Tuple( - {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); - builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + auto inner_tuple = Tuple(&builder, {ConstantR1(&builder, {1.0, 2.0}), + ConstantR0(&builder, 42.0)}); + Tuple(&builder, {inner_tuple, ConstantR1(&builder, {22.0, 44.0})}); auto expected_v1 = Literal::CreateR1({1.0, 2.0}); auto expected_s = Literal::CreateR0(42.0); @@ -432,10 +433,10 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { Shape outer_tuple_shape = ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape}); - auto input = builder.Parameter(0, outer_tuple_shape, "input"); - auto gte0 = builder.GetTupleElement(input, 0); - auto gte1 = builder.GetTupleElement(gte0, 1); - builder.Add(gte1, builder.ConstantR1({10.0, 11.0, 12.0})); + auto input = Parameter(&builder, 0, outer_tuple_shape, "input"); + auto gte0 = GetTupleElement(input, 0); + auto gte1 = GetTupleElement(gte0, 1); + Add(gte1, ConstantR1(&builder, {10.0, 11.0, 12.0})); std::unique_ptr data = client_ @@ -463,16 +464,16 @@ XLA_TEST_F(TupleTest, ComplexTuples) { Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); Shape arg0_shape = ShapeUtil::MakeTupleShape( {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); - auto input0 = builder.Parameter(0, arg0_shape, "input0"); - auto t0 = builder.GetTupleElement(input0, 0); - auto t1 = builder.GetTupleElement(input0, 1); - auto t10 = builder.GetTupleElement(t1, 0); - auto t11 = builder.GetTupleElement(t1, 1); - auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); - auto input1 = builder.Parameter(1, c64r1, "input1"); - auto prod = builder.Mul(input1, sum, {1}); - builder.Tuple({builder.Tuple({prod, sum}), - builder.ConstantR0({123, 456})}); + auto input0 = Parameter(&builder, 0, arg0_shape, "input0"); + auto t0 = GetTupleElement(input0, 0); + auto t1 = GetTupleElement(input0, 1); + auto t10 = GetTupleElement(t1, 0); + auto t11 = GetTupleElement(t1, 1); + auto sum = Add(Add(t10, t11, {1}), t0); + auto input1 = Parameter(&builder, 1, c64r1, "input1"); + auto prod = Mul(input1, sum, {1}); + Tuple(&builder, {Tuple(&builder, {prod, sum}), + ConstantR0(&builder, {123, 456})}); } std::unique_ptr arg0 = @@ -495,7 +496,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = Literal::CreateFromShape(sum->shape()); + auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -532,8 +533,8 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index c3abe22797f5eaa76ced2ad8534bd68c32983e60..929b1ca7fb93c545265bf85fec1ed7dc845405b2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -38,8 +38,8 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsSize0TestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {}); + Abs(arg); if (primitive_util::NativeToPrimitiveType() == C64) { ComputeAndCompareR1(&builder, {}, {}); @@ -51,8 +51,8 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {-2, 25, 0, -123, inf(), -inf()}); + Abs(arg); ComputeAndCompareR1(&builder, {2, 25, 0, 123, inf(), inf()}, {}); } @@ -60,9 +60,9 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); - auto sign = builder.Sign(arg); + auto arg = ConstantR1( + &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); + Sign(arg); ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); } @@ -70,10 +70,10 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignAbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({-2, 25, 0, -123}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, abs), arg); + auto arg = ConstantR1(&builder, {-2, 25, 0, -123}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, abs), arg); ComputeAndCompareR1(&builder, {0, 0, 0, 0}, {}); } @@ -92,13 +92,13 @@ int64 UnaryOpTest::inf() { template <> void UnaryOpTest::AbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({{-2, 0}, - {0, 25}, - {0, 0}, - {-0.3f, 0.4f}, - {0, inf()}, - {-inf(), 0}}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {{-2, 0}, + {0, 25}, + {0, 0}, + {-0.3f, 0.4f}, + {0, inf()}, + {-inf(), 0}}); + Abs(arg); std::unique_ptr expected = Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); @@ -108,9 +108,10 @@ void UnaryOpTest::AbsTestHelper() { template <> void UnaryOpTest::SignTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( + auto arg = ConstantR1( + &builder, {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); - auto sign = builder.Sign(arg); + Sign(arg); std::unique_ptr expected = Literal::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); @@ -121,10 +122,10 @@ template <> void UnaryOpTest::SignAbsTestHelper() { XlaBuilder builder(TestName()); auto arg = - builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg); + ConstantR1(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, ConvertElementType(abs, C64)), arg); std::unique_ptr expected = Literal::CreateR1({0, 0, 0, 0}); @@ -145,34 +146,31 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) { XLA_TEST_F(UnaryOpTest, AbsTestR0) { XlaBuilder builder(TestName()); - auto argi = builder.ConstantR0(-5); - auto absi = builder.Abs(argi); - auto argf = builder.ConstantR0(-3.0f); - auto absf = builder.Abs(argf); - auto argf0 = builder.ConstantR0(-0.0f); - auto absf0 = builder.Abs(argf0); - auto argc = builder.ConstantR0({-0.3f, 0.4f}); - auto absc = builder.Abs(argc); - builder.Add(builder.Add(absc, absf0), - builder.Add(absf, builder.ConvertElementType(absi, F32))); + auto argi = ConstantR0(&builder, -5); + auto absi = Abs(argi); + auto argf = ConstantR0(&builder, -3.0f); + auto absf = Abs(argf); + auto argf0 = ConstantR0(&builder, -0.0f); + auto absf0 = Abs(argf0); + auto argc = ConstantR0(&builder, {-0.3f, 0.4f}); + auto absc = Abs(argc); + Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32))); ComputeAndCompareR0(&builder, 8.5f, {}); } XLA_TEST_F(UnaryOpTest, SignTestR0) { XlaBuilder builder(TestName()); - auto argi = builder.ConstantR0(-5); - auto sgni = builder.Sign(argi); // -1 - auto argf = builder.ConstantR0(-4.0f); - auto sgnf = builder.Sign(argf); // -1 - auto argf0 = builder.ConstantR0(-0.0f); - auto sgnf0 = builder.Sign(argf0); // 0 - auto argc = builder.ConstantR0({-.3, .4}); - auto sgnc = builder.Sign(argc); // (-.6, .8) - builder.Add(sgnc, builder.ConvertElementType( - builder.Add(builder.Add(sgnf0, sgnf), - builder.ConvertElementType(sgni, F32)), - C64)); + auto argi = ConstantR0(&builder, -5); + auto sgni = Sign(argi); // -1 + auto argf = ConstantR0(&builder, -4.0f); + auto sgnf = Sign(argf); // -1 + auto argf0 = ConstantR0(&builder, -0.0f); + auto sgnf0 = Sign(argf0); // 0 + auto argc = ConstantR0(&builder, {-.3, .4}); + auto sgnc = Sign(argc); // (-.6, .8) + Add(sgnc, ConvertElementType( + Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); std::unique_ptr expected = Literal::CreateR0({-2.6f, 0.8f}); @@ -194,9 +192,9 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {2, 25, 0, 123, std::numeric_limits::max()}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}); + Abs(arg); ComputeAndCompareR1( &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); @@ -204,37 +202,37 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {2, 25, 0, 123, std::numeric_limits::max()}); - auto sign = builder.Sign(arg); + auto arg = ConstantR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}); + Sign(arg); ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, abs), arg); + auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, abs), arg); ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 1}); - auto rhs = builder.ConstantR1({1, 1}); - builder.ConvertElementType(builder.Eq(lhs, rhs), S32); + auto lhs = ConstantR1(&builder, {0, 1}); + auto rhs = ConstantR1(&builder, {1, 1}); + ConvertElementType(Eq(lhs, rhs), S32); ComputeAndCompareR1(&builder, {0, 1}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 1}); - auto rhs = builder.ConstantR1({1, 1}); - builder.ConvertElementType(builder.Eq(lhs, rhs), F32); + auto lhs = ConstantR1(&builder, {0, 1}); + auto rhs = ConstantR1(&builder, {1, 1}); + ConvertElementType(Eq(lhs, rhs), F32); ComputeAndCompareR1(&builder, {0.0, 1.0}, {}); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 82d301983fc7885ef5c1c1ed05b74fc017bb7727..ea3aba6df1d3fbd492a23b280309322b8524c0bf 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -46,7 +46,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase { {{1.0, 2.0, 3.0}, // } plane 2 in dim 0 {4.0, 5.0, 6.0}}}); // clang-format on - return builder_.ConstantR3FromArray3D(x3d); + return ConstantR3FromArray3D(&builder_, x3d); } XlaBuilder builder_; @@ -56,11 +56,10 @@ class VecOpsReduceTest : public ClientLibraryTestBase { TEST_F(VecOpsReduceTest, AddReduceR1F32) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, -4.2f, {}, errspec_); } @@ -71,10 +70,9 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { std::vector input(3000); std::iota(input.begin(), input.end(), 100.0f); - auto x = builder_.ConstantR1(input); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1(&builder_, input); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); float expected = std::accumulate(input.begin(), input.end(), 0.0f); ComputeAndCompareR0(&builder_, expected, {}, errspec_); @@ -83,11 +81,10 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { TEST_F(VecOpsReduceTest, MaxReduceR1F32) { auto max_reducer = CreateScalarMax(); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 0.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, 2.6f, {}, errspec_); } @@ -95,11 +92,10 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) { TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { auto max_reducer = CreateScalarMax(); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0(4.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 4.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, 4.0f, {}, errspec_); } @@ -108,15 +104,14 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); // clang-format off - auto x = builder_.ConstantR2({ + auto x = ConstantR2(&builder_, { {1.0, 2.0, 3.0}, // | dim 0 {4.0, 5.0, 6.0}}); // | // ------ dim 1 ---------- // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); ComputeAndCompareR1(&builder_, {6.0, 15.0}, {}, errspec_); } @@ -125,13 +120,12 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); // clang-format off - auto x = builder_.ConstantR2({ + auto x = ConstantR2(&builder_, { {1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR1(&builder_, {5.0, 7.0, 9.0}, {}, errspec_); } @@ -139,9 +133,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{2}); Array2D expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}}); @@ -151,9 +144,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); Array2D expected_array( {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}}); @@ -164,9 +156,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); Array2D expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}}); @@ -176,9 +167,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1, 2}); ComputeAndCompareR1(&builder_, {21.0, 21.0, 21.0}, {}, errspec_); } @@ -186,9 +176,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 2}); ComputeAndCompareR1(&builder_, {18.0, 45.0}, {}, errspec_); } @@ -196,9 +185,8 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1}); ComputeAndCompareR1(&builder_, {15.0, 21.0, 27.0}, {}, errspec_); } @@ -206,9 +194,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1, 2}); ComputeAndCompareR0(&builder_, 63.0, {}, errspec_); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 5cce7a2bf82c1a8403536a91e67910f949ef185a..79bae22dac9599a38c73ea1dc2e6b4856395ff79 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -50,9 +50,9 @@ class VecOpsSimpleTest : public ClientLibraryTestBase { XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto exp = builder.Exp(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Exp(x); std::vector expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02, 8.1662, 9.9742, 6.7379e-03, 4.0657e-01, @@ -69,8 +69,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int i = 0; i < count; ++i) { exponents.push_back(i / static_cast(count)); } - auto x = builder.ConstantR1(exponents); - auto exp = builder.Exp(x); + auto x = ConstantR1(&builder, exponents); + Exp(x); std::vector expected; expected.reserve(exponents.size()); @@ -98,8 +98,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); - auto x = builder.ConstantR4FromArray4D(exponents); - auto exp = builder.Exp(x); + auto x = ConstantR4FromArray4D(&builder, exponents); + Exp(x); ComputeAndCompareR4(&builder, expected, {}, ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); @@ -107,9 +107,9 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.Neg(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Neg(x); std::vector expected = {-2.1, 2.6, -2.6, 4.0, -2.1, -2.3, 5.0, 0.9, 2.4, -1.6}; @@ -118,8 +118,8 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); - builder.Neg(x); + auto x = ConstantR1(&builder, {2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); + Neg(x); std::vector expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1}; ComputeAndCompareR1(&builder, expected, {}); @@ -127,59 +127,19 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {0, 1, 42, static_cast(-1), static_cast(-12)}); - builder.Neg(x); + auto x = ConstantR1( + &builder, {0, 1, 42, static_cast(-1), static_cast(-12)}); + Neg(x); std::vector expected = {0, static_cast(-1), static_cast(-42), 1, 12}; ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.SquareF32(x); - - std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, - 5.29, 25., 0.81, 5.76, 2.56}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.ReciprocalF32(x); - - std::vector expected = { - 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, - 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0, -0.0}); - auto exp = builder.SqrtF32(x); - - ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); - auto exp = builder.SqrtF32(x); - - std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); - auto x = - builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); - auto exp = builder.Pow(x, builder.ConstantR0(-.5f)); + auto x = ConstantR1(&builder, + {16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); + Pow(x, ConstantR0(&builder, -.5f)); std::vector expected = {.25, 1, .03125, 2.5, 2.23607, .009000, .900025}; @@ -191,11 +151,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Map({x, y}, add, {0}); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Map(&builder, {x, y}, add, {0}); std::vector expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; @@ -204,11 +164,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Max(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Max(x, y); std::vector expected = {2.1, -0.6, 2.6, 0.2, 3.8, 2.3, -1.8, 4.9, 1.4, 1.6}; @@ -227,7 +187,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + Max(v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -267,7 +227,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + Max(v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -275,10 +235,10 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR0(0); - auto max = builder.Max(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR0(&builder, 0); + Max(x, y); std::vector expected = {2.1, 0.0, 2.6, 0.0, 2.1, 2.3, 0.0, 0.0, 0.0, 1.6}; @@ -287,11 +247,11 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto min = builder.Min(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Min(x, y); std::vector expected = {-0.4, -2.6, -3.0, -4.0, 2.1, -2.2, -5.0, -0.9, -2.4, 0.6}; @@ -300,11 +260,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(1); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Min(builder.Max(x, zero), one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Min(Max(x, zero), one); std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -313,11 +273,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(1); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Clamp(zero, x, one); std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -326,10 +286,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR1({0.0f, 0.0f}); - auto one = builder.ConstantR1({1.0f, 1.0f}); - auto x = builder.ConstantR1({2.1, -2.6}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR1(&builder, {0.0f, 0.0f}); + auto one = ConstantR1(&builder, {1.0f, 1.0f}); + auto x = ConstantR1(&builder, {2.1, -2.6}); + Clamp(zero, x, one); std::vector expected = {1.0, 0.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -337,11 +297,11 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XlaBuilder builder(TestName()); - auto one = builder.ConstantR0(1); - auto two = builder.ConstantR0(2); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(one, x, two); + auto one = ConstantR0(&builder, 1); + auto two = ConstantR0(&builder, 2); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Clamp(one, x, two); std::vector expected = {2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0}; @@ -350,10 +310,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(10); - auto x = builder.ConstantR1({-3, 3, 9, 13}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 10); + auto x = ConstantR1(&builder, {-3, 3, 9, 13}); + Clamp(zero, x, one); std::vector expected = {0, 3, 9, 10}; ComputeAndCompareR1(&builder, expected, {}); @@ -365,9 +325,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // add_half(x) = x + 0.5 XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto half = builder.ConstantR0(0.5); - builder.Add(x_value, half); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x_value"); + auto half = ConstantR0(&builder, 0.5); + Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); @@ -378,9 +338,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // clamp(y) = clamp<0,5>(y) XlaBuilder builder("clamp"); auto y_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); - auto zero = builder.ConstantR0(0.0); - auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0(5)); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "y_value"); + auto zero = ConstantR0(&builder, 0.0); + Clamp(zero, y_value, ConstantR0(&builder, 5)); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); clamp = computation_status.ConsumeValueOrDie(); @@ -391,13 +351,13 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // mult_relu_add(z) = clamp(add_half(2 * max(z, 0))) XlaBuilder builder("mult_relu_add"); auto z_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0(0.0); - auto two = builder.ConstantR0(2.0); - auto max = builder.Max(z_value, zero); - auto mult = builder.Mul(two, max); - auto inner = builder.Map({mult}, add_half, {}); - builder.Map({inner}, clamp, {}); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "z_value"); + auto zero = ConstantR0(&builder, 0.0); + auto two = ConstantR0(&builder, 2.0); + auto max = Max(z_value, zero); + auto mult = Mul(two, max); + auto inner = Map(&builder, {mult}, add_half, {}); + Map(&builder, {inner}, clamp, {}); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); mult_relu_add = computation_status.ConsumeValueOrDie(); @@ -405,9 +365,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { XlaBuilder builder("map10"); { - auto x = builder.ConstantR1( - {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto activations = builder.Map({x}, mult_relu_add, {0}); + auto x = ConstantR1( + &builder, {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Map(&builder, {x}, mult_relu_add, {0}); } std::vector expected = {4.7, 0.5, 5.0, 0.5, 4.7, @@ -417,9 +377,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto y = builder.ConstantR0(3); - builder.Rem(x, y); + auto x = ConstantR1(&builder, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto y = ConstantR0(&builder, 3); + Rem(x, y); std::vector expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); @@ -427,9 +387,9 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({false, true}); - auto y = builder.ConstantR1({true, false}); - builder.Eq(x, y); + auto x = ConstantR1(&builder, {false, true}); + auto y = ConstantR1(&builder, {true, false}); + Eq(x, y); std::array expected = {{false, false}}; ComputeAndCompareR1(&builder, expected, {}); @@ -437,9 +397,9 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({false, true}); - auto y = builder.ConstantR1({true, false}); - builder.Ne(x, y); + auto x = ConstantR1(&builder, {false, true}); + auto y = ConstantR1(&builder, {true, false}); + Ne(x, y); std::array expected = {{true, true}}; ComputeAndCompareR1(&builder, expected, {}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index c463f3eac55e5b8ab32dc52d5a38e7840241bc58..bbd67cd8d7c433550deefc38ce28b2b732d354aa 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -55,8 +55,8 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -64,16 +64,16 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -91,8 +91,8 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -100,16 +100,16 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -122,8 +122,8 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -131,18 +131,18 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.Reduce(builder.ConstantR1(2, 1), - builder.ConstantR0(0), - CreateScalarAddComputation(S32, &builder), {0}); - builder.While(condition, body, init); + auto init = + Reduce(ConstantR1(&builder, 2, 1), ConstantR0(&builder, 0), + CreateScalarAddComputation(S32, &builder), {0}); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -154,8 +154,8 @@ TEST_F(WhileTest, WhileWithPredicateResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Ne(builder.ConstantR0(true), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Ne(ConstantR0(&builder, true), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -163,16 +163,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Or(prev, builder.ConstantR0(true)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Or(prev, ConstantR0(&builder, true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.Ne(builder.ConstantR0(false), - builder.ConstantR0(true)); - builder.While(condition, body, init); + auto init = + Ne(ConstantR0(&builder, false), ConstantR0(&builder, true)); + While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } @@ -184,17 +184,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector(0); // } -// TODO(b/29185393): does not terminate on CPU. -TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -203,10 +202,10 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -215,16 +214,16 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1({}); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, {}); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1({}); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, {}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -246,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -257,10 +256,10 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -269,16 +268,16 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1(8, 0.125f); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, 8, 0.125f); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1(8, 0.f); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, 8, 0.f); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -317,10 +316,10 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -329,20 +328,20 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1(8, 0.125f); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, 8, 0.125f); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1(8, 0.f); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, 8, 0.f); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - builder.Tuple({result}); + Tuple(&builder, {result}); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(N), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, N), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -377,22 +376,23 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto w1 = builder.GetTupleElement(prev, 1); - auto w2 = builder.GetTupleElement(prev, 2); - auto w3 = builder.GetTupleElement(prev, 3); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto w1 = GetTupleElement(prev, 1); + auto w2 = GetTupleElement(prev, 2); + auto w3 = GetTupleElement(prev, 3); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), - builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 3, 1.f), + ConstantR1(&builder, 3, 2.f), + ConstantR1(&builder, 3, 3.f)}); + auto result = While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -419,9 +419,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(N), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, N), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -430,26 +430,27 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto w1 = builder.GetTupleElement(prev, 1); - auto w2 = builder.GetTupleElement(prev, 2); - auto w3 = builder.GetTupleElement(prev, 3); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto w1 = GetTupleElement(prev, 1); + auto w2 = GetTupleElement(prev, 2); + auto w3 = GetTupleElement(prev, 3); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), - builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); - auto xla_while = builder.While(condition, body, init); - - auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), - builder.GetTupleElement(xla_while, 2)); - auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 3, 1.f), + ConstantR1(&builder, 3, 2.f), + ConstantR1(&builder, 3, 3.f)}); + auto xla_while = While(condition, body, init); + + auto add12 = + Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2)); + auto result = Add(add12, GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -474,9 +475,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -486,21 +487,21 @@ TEST_F(WhileTest, WhileWithTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -524,9 +525,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -535,21 +536,20 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto pred = builder.GetTupleElement(prev, 1); - auto new_pred = builder.Or(pred, builder.ConstantR0(true)); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto pred = GetTupleElement(prev, 1); + auto new_pred = Or(pred, ConstantR0(&builder, true)); + Tuple(&builder, {Add(iteration, ConstantR0(&builder, 1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple({builder.ConstantR0(0), - builder.Ne(builder.ConstantR0(false), - builder.ConstantR0(true))}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + Ne(ConstantR0(&builder, false), + ConstantR0(&builder, true))}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -571,9 +571,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -583,18 +583,18 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), - builder.ConstantR0(7)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Tuple(&builder, {Add(iteration, ConstantR0(&builder, 1)), + ConstantR0(&builder, 7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR0(7)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR0(&builder, 7)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -632,9 +632,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -642,9 +642,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -654,43 +654,43 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaComputation body2; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); - auto while2 = builder.While(condition2, body2, while1); + auto while2 = While(condition2, body2, while1); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -711,9 +711,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -721,9 +721,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -733,30 +733,30 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); - auto while2 = builder.While(condition2, body, while1); + auto while2 = While(condition2, body, while1); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -778,9 +778,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -788,9 +788,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -800,29 +800,29 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); - auto while2 = builder.While(condition2, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); + auto while2 = While(condition2, body, init); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -844,9 +844,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -856,29 +856,28 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); + auto prev = Parameter(&builder, 0, result_shape, "prev"); // TupleElement 0 - auto iteration = builder.GetTupleElement(prev, 0); - auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + auto iteration = GetTupleElement(prev, 0); + auto out0 = Add(iteration, ConstantR0(&builder, 1)); // TupleElement 1 - auto input = builder.GetTupleElement(prev, 1); + auto input = GetTupleElement(prev, 1); // Update. - auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32); + auto update = ConvertElementType(Broadcast(out0, {2}), F32); // Starts = iteration * 2; - auto starts = builder.Reshape( - builder.Mul(iteration, builder.ConstantR0(2)), {1}); + auto starts = Reshape(Mul(iteration, ConstantR0(&builder, 2)), {1}); // UpdateSlice. - auto out1 = builder.DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, starts); - builder.Tuple({out0, out1}); + Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -913,10 +912,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { XlaBuilder builder(TestName()); - auto prev = builder.Reshape( - builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); - builder.Gt(builder.ConstantR0(count), prev); + auto prev = Reshape( + Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {}); + Gt(ConstantR0(&builder, count), prev); return builder.Build().ConsumeValueOrDie(); }; @@ -924,22 +922,22 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, v6s32, "prev"); - auto inc = builder.ConcatInDim( - {builder.ConstantR1({1}), - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(100), - ShapeUtil::MakeShape(S32, {5}))}, - 0); - builder.Add(inc, prev); + auto prev = Parameter(&builder, 0, v6s32, "prev"); + auto inc = ConcatInDim(&builder, + {ConstantR1(&builder, {1}), + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {5}))}, + 0); + Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { XlaBuilder builder(TestName()); - auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); - builder.While(build_condition(count), body, init); + auto init = ConstantR1(&builder, {0, 0, 0, 0, 0, 0}); + While(build_condition(count), body, init); return builder.Build(); }; @@ -958,26 +956,23 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); - auto t = outer.Tuple({p, outer.ConstantR1({1, 1})}); + auto p = Parameter(&outer, 0, element_shape, "param"); + auto t = Tuple(&outer, {p, ConstantR1(&outer, {1, 1})}); TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t)); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, tuple_shape, "t"); - TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), - cond.ConstantR1({42, 42})), - &cond) - .status()); + auto cond_t = Parameter(&cond, 0, tuple_shape, "t"); + Any(Eq(GetTupleElement(cond_t, 0), ConstantR1(&cond, {42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, tuple_shape, "t"); - auto e = body.GetTupleElement(body_t, 1); - body.Tuple({e, e}); + auto body_t = Parameter(&body, 0, tuple_shape, "t"); + auto e = GetTupleElement(body_t, 1); + Tuple(&body, {e, e}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, t); + While(cond_computation, body_computation, t); auto expected_element = Literal::CreateR1({1, 1}); auto expected = @@ -993,20 +988,19 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); + auto p = Parameter(&outer, 0, element_shape, "param"); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, element_shape, "t"); - TF_ASSERT_OK( - Any(cond.Eq(cond_t, cond.ConstantR1({42, 42})), &cond).status()); + auto cond_t = Parameter(&cond, 0, element_shape, "t"); + Any(Eq(cond_t, ConstantR1(&cond, {42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto e = body.Broadcast(body.ConstantR0(1.0), {2}); + Parameter(&body, 0, element_shape, "t"); + Broadcast(ConstantR0(&body, 1.0), {2}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, @@ -1019,21 +1013,20 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); + auto p = Parameter(&outer, 0, element_shape, "param"); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, element_shape, "t"); - cond.Eq(cond_t, cond.ConstantR0(42)); + auto cond_t = Parameter(&cond, 0, element_shape, "t"); + Eq(cond_t, ConstantR0(&cond, 42)); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto tuple = - body.Tuple({body_t, body.Add(body_t, body.ConstantR0(1))}); - auto e = body.GetTupleElement(tuple, 1); + auto body_t = Parameter(&body, 0, element_shape, "t"); + auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0(&body, 1))}); + GetTupleElement(tuple, 1); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, @@ -1056,25 +1049,23 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { XlaBuilder outer("outer"); auto p = - outer.Tuple({outer.ConstantR0(0), - outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); + Tuple(&outer, {ConstantR0(&outer, 0), + Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")}); XlaBuilder cond("cond"); - auto params = cond.Parameter(0, result_shape, "prev"); - auto cond_t = cond.Add(cond.GetTupleElement(params, 1), - cond.GetTupleElement(params, 0)); - cond.Lt(cond_t, cond.ConstantR0(30)); + auto params = Parameter(&cond, 0, result_shape, "prev"); + auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0)); + Lt(cond_t, ConstantR0(&cond, 30)); XlaBuilder body("body"); - auto body_t = body.Parameter(0, result_shape, "t"); + auto body_t = Parameter(&body, 0, result_shape, "t"); - auto tuple = body.Tuple( - {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0(1)), - body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0(1))}); + Tuple(&body, {Add(GetTupleElement(body_t, 0), ConstantR0(&body, 1)), + Add(GetTupleElement(body_t, 1), ConstantR0(&body, 1))}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, @@ -1105,9 +1096,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation inner_condition; { XlaBuilder builder("inner_condition"); - auto params = builder.Parameter(0, inner_result_shape, "prev"); - auto i = builder.GetTupleElement(params, 0); - builder.Lt(i, builder.ConstantR0(7)); + auto params = Parameter(&builder, 0, inner_result_shape, "prev"); + auto i = GetTupleElement(params, 0); + Lt(i, ConstantR0(&builder, 7)); inner_condition = builder.Build().ConsumeValueOrDie(); } @@ -1116,8 +1107,8 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation outer_condition; { XlaBuilder builder("outer_condition"); - auto prev = builder.Parameter(0, outer_result_shape, "prev"); - builder.Lt(prev, builder.ConstantR0(30)); + auto prev = Parameter(&builder, 0, outer_result_shape, "prev"); + Lt(prev, ConstantR0(&builder, 30)); outer_condition = builder.Build().ConsumeValueOrDie(); } @@ -1126,12 +1117,12 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation inner_body; { XlaBuilder builder("inner_body"); - auto params = builder.Parameter(0, inner_result_shape, "prev"); - auto i = builder.GetTupleElement(params, 0); - auto result = builder.GetTupleElement(params, 1); - i = builder.Add(builder.ConstantR0(1), i); - result = builder.Add(builder.ConstantR0(2), result); - builder.Tuple({i, result}); + auto params = Parameter(&builder, 0, inner_result_shape, "prev"); + auto i = GetTupleElement(params, 0); + auto result = GetTupleElement(params, 1); + i = Add(ConstantR0(&builder, 1), i); + result = Add(ConstantR0(&builder, 2), result); + Tuple(&builder, {i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } @@ -1139,17 +1130,17 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation outer_body; { XlaBuilder builder("outer_body"); - auto prev = builder.Parameter(0, outer_result_shape, "prev"); - auto init = builder.Tuple({builder.ConstantR0(0), prev}); - auto result = builder.While(inner_condition, inner_body, init); - builder.GetTupleElement(result, 1); + auto prev = Parameter(&builder, 0, outer_result_shape, "prev"); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), prev}); + auto result = While(inner_condition, inner_body, init); + GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(outer_condition, outer_body, init); + auto init = ConstantR0(&builder, 0); + While(outer_condition, outer_body, init); ComputeAndCompareR0(&builder, 42, {}); } @@ -1167,8 +1158,8 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation condition_callee; { XlaBuilder builder("condition_callee"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Tuple({builder.Gt(builder.ConstantR0(5), prev)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Tuple(&builder, {Gt(ConstantR0(&builder, 5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } @@ -1176,9 +1167,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Call(condition_callee, {prev}); - builder.GetTupleElement(result, 0); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto result = Call(&builder, condition_callee, {prev}); + GetTupleElement(result, 0); condition = builder.Build().ConsumeValueOrDie(); } @@ -1186,16 +1177,16 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -1210,30 +1201,30 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { XlaComputation condition; { XlaBuilder builder("condition"); - auto state = builder.Parameter(0, while_shape, "state"); - builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + auto state = Parameter(&builder, 0, while_shape, "state"); + Gt(ConstantR0(&builder, 5), GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } XlaComputation body; { XlaBuilder builder("body"); - auto state = builder.Parameter(0, while_shape, "state"); - auto indvar = builder.GetTupleElement(state, 0); - auto input_0 = builder.GetTupleElement(state, 1); - auto input_1 = builder.GetTupleElement(state, 2); - auto output = builder.Tanh(builder.Dot(input_0, input_1)); - auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); - builder.Tuple({indvar_next, input_0, input_1, output}); + auto state = Parameter(&builder, 0, while_shape, "state"); + auto indvar = GetTupleElement(state, 0); + auto input_0 = GetTupleElement(state, 1); + auto input_1 = GetTupleElement(state, 2); + auto output = Tanh(Dot(input_0, input_1)); + auto indvar_next = Add(indvar, ConstantR0(&builder, 1)); + Tuple(&builder, {indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaBuilder builder(TestName()); - auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); - auto init = builder.Tuple( - {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); - auto while_instruction = builder.While(condition, body, init); - builder.GetTupleElement(while_instruction, 3); + auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix"); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), matrix_input, + matrix_input, matrix_input}); + auto while_instruction = While(condition, body, init); + GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN(auto param_value, client_->TransferToServer(*Literal::CreateR2( @@ -1264,9 +1255,9 @@ void BM_WhileLoop(int num_iters) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, loop_state_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(loop_limit)); + auto prev = Parameter(&builder, 0, loop_state_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, loop_limit)); condition = builder.Build().ConsumeValueOrDie(); } @@ -1274,29 +1265,29 @@ void BM_WhileLoop(int num_iters) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, loop_state_shape, "prev"); + auto prev = Parameter(&builder, 0, loop_state_shape, "prev"); // TupleElement 0 - auto iteration = builder.GetTupleElement(prev, 0); - auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + auto iteration = GetTupleElement(prev, 0); + auto out0 = Add(iteration, ConstantR0(&builder, 1)); // TupleElement 1 - auto input = builder.GetTupleElement(prev, 1); + auto input = GetTupleElement(prev, 1); // Update. - auto one = builder.ConstantR0(1.0); - auto update = builder.Broadcast(one, {1, 1024, 1024}); + auto one = ConstantR0(&builder, 1.0); + auto update = Broadcast(one, {1, 1024, 1024}); // Starts = iteration * 2; - auto starts = builder.ConstantR1({0, 0, 0}); + auto starts = ConstantR1(&builder, {0, 0, 0}); // UpdateSlice. - auto out1 = builder.DynamicUpdateSlice(input, update, starts); - builder.Tuple({out0, out1}); + auto out1 = DynamicUpdateSlice(input, update, starts); + Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. XlaBuilder builder("while"); - auto zero = builder.ConstantR0(0.0); - auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); - auto init = builder.Tuple({builder.ConstantR0(0), input}); - builder.While(condition, body, init); + auto zero = ConstantR0(&builder, 0.0); + auto input = Broadcast(zero, {seq_len, 1024, 1024}); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), input}); + While(condition, body, init); auto computation = builder.Build().ConsumeValueOrDie(); std::unique_ptr executable = diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 3c9a01653c67203cbc962a3d3d967142f7a2102c..7dba058d407758b42365c3b6883e5e0891e1ab6c 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -128,20 +128,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, se::StreamExecutor* executor = backend->default_stream_executor(); DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); + TF_ASSERT_OK_AND_ASSIGN( + Backend::StreamPtr stream_ptr, + backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer lhs_arg, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -153,9 +156,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, &executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); - TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, - backend->BorrowStream(backend->default_device_ordinal())); ExecutableRunOptions exec_run_options; exec_run_options.set_stream(stream_ptr.get()); exec_run_options.set_allocator(backend->memory_allocator()); @@ -168,6 +168,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, auto execution_result, executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, &hlo_execution_profile)); + TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; *profile_output = @@ -187,9 +188,9 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { ClientLibrary::GetOrCreateLocalClient(platform)); XlaBuilder builder(TestName()); - auto result = builder.Tanh(builder.Add( - builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); + Tanh(Add( + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -239,9 +240,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { EXPECT_TRUE(HasTrops(tanh_profile)); } -// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo -// instructions "interior" to while nodes. -XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { +XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { const int64 size = 256; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = @@ -255,30 +254,30 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { XlaComputation condition; { XlaBuilder builder("condition"); - auto state = builder.Parameter(0, while_result_shape, "state"); - auto iteration = builder.GetTupleElement(state, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto state = Parameter(&builder, 0, while_result_shape, "state"); + auto iteration = GetTupleElement(state, 0); + Gt(ConstantR0(&builder, 5), iteration); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } XlaComputation body; { XlaBuilder builder("body"); - auto state = builder.Parameter(0, while_result_shape, "state"); - auto matrix = builder.GetTupleElement(state, 1); - auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), - builder.ConstantR0(1)); - builder.Tuple({next_iteration, builder.Add(matrix, matrix)}); + auto state = Parameter(&builder, 0, while_result_shape, "state"); + auto matrix = GetTupleElement(state, 1); + auto next_iteration = + Add(GetTupleElement(state, 0), ConstantR0(&builder, 1)); + Tuple(&builder, {next_iteration, Add(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaBuilder builder(TestName()); auto initial_while_state = - builder.Tuple({builder.ConstantR0(0), - builder.Parameter(0, matrix_shape, "initial_value")}); - auto while_result = builder.While(condition, body, initial_while_state); - builder.Add(builder.GetTupleElement(while_result, 1), - builder.Parameter(1, matrix_shape, "other_value")); + Tuple(&builder, {ConstantR0(&builder, 0), + Parameter(&builder, 0, matrix_shape, "initial_value")}); + auto while_result = While(condition, body, initial_while_state); + Add(GetTupleElement(while_result, 1), + Parameter(&builder, 1, matrix_shape, "other_value")); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -336,8 +335,11 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { new_argv[argc] = strdup("--xla_hlo_profile"); // Fusion can change the Hlo instructions that show up in the final Hlo - // executable, so block it here. - new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion"); + // executable, so block it here. Also block the WhileLoopInvariantCodeMotion + // pass, otherwise a while loop is transformed and we could not match the + // original name in the ProfileWhileComputation test. + new_argv[argc + 1] = strdup( + "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion"); return {argc + 2, new_argv}; } diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a9f2915b458b1816926de727b3da21982d06f6c0..a075195618c42aaa11f7b1c17730e67889a2c308 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -49,6 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { } // Unfortunately Google's internal benchmark infrastructure has a // different API than Tensorflow's. + testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) base::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 78ab2dccafc37aa4f93da0b8d5b39a779ddd5db8..e4a052c8f1c0009619c3a94606f6384d04006e4e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,11 +36,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -63,10 +62,9 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -84,12 +82,12 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -138,7 +136,7 @@ tf_cc_binary( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -165,12 +163,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -184,12 +180,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -202,13 +197,12 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc index fe03a6e7bdfe99877c250fe1ae22beee4c8018a2..14d01b5bfb067cc39abc4d6e0605007624b6e0ae 100644 --- a/tensorflow/compiler/xla/tools/convert_computation.cc +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/env.h" @@ -33,7 +33,7 @@ namespace xla { namespace tools { void RealMain(const string& mode, const string& path) { - SessionModule module; + HloSnapshot module; tensorflow::Env* env = tensorflow::Env::Default(); if (mode == "txt2bin") { TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 21ae8583d7cd3343230dcaff7dc17456e9e3e702..befb55453777dce30af89bcaad2ffe1647097576 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -17,7 +17,7 @@ limitations under the License. // // Dumps a graphviz URL for a snapshot computation to the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The GraphViz URL is placed into the log stderr, whereas computation @@ -30,11 +30,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,10 +48,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index b82f1c81c84b487c1661af5267b9123da97bb107..cfb8f37487d6499b803438a135be54524fcf17d2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -21,11 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); std::unique_ptr program_shape = client->GetComputationShape(computation).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1..5dd5150be339846d0775880931f615b92c5b08d8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -19,11 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,16 +38,16 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); if (compile) { std::unique_ptr program_shape = @@ -65,8 +63,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); @@ -74,13 +71,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { local_service->backend().platform()->Name().c_str(), module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { - const ComputationTracker& tracker = local_service->computation_tracker(); - UserComputation* user_computation = - tracker.Resolve(computation.handle()).ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); std::unique_ptr module = - tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + HloModule::CreateFromProto(computation.proto(), config) .ConsumeValueOrDie(); fprintf(stdout, "%s\n", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 51f90b07c66f7d839f587350726333b9dbe6a9f0..a5dce20456c6a2402f425ebb3d575d1bb625f839 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -48,10 +47,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index 0fa4b98d0a41a1e7c681bb2302da3b752315867b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index d8cedad65ea68ef86b94394a1accf2c08517c0b2..3a7917cf3043de8a77f189f011bdeb3e8d2ddf3c 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -17,13 +17,16 @@ limitations under the License. // // Replays computations and shows the results on the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // Computations that require arguments can be replayed using fake data by // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -36,14 +39,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -65,136 +68,179 @@ namespace { // fields. struct Options { string fake_infeed_shape; + bool generate_fake_infeed = false; bool use_fake_data = false; bool print_result = true; int num_runs = 1; - bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // -// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; -// otherwise, no infeed is performed. -template -StatusOr> ReplayComputation(const ModuleT& module, - Client* client, - const Options& opts) { - static_assert(std::is_same::value || - std::is_same::value, - "Proto must be in HloSnapshot or SessionModule format"); - TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); - - std::vector> arguments; +// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided. +// If generate_fake_infeed is true, the required infeed shape is derived from +// the computation and then used to provide a fake infeed shape. +// +// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, +// no infeed is performed. +StatusOr ReplayComputation(const HloSnapshot& module, + LocalClient* client, const Options& opts) { + XlaComputation computation(module.hlo().hlo_module()); + + // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our + // arguments. This is a bit involved, because we may have to convert from + // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our + // objects. + std::vector scoped_shaped_buffer_arguments; + std::vector> global_data_arguments; + std::vector argument_ptrs; if (opts.use_fake_data) { - arguments = MakeFakeArgumentsOrDie(computation, client); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + for (const auto& data : global_data_arguments) { + argument_ptrs.push_back( + client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) + .ValueOrDie()); + } } else { // use recorded data if available for (const auto& proto : module.arguments()) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(proto)); - TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(*literal)); - arguments.push_back(std::move(data)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer data, + client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + scoped_shaped_buffer_arguments.push_back(std::move(data)); + } + for (const auto& argument : scoped_shaped_buffer_arguments) { + argument_ptrs.push_back(&argument); } } + bool provide_infeed = false; + Shape infeed_shape; + if (!opts.fake_infeed_shape.empty()) { + StatusOr shape_status = + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + TF_CHECK_OK(shape_status.status()); + infeed_shape = std::move(shape_status).ValueOrDie(); + provide_infeed = true; + } else if (opts.generate_fake_infeed) { + for (const auto& comp : computation.proto().computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { + CHECK(!provide_infeed) + << "--generate_fake_infeed only works if the model has 0 or 1 " + "infeed ops, but this one has >= 2."; + provide_infeed = true; + infeed_shape = instruction.shape(); + LOG(INFO) << "Generating fake infeed shape for inferred shape: " + << ShapeUtil::HumanString(infeed_shape); + } + } + } + } // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape. + // concurrent infeed occur via the fake_infeed_shape, or when + // --generate_fake_infeed is passed and there exists an infeed operation in + // the HloSnapshot. tensorflow::gtl::optional pool; - - if (!opts.fake_infeed_shape.empty()) { + std::unique_ptr data; + if (provide_infeed) { + data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + } + auto transfer_infeed = [&data, client]() { + TF_CHECK_OK(client->TransferToInfeed(*data)); + }; + if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([opts, client]() { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - Shape shape = std::move(shape_status).ValueOrDie(); - StatusOr> data_status = MakeFakeLiteral(shape); - TF_CHECK_OK(data_status.status()); - std::unique_ptr data = std::move(data_status).ValueOrDie(); - while (true) { - TF_CHECK_OK(client->TransferToInfeed(*data)); - } + pool->Schedule([transfer_infeed]() { + // There may be several infeed buffers needed, however we don't know how + // many. If we proactively transfer too many infeed buffers, we may run + // out of memory. If we transfer too few infeed buffers, the program will + // hang. Therefore, we register a callback that is called when the infeed + // becomes empty, and in this callback we will transfer another fake + // infeed. + auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); + infeed_manager->RegisterOnEmptyCallback(transfer_infeed); + transfer_infeed(); }); } - std::vector execute_arguments; - execute_arguments.reserve(arguments.size()); - for (auto& argument : arguments) { - execute_arguments.push_back(argument.get()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); + } + std::unique_ptr executable = + client->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); + + // Do not attmept to run the executable, if num_runs is less than 1. + if (opts.num_runs < 1) { + return Cancelled("Cancelled after compilation since --num_runs < 1."); } // Run the computation num_runs times, and return the result from the last // execution. - std::unique_ptr result; + StreamExecutorMemoryAllocator allocator( + client->platform(), + {client->platform()->ExecutorForDevice(0).ValueOrDie()}); + tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { - execution_options.mutable_debug_options()->set_xla_hlo_profile(true); - } + ExecutableRunOptions run_options; + run_options.set_execution_profile(&profile); + run_options.set_allocator(&allocator); - if (opts.print_result) { - TF_ASSIGN_OR_RETURN( - result, client->ExecuteAndTransfer(computation, execute_arguments, - &execution_options, &profile)); - } else { - // If we're not printing the result, execute the computation but don't - // bother retrieving the result. This can be a significant speedup. - TF_RETURN_IF_ERROR(client - ->Execute(computation, execute_arguments, - &execution_options, &profile) - .status()); - } + TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - return std::move(result); + TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + client->ShapedBufferToLiteral(*result)); + return std::move(*result_literal); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { - Client* client = ClientLibrary::LocalClientOrDie(); +StatusOr ParseInputFile(const string& filename, + const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); - int exit_status = EXIT_SUCCESS; - for (char* arg : args) { - HloSnapshot snapshot; - auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (status.ok()) { - StatusOr> result_status = - ReplayComputation(snapshot, client, opts); - if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, - result_status.status().ToString().c_str()); - exit_status = EXIT_FAILURE; - continue; - } + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (snapshot.has_result()) { - std::unique_ptr literal = - Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); - } - } + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + int exit_status = EXIT_SUCCESS; + for (char* arg : args) { + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { continue; } - fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n", - arg, status.ToString().c_str()); - - SessionModule module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); - StatusOr> result_status = - ReplayComputation(module, client, opts); + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); + StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -202,16 +248,17 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { continue; } - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { + if (opts.print_result) { + Literal result = std::move(result_status).ValueOrDie(); + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), + ShapeUtil::HumanString(result.shape()).c_str(), + result.ToString().c_str()); + if (snapshot.has_result()) { std::unique_ptr literal = - Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), literal->ToString().c_str()); } } @@ -236,9 +283,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), - tensorflow::Flag( - "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, - "Pass --xla_hlo_profile the last time we run the computation."), + tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, + "Whether a fake infeed shape should be generated " + "derived from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 1f3340cbc6afa9bda8bf639d01b8185968f79a4d..4e53fafcc97ff53afc5713e7ed8ee5222fac316b 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -18,7 +18,7 @@ limitations under the License. // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command // line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The output format is: @@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -49,13 +48,14 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + auto computation = client->LoadSnapshot(module).ConsumeValueOrDie(); std::unique_ptr shape = client->GetComputationShape(computation).ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s\n", arg, + module.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(*shape).c_str()); } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index be33bd6dd1304fa8fc6e5aed1d4c4d65bf97e692..b23b968aae6ed8d6fb2b9f61ea5db2690eb5246c 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -218,6 +219,12 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); // Passed-varargs variant of the InvalidArgument factory above. Status InvalidArgumentV(const char* format, va_list args); +template +Status InvalidArgumentStrCat(Args&&... concat) { + return InvalidArgument( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + template Status UnimplementedStrCat(Args&&... concat) { return Unimplemented( @@ -486,6 +493,12 @@ bool c_is_sorted(const C& c) { return std::is_sorted(std::begin(c), std::end(c)); } +template +bool c_is_sorted(const C& c, Compare&& comp) { + return std::is_sorted(std::begin(c), std::end(c), + std::forward(comp)); +} + template auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); @@ -514,12 +527,47 @@ typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, std::forward(binary_op)); } +template +typename std::iterator_traits< + decltype(std::begin(std::declval()))>::difference_type +c_count_if(const C& c, Pred&& pred) { + return std::count_if(std::begin(c), std::end(c), std::forward(pred)); +} + +// Determines whether `value` is present in `c`. +template +bool c_linear_search(const C& c, T&& value) { + auto last = std::end(c); + return std::find(std::begin(c), last, std::forward(value)) != last; +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } +template +bool ArrayContains(tensorflow::gtl::ArraySlice c, const T& value) { + return c_find(c, value) != c.end(); +} + +template +void InsertAt(C* c, int64 index, Value&& value) { + c->insert(c->begin() + index, std::forward(value)); +} + +template +void EraseAt(C* c, int64 index) { + c->erase(c->begin() + index); +} + +template +std::vector InlinedVectorToVector( + const tensorflow::gtl::InlinedVector& inlined_vector) { + return std::vector(inlined_vector.begin(), inlined_vector.end()); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f619b8dc24038af64a27fc0565c74447ca9d09cf..6f07e4606bef015214f2c564515c8258a906205b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -17,7 +17,6 @@ syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; -import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -226,22 +225,6 @@ message ExecutionOptions { repeated DeviceHandle device_handles = 5; } -message SnapshotComputationRequest { - ComputationHandle computation = 1; -} - -message SnapshotComputationResponse { - SessionModule module = 1; -} - -message LoadComputationSnapshotRequest { - SessionModule module = 1; -} - -message LoadComputationSnapshotResponse { - ComputationHandle computation = 1; -} - message GetDeviceHandlesRequest { int64 device_count = 1; } @@ -300,11 +283,6 @@ message ResetDeviceRequest { message ResetDeviceResponse { } -message ComputationStatsRequest { - ComputationHandle computation = 1; - DebugOptions debug_options = 2; -} - message ComputationGraphStatsRequest { HloModuleProto computation = 1; DebugOptions debug_options = 2; @@ -314,14 +292,6 @@ message ComputationStatsResponse { ComputationStats stats = 1; } -message ComputationRequest { - string name = 1; -} - -message ComputationResponse { - ComputationHandle computation = 1; -} - message CreateChannelHandleRequest { } @@ -336,24 +306,6 @@ message UnregisterRequest { message UnregisterResponse { } -message SetReturnValueRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message SetReturnValueResponse { -} - -message ExecuteRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 5; -} - message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; @@ -362,10 +314,6 @@ message ExecuteGraphRequest { ExecutionOptions execution_options = 3; } -message ExecuteParallelRequest { - repeated ExecuteRequest requests = 1; -} - message ExecuteGraphParallelRequest { repeated ExecuteGraphRequest requests = 1; } @@ -379,21 +327,6 @@ message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } -message ExecuteAsyncRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; -} - -message ExecuteAsyncResponse { - // A handle to the execution launched asynchronously. - ExecutionHandle execution = 1; -} - message WaitForExecutionRequest { ExecutionHandle execution = 1; } @@ -403,31 +336,13 @@ message WaitForExecutionResponse { ExecutionProfile profile = 2; } -message IsConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - int64 num_parameters = 3; -} - -message IsConstantResponse { - bool is_constant = 1; -} - -message ComputeConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - Layout output_layout = 3; - repeated LiteralProto parameters = 4; -} - message ComputeConstantGraphRequest { HloModuleProto computation = 1; Layout output_layout = 2; } message ComputeConstantResponse { - // A LiteralProto is returned directly for this request, instead of a - // ComputationDataHandle. + // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } @@ -469,14 +384,6 @@ message LoadDataResponse { int64 nanoseconds = 5; } -message SpecializeRequest { - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; -} - -message SpecializeResponse { -} - message GetShapeRequest { GlobalDataHandle data = 1; } @@ -485,14 +392,6 @@ message GetShapeResponse { Shape shape = 1; } -message GetComputationShapeRequest { - ComputationHandle computation = 1; -} - -message GetComputationShapeResponse { - ProgramShape program_shape = 1; -} - message UnpackRequest { GlobalDataHandle data = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b895ac045c361b2336e0081eadf16334d49d3bee..c7472173a705b7a6e1bee2f5221f23db0a77991d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -66,11 +66,16 @@ enum PrimitiveType { // in the dimensions field. TUPLE = 13; - // An opaque type used for passing context specific data to a custom - // operation. + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. OPAQUE = 14; - // Next = 17 + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 18 } // Describes the value held inside padding elements. @@ -269,12 +274,9 @@ message ExecutionProfile { // for the input data transfer since the memory is initialized with the proper // values before the execution. int64 compute_and_transfer_time_ns = 5; -} -// Handle given to a user that represents a computation that the user builds up -// before execution. -message ComputationHandle { - int64 handle = 1; + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; } // Handle given to a user that represents an execution that the user launched @@ -290,13 +292,6 @@ message GlobalDataHandle { int64 handle = 1; } -// Handle given to a user that represents a data result in a computation. -// This is used to pass to subsequent computations that depends upon the data as -// an operand. -message ComputationDataHandle { - int64 handle = 1; -} - // Handle given to a user that represents a replicated virtual device. Each // replicated device represents N physical devices for execution where N is the // number of replicas. @@ -436,44 +431,6 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } -// Operation requests that are all collected as a tagged union with a oneof -// field in OpRequest. - -message ConstantRequest { - LiteralProto literal = 2; -} - -message GetTupleElementRequest { - ComputationDataHandle operand = 2; - int64 index = 3; -} - -message SliceRequest { - ComputationDataHandle operand = 2; - repeated int64 start_indices = 3; - repeated int64 limit_indices = 4; - repeated int64 strides = 5; -} - -message DynamicSliceRequest { - // Operand from which to slice at dynamic 'start_indices'. - ComputationDataHandle operand = 2; - // Dynamically computed 'start_indices' for slice operation. - ComputationDataHandle start_indices = 3; - // Slice sizes for each dimension (note that indices calculations are computed - // modulo dimension sizes to avoid out-of-bound array accesses). - repeated int64 slice_sizes = 4; -} - -message DynamicUpdateSliceRequest { - // Operand on which slice 'update' is to be applied. - ComputationDataHandle operand = 2; - // The slice update to apply to 'operand'. - ComputationDataHandle update = 3; - // Dynamically computed start indices for the update slice operation. - ComputationDataHandle start_indices = 4; -} - message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -511,13 +468,6 @@ message ConvolutionDimensionNumbers { // Next = 13 }; -message ConvolveRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kernel. - ConvolutionDimensionNumbers dimension_numbers = 5; -} - enum FftType { FFT = 0; // Forward FFT; complex in, complex out. IFFT = 1; // Inverse FFT; complex in, complex out. @@ -526,56 +476,6 @@ enum FftType { // fft_length real out } -message FftRequest { - FftType fft_type = 1; - repeated int64 fft_length = 2; // Multivalent for higher-order FFT. - ComputationDataHandle operand = 3; -} - -message InfeedRequest { - // The shape of the data returned by reading the device's infeed buffer. - Shape shape = 2; - - // Additional infeed configuration for the backend. - bytes config = 3; -} - -message OutfeedRequest { - // The shape of the data returned by reading the device's outfeed buffer. - Shape shape = 1; - - // Operand to the Outfeed. Supports tuple. - ComputationDataHandle operand = 2; - - // Backend-specific information for how to perform the outfeed. - bytes outfeed_config = 3; -} - -message CallRequest { - ComputationHandle to_apply = 2; - repeated ComputationDataHandle operands = 3; -} - -message CustomCallRequest { - string call_target_name = 2; - repeated ComputationDataHandle operands = 3; - Shape shape = 4; -} - -message HostComputeRequest { - // Operand to the HostCompute. Supports tuple. - repeated ComputationDataHandle operands = 1; - - // Name used to identify HostSend/Recv channels. - string channel_name = 2; - - // Cost estimate in nanoseconds. - int64 cost_estimate_ns = 3; - - // The shape of any data returned by host. - Shape shape = 4; -} - message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -587,297 +487,6 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; }; -message DotRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; - DotDimensionNumbers dimension_numbers = 4; -} - -message MapRequest { - repeated ComputationDataHandle operands = 2; - ComputationHandle to_apply = 3; - repeated ComputationDataHandle static_operands = 4; - // The dimensions over which to map. - // Example mapping a Dot operation along the batch dimension 0: - // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] - // Map({operand0, operand1}, Dot, {0}) - repeated int64 dimensions = 5; -} - -message ReduceRequest { - // Operand to the reduction. - ComputationDataHandle operand = 2; - - // Initial value for the reduction. This must be consistent with the result - // shape of to_apply. - ComputationDataHandle init_value = 3; - - // The dimensions to reduce over. - repeated int64 dimensions = 4; - - // The computation to apply in the reduction. - ComputationHandle to_apply = 5; -} - -message ReduceWindowRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle init_value = 3; - Window window = 4; - ComputationHandle to_apply = 5; -} - -message BatchNormTrainingRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - float epsilon = 4; - int64 feature_index = 5; -} - -message BatchNormInferenceRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - ComputationDataHandle mean = 4; - ComputationDataHandle variance = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message BatchNormGradRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle mean = 3; - ComputationDataHandle variance = 4; - ComputationDataHandle grad_output = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message CrossReplicaSumRequest { - ComputationDataHandle operand = 2; -} - -message SelectAndScatterRequest { - // Operand array on which the windows slide. - ComputationDataHandle operand = 2; - - // Source array for the data to scatter. - ComputationDataHandle source = 3; - - // Initial scalar value for each element in the output. - ComputationDataHandle init_value = 4; - - // Window configuration. - Window window = 5; - - // Binary function used to select an element from each window. - ComputationHandle select = 6; - - // Binary function used to combine each scattered value from source with the - // current output value at the selected location. - ComputationHandle scatter = 7; -} - -message ReverseRequest { - ComputationDataHandle operand = 2; - repeated int64 dimensions = 3; -} - -message BroadcastRequest { - ComputationDataHandle operand = 2; - repeated int64 broadcast_sizes = 3; -} - -message PadRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle padding_value = 3; - PaddingConfig padding_config = 4; -} - -message ReshapeRequest { - ComputationDataHandle operand = 2; - - // The dimension order for collapse (from fastest-changing to slowest). - repeated int64 dimensions = 3; - - // The new dimension sizes (from dimension 0 to n-1). - repeated int64 new_sizes = 4; -} - -message TransposeRequest { - ComputationDataHandle operand = 2; - - // The permutation of the operand's dimensions (in the range 0 to n-1). - repeated int64 dimensions = 3; -} - -message ParameterRequest { - Shape shape = 2; - int64 parameter = 3; - string name = 4; -} - -message GetLocalShapeRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message GetLocalShapeResponse { - Shape shape = 1; -} - -message TraceRequest { - string tag = 2; - ComputationDataHandle operand = 3; -} - -message ConvertRequest { - ComputationDataHandle operand = 2; - PrimitiveType new_element_type = 3; -} - -message ConcatenateRequest { - repeated ComputationDataHandle operands = 2; - // The dimension in which we concatenate; e.g. if you had dimension arrays of - // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. - // Attempting to concatenate those in dimension 1 would produce an error, as - // 4 != 5 (and there is no ragged array support). - int64 dimension = 3; -} - -message ConditionalRequest { - ComputationDataHandle predicate = 2; - ComputationDataHandle true_operand = 3; - ComputationHandle true_computation = 4; - ComputationDataHandle false_operand = 5; - ComputationHandle false_computation = 6; -} - -message WhileRequest { - ComputationHandle condition = 2; - ComputationHandle body = 3; - ComputationDataHandle init = 4; -} - -enum UnaryOperation { - UNOP_INVALID = 0; - - // Elementwise, logical negation on booleans and bitwise negation on ints. - UNOP_NOT = 1; - - // Elementwise, computes e^x. - UNOP_EXP = 2; - - // Elementwise, computes -x. - UNOP_NEGATE = 3; - - // Puts the elements in the operand into sorted order. - UNOP_SORT = 4; - - // Elementwise, computes tanh(x). - UNOP_TANH = 5; - - // Elementwise, computes the natural logarithm of x. - UNOP_LOG = 6; - - // Elementwise, computes the floor of x. - UNOP_FLOOR = 7; - - // Elementwise, computes the ceil of x. - UNOP_CEIL = 8; - - // Elementwise, computes the abs of x. - UNOP_ABS = 9; - - // Elementwise, computes the sign of x. - UNOP_SIGN = 10; - - // Elementwise, tests if values are finite (not NaN or inf) - UNOP_IS_FINITE = 11; - - // Elementwise, computes the cosine of x. - UNOP_COS = 12; - - // Elementwise, computes the sine of x. - UNOP_SIN = 13; - - // Elementwise, rounds x to nearest integral value, rounding half-way cases - // away from zero. - UNOP_ROUND_NEAREST_AFZ = 14; - - // Elementwise, extract real component of complex x. - UNOP_REAL = 15; - - // Elementwise, extract real component of complex x. - UNOP_IMAG = 16; - - // Elementwise, computes clz(x). - UNOP_CLZ = 17; - - // Elementwise, computes exp(x)-1. - UNOP_EXPM1 = 18; - - // Elementwise, computes log(x+1). - UNOP_LOG1P = 19; -} - -message UnaryOpRequest { - UnaryOperation unop = 2; - ComputationDataHandle operand = 3; -} - -enum BinaryOperation { - BINOP_INVALID = 0; - - // Arithmetic operations. - BINOP_ADD = 1; - BINOP_DIV = 2; - BINOP_MUL = 3; - BINOP_SUB = 4; - - // Comparison operators. - BINOP_EQ = 5; - BINOP_GE = 6; - BINOP_GT = 7; - BINOP_LE = 8; - BINOP_LT = 9; - BINOP_NE = 10; - - // Element-wise maximum. - BINOP_MAX = 14; - - // Element-wise minimum. - BINOP_MIN = 15; - - // Raises the left-hand-side to the right-hand-side power. - BINOP_POW = 16; - - // Remainder operation. - BINOP_REM = 17; - - // Element-wise, logical operators on booleans and bitwise operators on ints. - BINOP_AND = 18; - BINOP_OR = 19; - - BINOP_SHIFT_LEFT = 20; - BINOP_SHIFT_RIGHT_ARITHMETIC = 21; - BINOP_SHIFT_RIGHT_LOGICAL = 22; - - // Complex from real, imag. - BINOP_COMPLEX = 23; - - // Computes the 4-quadrant arctangent of the y, x input arguments. - BINOP_ATAN2 = 24; -} - -message BinaryOpRequest { - BinaryOperation binop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - repeated int64 broadcast_dimensions = 5; -} - enum RandomDistribution { RNG_INVALID = 0; @@ -892,67 +501,6 @@ enum RandomDistribution { // Next: 4 } -message RngRequest { - RandomDistribution distribution = 2; - repeated ComputationDataHandle parameter = 3; - Shape shape = 4; -} - -enum TernaryOperation { - TRIOP_INVALID = 0; - - // Given a predicate and two operands, selects operand0 if the predicate is - // true and operand1 if the predicate is false. - TRIOP_SELECT = 1; - - // Given a min, max and an operand returns the operand if between min and max, - // else returns min if operand is less than min or max if operand is greater - // than max. - TRIOP_CLAMP = 3; -} - -message TernaryOpRequest { - TernaryOperation triop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - ComputationDataHandle ehs = 5; -} - -enum VariadicOperation { - VAROP_INVALID = 0; - - // Creates a tuple from its operands. - VAROP_TUPLE = 1; -} - -message VariadicOpRequest { - VariadicOperation varop = 2; - repeated ComputationDataHandle operands = 3; -} - -message ReducePrecisionRequest { - ComputationDataHandle operand = 1; - int32 exponent_bits = 2; - int32 mantissa_bits = 3; -} - -message SendRequest { - ComputationDataHandle operand = 1; - ChannelHandle channel_handle = 2; -} - -message RecvRequest { - Shape shape = 1; - ChannelHandle channel_handle = 2; -} - -message GatherRequest { - ComputationDataHandle input = 1; - ComputationDataHandle gather_indices = 2; - GatherDimensionNumbers dimension_numbers = 3; - repeated int64 window_bounds = 4; -} - message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -983,59 +531,3 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } - -message OpRequest { - ComputationHandle computation = 1; - OpMetadata metadata = 33; - OpSharding sharding = 40; - - oneof op { - BinaryOpRequest binary_op_request = 2; - BroadcastRequest broadcast_request = 3; - CallRequest call_request = 4; - ConcatenateRequest concatenate_request = 5; - ConstantRequest constant_request = 6; - ConvertRequest convert_request = 7; - ConvolveRequest convolve_request = 8; - CrossReplicaSumRequest cross_replica_sum_request = 9; - CustomCallRequest custom_call_request = 10; - DotRequest dot_request = 43; - DynamicSliceRequest dynamic_slice_request = 11; - DynamicUpdateSliceRequest dynamic_update_slice_request = 12; - GetTupleElementRequest get_tuple_element_request = 13; - InfeedRequest infeed_request = 14; - MapRequest map_request = 15; - PadRequest pad_request = 16; - ParameterRequest parameter_request = 17; - ReducePrecisionRequest reduce_precision_request = 36; - ReduceRequest reduce_request = 18; - ReduceWindowRequest reduce_window_request = 19; - ReshapeRequest reshape_request = 20; - ReverseRequest reverse_request = 21; - RngRequest rng_request = 22; - SelectAndScatterRequest select_and_scatter_request = 23; - SliceRequest slice_request = 24; - TernaryOpRequest ternary_op_request = 25; - TraceRequest trace_request = 26; - TransposeRequest transpose_request = 34; - UnaryOpRequest unary_op_request = 27; - VariadicOpRequest variadic_op_request = 28; - WhileRequest while_request = 29; - SendRequest send_request = 30; - RecvRequest recv_request = 31; - OutfeedRequest outfeed_request = 32; - BatchNormTrainingRequest batch_norm_training_request = 35; - BatchNormGradRequest batch_norm_grad_request = 37; - BatchNormInferenceRequest batch_norm_inference_request = 38; - FftRequest fft_request = 41; - ConvertRequest bitcast_convert_request = 42; - ConditionalRequest conditional_request = 44; - HostComputeRequest host_compute_request = 45; - GatherRequest gather_request = 46; - // Next: 47 - } -} - -message OpResponse { - ComputationDataHandle output = 1; -} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 0f9c80404ad33c39ae783e0bfa3cfb26e342fe3d..c039624daa65174b0550ff6a304947e37cf58e1d 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -9,6 +9,7 @@ load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") py_library( name = "contrib_py", @@ -26,25 +27,24 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/checkpoint/python:checkpoint", - "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", - "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/gan", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/grid_rnn:grid_rnn_py", @@ -83,7 +83,6 @@ py_library( "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", @@ -114,6 +113,7 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", + "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", ]) + select({ @@ -122,7 +122,17 @@ py_library( "//tensorflow/contrib/kafka", ], "//conditions:default": [], - }) + if_not_windows([ + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis", + ], + "//conditions:default": [], + }) + if_not_windows_cuda([ + "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols + ]) + if_not_windows([ + "//tensorflow/contrib/bigtable", # depends on bigtable + "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), @@ -153,6 +163,12 @@ cc_library( "//tensorflow/contrib/kafka:dataset_kernels", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis:dataset_kernels", + ], + "//conditions:default": [], }), ) @@ -182,5 +198,11 @@ cc_library( "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + "//conditions:default": [], }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9aad772f0acd941d50d6ba238d345616195a6939..ded05da71877566781a5fb6d0c21e1c8d43de9ed 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -25,7 +25,8 @@ import os from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint -from tensorflow.contrib import cloud +if os.name != "nt": + from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 62d1b1cf079d04d50e4899cfd9ba1d405ee1efb9..881808a98bfd688c2efaa8beb5b8f11a2527fee8 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -11,6 +11,16 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_py_test") +py_library( + name = "all_reduce_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":all_reduce", + "//tensorflow/python:util", + ], +) + py_library( name = "all_reduce", srcs = [ diff --git a/tensorflow/contrib/all_reduce/__init__.py b/tensorflow/contrib/all_reduce/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9824f4cfbf83d9b001a58cafe582226e96c076f --- /dev/null +++ b/tensorflow/contrib/all_reduce/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All-reduce implementations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.all_reduce.python.all_reduce import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'build_ring_all_reduce', + 'build_recursive_hd_all_reduce', + 'build_shuffle_all_reduce', + 'build_nccl_all_reduce', + 'build_nccl_then_ring', + 'build_nccl_then_recursive_hd', + 'build_nccl_then_shuffle', + 'build_shuffle_then_ring', + 'build_shuffle_then_shuffle' +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index c10179ba8b290b6209f5567d6323df4bcf711585..f0b1c92cf7e4b760381da38febd9682ce2a4f27c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -1,6 +1,8 @@ # Description: # JNI-based Java inference interface for TensorFlow. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 707853b59befc2625145ad96952fbf9f66d62b43..30de7b59af79cb36ee266a15bb6e668c2e3f628a 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/android/jni/run_stats_jni.h" #include + #include #include "tensorflow/core/protobuf/config.pb.h" @@ -73,7 +74,8 @@ JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz, StatSummarizer* s = requireHandle(env, handle); if (s == nullptr) return nullptr; std::stringstream ret; - ret << s->GetStatsByMetric("Top 10 CPU", StatSummarizer::BY_TIME, 10) + ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME, + 10) << s->GetStatsByNodeType() << s->ShortSummary(); return env->NewStringUTF(ret.str().c_str()); } diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index 30dd846893c30b9205972bd5216cc1871ab03d76..ad700ac4a0342e2a7bc07a6ecf6710cea892e296 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -23,9 +23,9 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/utils", - "@gast_archive//:gast", - "@six_archive//:six", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..06fb7b03d5dbbfd2fcb6d6a2ecfe5c817f94a469 --- /dev/null +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -0,0 +1,95 @@ +# How to contribute + +We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. + +## TensorFlow Code of Conduct +Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult [GitHub +Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +After a pull request is approved, we merge it. Note our merging process differs +from GitHub in that we pull and submit the change into an internal version +control system. This system automatically pushes a git commit to the GitHub +repository (with credit to the original author) and closes the pull request. + +## Style + +See the [AutoGraph style guide](STYLE_GUIDE.md). + +## Unit tests + +Please include unit tests when contributing new features ([example here](converters/continue_statements_test.py)), as they help to a) prove that your code works correctly, and b) guard against future breaking +changes to lower the maintenance cost. +It's also helpful to check that any +changes you propose do not break existing unit tests. You can run tests using the command, + +```shell +bazel test --config=opt --copt=-O3 --copt=-march=native \ + //tensorflow/contrib/autograph/... +``` + +from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) + +## Developer info + +### Module structure + +The graph below describes the dependencies between AutoGraph modules (not to be mistaken with the directory structure for these modules, which is flat): + +```dot +digraph d_modules { + autograph [style=filled]; + converters; + core; + impl; + lang; + operators; + + autograph -> impl + autograph -> lang + + impl -> converters + impl -> core + impl -> operators + + lang -> operators + + converters -> core + converters -> lang +} +``` + +`autograph` is the sole user-visible module. + +A short description of the modules: + + * `autograph`: the main module imported by the user and by the generated code; only contains declarations + * `impl`: high level code and the implementation of the api frontend + * `core`: base classes for the AutoGraph source code transformation logic; see in particular `converter.py` + * `lang`: special user-visible functions that serve as extensions to the Python language + * `converters`: collection of source code transformation modules specialized for particular AutoGraph features + * `operators`: collection of operators that AutoGraph overloads; these correspond to Python operators as well as Python syntactic structures, like control flow + +There are two additional modules, `pyct` and `utils`. These are independent of AutoGraph: + + * `pyct`: a general purpose Python source code transformation library + * `utils`: the kitchen sync; deprecated + +Note: we have a long term plan to factor out an implementation of `impl` and `converters` that is independent of autograph, into a general purpose Python operator overloading library. diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/contrib/autograph/LIMITATIONS.md new file mode 100644 index 0000000000000000000000000000000000000000..d8b1cb7616ac348981bf2b69d6e2fd8d8a6e6b78 --- /dev/null +++ b/tensorflow/contrib/autograph/LIMITATIONS.md @@ -0,0 +1,50 @@ +# Capabilities and Limitations + +TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`. + +Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support. + +# Python Language Support Status + +Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved. + + Construct | Supported now? | Plan to support? | Notes + :--------- | :--------------: | :----------------: | :----- +If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error. +For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations. +While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations. +Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests. +Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested. +Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`. +Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so. +Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists. +Static function calls | Yes | | Non-recursive function calls +Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion. +Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant. +Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html). +List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation. +Function variables | Yes | | e.g. `f_new = f_orig; f_new()` +Lambda functions | No | Yes | Planned feature. +Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods. +Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported. +Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case. +Dynamic code / exec | No | | +Reflection | No | | +Try / Except | No | No | No current sane TF equivalent. +Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code. +Functions with side effects | Some | | Side effects are allowed, under certain circumstances. +Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples. +List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority. +Custom context managers | No | Yes | Currently low priority. Left unconverted currently. +Generators | No | Maybe | Could be achievable using queues; very low priority. +Assertions | Yes | | As `tf.Assert` +Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in. +Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority. +Async | No | No | + +## Extra capabilities + + - We liberally add name scopes to generated functions + - Operations get decent default names everywhere (planned) + - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially. + diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 0ba99c396fc1c8ee1e12fbb4fe0293ee52ed9bc9..7e26f4711851138c1834f881621ebfa227a85821 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,10 +1,10 @@ # AutoGraph -IMPORTANT: AutoGraph is pre-alpha, under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! +IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. -With AutoGraph, you can write [Eager style](https://www.tensorflow.org/programmers_guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. +With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. For example, this Python function: @@ -120,3 +120,15 @@ You can use the functional API to inspect the generated code as well: print(ag.to_code(f)) # Output: ``` + +## Filing bugs and feature requests + +### Reporting a bug + + - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you. + +### Requesting a feature + +If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there. diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..7e6b0cc27dd1cf8c0f459a0a34f98092728342a2 --- /dev/null +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -0,0 +1,85 @@ +# AutoGraph Style Guide + +This page contains style decisions that developers should follow when +contributing code to AutoGraph. + +## TensorFlow Style + +Follow the [TensorFlow style +guide](https://www.tensorflow.org/community/style_guide), the [documentation +guide](https://www.tensorflow.org/community/documentation) and the +[Google Python style guide](https://google.github.io/styleguide/pyguide.html). + +Naming conventions: + +1. The name is TensorFlow, not Tensorflow. +2. The name is AutoGraph, not Autograph. + +## AutoGraph Style + +Below are AutoGraph-specific conventions. In the event of conflict, +it supercedes all previous conventions. + +1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/] + notation to describe the type for args, return values and attributes. + + Example: + + ``` + Args: + foo: Dict[str, List[int]], a dictionary of sorts + ``` + +2. __Citations in Docstrings.__ Write a `#### References` subsection at the + bottom of any docstring with citations. Use ICLR’s bibliography style to + write references; for example, order entries by the first author's last + name. Add a link to the paper if the publication is open source (ideally, + arXiv). + + Write in-paragraph citations in general, e.g., [(Tran and Blei, 2018)][1]. + Write in-text citations when the citation is a noun, e.g., [Tran and Blei + (2018)][1]. Write citations with more than two authors using et al., e.g., + [(Tran et al., 2018)][1]. Separate multiple citations with semicolon, e.g., + ([Tran and Blei, 2018][1]; [Gelman and Rubin, 1992][2]). + + Examples: + + ```none + #### References + + # technical report + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + + # journal + [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation + Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. + + # arXiv preprint + # use "et al." for papers with too many authors to maintain + [3]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 + + # conference + [4]: Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, and Roger Grosse. + Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches. In _International Conference on Learning + Representations_, 2018. + https://arxiv.org/abs/1803.04386 + ``` + +3. Avoid LaTeX in docstrings. + + * It is not rendered in many (if not most) editors and can be hard to read + for both LaTeX experts and non-experts. + +4. Write docstring and comment math using ASCII friendly notation; python using + operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, + `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: + x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. + + * The more we stick to python style, the more someone can + copy/paste/execute. + * Python style is usually easier to read as ASCII. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 3386c4eca4b93e850f6fe3c6239d29c61d787ece..361cf2d77c7e46912d5bff5881df2ffa897c5179 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -23,18 +23,37 @@ from __future__ import print_function # TODO(mdan): Bring only the relevant symbols to the top level. from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.lang.directives import set_element_type +from tensorflow.contrib.autograph.lang.directives import set_loop_options +from tensorflow.contrib.autograph.lang.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', - 'to_code', 'to_graph', 'AutographParseError' + # Main API + 'RunMode', + 'convert', + 'converted_call', + 'do_not_convert', + 'to_code', + 'to_graph', + # Overloaded operators + 'operators', + # Python language "extensions" + 'set_element_type', + 'set_loop_options', + 'stack', + # Exceptions + 'AutographParseError', + # Utilities: to be removed + 'utils', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 8f9bffa55e44e4942bb3845945b3d440c7957cc9..b2e2e27673dafe290cef40a9fe0a834bfe1ea61f 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -31,29 +31,17 @@ py_library( "name_scopes.py", "side_effect_guards.py", "single_return.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "@gast_archive//:gast", - ], -) - -py_library( - name = "test_lib", - srcs = [ - "converter_test_base.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":converters", - "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/core", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", - "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:util", "@gast_archive//:gast", - "@six_archive//:six", ], ) @@ -63,7 +51,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -73,7 +62,8 @@ py_test( srcs = ["break_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -84,7 +74,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -96,7 +87,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], @@ -107,7 +99,8 @@ py_test( srcs = ["continue_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -117,7 +110,8 @@ py_test( srcs = ["control_flow_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -126,8 +120,13 @@ py_test( name = "decorators_test", srcs = ["decorators_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", + ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -136,7 +135,8 @@ py_test( name = "name_scopes_test", srcs = ["name_scopes_test.py"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -147,7 +147,8 @@ py_test( srcs = ["list_comprehension_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -157,7 +158,8 @@ py_test( srcs = ["lists_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -167,7 +169,8 @@ py_test( srcs = ["logical_expressions_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -182,7 +185,8 @@ py_test( "notap", ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -192,7 +196,8 @@ py_test( srcs = ["single_return_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -203,7 +208,20 @@ py_test( srcs = ["ifexp_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index 3b0db677ce5e417e7afea8d8fe4121a0352bb6d7..e664a403a5fb800e7d0dddfa5695330927aaf4e0 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class AssertsTransformer(transformer.Base): +class AssertsTransformer(converter.Base): """Transforms Print nodes to Call so they can be handled as functions.""" def visit_Assert(self, node): @@ -45,5 +45,5 @@ class AssertsTransformer(transformer.Base): raise NotImplementedError('can only convert string messages for now.') -def transform(node, context): - return AssertsTransformer(context).visit(node) +def transform(node, ctx): + return AssertsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index cc913febe8d0f411588af69b87ec52ce58f4469c..2cd0e626bc4552bd40bc94b890fdcc7efcafb3f3 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -21,11 +21,11 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.converters import asserts -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class AssertsTest(converter_test_base.TestCase): +class AssertsTest(converter_testing.TestCase): def test_transform(self): diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 35877224b87c1abda1a270be4869e9dcfd0cf97c..a990e359a2a25a57ee2a4f8a866350633f3b9ea8 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -18,11 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -31,17 +29,9 @@ BREAK_USED = 'break_used' CONTROL_VAR_NAME = 'control_var_name' -class BreakStatementTransformer(transformer.Base): +class BreakStatementTransformer(converter.Base): """Canonicalizes break statements into additional conditionals.""" - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) - nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() - return nodes, break_used - def visit_Break(self, node): self.set_local(BREAK_USED, True) var_name = self.get_local(CONTROL_VAR_NAME) @@ -54,13 +44,9 @@ class BreakStatementTransformer(transformer.Base): def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" - - # If we don't have statements that immediately depend on the break - # we still need to make sure that the break variable remains - # used, in case the break becomes useful in later stages of transformation. - # Not having this broke the break_in_inner_loop test. if not block: - block = [gast.Pass()] + return block + template = """ if not var_name: block @@ -71,9 +57,17 @@ class BreakStatementTransformer(transformer.Base): block=block) return node + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used + def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -81,6 +75,10 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + template = """ var_name = False while test and not var_name: @@ -88,20 +86,18 @@ class BreakStatementTransformer(transformer.Base): else: orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, - orelse=self._guard_if_present(node.orelse, break_var)) + orelse=guarded_orelse) return node def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -110,23 +106,36 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: - node.orelse = self._guard_if_present(node.orelse, break_var) + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False - for_stmt + for target in iter_: + (var_name,) + body + else: + orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, - for_stmt=node) - extra_test = templates.replace_as_expression( - 'not var_name', var_name=break_var) + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + anno.setanno(node[1], 'extra_test', extra_test) return node -def transform(node, context): - return BreakStatementTransformer(context).visit(node) +def transform(node, ctx): + return BreakStatementTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index 1af59e9b5260fe0d3a3ef72c7a003dc451e230f3..dcff1c54c2f9300d58d217517e108d634ae85fb4 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class BreakCanonicalizationTest(converter_test_base.TestCase): +class BreakCanonicalizationTest(converter_testing.TestCase): def test_basic_while(self): diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 317711a866f731de1b497295a2752dee0eb544f5..b26c52294c2d1c11ce14d8a2903f7f88079a703f 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,20 +20,17 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class BuiltinFunctionTransformer(transformer.Base): +class BuiltinFunctionTransformer(converter.Base): """Handles builtin functions. This transformer only covers functions that are translated into a TF equivalent, like `len`. """ - def __init__(self, context): - super(BuiltinFunctionTransformer, self).__init__(context) - def _convert_builtin(self, node): template = """ ag__.utils.dynamic_builtin(func, args) @@ -51,7 +48,7 @@ class BuiltinFunctionTransformer(transformer.Base): # TODO(mdan): This won't work if the function was hidden. # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange')): + node.func.id in ('len', 'range', 'xrange', 'float', 'int')): return self._convert_builtin(node) # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': @@ -71,5 +68,5 @@ class BuiltinFunctionTransformer(transformer.Base): return self.visit(function_call) -def transform(node, context): - return BuiltinFunctionTransformer(context).visit(node) +def transform(node, ctx): + return BuiltinFunctionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index 30272409df322560b04ba75b3e1cb6f9ad5ff0af..e9000e518ce14f9e0ea486d5b3e374439b8c78ca 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -23,13 +23,13 @@ import sys import six from tensorflow.contrib.autograph.converters import builtin_functions -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class BuiltinFunctionsTest(converter_test_base.TestCase): +class BuiltinFunctionsTest(converter_testing.TestCase): def test_len(self): diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index 554f0471d44d54194c45c3855b1483796ae65a6a..a36b3d77a9233daed864c616306b2ad27f582a38 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -26,12 +26,12 @@ from collections import namedtuple import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -45,6 +45,9 @@ KNOWN_NUMPY_FUNCTIONS = { } +# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer. + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -76,20 +79,18 @@ class FunctionNamer(object): raise NotImplementedError() -class CallTreeTransformer(transformer.Base): - """Transforms the call tree by renaming transformed symbols.""" +# TODO(mdan): Rename to CallsTransformer. + - def __init__(self, context, uncompiled_modules, nocompile_decorators): - super(CallTreeTransformer, self).__init__(context) - self.uncompiled_modules = uncompiled_modules - self.nocompile_decorators = nocompile_decorators +class CallTreeTransformer(converter.Base): + """Transforms the call tree by renaming transformed symbols.""" def _resolve_name(self, node): """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): - return self.context.namespace.get(node.id) + return self.ctx.namespace.get(node.id) if isinstance(node, gast.Attribute): parent = self._resolve_name(node.value) if parent is not None: @@ -119,12 +120,12 @@ class CallTreeTransformer(transformer.Base): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] - for mod in self.uncompiled_modules: + for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): - if fqn[:i] in self.uncompiled_modules: + if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations @@ -140,7 +141,7 @@ class CallTreeTransformer(transformer.Base): if hasattr(target_entity, '__pyct_is_compile_decorator'): return False - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: return False # Inspect the target function decorators. If any include a @convert @@ -159,7 +160,7 @@ class CallTreeTransformer(transformer.Base): for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and - decorator_fn in self.nocompile_decorators): + decorator_fn in self.ctx.program.autograph_decorators): return False return True @@ -174,7 +175,7 @@ class CallTreeTransformer(transformer.Base): return node if anno.hasanno(node, 'is_constructor'): - new_name = self.context.namer.compiled_class_name( + new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: @@ -183,7 +184,7 @@ class CallTreeTransformer(transformer.Base): else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) - new_name, do_rename = self.context.namer.compiled_function_name( + new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: @@ -264,15 +265,16 @@ class CallTreeTransformer(transformer.Base): return node def visit_Call(self, node): - # If the function is wrapped by one of the marker decorators, + # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' - 'A decorator needs at least an argument.') + 'A decorator needs at least one positional argument.' % + target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) @@ -292,34 +294,37 @@ class CallTreeTransformer(transformer.Base): raise NotImplementedError( 'py_func with return values (unknown function)') else: + if anno.hasanno(node.func, anno.Basic.QN): + # Special-case a few builtins that otherwise go undetected. This + # normally doesn't pose a problem, but the dict built-in doesn't + # work with inspect.getargspec which is required for dynamic functions. + # Note: expecting this is resilient to aliasing (e.g. + # dict = an_evil_dict), because in those cases the regular mechanisms + # process a simple user function. + qn = anno.getanno(node.func, anno.Basic.QN) + # Add items to this list as needed. + if str(qn) in ('dict',): + return node + if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. - pass - elif self.context.recursive: + return node + + if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) - else: - # Unresolved functions are allowed in non-recursive mode. - pass return node -def transform(node, context, uncompiled_modules, nocompile_decorators): +def transform(node, ctx): """Transform function call to the compiled counterparts. Args: - node: AST to transform. - context: An EntityContext object. - uncompiled_modules: set of string tuples, each tuple represents the fully - qualified name of a package containing functions that will not be - compiled. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST + ctx: EntityContext Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ - t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators) - node = t.visit(node) - return node + return CallTreeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py index 303dd54a4ee49de27fad0c5cdc2d6274abfe0fa8..27d8281b856f505062ceacc8ad50c8cbc2ce6c81 100644 --- a/tensorflow/contrib/autograph/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph.converters import call_trees -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CallTreesTest(converter_test_base.TestCase): +class CallTreesTest(converter_testing.TestCase): def test_basic(self): @@ -43,7 +43,7 @@ class CallTreesTest(converter_test_base.TestCase): return test_fn_1(a) + 1 node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 @@ -60,7 +60,7 @@ class CallTreesTest(converter_test_base.TestCase): return f() + 3 node = self.parse_and_analyze(test_fn_2, {}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: # 10 = 7 (from the mock) + 3 (from test_fn_2) @@ -78,9 +78,9 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, - namer=converter_test_base.FakeNoRenameNamer(), + namer=converter_testing.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: tc = TestClass() @@ -92,7 +92,7 @@ class CallTreesTest(converter_test_base.TestCase): setattr(a, 'foo', 'bar') node = self.parse_and_analyze(test_fn, {'setattr': setattr}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: with self.test_session() as sess: @@ -115,7 +115,7 @@ class CallTreesTest(converter_test_base.TestCase): return np.random.binomial(2, 0.5) node = self.parse_and_analyze(test_fn, {'np': np}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node, dtypes.int64) as result: result.np = np @@ -130,13 +130,13 @@ class CallTreesTest(converter_test_base.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self.parse_and_analyze(test_fn, { - 'math_ops': math_ops, - 'constant_op': constant_op - }) - node = call_trees.transform(node, self.ctx, - set(((math_ops.__name__,), - (constant_op.__name__,))), ()) + node = self.parse_and_analyze( + test_fn, { + 'math_ops': math_ops, + 'constant_op': constant_op + }, + arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: result.math_ops = math_ops diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 4299a8a9d59715d032222c47794bbb4393f34ce6..958bde0a58764e705c35ab73ce879b2c11ce7cdc 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,110 +18,122 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(transformer.Base): - """Canonicalizes continue statements into additional conditionals.""" +# Tags for local state. +CONTROL_VAR_NAME = 'control_var_name' +CONTINUE_USED = 'continue_used' +GUARD_CREATED = 'guard_created' +CREATE_GUARD_NEXT = 'create_guard_next' - def __init__(self, context): - super(ContinueCanonicalizationTransformer, self).__init__(context) - # This is a stack structure, to correctly process nested loops. - self.continuation_uses = [] - def _create_continuation_check(self): - template = """ - if not var_name: - pass - """ - cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) - cond.body = [] - return cond +class ContinueCanonicalizationTransformer(converter.Base): + """Canonicalizes continue statements into additional conditionals.""" - def _create_continuation_trigger(self): + def visit_Continue(self, node): + self.set_local(CONTINUE_USED, True) template = """ var_name = True """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _create_continuation_init(self): - template = """ - var_name = False - """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _visit_and_reindent_if_necessary(self, nodes): - reorganized_nodes = [] - current_dest = reorganized_nodes - continue_used_in_block = False - for i, n in enumerate(nodes): - # TODO(mdan): This could be optimized if control structures are simple. - self.continuation_uses[-1][0] = False - n = self.visit(n) - current_dest.append(n) - if self.continuation_uses[-1][0]: - continue_used_in_block = True - if i < len(nodes) - 1: # Last statement in block needs no protection. - cond = self._create_continuation_check() - current_dest.append(cond) - current_dest = cond.body - self.continuation_uses[-1][0] = continue_used_in_block - return reorganized_nodes - - def _process_loop_block(self, block, scope): - cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) - self.continuation_uses.append([False, cont_var]) - block = self._visit_and_reindent_if_necessary(block) - if self.continuation_uses[-1][0]: - block.insert(0, self._create_continuation_init()) - self.continuation_uses.pop() - return block + return templates.replace( + template, var_name=self.get_local(CONTROL_VAR_NAME)) + + def _postprocess_statement(self, node): + # Example of how the state machine below works: + # + # 1| stmt # State: CONTINUE_USED = False + # | # Action: none + # 2| if cond: + # 3| continue # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = False + # | # Action: set CREATE_GUARD_NEXT = True + # 4| stmt # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = True + # | # Action: create `if not continue_used`, + # | # set GUARD_CREATED = True + # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True + # | # Action: none (will be wrapped under previously + # | # created if node) + + if self.get_local(CONTINUE_USED, False): + if self.get_local(GUARD_CREATED, False): + return node, None + + elif not self.get_local(CREATE_GUARD_NEXT, False): + self.set_local(CREATE_GUARD_NEXT, True) + return node, None + + else: + self.set_local(GUARD_CREATED, True) + template = """ + if not var_name: + original_node + """ + cond, = templates.replace( + template, + var_name=self.get_local(CONTROL_VAR_NAME), + original_node=node) + return cond, cond.body + return node, None + + def _visit_loop_body(self, node, nodes): + self.enter_local_scope() + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) + self.set_local(CONTROL_VAR_NAME, continue_var) + + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + + if self.get_local(CONTINUE_USED, False): + template = """ + var_name = False + """ + control_var_init = templates.replace(template, var_name=continue_var) + nodes = control_var_init + nodes + + self.exit_local_scope() + return nodes + + def _visit_non_loop_body(self, nodes): + self.enter_local_scope(inherit=(CONTROL_VAR_NAME,)) + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + continue_used = self.get_local(CONTINUE_USED, False) + self.exit_local_scope(keep=(CONTINUE_USED,)) + return nodes, continue_used def visit_While(self, node): - self.generic_visit(node.test) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.test = self.visit(node.test) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_For(self, node): - self.generic_visit(node.target) - self.generic_visit(node.iter) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.target = self.generic_visit(node.target) + node.iter = self.generic_visit(node.iter) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_If(self, node): - if self.continuation_uses: - self.generic_visit(node.test) - node.body = self._visit_and_reindent_if_necessary(node.body) - continue_used_in_body = self.continuation_uses[-1][0] - node.orelse = self._visit_and_reindent_if_necessary(node.orelse) - self.continuation_uses[-1][0] = ( - continue_used_in_body or self.continuation_uses[-1][0]) - else: - node = self.generic_visit(node) + node.test = self.generic_visit(node.test) + node.body, continue_used_body = self._visit_non_loop_body(node.body) + node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse) + self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse) return node - def visit_Continue(self, node): - self.continuation_uses[-1][0] = True - return self._create_continuation_trigger() - - def visit_Break(self, node): - assert False, 'break statement should be desugared at this point' + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body, _ = self._visit_non_loop_body(node.body) + return node -def transform(node, namer): - return ContinueCanonicalizationTransformer(namer).visit(node) +def transform(node, ctx): + return ContinueCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index bcbb316d7459aa5a25bb0bd128cd6e359a393288..2ce1837972c50bbc4921487a290f5cb2f782b5f3 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ContinueCanonicalizationTest(converter_test_base.TestCase): +class ContinueCanonicalizationTest(converter_testing.TestCase): def test_basic_continue(self): diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index d7ddbe8a04f64848d6ec21155d8d85f60e19d276..f4a87106279d5658ecaa90a577cbe741711ba22e 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -45,9 +45,8 @@ class SymbolNamer(object): raise NotImplementedError() -class ControlFlowTransformer(transformer.Base): +class ControlFlowTransformer(converter.Base): """Transforms control flow structures like loops an conditionals.""" - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -141,10 +140,10 @@ class ControlFlowTransformer(transformer.Base): aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) @@ -165,9 +164,8 @@ class ControlFlowTransformer(transformer.Base): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.context.namer.new_symbol('if_false', - orelse_scope.referenced) + body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): @@ -235,7 +233,7 @@ class ControlFlowTransformer(transformer.Base): raise ValueError('cannot convert while loop: no outputs') state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -267,11 +265,9 @@ class ControlFlowTransformer(transformer.Base): state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=self.context.namer.new_symbol('loop_test', - body_scope.referenced), + test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, - body_name=self.context.namer.new_symbol('loop_body', - body_scope.referenced), + body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) @@ -288,7 +284,7 @@ class ControlFlowTransformer(transformer.Base): state = list(body_closure) state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -326,17 +322,16 @@ class ControlFlowTransformer(transformer.Base): state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, - extra_test_name=self.context.namer.new_symbol('extra_test', - all_referenced), + extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, - body_name=self.context.namer.new_symbol('loop_body', all_referenced), + body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node -def transform(node, context): - cfg.run_analyses(node, cfg.Liveness(context)) - cfg.run_analyses(node, cfg.Defined(context)) - node = ControlFlowTransformer(context).visit(node) +def transform(node, ctx): + cfg.run_analyses(node, cfg.Liveness(ctx.info)) + cfg.run_analyses(node, cfg.Defined(ctx.info)) + node = ControlFlowTransformer(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 1a863590f97add9bfa587d1142a09ae26a9fdb44..735eb92a0dd06ee7fd621b92b1a8f894e09cee4a 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import control_flow -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -27,7 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test -class ControlFlowTest(converter_test_base.TestCase): +class ControlFlowTest(converter_testing.TestCase): def test_simple_while(self): @@ -42,7 +42,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((10, 5, 5), sess.run(result.test_fn(constant_op.constant(5)))) @@ -57,7 +57,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) @@ -75,7 +75,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) @@ -92,7 +92,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py index 92445f31746cf94856ea43893f99a2ba60355fb5..3471bd11d6073f57a2703b438df95a60f19e8e0c 100644 --- a/tensorflow/contrib/autograph/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,19 +24,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.python.util import tf_inspect -class DecoratorsTransformer(gast.NodeTransformer): +class DecoratorsTransformer(converter.Base): """Converts or removes decorators.""" - def __init__(self, remove_decorators): - self.remove_decorators = remove_decorators - self.additional_dependencies = set() - - # pylint:disable=invalid-name - def visit_FunctionDef(self, node): self.generic_visit(node) kept_decorators = [] @@ -58,31 +53,53 @@ class DecoratorsTransformer(gast.NodeTransformer): # This is currently verified by tests. continue - if not anno.hasanno(dec_func, 'live_val'): - raise ValueError( - 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) - + original_dec = anno.getanno(dec_func, anno.Basic.QN) dec_value = anno.getanno(dec_func, 'live_val') - if dec_value not in self.remove_decorators: - kept_decorators.append((dec, dec_value)) - for _, dec_value in kept_decorators: - if dec_value.__module__ == '__main__': + if dec_value in self.ctx.program.autograph_decorators: + # AutoGraph decorators do not need to be preserved. + continue + + # When using foo.bar.baz, we only really need to grab foo and import + # that. + dec_support_node = dec_func + while isinstance(dec_support_node, gast.Attribute): + dec_support_node = dec_support_node.value + + if not anno.hasanno(dec_support_node, 'live_val'): raise ValueError( - 'decorator "%s" was not allowed because it is declared ' - 'in the module "%s". To fix this, declare it in a separate ' - 'module that we can import it from.' % (dec_value, - dec_value.__module__)) + 'could not resolve symbol "%s" when looking up decorator "%s"' % + (anno.getanno(dec_support_node, anno.Basic.QN), original_dec)) + + dec_support = anno.getanno(dec_support_node, 'live_val') + # The tuple contains: + # * the AST that represents the decorator + # * the entity supporting the decorator (i.e., what we need to import) + # * the name of the module that needs to be imported for this decorator + # to properly resolve. + # Examples: + # for foo.bar, the tuple is (, , 'foo') + # for baz, the tuple is (, , 'baz') + kept_decorators.append((dec, dec_support, + anno.getanno(dec_support_node, anno.Basic.QN))) + + for _, dec_support, name in kept_decorators: + if tf_inspect.ismodule(dec_support): + self.ctx.program.additional_imports.add( + 'import %s as %s' % (dec_support.__name__, name)) else: - self.additional_dependencies.add(dec_value) - - node.decorator_list = [dec for dec, _ in kept_decorators] + if dec_support.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_support, + dec_support.__module__)) + self.ctx.program.additional_imports.add( + 'from %s import %s' % (dec_support.__module__, name)) + + node.decorator_list = [dec for dec, _, _ in kept_decorators] return node - # pylint:enable=invalid-name - -def transform(node, remove_decorators): - transformer = DecoratorsTransformer(remove_decorators) - node = transformer.visit(node) - return node, transformer.additional_dependencies +def transform(node, ctx): + return DecoratorsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py index 9c01f689127dbedad7669c65b03e7da071b2d64d..d41c7fde2474803a438100e7e00ce8e9f675de45 100644 --- a/tensorflow/contrib/autograph/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -20,9 +20,10 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -39,28 +40,35 @@ def simple_decorator(f): return lambda a: f(a) + 1 -def self_removing_decorator(removing_wrapper): +def self_transform_decorator(transform): + def decorator(f): @wraps(f) def wrapper(*args): # This removing wrapper is defined in the test below. This setup is so - # intricate just to simulate how we use the transformer in practice. - transformed_f = removing_wrapper(f, (self_removing_decorator,)) + # intricate in order to simulate how we use the transformer in practice. + transformed_f = transform(f, (self_transform_decorator,)) return transformed_f(*args) + 1 return wrapper return decorator -class DecoratorsTest(converter_test_base.TestCase): +class DecoratorsTest(converter_testing.TestCase): - def _remover_wrapper(self, f, remove_decorators): + def _transform(self, f, autograph_decorators): namespace = { - 'self_removing_decorator': self_removing_decorator, - 'simple_decorator': simple_decorator + 'self_transform_decorator': self_transform_decorator, + 'simple_decorator': simple_decorator, + 'converter_testing': converter_testing, } - node = self.parse_and_analyze(f, namespace) - node, _ = decorators.transform(node, remove_decorators=remove_decorators) - result, _ = compiler.ast_to_object(node) + node = self.parse_and_analyze( + f, + namespace, + recursive=False, + autograph_decorators=autograph_decorators) + node = decorators.transform(node, self.ctx) + import_line = '\n'.join(self.ctx.program.additional_imports) + result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__) def test_noop(self): @@ -69,15 +77,14 @@ class DecoratorsTest(converter_test_base.TestCase): return a node = self.parse_and_analyze(test_fn, {}) - node, deps = decorators.transform(node, remove_decorators=()) + node = decorators.transform(node, self.ctx) result, _ = compiler.ast_to_object(node) - self.assertFalse(deps) self.assertEqual(1, result.test_fn(1)) def test_function(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): return a @@ -88,7 +95,7 @@ class DecoratorsTest(converter_test_base.TestCase): class TestClass(object): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(self, a): return a @@ -101,38 +108,39 @@ class DecoratorsTest(converter_test_base.TestCase): # Note that reversing the order of this two doesn't work. @classmethod - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(cls, a): return a # 2 = 1 (a) + 1 (decorator applied exactly once) self.assertEqual(2, TestClass.test_fn(1)) - def test_nested_decorators(self): + def test_nested_decorators_local(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): @simple_decorator def inner_fn(b): return b + 11 return inner_fn(a) - with self.assertRaises(ValueError): + # Expected to fail because simple_decorator cannot be imported. + with self.assertRaises(transformer.AutographParseError): test_fn(1) - # TODO(mdan): Uncomment this test once converter_test_base is updated. - # (can't do it now because it has unrelated pending changes) - # def test_nested_decorators(self): - # - # @self_removing_decorator(self._remover_wrapper) - # def test_fn(a): - # @imported_decorator - # def inner_fn(b): - # return b + 11 - # return inner_fn(a) - # - # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) - # self.assertEqual(14, test_fn(1)) + def test_nested_decorators_imported(self): + + @self_transform_decorator(self._transform) + def test_fn(a): + + @converter_testing.imported_decorator + def inner_fn(b): + return b + 11 + + return inner_fn(a) + + # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + self.assertEqual(14, test_fn(1)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py index 616d222762e09feeba1809f119d915dfbe522283..e996138498ab2b7efa76671d8cc67fd4c6a9d9b8 100644 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class IfExp(transformer.Base): +class IfExp(converter.Base): """Canonicalizes all IfExp nodes into plain conditionals.""" def visit_IfExp(self, node): @@ -34,16 +34,16 @@ class IfExp(transformer.Base): return desugared_ifexp -def transform(node, context): +def transform(node, ctx): """Desugar IfExp nodes into plain conditionals. Args: - node: an AST node to transform - context: a context object + node: ast.AST, the node to transform + ctx: converter.EntityContext Returns: new_node: an AST with no IfExp nodes, only conditionals. """ - node = IfExp(context).visit(node) + node = IfExp(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py index ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e..cdd5a2f591edc1138df1c165577ed375131ddf09 100644 --- a/tensorflow/contrib/autograph/converters/ifexp_test.py +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -19,12 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class IfExpTest(converter_test_base.TestCase): +class IfExpTest(converter_testing.TestCase): def compiled_fn(self, test_fn, *args): node = self.parse_and_analyze(test_fn, {}) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py index d7f292015164e047d054c5d1fb0b391e960bb73d..c4a13ee822ab84706df83256d9e9684c3f7dacba 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,17 +31,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class ListCompCanonicalizationTransformer(transformer.Base): +class ListCompCanonicalizationTransformer(converter.Base): """NodeTransformer to canonicalize list comprehensions.""" - def __init__(self, context): - super(ListCompCanonicalizationTransformer, self).__init__(context) - def make_update_list_node(self, list_, elt): return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] @@ -76,5 +73,5 @@ class ListCompCanonicalizationTransformer(transformer.Base): return make_list + loop_body -def transform(node, context): - return ListCompCanonicalizationTransformer(context).visit(node) +def transform(node, ctx): + return ListCompCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py index 4758671f5ec83c26cfa54be0ef68f5f564094f6c..2bbee93412ce3174a14f3d60af9435dcf3b82cc6 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,12 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import list_comprehension +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ListCompTest(converter_test_base.TestCase): +class ListCompTest(converter_testing.TestCase): def test_basic(self): diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index b49521b2c328f418828a5e92890aa1b169384b70..d77a04479826779b8aa859d70f2f7ff51138f841 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -32,85 +32,196 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.python.framework import dtypes +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ListTransformer(transformer.Base): +# Tags for local state. +POP_USES = 'pop_uses' + + +class ListTransformer(converter.Base): """Converts lists and related operations to their TF counterpart.""" - def _empty_list(self, node): - if not anno.hasanno(node, 'element_type'): - raise NotImplementedError( - 'type inference for empty lists is not yet supported; ' - 'use set_element_type(, ) to continue') - dtype = anno.getanno(node, 'element_type') - if not isinstance(dtype, dtypes.DType): - # TODO(mdan): Allow non-TF dtypes? - # That would be consistent with the dynamic dispatch pattern, but - # we must make sure that doesn't become confusing. - raise NotImplementedError('element type "%s" not yet supported' % dtype) - - dtype_name = dtype.name - # TODO(mdan): Does it ever make sense not to use tensor lists? + def visit_List(self, node): + node = self.generic_visit(node) template = """ - tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + ag__.new_list(elements) """ - return templates.replace_as_expression(template, dtype_name=dtype_name) + return templates.replace_as_expression(template, elements=node) - def _pre_populated_list(self, node): - raise NotImplementedError('pre-populated lists') + def _replace_append_call(self, node): + assert len(node.args) == 1 + assert isinstance(node.func, gast.Attribute) + template = """ + target = ag__.list_append(target, element) + """ + return templates.replace( + template, + target=node.func.value, + element=node.args[0]) + + def _replace_pop_call(self, node): + # Expressions that use pop() are converted to a statement + expression. + # + # For example: + # + # print(target.pop()) + # + # ... is converted to: + # + # target, target_pop = ag__.list_pop(target) + # print(target_pop) + # + # Here, we just generate the variable name and swap it in, + # and _generate_pop_operation will handle the rest. + # + # Multiple uses of pop() are allowed: + # + # print(tartget.pop(), target.pop()) + # print(tartget.pop().pop()) + # + assert isinstance(node.func, gast.Attribute) + scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) + target_node = node.func.value + + # Attempt to use a related name if can get one. Otherwise use something + # generic. + if anno.hasanno(target_node, anno.Basic.QN): + target_name = anno.getanno(target_node, anno.Basic.QN).ssf() + else: + target_name = 'list' + pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) + + pop_uses = self.get_local(POP_USES, []) + pop_uses.append((node, pop_var_name)) + self.set_local(POP_USES, pop_uses) + + return templates.replace_as_expression('var_name', var_name=pop_var_name) + + def _replace_stack_call(self, node): + assert len(node.args) == 1 + dtype = anno.getanno( + node.args[0], + 'element_type', + default=templates.replace_as_expression('None')) + template = """ + ag__.list_stack( + target, + opts=ag__.ListStackOpts( + element_dtype=dtype, + original_call=orig_call)) + """ + return templates.replace_as_expression( + template, + dtype=dtype, + target=node.args[0], + orig_call=node.func) - def visit_Expr(self, node): + def visit_Call(self, node): node = self.generic_visit(node) - if isinstance(node.value, gast.Call): - call_node = node.value - - if not anno.hasanno(call_node.func, anno.Basic.QN): - return node - qn = anno.getanno(call_node.func, anno.Basic.QN) - - if qn.qn[-1] == 'append' and (len(call_node.args) == 1): - template = """ - target = ag__.utils.dynamic_list_append(target, element) - """ - node = templates.replace( - template, - target=qn.parent.ast(), - element=call_node.args[0]) + + # TODO(mdan): This is insufficient if target is a function argument. + # In the case of function arguments, we need to add the list to the + # function's return value, because it is being modified. + # TODO(mdan): Checking just the name is brittle, can it be improved? + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == 'append' and (len(node.args) == 1): + node = self._replace_append_call(node) + elif func_name == 'pop' and (len(node.args) <= 1): + node = self._replace_pop_call(node) + elif func_name == 'stack' and (len(node.args) == 1): + node = self._replace_stack_call(node) + return node - def _replace_list_constructors(self, targets, values): - for target in targets: - if (isinstance(target, (gast.Tuple, gast.List)) and - isinstance(values, (gast.Tuple, gast.List))): - n_targets = len(target.elts) - for i in range(n_targets): - target_el, value_el = target.elts[i], values.elts[i] - values.elts[i] = self._replace_list_constructors( - (target_el,), value_el) - return values - if isinstance(values, gast.List): - if values.elts: - return self._pre_populated_list(values) - else: - return self._empty_list(values) - return values - - def visit_Assign(self, node): - node = self.generic_visit(node) + def _generate_pop_operation(self, original_call_node, pop_var_name): + assert isinstance(original_call_node.func, gast.Attribute) + + if original_call_node.args: + pop_element = original_call_node.args[0] + else: + pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to + # target, hence the func.value. + dtype = anno.getanno( + original_call_node.func.value, + 'element_type', + default=templates.replace_as_expression('None')) + shape = anno.getanno( + original_call_node.func.value, + 'element_shape', + default=templates.replace_as_expression('None')) + + template = """ + target, pop_var_name = ag__.list_pop( + target, element, + opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) + """ + return templates.replace( + template, + target=original_call_node.func.value, + pop_var_name=pop_var_name, + element=pop_element, + dtype=dtype, + shape=shape) + + def _postprocess_statement(self, node): + """Inserts any separate pop() calls that node may use.""" + pop_uses = self.get_local(POP_USES, None) + if pop_uses: + replacements = [] + for original_call_node, pop_var_name in pop_uses: + replacements.extend( + self._generate_pop_operation(original_call_node, pop_var_name)) + replacements.append(node) + node = replacements + self.exit_local_scope() + return node, None + + # TODO(mdan): Should we have a generic visit_block instead? + # Right now it feels that a visit_block would add too much magic that's + # hard to follow. + + def _visit_and_process_block(self, block): + return self.visit_block( + block, + before_visit=self.enter_local_scope, + after_visit=self._postprocess_statement) + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._visit_and_process_block(node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node - # Only convert lists when they are assigned to a variable, e.g.: - # l = [] - # TODO(mdan): A similar pattern exists in type_info.py - # We should add a generic "unpack_assignment" function to the base - # transformer, that has the same effect as applying some logic to the SSA - # form. - node.value = self._replace_list_constructors(node.targets, node.value) + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._visit_and_process_block(node.body) return node -def transform(node, context): - return ListTransformer(context).visit(node) +def transform(node, ctx): + return ListTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 74c6dc64f197f75eb3e66c01fb078467e8e8ea89..ea04097b28deedd705164bd95ab62dba3e3c7834 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -19,77 +19,129 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import dtypes -from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class ListTest(converter_test_base.TestCase): +class ListTest(converter_testing.TestCase): - def test_empty_annotated_list(self): + def test_empty_list(self): def test_fn(): - l = [] - utils.set_element_type(l, dtypes.int32) - l.append(1) - return l + return [] - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: - # TODO(mdan): Attach these additional modules automatically. - result.utils = utils - result.dtypes = dtypes + with self.compiled(node) as result: + tl = result.test_fn() + # Empty tensor lists cannot be evaluated or stacked. + self.assertTrue(isinstance(tl, ops.Tensor)) + self.assertEqual(tl.dtype, dtypes.variant) + + def test_initialized_list(self): + + def test_fn(): + return [1, 2, 3] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: with self.test_session() as sess: - self.assertAllEqual([1], sess.run(result.test_fn().stack())) + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) - def test_empty_annotated_lists_unpacked(self): + def test_list_append(self): def test_fn(): - l, m = [], [] - utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m + l = [1] + l.append(2) + l.append(3) + return l - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_pop(self): + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32, ()) + s = l.pop() + return s, l + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + ts, tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2]) + self.assertAllEqual(sess.run(ts), 3) + + def test_double_list_pop(self): - def test_empty_annotated_lists_list_unpacked(self): + def test_fn(l): + s = l.pop().pop() + return s + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(result.test_fn(test_input), 3) + + def test_list_stack(self): + + tf = None # Will be replaced with a mock. def test_fn(): - [l, m] = [], [] + l = [1, 2, 3] utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m - - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + return tf.stack(l) + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node, array_ops.stack, dtypes.int32) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 3a795a315a3c2aa08ac1577a204102755b6e849c..16eb1f0e3f8ad34e615931882ab2896db485f457 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -23,10 +23,10 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer # TODO(mdan): Properly extrack boolean ops according to lazy eval rules. @@ -39,11 +39,11 @@ from tensorflow.contrib.autograph.pyct import transformer SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' -class LogicalExpressionTransformer(transformer.Base): +class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" - def __init__(self, context): - super(LogicalExpressionTransformer, self).__init__(context) + def __init__(self, ctx): + super(LogicalExpressionTransformer, self).__init__(ctx) # TODO(mdan): Look into replacing with bitwise operators instead. # TODO(mdan): Skip replacing if the function is trivial. self.op_mapping = { @@ -128,5 +128,5 @@ class LogicalExpressionTransformer(transformer.Base): return right -def transform(node, context): - return LogicalExpressionTransformer(context).visit(node) +def transform(node, ctx): + return LogicalExpressionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 2814060c4d831e4dddacb3dcbcbe1db42160db20..48186024a9da7b41fa7ff9a8ab18f3477ba09c8f 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class GradientsFunctionTest(converter_test_base.TestCase): +class GradientsFunctionTest(converter_testing.TestCase): def test_equals(self): diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py index dfee529abaa8c14d9b408819b32c5199500a2c2f..dd6c6bf960c52d094a16d4cd72fa84f65b9322a1 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class FunctionNameScopeTransformer(transformer.Base): +class FunctionNameScopeTransformer(converter.Base): """Wrap a function body with a `name_scope` of the function name.""" def _name_for_current_scope(self): @@ -70,5 +70,5 @@ class FunctionNameScopeTransformer(transformer.Base): return node -def transform(node, context): - return FunctionNameScopeTransformer(context).visit(node) +def transform(node, ctx): + return FunctionNameScopeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py index 17692cbd880dbc1db4bb40ad7345e27907499f9d..444d0bcd469f35689d078debe3622f930dbac723 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test -class FunctionNameScopeTransformer(converter_test_base.TestCase): +class FunctionNameScopeTransformer(converter_testing.TestCase): def test_basic(self): diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py index 3bcb2d3c42c6e0663c8f78523199a364b6ac231f..b808604f0ab2d42f41a560035ab046ff782a3431 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,11 +36,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -59,14 +59,9 @@ class SymbolNamer(object): raise NotImplementedError() -class SideEffectGuardTransformer(transformer.Base): +class SideEffectGuardTransformer(converter.Base): """Adds control dependencies to functions with side effects.""" - def __init__(self, context): - super(SideEffectGuardTransformer, self).__init__(context) - - # pylint:disable=invalid-name - def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes @@ -149,7 +144,7 @@ class SideEffectGuardTransformer(transformer.Base): s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( - self.context.namer.new_symbol( + self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: @@ -183,8 +178,6 @@ class SideEffectGuardTransformer(transformer.Base): (node.body, alias_map)) return node - # pylint:enable=invalid-name - -def transform(node, context): - return SideEffectGuardTransformer(context).visit(node) +def transform(node, ctx): + return SideEffectGuardTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index ce0ce33243a1352107eb8121050ee76474869809..a7ad8efed4c88e15ce9dc14cb02e5e035602013d 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class SideEffectGuardsTest(converter_test_base.TestCase): +class SideEffectGuardsTest(converter_testing.TestCase): def test_side_effect_on_return_only_variable(self): diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py index bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7..a351cd81b82f7fb32f62ac1579355ace0501759d 100644 --- a/tensorflow/contrib/autograph/converters/single_return.py +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -20,21 +20,21 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Move this logic into transformer_base. -class BodyVisitor(transformer.Base): +class BodyVisitor(converter.Base): """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" - def __init__(self, context, depth_first=False): + def __init__(self, ctx, depth_first=False): + super(BodyVisitor, self).__init__(ctx) self.depth_first = depth_first self.changes_made = False - super(BodyVisitor, self).__init__(context) def visit_nodelist(self, nodelist): for node in nodelist: @@ -144,13 +144,13 @@ def contains_return(node): return False -class LiftReturn(transformer.Base): +class LiftReturn(converter.Base): """Move return statements out of If and With blocks.""" - def __init__(self, context): + def __init__(self, ctx): + super(LiftReturn, self).__init__(ctx) self.changes_made = False self.common_return_name = None - super(LiftReturn, self).__init__(context) def visit_If(self, node): # Depth-first traversal of if statements @@ -195,8 +195,8 @@ class LiftReturn(transformer.Base): last_return_name = self.common_return_name body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) referenced_names = body_scope.referenced - self.common_return_name = self.context.namer.new_symbol( - 'return_', referenced_names) + self.common_return_name = self.ctx.namer.new_symbol('return_', + referenced_names) node = self.generic_visit(node) self.common_return_name = last_return_name return node @@ -265,7 +265,7 @@ class DetectReturnInFunctionDef(gast.NodeVisitor): 'Each function definition should contain at least one return.') -def transform(node, context): +def transform(node, ctx): """Ensure a function has only a single return. This transforms an AST node with multiple returns successively into containing @@ -280,8 +280,8 @@ def transform(node, context): this is an error. Args: - node: an AST node to transform - context: a context object + node: ast.AST + ctx: converter.EntityContext Returns: new_node: an AST with a single return value @@ -301,10 +301,10 @@ def transform(node, context): while True: # Try to lift all returns out of if statements and with blocks - lr = LiftReturn(context) + lr = LiftReturn(ctx) node = lr.visit(node) changes_made = lr.changes_made - fe = FoldElse(context) + fe = FoldElse(ctx) node = fe.visit(node) changes_made = changes_made or fe.changes_made diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py index d483005a09537ea8227814f65aa7e6402c853f60..1f0de4310e370235a4a7bfeaa61bd519a81aff47 100644 --- a/tensorflow/contrib/autograph/converters/single_return_test.py +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework.ops import name_scope from tensorflow.python.platform import test -class SingleReturnTest(converter_test_base.TestCase): +class SingleReturnTest(converter_testing.TestCase): def compiled_fn(self, test_fn, *args): node = self.parse_and_analyze(test_fn, {}) diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5fc57125a8b65faf1e3a377d7984ff05b3245c --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for slice operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates + + +class SliceTransformer(converter.Base): + """Converts slicing operations to their TF counterpart. + + Currently, relying on the default slice operator that Tensor uses is + insufficient, because TensorArray and tensor lists use dedicated index read + and write functions. + """ + + def _process_single_assignment(self, target, value): + if not isinstance(target, gast.Subscript): + return None + + template = """ + target = ag__.set_item(target, key, item) + """ + return templates.replace( + template, target=target.value, key=target.slice, item=value) + + def visit_Assign(self, node): + node = self.generic_visit(node) + # TODO(mdan): Support unpackings and multiple assignments. + if len(node.targets) != 1: + raise NotImplementedError('multiple assignment') + replacement = self._process_single_assignment(node.targets[0], node.value) + if replacement is not None: + return replacement + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + if not isinstance(node.slice, gast.Index): + # TODO(mdan): It might make more sense to wave them through. + raise NotImplementedError('non-index slice') + + if not isinstance(node.ctx, gast.Load): + # Index writes are handled at a higher level, one at which the rvalue is + # also available. + return node + + dtype = anno.getanno( + node.value, + 'element_type', + default=templates.replace_as_expression('None')) + + template = """ + ag__.get_item( + target, + key, + opts=ag__.GetItemOpts(element_dtype=dtype)) + """ + return templates.replace_as_expression( + template, target=node.value, key=node.slice, dtype=dtype) + + +def transform(node, ctx): + return SliceTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..df9a4c8bab66f24374605b45bc90bc2730431323 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SliceTest(converter_testing.TestCase): + + def test_index_access(self): + + def test_fn(l): + utils.set_element_type(l, dtypes.int32) + return l[1] + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = slices.transform(node, self.ctx) + + with self.compiled(node, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = result.test_fn(tl) + self.assertEqual(2, sess.run(y)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..833f9dced81bd651244d281322c830bb1c88b259 --- /dev/null +++ b/tensorflow/contrib/autograph/core/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "core", + srcs = [ + "config.py", + "converter.py", + "naming.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + ], +) + +py_library( + name = "test_lib", + srcs = [ + "converter_testing.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":core", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/python/keras/applications/inception_resnet_v2/__init__.py b/tensorflow/contrib/autograph/core/annos.py similarity index 62% rename from tensorflow/python/keras/applications/inception_resnet_v2/__init__.py rename to tensorflow/contrib/autograph/core/annos.py index 223660e9bef33896bc83f43ed26c1792e48105b9..b8937ce36a9631739ab3d7e65a4dad4124406a00 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2/__init__.py +++ b/tensorflow/contrib/autograph/core/annos.py @@ -12,16 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""InceptionResNetV2 Keras application.""" +"""Annotations specific to AutoGraph.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import InceptionResNetV2 -from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import preprocess_input +from enum import Enum -del absolute_import -del division -del print_function + +class NoValue(Enum): + + def __repr__(self): + return self.name + + +class NodeAnno(NoValue): + """Additional annotations used by AutoGraph converters. + + These are in addition to the basic annotations declared in pyct/anno.py and + pyct/static_analysis/annos.py. + """ + + # The directives collection - see directives.py + DIRECTIVES = ( + 'Dict depicting static directive calls. See the directives converter.') diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/core/config.py similarity index 100% rename from tensorflow/contrib/autograph/impl/config.py rename to tensorflow/contrib/autograph/core/config.py diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..54e6aa0f3bbb9059e044861362407cb5050240b4 --- /dev/null +++ b/tensorflow/contrib/autograph/core/converter.py @@ -0,0 +1,210 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter construction support. + +This module contains a base class for all converters, as well as supporting +structures. These structures are referred to as contexts. + +The class hierarchy is as follows: + + + [extends] converter.Base + [extends] transformer.Base + [extends] gast.nodeTransformer + [uses] transfomer.SourceInfo + [uses] converter.EntityContext + [uses] converter.ProgramContext + [uses] transfomer.SourceInfo + +converter.Base is a specialization of transformer.Base for AutoGraph. It's a +very lightweight subclass that adds a `ctx` attribute holding the corresponding +EntityContext object (see below). Note that converters are not reusable, and +`visit` will raise an error if called more than once. + +converter.EntityContext contains mutable state associated with an entity that +the converter processes. + +converter.ProgramContext contains mutable state across related entities. For +example, when converting several functions that call one another, the +ProgramContext should be shared across these entities. + +Below is the overal flow at conversion: + + program_ctx = ProgramContext(, , ...) + while : + entity, source_info = + entity_ctx = EntityContext(program_ctx, source_info) + for : + converter = ConverterClass(entity_ctx) + + # May update entity_ctx and program_ctx + entity = converter.visit(entity) + + + +Note that pyct contains a small number of transformers used for static analysis. +These implement transformer.Base, rather than converter.Base, to avoid a +dependency on AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import naming +from tensorflow.contrib.autograph.pyct import transformer + +# TODO(mdan): These contexts can be refactored into first class objects. +# For example, we could define Program and Entity abstractions that hold on +# to the actual entity and have conversion methods. + + +class ProgramContext(object): + """ProgramContext keeps track of converting function hierarchies. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + recursive: bool, whether to recursively convert any functions that the + decorator function may call. + autograph_decorators: Tuple[Callable, ...], decorator functions that belong + to AutoGraph. These require special treatment. + dependency_cache: Dict[Any, ast.AST], the original entities mapped to their + converted AST + additional_imports: Set[Any], additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code + name_map: Dict[str, str], map of original entity name to the name of + their converted counterparts + autograph_module: Module, a reference to the autograph module. This + needs to be specified by the caller to avoid circular dependencies. + uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the + fully qualified name of a package containing functions that will not be + compiled. + required_imports: str, containing an import statement on each line. These + are all the imports necessary for the compiled code to run, in addition + to the closures of each entity, which are attached dynamically. + """ + + def __init__( + self, + recursive, + autograph_decorators, + partial_types, + autograph_module, + uncompiled_modules, + ): + self.recursive = recursive + self.autograph_decorators = autograph_decorators + self.partial_types = partial_types if partial_types else () + self.autograph_module = autograph_module + self.uncompiled_modules = uncompiled_modules + + # Required to output dependencies in discovery order, which should match + # the reverse dependency order. + self.dependency_cache = collections.OrderedDict() + self.additional_imports = set() + self.name_map = {} + + @property + def required_imports(self): + """Returns a block containing all imports required by the converted code.""" + # TODO(mdan): Check that these don't clobber one another. + return '\n'.join(config.COMPILED_IMPORT_STATEMENTS + + tuple(self.additional_imports)) + + def new_namer(self, namespace): + return naming.Namer(namespace, self.recursive, self.name_map, + self.partial_types) + + def update_name_map(self, namer): + """Updates renamed_calls based on the recent activity from the namer. + + Whenever we convert a new entity, any references to other entities are being + renamed to match their soon-to-be-converted counterparts. The namer keeps + track of these renames. When conversion is complete, we copy those renames + so that when those referenced entities are being converted, their new name + matches. + + Args: + namer: naming.Namer + + Raises: + ValueError: when an entity was renamed twice and to different names. + """ + # TODO(mdan): Have call_trees do this directly. + # This is done so indirectly, via the namer, for historic reasons. But + # now we can have the converter that does the rename record the new name + # as well and skip this step altogether. + for o, name in namer.renamed_calls.items(): + if o in self.name_map: + if self.name_map[o] != name: + raise ValueError( + 'Calls to %s were converted using multiple names (%s). This is ' + 'possible when an entity with one of these names already ' + 'existed. To fix, avoid using any of these names.' % + (o, (name, self.name_map[o]))) + else: + self.name_map[o] = name + + def add_to_cache(self, original_entity, converted_ast): + self.dependency_cache[original_entity] = converted_ast + + +class EntityContext(object): + """Tracks the conversion of a single entity. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + namer: Namer + info: transformer.EntityInfo + program: ProgramContext + """ + + def __init__(self, namer, entity_info, program_ctx): + self.namer = namer + self.info = entity_info + self.program = program_ctx + + +class Base(transformer.Base): + """All converters should inherit from this class. + + Attributes: + ctx: EntityContext + """ + + def __init__(self, ctx): + super(Base, self).__init__(ctx.info) + self.ctx = ctx # Keeping this short because it's used frequently. + + self._used = False + self._ast_depth = 0 + + def visit(self, node): + if not self._ast_depth: + if self._used: + raise ValueError('converter objects cannot be reused') + self._used = True + + self._ast_depth += 1 + try: + return super(Base, self).visit(node) + finally: + self._ast_depth -= 1 diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/core/converter_testing.py similarity index 80% rename from tensorflow/contrib/autograph/converters/converter_test_base.py rename to tensorflow/contrib/autograph/core/converter_testing.py index 41c2e71702e7e3ee3811a2cbee27c8c988eb3a5c..0e46aacc1216d2dbd9d34ad0e72ca8251094bddc 100644 --- a/tensorflow/contrib/autograph/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -23,17 +23,24 @@ import imp from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import compiler -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test +def imported_decorator(f): + return lambda a: f(a) + 1 + + +# TODO(mdan): We might be able to use the real namer here. class FakeNamer(object): """A fake namer that uses a global counter to generate unique names.""" @@ -114,23 +121,32 @@ class TestCase(test.TestCase): arg_types=None, include_type_analysis=True, owner_type=None, - recursive=True): + recursive=True, + autograph_decorators=()): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=namer or FakeNamer(), + + if namer is None: + namer = FakeNamer() + program_ctx = converter.ProgramContext( + recursive=recursive, + autograph_decorators=autograph_decorators, + partial_types=None, + autograph_module=None, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + entity_info = transformer.EntityInfo( source_code=source, - source_file=None, + source_file='', namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=owner_type, - recursive=recursive, - type_annotation_func=utils.set_element_type) + owner_type=owner_type) + ctx = converter.EntityContext(namer, entity_info, program_ctx) + node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) if include_type_analysis: - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) self.ctx = ctx return node diff --git a/tensorflow/contrib/autograph/impl/naming.py b/tensorflow/contrib/autograph/core/naming.py similarity index 100% rename from tensorflow/contrib/autograph/impl/naming.py rename to tensorflow/contrib/autograph/core/naming.py diff --git a/tensorflow/contrib/autograph/impl/naming_test.py b/tensorflow/contrib/autograph/core/naming_test.py similarity index 98% rename from tensorflow/contrib/autograph/impl/naming_test.py rename to tensorflow/contrib/autograph/core/naming_test.py index 73fc0894655cb49e4f61bf8ca51995b06feb3072..d2bebd0478b1074e421b5da1427a0dbaf91b6c9f 100644 --- a/tensorflow/contrib/autograph/impl/naming_test.py +++ b/tensorflow/contrib/autograph/core/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.core import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index d62390494b78c415212ba91ac914cdfee324f971..0702273fac15da61a72d66d8344a5add32ad12a6 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -570,7 +570,7 @@ " autograph.utils.set_element_type(numbers, tf.int32)\n", " for i in range(n):\n", " numbers.append(i)\n", - " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n", "\n", "\n", "tf_f = autograph.to_graph(f)\n", @@ -648,7 +648,7 @@ " if not is_prime:\n", " continue\n", " primes.append(i)\n", - " all_primes = primes.stack()\n", + " all_primes = autograph.stack(primes)\n", "\n", " print('The prime numbers less than', n, 'are:')\n", " print(all_primes)\n", @@ -953,8 +953,9 @@ " train_accuracies.append(step_train_accuracy)\n", " test_accuracies.append(step_test_accuracy)\n", " i += 1\n", - " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", - " test_accuracies.stack())" + " return (autograph.stack(train_losses), autograph.stack(test_losses),\n", + " autograph.stack(train_accuracies),\n", + " autograph.stack(test_accuracies))" ], "execution_count": 0, "outputs": [] @@ -1236,7 +1237,7 @@ " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", " i += 1\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 324b23c24b5a7970d7f20ed955839ba1cf1774fc..44532cb078f9bd1578172f8a7d8a4b55cd21a7cb 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -190,7 +190,6 @@ " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", - "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", " \"\"\"A single RNN layer.\n", "\n", @@ -203,13 +202,12 @@ " Returns:\n", " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n", " state, output = cell.zero_state(batch_size, tf.float32)\n", " for ch in chars:\n", " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", @@ -223,7 +221,7 @@ "\n", "\n", " def call(self, inputs, training=False):\n", - " \"\"\"The RNN model code. Uses Eager and \n", + " \"\"\"The RNN model code. Uses Eager.\n", "\n", " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", " followed by a fully connected layer with ReLU activation.\n", @@ -243,7 +241,8 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " indices = (length - 1, range(batch_size))\n", + " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", "\n", @@ -381,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 107, "metadata": { "colab": { "autoexec": { @@ -392,9 +391,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 10604, + "elapsed": 5454, "status": "ok", - "timestamp": 1524095272039, + "timestamp": 1529952160455, "user": { "displayName": "", "photoUrl": "", @@ -403,7 +402,7 @@ "user_tz": 240 }, "id": "2pg1AfbxBJQq", - "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660", + "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb", "slideshow": { "slide_type": "-" } @@ -413,7 +412,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eval loss at step 100: 0.0674834\n" + "Eval loss at step 100: 0.0705221\n" ] } ], @@ -423,8 +422,8 @@ " 'learning_rate': 0.01,\n", "}\n", "\n", - "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", - "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n", "data_dir = \"tmp/rnn/data\"\n", "\n", "regressor = tf.estimator.Estimator(\n", @@ -457,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 108, "metadata": { "colab": { "autoexec": { @@ -468,9 +467,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 7990, + "elapsed": 3432, "status": "ok", - "timestamp": 1524095280105, + "timestamp": 1529952163923, "user": { "displayName": "", "photoUrl": "", @@ -479,7 +478,7 @@ "user_tz": 240 }, "id": "dxHex2tUN_10", - "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101", + "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2", "slideshow": { "slide_type": "slide" } @@ -491,12 +490,12 @@ "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -507,12 +506,12 @@ "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -520,15 +519,15 @@ { "data": { "text/html": [ - "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -536,16 +535,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n", - "//# sourceURL=js_71b9087b6d" + "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n", + "//# sourceURL=js_dc5d7f2784" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -553,16 +552,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_e390445f33" + "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_be7950150b" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -570,17 +569,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_241dd76d85" + "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_d0c3bd4eaa" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -588,17 +587,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_60c64e3d50" + "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_f10f6eba86" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -606,17 +605,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_14ea437cbd" + "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff29697179" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -624,17 +623,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_09294c2226" + "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_ff85295dc7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -642,17 +641,17 @@ { "data": { "application/javascript": [ - "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_e5e8266997" + "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ed7aabfedb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -660,17 +659,17 @@ { "data": { "application/javascript": [ - "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_07a097f0ee" + "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_c86f8feaf4" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -678,17 +677,17 @@ { "data": { "application/javascript": [ - "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_790d669ca8" + "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_4d0fde6662" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -696,17 +695,17 @@ { "data": { "application/javascript": [ - "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_d30df771f0" + "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3f66d52720" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -714,32 +713,32 @@ { "data": { "application/javascript": [ - "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_8a43a2da4b" + "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_375f5ae6d7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n", "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e" + "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -748,17 +747,17 @@ { "data": { "application/javascript": [ - "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_893ad561f4" + "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_34b0509660" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -766,17 +765,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_2d99e0ac17" + "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_518a0f26fe" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -784,17 +783,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_5c19462e32" + "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_17eb3ff612" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -802,17 +801,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_b9c8b7567b" + "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_99da807c8e" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -820,17 +819,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_fd05186348" + "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_dee01cb4b6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -838,16 +837,16 @@ { "data": { "text/html": [ - "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -856,17 +855,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_efef96e882" + "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_8c378be329" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -875,17 +874,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_6eca889864" + "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_f0b946600c" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -894,17 +893,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n", - "//# sourceURL=js_f02070cc60" + "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n", + "//# sourceURL=js_9e21b1373a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -913,17 +912,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", - "//# sourceURL=js_ed9faba660" + "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_a7764968c6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -932,17 +931,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_f3458d7074" + "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_74279d3ff0" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -951,17 +950,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_3ffd97bd6f" + "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_82b6c34cdb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -970,17 +969,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_7f73e8bcca" + "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff6144734a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -1043,28 +1042,6 @@ "kind": "local" }, "name": "RNN Colorbot using Keras and Estimators", - "provenance": [ - { - "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", - "timestamp": 1523579810961 - }, - { - "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", - "timestamp": 1523016192637 - }, - { - "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", - "timestamp": 1522238054357 - }, - { - "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", - "timestamp": 1521743157199 - }, - { - "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", - "timestamp": 1520522344607 - } - ], "version": "0.3.2", "views": {} }, diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 54424e26472b8466b8fe68ea848b5463c10224c9..a5438592c30021eac7183b65ccc10c36d220bc57 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -18,18 +18,19 @@ py_library( name = "impl", srcs = [ "api.py", - "config.py", "conversion.py", - "naming.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/core", "//tensorflow/contrib/autograph/operators", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:platform", + "//tensorflow/python:util", "@gast_archive//:gast", "@six_archive//:six", ], @@ -59,13 +60,3 @@ py_test( "@gast_archive//:gast", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 24f87b2c14da4a3523f1e580d4362cbd3679a2cd..c7401c7df126b73ca22cdaf74a2f1fd6149d7545 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -27,14 +27,15 @@ import gast import six # pylint:enable=g-bad-import-order -from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect # TODO(mdan): Properly document the type hints. @@ -70,6 +71,8 @@ def convert(recursive=False, verbose=False, arg_types=None): def wrapper(*args, **kwargs): return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + wrapper = tf_decorator.make_decorator(f, wrapper) + # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. setattr(wrapper, '__pyct_is_compile_decorator', True) @@ -230,20 +233,20 @@ def to_graph(e, A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values, + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) - for import_line in config.COMPILED_IMPORT_STATEMENTS: - module.body.extend(parser.parse_str(import_line).body) - for dep in reversed(conversion_map.dependency_cache.values()): + for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object(module) + compiled_node, compiled_src = compiler.ast_to_object( + module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? @@ -280,17 +283,16 @@ def to_code(e, Returns: String. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( compiler.ast_to_source(dep, indentation) - for dep in reversed(tuple( - six.itervalues(conversion_map.dependency_cache)))) + for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) - return imports + '\n\n' + code + return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index a7737b7f448131b1c54951efa719b481e1f4d0c9..994309333209586001c9369322ec3ddeee0a508e 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -21,12 +21,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.impl import api -from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect tf = utils.fake_tf() @@ -154,6 +155,22 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_decorator_preserves_argspec(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + called_member_converted = api.convert()(called_member) + + tc = TestClass() + self.assertListEqual( + list(tf_inspect.getfullargspec(tc.called_member)), + list(tf_inspect.getfullargspec(tc.called_member_converted))) + def test_convert_call_site_decorator(self): class TestClass(object): diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 55a30dc127957b2a9caa053db843380c94bacfbf..776d19f672ebbd6b88985dda157434f2046d87e7 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""High level conversion support.""" +"""Core conversion logic, serves as main point of access.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import imp import gast @@ -38,77 +37,23 @@ from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return -from tensorflow.contrib.autograph.impl import config -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import ast_util -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info -from tensorflow.contrib.autograph.utils import type_hints from tensorflow.python.util import tf_inspect # TODO(mdan): Might we not need any renaming at all? -class ConversionMap(object): - """ConversionMap keeps track of converting function hierarchies. - - This object is mutable, and is updated as functions are converted. - - Attributes: - recursive: Whether to recursively convert any functions that the decorator - function may call. - nocompile_decorators: tuple of decorator functions that toggle compilation - off. - dependency_cache: dict[object]: ast; maps original entities to their - converted AST - additional_imports: set(object); additional entities which for any reason - cannot be attached after loading and need to be explicitly imported - in the generated code - name_map: dict[string]: string; maps original entities to the name of - their converted counterparts - api_module: A reference to the api module. The reference needs to be passed - to avoid circular dependencies. - """ - - # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - - def __init__(self, recursive, nocompile_decorators, partial_types, - api_module): - self.recursive = recursive - self.nocompile_decorators = nocompile_decorators - self.partial_types = partial_types if partial_types else () - # Required to output dependencies in discovery order, which should match - # the reverse dependency order. - self.dependency_cache = collections.OrderedDict() - self.additional_imports = set() - self.name_map = {} - self.api_module = api_module - - def new_namer(self, namespace): - return naming.Namer(namespace, self.recursive, self.name_map, - self.partial_types) - - def update_name_map(self, namer): - for o, name in namer.renamed_calls.items(): - if o in self.name_map: - if self.name_map[o] != name: - raise ValueError( - 'Calls to %s were converted using multiple names (%s). This is ' - 'possible when an entity with one of these names already ' - 'existed. To fix, avoid using any of these names.') - else: - self.name_map[o] = name - - def add_to_cache(self, original_entity, converted_ast): - self.dependency_cache[original_entity] = converted_ast - - def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. @@ -127,7 +72,7 @@ def is_whitelisted_for_graph(o): return False -def entity_to_graph(o, conversion_map, arg_values, arg_types): +def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` @@ -138,7 +83,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): Args: o: A Python entity. - conversion_map: A ConversionMap object. + program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function @@ -156,7 +101,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): - node, name, ns = class_to_graph(o, conversion_map) + node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): # TODO(mdan): This is not a reliable mechanism. # The most reliable way is to check the source code, the AST will contain @@ -166,36 +111,35 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): 'lambda functions are not yet supported; declare the function' ' using def instead: %s' % o) else: - node, name, ns = function_to_graph(o, conversion_map, arg_values, - arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): - node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) - conversion_map.add_to_cache(o, node) - if conversion_map.recursive: + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: + for obj in program_ctx.name_map.keys(): + if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and - getattr(candidate, 'im_class') not in conversion_map.partial_types): + getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue - entity_to_graph(candidate, conversion_map, {}, {}) + entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns -def class_to_graph(c, conversion_map): +def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) @@ -210,7 +154,7 @@ def class_to_graph(c, conversion_map): continue node, _, namespace = function_to_graph( m, - conversion_map=conversion_map, + program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) @@ -219,14 +163,14 @@ def class_to_graph(c, conversion_map): else: class_namespace.update(namespace) converted_members[m] = node - namer = conversion_map.new_namer(class_namespace) + namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by - # conversion_map.update_name_map(namer)). + # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] @@ -246,7 +190,7 @@ def class_to_graph(c, conversion_map): alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( @@ -278,14 +222,14 @@ def _add_reserved_symbol(namespace, name, entity): ag_internal = None -def _add_self_references(namespace, api_module): +def _add_self_references(namespace, autograph_module): """Adds namespace references to the module that exposes the api itself.""" global ag_internal if ag_internal is None: # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') - ag_internal.converted_call = api_module.converted_call + ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be @@ -295,27 +239,24 @@ def _add_self_references(namespace, api_module): _add_reserved_symbol(namespace, 'ag__', ag_internal) -def function_to_graph(f, conversion_map, arg_values, arg_types, - owner_type=None): +def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) - _add_self_references(namespace, conversion_map.api_module) - namer = conversion_map.new_namer(namespace) + _add_self_references(namespace, program_ctx.autograph_module) + namer = program_ctx.new_namer(namespace) - ctx = context.EntityContext( - namer=namer, + entity_info = transformer.EntityInfo( source_code=source, source_file='', namespace=namespace, arg_values=arg_values, arg_types=arg_types, - owner_type=owner_type, - recursive=conversion_map.recursive, - type_annotation_func=type_hints.set_element_type) - node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + owner_type=owner_type) + context = converter.EntityContext(namer, entity_info, program_ctx) + node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -325,29 +266,28 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - conversion_map.additional_imports.update(deps) return node, new_name, namespace -def _static_analysis_pass(node, ctx): +def _apply_transformer(node, context, converter_module): + # TODO(mdan): Clear static analysis here. node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) - node = type_info.resolve(node, ctx) + node = activity.resolve(node, context.info, None) + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + node = type_info.resolve(node, context.info) + node = converter_module.transform(node, context) return node -def node_to_graph(node, ctx, nocompile_decorators): +def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: - node: A Python AST node representing the code to convert. - ctx: An EntityContext object. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST, the code to convert. + context: converter.EntityContext Returns: A tuple (node, deps): @@ -357,53 +297,26 @@ def node_to_graph(node, ctx, nocompile_decorators): """ # TODO(mdan): Verify arguments for correctness. - # TODO(mdan): Factor out common elements. - # These include: - # * code move between blocks - # * visiting blocks in transformers - - # Certain steps, especially canonicalization, insert new symbols into the - # tree, which must be accounted. Although less efficient, it is most robust - # to re-run the analysis. - - node = _static_analysis_pass(node, ctx) - - # TODO(mdan): Clean this up. - # Some intermediate analyses are not required, and some comments got orphaned. - + node = _apply_transformer(node, context, ifexp) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? - ctx.source_code = None - node = ifexp.transform(node, ctx) - node, deps = decorators.transform(node, nocompile_decorators) - node = break_statements.transform(node, ctx) - node = _static_analysis_pass(node, ctx) - - node = asserts.transform(node, ctx) - + context.info.source_code = None + node = _apply_transformer(node, context, decorators) + node = _apply_transformer(node, context, break_statements) + node = _apply_transformer(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_statements.transform(node, ctx) - ctx.namespace['len'] = len - - node = _static_analysis_pass(node, ctx) - node = single_return.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = lists.transform(node, ctx) - node = builtin_functions.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, - nocompile_decorators) - node = control_flow.transform(node, ctx) - - # control_flow may create new symbols and change scopes. - node = _static_analysis_pass(node, ctx) - node = logical_expressions.transform(node, ctx) - node = side_effect_guards.transform(node, ctx) - node = name_scopes.transform(node, ctx) - - return node, deps + node = _apply_transformer(node, context, continue_statements) + context.info.namespace['len'] = len + node = _apply_transformer(node, context, single_return) + node = _apply_transformer(node, context, lists) + node = _apply_transformer(node, context, slices) + node = _apply_transformer(node, context, builtin_functions) + node = _apply_transformer(node, context, call_trees) + node = _apply_transformer(node, context, control_flow) + node = _apply_transformer(node, context, logical_expressions) + node = _apply_transformer(node, context, side_effect_guards) + node = _apply_transformer(node, context, name_scopes) + return node diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 5edd8e74a8899a25fb51e2a4e133f3cb7933fa26..f5279298afdcd406a9a6762e58367cea8ca63141 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -21,17 +21,24 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras.engine import training from tensorflow.python.platform import test class ConversionTest(test.TestCase): - def _simple_conversion_map(self): - return conversion.ConversionMap(True, (), (), api) + def _simple_program_ctx(self): + return converter.ProgramContext( + recursive=True, + autograph_decorators=(), + partial_types=(), + autograph_module=api, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) def test_is_whitelisted_for_graph(self): @@ -44,16 +51,16 @@ class ConversionTest(test.TestCase): def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph('dummy', conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph('dummy', program_ctx, None, None) def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b - conversion_map = self._simple_conversion_map() - ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', name) self.assertTrue(ns['b'] is b) @@ -66,18 +73,17 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(f in conversion_map.dependency_cache) - self.assertTrue(g in conversion_map.dependency_cache) - self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) + self.assertTrue(f in program_ctx.dependency_cache) + self.assertTrue(g in program_ctx.dependency_cache) + self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( - 'tf__g', - conversion_map.dependency_cache[f].body[0].body[0].value.func.id) - self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) + 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) + self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) def test_entity_to_graph_class_hierarchy(self): @@ -104,16 +110,15 @@ class ConversionTest(test.TestCase): def baz(self): return self.y - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestBase in conversion_map.dependency_cache) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertTrue(TestBase in program_ctx.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertEqual('TfTestBase', - conversion_map.dependency_cache[TestBase].body[-1].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestBase].body[-1].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass].body[-1].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -126,24 +131,23 @@ class ConversionTest(test.TestCase): def call(self, x): return 3 * x - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) - self.assertFalse(training.Model in conversion_map.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) + self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', - conversion_map.dependency_cache[TestSubclass].body[0].names[0].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass].body[-1].name) def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) def test_ag_module_cached(self): def callee(): @@ -152,11 +156,11 @@ class ConversionTest(test.TestCase): def caller(a): return a() - conversion_map = self._simple_conversion_map() - _, _, callee_ns = conversion.entity_to_graph( - callee, conversion_map, None, None) - _, _, caller_ns = conversion.entity_to_graph( - caller, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None, + None) + _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None, + None) self.assertTrue(callee_ns['ag__'] is caller_ns['ag__']) diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/contrib/autograph/lang/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..77a2184e229003a3403cbe3bf116ad2570274a1b --- /dev/null +++ b/tensorflow/contrib/autograph/lang/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "lang", + srcs = [ + "directives.py", + "special_functions.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/operators", + ], +) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":lang", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/contrib/autograph/lang/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/lang/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..11135295a7966bc5d693676fcc71fe43791f2e99 --- /dev/null +++ b/tensorflow/contrib/autograph/lang/special_functions.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None, strict=True): + """Stacks the input, if it admits the notion of stacking. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any + element_dtype: tf.DType, optional dtypedtype for the elements in the list. + Required if the input is stackable, and the list is untyped. + strict: bool, if True an error is raised if the input is not stackable. + Otherwise the function is a no-op. + + Returns: + Any, if the input is stackable, the result will be a tf.Tensor. Otherwise, + if strict=False, the result will be list_or_tensor. + + Raises: + ValueError: if strict=True and the input is not stackable. + """ + if strict: + def raise_error(x): + raise ValueError('%s must be stackable when strict=True' % x) + original_call = raise_error + else: + original_call = lambda x: x + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=original_call)) diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a49cb6407517b634e0f1259fccda03d4ed18e83f --- /dev/null +++ b/tensorflow/contrib/autograph/lang/special_functions_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.lang import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1, strict=False), 1) + self.assertListEqual( + special_functions.stack([1, 2, 3], strict=False), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)], strict=False), list)) + + with self.assertRaises(ValueError): + special_functions.stack([1, 2, 3]) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 18bfec5d9c69912f90414c51ac63ba540cf4d5fc..332d5dab19e7ade1531b564fbdef2fa0dc2d09d5 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,13 +22,21 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", - "dispatch_context.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:list_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -52,3 +60,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 38b761d97d54bdaee4da91269964469b482895ae..c900fd6af2ea5dfb419f731ee8d8822d68424b27 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -28,6 +28,10 @@ closures for the body. # - the names used in the Python docs, if the operator is a function (e.g. # list_ and x for append, see # https://docs.python.org/3.7/tutorial/datastructures.html) +# +# All operators may accept a final argument named "opts", of a type that +# subclasses namedtuple and contains any arguments that are only required +# for some specializations of the operator. from __future__ import absolute_import from __future__ import division @@ -35,3 +39,12 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators.control_flow import for_stmt from tensorflow.contrib.autograph.operators.control_flow import while_stmt +from tensorflow.contrib.autograph.operators.data_structures import list_append +from tensorflow.contrib.autograph.operators.data_structures import list_pop +from tensorflow.contrib.autograph.operators.data_structures import list_stack +from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts +from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts +from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.slices import get_item +from tensorflow.contrib.autograph.operators.slices import GetItemOpts +from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 671c9ccc13eaa887522cfc248a6d56d7ab9719ca..988df70157170ed0a9ece33976e871e6f7693bbc 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state): Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. - An additionnal loop condition. + An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index c862306baa9e8114a71a26323ddcbd35c8592c55..06d8727b0fcc30b532b3f11281cd1a83c51ac8bc 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -18,39 +18,250 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables + + +# TODO(mdan): Once control flow supports objects, repackage as a class. + + +def new_list(iterable=None): + """The list constructor. + + Args: + iterable: Optional elements to fill the list with. + + Returns: + A list-like object. The exact return value depends on the initial elements. + """ + if iterable: + elements = tuple(iterable) + else: + elements = () + + # TODO(mdan): Extend these criteria. + if any(isinstance(el, variables.Variable) for el in elements): + return _py_list_new(elements) + return _tf_tensor_list_new(elements) -# TODO(mdan): Add support for TensorList once functional. -# TODO(mdan): Add primitives for empty list, list with elements. +def _tf_tensor_list_new(elements): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + element_dtype = tuple(all_dtypes)[0] + else: + # Heterogeneous lists are ok. + element_dtype = dtypes.variant + + # TODO(mdan): This may fail for elements of variable shapes. + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + element_shape = array_ops.shape(elements[0]) + else: + # Heterogeneous lists are ok. + element_shape = constant_op.constant(-1) # unknown shape, by convention + + l = list_ops.empty_tensor_list( + element_shape=element_shape, element_dtype=element_dtype) + for el in elements: + l = list_ops.tensor_list_push_back(l, el) + return l -def append(target, element): + +def _py_list_new(elements): + """Overload of new_list that creates a Python list.""" + return list(elements) + + +def list_append(list_, x): """The list append function. - Note: it is unspecified where target will be mutated or not. If target is - a TensorFlow entity, it will not be typically mutated. If target is a plain - list, it will be. In general, if the target is mutated then the return value + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value should point to the original entity. Args: - target: An entity that supports append semantics. - element: The element to append. + list_: An entity that supports append semantics. + x: The element to append. Returns: - Same as target, after the append was performed. + Same as list_, after the append was performed. + + Raises: + ValueError: if list_ is not of a known list-like type. """ - if isinstance(target, tensor_array_ops.TensorArray): - return _tf_tensorarray_append(target, element) + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(list_, x) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_append(list_, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) else: - return _py_append(target, element) + return _py_list_append(list_, x) + + +def _tf_tensor_list_append(list_, x): + """Overload of list_append that stages a Tensor list write.""" + def empty_list_of_elements_like_x(): + tensor_x = ops.convert_to_tensor(x) + return list_ops.empty_tensor_list( + element_shape=array_ops.shape(tensor_x), + element_dtype=tensor_x.dtype) + + list_ = control_flow_ops.cond( + list_ops.tensor_list_length(list_) > 0, + lambda: list_, + empty_list_of_elements_like_x, + ) + return list_ops.tensor_list_push_back(list_, x) + + +def _tf_tensorarray_append(list_, x): + """Overload of list_append that stages a TensorArray write.""" + return list_.write(list_.size(), x) + + +def _py_list_append(list_, x): + """Overload of list_append that executes a Python list append.""" + # Revert to the original call. + list_.append(x) + return list_ + + +class ListPopOpts( + collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): + pass + + +def list_pop(list_, i, opts): + """The list pop function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports pop semantics. + i: Optional index to pop from. May be None. + opts: A ListPopOpts. + + Returns: + Tuple (x, out_list_): + out_list_: same as list_, after the removal was performed. + x: the removed element value. + + Raises: + ValueError: if list_ is not of a known list-like type or the operation is + not supported for that type. + """ + assert isinstance(opts, ListPopOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + raise ValueError('TensorArray does not support item removal') + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_pop(list_, i, opts) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_pop(list_, i) + + +def _tf_tensor_list_pop(list_, i, opts): + """Overload of list_pop that stages a Tensor list pop.""" + if i is not None: + raise NotImplementedError('tensor lists only support removing from the end') + + if opts.element_dtype is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'type; use set_element_type to annotate it') + if opts.element_shape is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'shape; use set_element_type to annotate it') + list_out, x = list_ops.tensor_list_pop_back( + list_, element_dtype=opts.element_dtype) + x.set_shape(opts.element_shape) + return list_out, x + + +def _py_list_pop(list_, i): + """Overload of list_pop that executes a Python list append.""" + if i is None: + x = list_.pop() + else: + x = list_.pop(i) + return list_, x + + +# TODO(mdan): Look into reducing duplication between all these containers. +class ListStackOpts( + collections.namedtuple('ListStackOpts', + ('element_dtype', 'original_call'))): + pass + + +def list_stack(list_, opts): + """The list stack function. + + This does not have a direct correspondent in Python. The closest idiom to + this is tf.append or np.stack. It's different from those in the sense that it + accepts a Tensor list, rather than a list of tensors. It can also accept + TensorArray. When the target is anything else, the dispatcher will rely on + ctx.original_call for fallback. + + Args: + list_: An entity that supports append semantics. + opts: A ListStackOpts object. + + Returns: + The output of the stack operation, typically a Tensor. + """ + assert isinstance(opts, ListStackOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_stack(list_) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_stack(list_, opts) + else: + # No-op for primitive Tensor arguments. + return list_ + else: + return _py_list_stack(list_, opts) + + +def _tf_tensorarray_stack(list_): + """Overload of list_stack that stages a TensorArray stack.""" + return list_.stack() -def _tf_tensorarray_append(target, element): - """Overload of append that stages a TensorArray write at the last position.""" - return target.write(target.size(), element) +def _tf_tensor_list_stack(list_, opts): + """Overload of list_stack that stages a Tensor list write.""" + if opts.element_dtype is None: + raise ValueError('cannot stack a list without knowing its element type;' + ' use set_element_type to annotate it') + return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) -def _py_append(target, element): - """Overload of append that executes a Python list append.""" - target.append(element) - return target +def _py_list_stack(list_, opts): + """Overload of list_stack that executes a Python list append.""" + # Revert to the original call. + return opts.original_call(list_) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 577d28c34da39f1216669513c29a00ac07bec126..8bbb52d6c10b241ec754c7dea599fa15a869595f 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -19,25 +19,98 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test -class AppendTest(test.TestCase): +class ListTest(test.TestCase): - def test_tf_tensorarray(self): + def test_new_list_empty(self): + l = data_structures.new_list() + # Can't evaluate an empty list. + # TODO(mdan): sess.run should allow tf.variant maybe? + self.assertTrue(isinstance(l, ops.Tensor)) + + def test_new_list_tensor(self): + l = data_structures.new_list([3, 4, 5]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_append_tensor_list(self): + l = data_structures.new_list() + x = constant_op.constant([1, 2, 3]) + l = data_structures.list_append(l, x) + + t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [[1, 2, 3]]) + + def test_append_tensorarray(self): l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) - l1 = data_structures.append(l, 1) - l2 = data_structures.append(l1, 2) + l1 = data_structures.list_append(l, 1) + l2 = data_structures.list_append(l1, 2) with self.test_session() as sess: self.assertAllEqual(sess.run(l1.stack()), [1]) self.assertAllEqual(sess.run(l2.stack()), [1, 2]) - def test_python(self): + def test_append_python(self): l = [] - self.assertAllEqual(data_structures.append(l, 1), [1]) - self.assertAllEqual(data_structures.append(l, 2), [1, 2]) + self.assertAllEqual(data_structures.list_append(l, 1), [1]) + self.assertAllEqual(data_structures.list_append(l, 2), [1, 2]) + + def test_pop_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListPopOpts( + element_dtype=initial_list.dtype, + element_shape=(2,)) + + with self.assertRaises(NotImplementedError): + data_structures.list_pop(l, 0, opts) + + with self.test_session() as sess: + l, x = data_structures.list_pop(l, None, opts) + self.assertAllEqual(sess.run(x), [3, 4]) + + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[1, 2]]) + + def test_pop_python(self): + l = [1, 2, 3] + opts = data_structures.ListPopOpts(element_dtype=None, element_shape=()) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3)) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2)) + + def test_stack_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListStackOpts( + element_dtype=initial_list.dtype, original_call=None) + + with self.test_session() as sess: + t = data_structures.list_stack(l, opts) + self.assertAllEqual(sess.run(t), sess.run(initial_list)) + + def test_stack_fallback(self): + + def dummy_function(l): + # Lazy person's mock: just transform the argument in a way in which we + # can check that this function was indeed called. + return [x * 2 for x in l] + + opts = data_structures.ListStackOpts( + element_dtype=None, original_call=dummy_function) + + self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..04fbeb2f6e39234cad139442704fd7a8d0f56172 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to slicing operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +# TODO(mdan): Support extended slices. + + +class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): + pass + + +def get_item(target, i, opts): + """The slice read operator (i.e. __getitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports getitem semantics. + i: Index to read from. + opts: A GetItemOpts object. + + Returns: + The read element. + + Raises: + ValueError: if target is not of a supported type. + """ + assert isinstance(opts, GetItemOpts) + + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_get_item(target, i) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_get_item(target, i, opts) + else: + return _tf_tensor_get_item(target, i) + else: + return _py_get_item(target, i) + + +def _tf_tensorarray_get_item(target, i): + """Overload of get_item that stages a TensorArray read.""" + return target.read(i) + + +def _tf_tensor_list_get_item(target, i, opts): + """Overload of get_item that stages a Tensor list read.""" + if opts.element_dtype is None: + raise ValueError('cannot retrieve from a list without knowing its ' + 'element type; use set_element_type to annotate it') + x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) + return x + + +def _tf_tensor_get_item(target, i): + """Overload of get_item that stages a Tensor (not Tensor list) read.""" + return target[i] + + +def _py_get_item(target, i): + """Overload of get_item that executes a Python list modification.""" + return target[i] + + +def set_item(target, i, x): + """The slice write operator (i.e. __setitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports setitem semantics. + i: Index to modify. + x: The new element value. + + Returns: + Same as target, after the update was performed. + + Raises: + ValueError: if target is not of a supported type. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_set_item(target, i, x) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_set_item(target, i, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % target) + else: + return _py_set_item(target, i, x) + + +def _tf_tensorarray_set_item(target, i, x): + """Overload of set_item that stages a TensorArray write.""" + return target.write(i, x) + + +def _tf_tensor_list_set_item(target, i, x): + """Overload of set_item that stages a Tensor list update.""" + return list_ops.tensor_list_set_item(target, i, x) + + +def _py_set_item(target, i, x): + """Overload of set_item that executes a Python list modification.""" + target[i] = x + return target diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aacb9d2015fec56a8df5ad85a20b733765ba26 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SlicesTest(test.TestCase): + + def test_set_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + l = slices.set_item(l, 0, [5, 6]) + + with self.test_session() as sess: + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]]) + + def test_get_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + t = slices.get_item( + l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) + + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 796ab445c74128e1123e24b67c288e0e3c5ca24c..a49a4ed05ca99a5c9784cfc132784890e63a94de 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -22,8 +22,8 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cfg.py", "compiler.py", - "context.py", "inspect_utils.py", "parser.py", "pretty_printer.py", @@ -38,6 +38,8 @@ py_library( "@gast_archive//:gast", "@six_archive//:six", "@termcolor_archive//:termcolor", + # TODO(mdan): Remove this dependency. + "//tensorflow/python:util", ], ) @@ -62,6 +64,17 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "compiler_test", srcs = ["compiler_test.py"], @@ -130,6 +143,7 @@ py_test( name = "transformer_test", srcs = ["transformer_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index cc4a7edf02ed7556c9a552d8730e4c7875038c83..ae861627fd65cca057e7bf1af41424e605d4b7a1 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -46,8 +46,15 @@ class Basic(NoValue): '`name_map` allows renaming symbols.') -def getanno(node, key, field_name='___pyct_anno'): - return getattr(node, field_name)[key] +FAIL = object() + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or + (hasattr(node, field_name) and (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default def hasanno(node, key, field_name='___pyct_anno'): @@ -73,5 +80,9 @@ def delanno(node, key, field_name='___pyct_anno'): def copyanno(from_node, to_node, key, field_name='___pyct_anno'): - if hasanno(from_node, key, field_name): - setanno(to_node, key, getanno(from_node, key, field_name), field_name) + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index 1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d..f2c0c8cf05ca4b3671eb653ce56f6da61de54aee 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -38,12 +38,14 @@ class AnnoTest(test.TestCase): anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) - self.assertEqual(3, anno.getanno(node, 'foo')) + self.assertEqual(anno.getanno(node, 'foo'), 3) + self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + self.assertIsNone(anno.getanno(node, 'foo', default=None)) def test_copyanno(self): node_1 = ast.Name() diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..666328781f683c9457f6892c0a26088c33ba94a7 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -0,0 +1,733 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph (CFG) structure for Python AST representation. + +The CFG is a digraph with edges representing valid control flow. Each +node is associated with exactly one AST node, but not all AST nodes may have +a corresponding CFG counterpart. + +Once built, the CFG itself is immutable, but the values it holds need not be; +they are usually annotated with information extracted by walking the graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.pyct import compiler + + +class Node(object): + """A node in the CFG. + + Although new instances of this class are mutable, the objects that a user + finds in the CFG are typically not. + + The nodes represent edges in the CFG graph, and maintain pointers to allow + efficient walking in both forward and reverse order. The following property + holds for all nodes: "child in node.next" iff "node in child.prev". + + Attributes: + next: FrozenSet[Node, ...], the nodes that follow this node, in control + flow order + prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse + control flow order + ast_node: ast.AST, the AST node corresponding to this CFG node + """ + + def __init__(self, next_, prev, ast_node): + self.next = next_ + self.prev = prev + self.ast_node = ast_node + + def freeze(self): + self.next = frozenset(self.next) + self.prev = frozenset(self.prev) + + def __repr__(self): + return compiler.ast_to_source(self.ast_node).strip() + + +class Graph( + collections.namedtuple('Graph', ['entry', 'exit', 'error', 'index'])): + """A Control Flow Graph. + + The CFG maintains an index to allow looking up a CFG node by the AST node to + which it is associated. The index can also be enumerated in top-down, depth + first order. + + Walking the graph in forward or reverse order is supported by double + parent-child links. + + Note: the error nodes are not wired to their corresponding finally guards, + because these are shared, and wiring them would create a reverse path from + normal control flow into the error nodes, which we want to avoid. + + Attributes: + entry: Node, the entry node + exit: FrozenSet[Node, ...], the exit nodes + error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised + error (errors propagated from function calls are not accounted) + index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG + node + """ + + def __repr__(self): + result = 'digraph CFG {\n' + for node in self.index.values(): + result += ' %s [label="%s"];\n' % (id(node), node) + for node in self.index.values(): + if node.next: + result += ' %s -> {%s};\n' % (id(node), ', '.join( + repr(id(n)) for n in node.next)) + result += '}' + return result + + +class _WalkMode(Enum): + FORWARD = 1 + REVERSE = 2 + + +class GraphVisitor(object): + """Base class for a CFG visitors. + + This implementation is not thread safe. + + The visitor has some facilities to simplify dataflow analyses. In particular, + it allows revisiting the nodes at the decision of the subclass. This can be + used to visit the graph until the state reaches a fixed point. + + For more details on dataflow analysis, see + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Note: the literature generally suggests visiting successor nodes only when the + state of the current node changed, regardless of whether that successor has + ever been visited. This implementation visits every successor at least once. + + Attributes: + graph: Graph + in_: Dict[Node, Any], stores node-keyed state during a visit + out: Dict[Node, Any], stores node-keyed state during a visit + """ + + def reset(self): + self.in_ = { + node: self.init_state(node) for node in self.graph.index.values() + } + self.out = { + node: self.init_state(node) for node in self.graph.index.values() + } + + def init_state(self, node): + """State initialization function. Optional to overload. + + An in/out state slot will be created for each node in the graph. Subclasses + may overload this to control what that is initialized to. + + Args: + node: Node + """ + del node + return None + + def visit_node(self, node): + """Visitor function. + + Args: + node: Node + Returns: + bool, whether the node should be revisited; subclasses can visit every + reachable node exactly once by always returning False + """ + raise NotImplementedError('Subclasses must implement this.') + + def _visit_internal(self, mode): + """Visits the CFG, depth-first.""" + assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE) + if mode == _WalkMode.FORWARD: + open_ = [self.graph.entry] + elif mode == _WalkMode.REVERSE: + open_ = list(self.graph.exit) + closed = set() + self.reset() + + while open_: + node = open_.pop(0) + closed.add(node) + + should_revisit = self.visit_node(node) + + if mode == _WalkMode.FORWARD: + children = node.next + elif mode == _WalkMode.REVERSE: + children = node.prev + + for next_ in children: + if should_revisit or next_ not in closed: + open_.append(next_) + + def visit_forward(self, graph): + self.graph = graph + self._visit_internal(_WalkMode.FORWARD) + + def visit_reverse(self, graph): + self.graph = graph + self._visit_internal(_WalkMode.REVERSE) + + +class GraphBuilder(object): + """Builder that constructs a CFG from a given AST. + + This GraphBuilder facilitates constructing the DAG that forms the CFG when + nodes + are supplied in lexical order (i.e., top-down, depth first). Under these + conditions, it supports building patterns found in typical structured + programs. + + This builder ignores the flow generated by exceptions, which are assumed to + always be catastrophic and present purely for diagnostic purposes (e.g. to + print debug information). Statements like raise and try/catch sections are + allowed and will generate control flow edges, but ordinaty statements are + assumed not to raise exceptions. + + Finally sections are also correctly interleaved between break/continue/return + nodes and their subsequent statements. + + Important concepts: + * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly + * leaf set - since the graph is constructed gradually, a leaf set maintains + the CFG nodes that will precede the node that the builder expects to + receive next; when an ordinary node is added, it is connected to the + existing leaves and it in turn becomes the new leaf + * jump nodes - nodes that should generate edges other than what + ordinary nodes would; these correspond to break, continue and return + statements + * sections - logical delimiters for subgraphs that require special + edges; there are various types of nodes, each admitting various + types of jump nodes; sections are identified by their corresponding AST + node + """ + + # TODO(mdan): Perhaps detail this in a markdown doc. + # TODO(mdan): Add exception support. + + def __init__(self, parent_ast_node): + self.reset() + self.parent = parent_ast_node + + def reset(self): + """Resets the state of this factory.""" + self.head = None + self.errors = set() + self.node_index = collections.OrderedDict() + + # TODO(mdan): Too many primitives. Use classes. + self.leaves = set() + + self.finally_sections = {} + self.finally_section_subgraphs = {} # Values are [begin_node, exit_nodes] + # Whether the guard section can be reached from the statement that precedes + # it. + self.finally_section_has_direct_flow = {} + # Finally sections that await their first node. + self.pending_finally_sections = set() + + # Exit jumps keyed by the section they affect. + self.exits = {} + + # The entry of loop sections, keyed by the section. + self.section_entry = {} + # Continue jumps keyed by the section they affect. + self.continues = {} + + # The entry of conditional sections, keyed by the section. + self.cond_entry = {} + # Lists of leaf nodes corresponding to each branch in the section. + self.cond_leaves = {} + + def _connect_nodes(self, first, second): + """Connects nodes to signify that control flows from first to second. + + Args: + first: Union[Set[Node, ...], Node] + second: Node + """ + if isinstance(first, Node): + first.next.add(second) + second.prev.add(first) + else: + for node in first: + self._connect_nodes(node, second) + + def _add_new_node(self, ast_node): + """Grows the graph by adding a CFG node following the current leaves.""" + if ast_node is self.node_index: + raise ValueError('%s added twice' % ast_node) + node = Node(next_=set(), prev=set(), ast_node=ast_node) + self.node_index[ast_node] = node + + if self.head is None: + self.head = node + + for leaf in self.leaves: + self._connect_nodes(leaf, node) + + # If any finally section awaits its first node, populate it. + for section_id in self.pending_finally_sections: + self.finally_section_subgraphs[section_id][0] = node + self.pending_finally_sections = set() + + return node + + def add_ordinary_node(self, ast_node): + """Grows the graph by adding an ordinary CFG node. + + Ordinary nodes are followed by the next node, in lexical order, that is, + they become the new leaf set. + + Args: + ast_node: ast.AST + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set((node,)) + return node + + def _add_jump_node(self, ast_node, guards): + """Grows the graph by adding a jump node. + + Jump nodes are added to the current leaf set, and the leaf set becomes + empty. If the jump node is the last in a cond section, then it may be added + back to the leaf set by a separate mechanism. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections active for this node + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set() + # The guards themselves may not yet be complete, and will be wired later. + self.finally_sections[node] = guards + return node + + def _connect_jump_to_finally_sections(self, node): + """Connects a jump node to the finally sections protecting it.""" + cursor = set((node,)) + for guard_section_id in self.finally_sections[node]: + guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id] + self._connect_nodes(cursor, guard_begin) + cursor = guard_ends + del self.finally_sections[node] + # TODO(mdan): Should garbage-collect finally_section_subgraphs. + return cursor + + def add_exit_node(self, ast_node, section_id, guards): + """Grows the graph by adding an exit node. + + This node becomes an exit for the current section. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.exits[section_id].add(node) + + def add_continue_node(self, ast_node, section_id, guards): + """Grows the graph by adding a reentry node. + + This node causes control flow to go back to the loop section's entry. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.continues[section_id].add(node) + + def add_error_node(self, ast_node, guards): + """Grows the graph by adding an error node. + + This node becomes an exit for the entire graph. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.errors.add(node) + self.leaves = set() + + def enter_section(self, section_id): + """Enters a regular section. + + Regular sections admit exit jumps, which end the section. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_exit_node + """ + assert section_id not in self.exits + self.exits[section_id] = set() + + def exit_section(self, section_id): + """Exits a regular section.""" + + # Exits are jump nodes, which may be protected. + for exit_ in self.exits[section_id]: + self.leaves |= self._connect_jump_to_finally_sections(exit_) + + del self.exits[section_id] + + def enter_loop_section(self, section_id, entry_node): + """Enters a loop section. + + Loop sections define an entry node. The end of the section always flows back + to the entry node. These admit continue jump nodes which also flow to the + entry node. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_continue_node + entry_node: ast.AST, the entry node into the loop (e.g. the test node + for while loops) + """ + assert section_id not in self.section_entry + assert section_id not in self.continues + self.continues[section_id] = set() + node = self.add_ordinary_node(entry_node) + self.section_entry[section_id] = node + + def exit_loop_section(self, section_id): + """Exits a loop section.""" + self._connect_nodes(self.leaves, self.section_entry[section_id]) + + # continues are jump nodes, which may be protected. + for reentry in self.continues[section_id]: + guard_ends = self._connect_jump_to_finally_sections(reentry) + self._connect_nodes(guard_ends, self.section_entry[section_id]) + + # Loop nodes always loop back. + self.leaves = set((self.section_entry[section_id],)) + + del self.continues[section_id] + del self.section_entry[section_id] + + def enter_cond_section(self, section_id): + """Enters a conditional section. + + Conditional sections define an entry node, and one or more branches. + + Args: + section_id: Hashable, the same node that will be used in calls to the + section_id arg passed to new_cond_branch + """ + + assert section_id not in self.cond_entry + assert section_id not in self.cond_leaves + self.cond_leaves[section_id] = [] + + def new_cond_branch(self, section_id): + """Begins a new branch in a cond section.""" + assert section_id in self.cond_leaves + + if section_id in self.cond_entry: + # Subsequent splits move back to the split point, and memorize the + # current leaves. + self.cond_leaves[section_id].append(self.leaves) + self.leaves = self.cond_entry[section_id] + else: + # If this is the first time we split a section, just remember the split + # point. + self.cond_entry[section_id] = self.leaves + + def exit_cond_section(self, section_id): + """Exits a conditional section.""" + for split in self.cond_leaves[section_id]: + self.leaves |= split + del self.cond_entry[section_id] + del self.cond_leaves[section_id] + + def enter_finally_section(self, section_id): + """Enters a finally section.""" + # TODO(mdan): This, not the caller, should track the active sections. + self.finally_section_subgraphs[section_id] = [None, None] + if self.leaves: + self.finally_section_has_direct_flow[section_id] = True + else: + self.finally_section_has_direct_flow[section_id] = False + self.pending_finally_sections.add(section_id) + + def exit_finally_section(self, section_id): + """Exits a finally section.""" + assert section_id not in self.pending_finally_sections, 'Empty finally?' + self.finally_section_subgraphs[section_id][1] = self.leaves + # If the guard can only be reached by a jump, then it will not flow + # into the statement that follows it. + if not self.finally_section_has_direct_flow[section_id]: + self.leaves = set() + del self.finally_section_has_direct_flow[section_id] + + def build(self): + """Returns the CFG accumulated so far and resets the builder. + + Returns: + Graph + """ + # Freeze the nodes. + for node in self.node_index.values(): + node.freeze() + + result = Graph( + entry=self.head, + exit=self.leaves, + error=self.errors, + index=self.node_index) + + # Reset the state. + self.reset() + + return result + + +class AstToCfg(gast.NodeVisitor): + """Converts an AST to CFGs. + + A separate CFG will be constructed for each function. + """ + + # TODO(mdan): Figure out how to deal with closures. + + def __init__(self): + super(AstToCfg, self).__init__() + + self.builder_stack = [] + self.builder = None + self.cfgs = {} + + self.lexical_scopes = [] + + def _enter_lexical_scope(self, node): + self.lexical_scopes.append(node) + + def _exit_lexical_scope(self, node): + leaving_node = self.lexical_scopes.pop() + assert node == leaving_node + + def _get_enclosing_scopes(self, include, stop_at): + included = [] + for node in reversed(self.lexical_scopes): + if isinstance(node, include): + included.append(node) + if isinstance(node, stop_at): + return node, included + return None, included + + def _process_basic_statement(self, node): + self.generic_visit(node) + self.builder.add_ordinary_node(node) + + def _process_exit_statement(self, node, *exits_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(exits_nodes_of_type), + ) + if try_node is None: + raise ValueError( + '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type)) + self.builder.add_exit_node(node, try_node, guards) + + def _process_continue_statement(self, node, *loops_to_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(loops_to_nodes_of_type), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any of %s' % + (node, loops_to_nodes_of_type)) + self.builder.add_continue_node(node, try_node, guards) + + def visit_FunctionDef(self, node): + self.builder_stack.append(self.builder) + self.builder = GraphBuilder(node) + + self._enter_lexical_scope(node) + self.builder.enter_section(node) + + self._process_basic_statement(node.args) + for stmt in node.body: + self.visit(stmt) + + self.builder.exit_section(node) + self._exit_lexical_scope(node) + + self.cfgs[node] = self.builder.build() + self.builder = self.builder_stack.pop() + + def visit_Lambda(self, node): + # TODO(mdan): Treat like FunctionDef? That would be a separate CFG. + raise NotImplementedError() + + def visit_Return(self, node): + self._process_exit_statement(node, gast.FunctionDef) + + def visit_Expr(self, node): + self._process_basic_statement(node) + + def visit_Assign(self, node): + self._process_basic_statement(node) + + def visit_AnnAssign(self, node): + self._process_basic_statement(node) + + def visit_AugAssign(self, node): + self._process_basic_statement(node) + + def visit_Print(self, node): + self._process_basic_statement(node) + + def visit_Raise(self, node): + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=(gast.FunctionDef,), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any FunctionDef' % node) + self.builder.add_error_node(node, try_node, guards) + + def visit_Assert(self, node): + # Ignoring the effect of exceptions. + self._process_basic_statement(node) + + def visit_Delete(self, node): + self._process_basic_statement(node) + + def visit_If(self, node): + # No need to track ifs as lexical scopes, for now. + # Lexical scopes are generally tracked in order to be able to resolve the + # targets of jump statements like break/continue/etc. Since there is no + # statement that can interrupt a conditional, we don't need to track their + # lexical scope. That may change in the future. + + self.builder.enter_cond_section(node) + self._process_basic_statement(node.test) + + self.builder.new_cond_branch(node) + for stmt in node.body: + self.visit(stmt) + + self.builder.new_cond_branch(node) + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_cond_section(node) + + def visit_While(self, node): + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + self.builder.enter_loop_section(node, node.test) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # the statements inside it don't affect the loop itself. For example, a + # break in the loop's orelse will not affect the loop itself. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + + def visit_For(self, node): + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + # TODO(mdan): Strictly speaking, this should be node.target + node.iter. + # A blind dataflow analysis would have to process both node.target and + # node.iter to properly process read and write access. + self.builder.enter_loop_section(node, node.iter) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # they don't count as loop bodies. For example, a break in the loop's + # orelse will affect the parent loop, not the current one. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + + def visit_Break(self, node): + self._process_exit_statement(node, gast.While, gast.For) + + def visit_Continue(self, node): + self._process_continue_statement(node, gast.While, gast.For) + + def visit_Try(self, node): + self._enter_lexical_scope(node) + + for stmt in node.body: + self.visit(stmt) + # Unlike loops, the orelse is a simple continuation of the body. + for stmt in node.orelse: + self.visit(stmt) + + if node.handlers: + # TODO(mdan): Should we still support bare try/except? Might be confusing. + raise NotImplementedError('exceptions are not yet supported') + + self._exit_lexical_scope(node) + + self.builder.enter_finally_section(node) + for stmt in node.finalbody: + self.visit(stmt) + self.builder.exit_finally_section(node) + + def visit_With(self, node): + # TODO(mdan): Mark the context manager's exit call as exit guard. + self._process_basic_statement(node.items) + for stmt in node.body: + self.visit(stmt) + + +def build(node): + builder = AstToCfg() + builder.visit(node) + return builder.cfgs diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..00afadd5212a3aba8f25cd9a6f111d292635bbce --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg_test.py @@ -0,0 +1,790 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class CountingVisitor(cfg.GraphVisitor): + + def __init__(self): + self.counts = {} + + def visit_node(self, node): + self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1 + return False # visit only once + + +class GraphVisitorTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs, node + + def test_basic_coverage_forward(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor() + visitor.visit_forward(graph) + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + # The return node should be unreachable in forward direction. + self.assertTrue(fn_node.body[0].body[2] not in visitor.counts) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + def test_basic_coverage_reverse(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor() + visitor.visit_reverse(graph) + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + +class AstToCfgTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs + + def _repr_set(self, node_set): + return set(repr(n) for n in node_set) + + def _as_set(self, elements): + if elements is None: + return frozenset() + elif isinstance(elements, str): + return frozenset((elements,)) + else: + return frozenset(elements) + + def assertGraphMatches(self, graph, edges): + """Tests whether the CFG contains the specified edges.""" + for prev, node_repr, next_ in edges: + matched = False + for cfg_node in graph.index.values(): + if repr(cfg_node) == node_repr: + if (self._as_set(prev) == set(map(repr, cfg_node.prev)) and + self._as_set(next_) == set(map(repr, cfg_node.next))): + matched = True + break + if not matched: + self.fail( + 'match failed for node "%s" in graph:\n%s' % (node_repr, graph)) + + def test_straightline(self): + + def test_fn(a): + a += 1 + a = 2 + a = 3 + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'a += 1'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', 'return'), + ('a = 3', 'return', None), + ), + ) + + def test_straightline_no_return(self): + + def test_fn(a, b): + a = b + 1 + a += max(a) + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a, b', 'a = b + 1'), + ('a = b + 1', 'a += max(a)', None), + ), + ) + + def test_unreachable_code(self): + + def test_fn(a): + return + a += 1 # pylint:disable=unreachable + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'return'), + ('a', 'return', None), + (None, 'a += 1', None), + ), + ) + + def test_branch_straightline(self): + + def test_fn(a): + if a > 0: + a = 1 + else: + a += -1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('(a > 0)', 'a = 1', None), + ('(a > 0)', 'a += -1', None), + ), + ) + + def test_branch_nested(self): + + def test_fn(a): + if a > 0: + if a > 1: + a = 1 + else: + a = 2 + else: + if a > 2: + a = 3 + else: + a = 4 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', ('(a > 1)', '(a > 2)')), + ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', None), + ('(a > 1)', 'a = 2', None), + ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')), + ('(a > 2)', 'a = 3', None), + ('(a > 2)', 'a = 4', None), + ), + ) + + def test_branch_straightline_semi(self): + + def test_fn(a): + if a > 0: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', 'a = 1'), + ('(a > 0)', 'a = 1', None), + ), + ) + + def test_branch_return(self): + + def test_fn(a): + if a > 0: + return + else: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', ('return', 'a = 1')), + ('(a > 0)', 'a = 1', 'a = 2'), + ('(a > 0)', 'return', None), + ('a = 1', 'a = 2', None), + ), + ) + + def test_branch_return_minimal(self): + + def test_fn(a): + if a > 0: + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'return'), + ('(a > 0)', 'return', None), + ), + ) + + def test_while_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', None), + ), + ) + + def test_while_else_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_else_continue(self): + + def test_fn(a): + while a > 0: + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', '(a > 0)'), + ('a = 0', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_else_break(self): + + def test_fn(a): + while a > 0: + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + + def test_while_else_return(self): + + def test_fn(a): + while a > 0: + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_nested_straightline(self): + + def test_fn(a): + while a > 0: + while a > 1: + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_while_nested_continue(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 3: + continue + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')), + ('(a > 1)', '(a > 3)', ('continue', 'a = 1')), + ('(a > 3)', 'continue', '(a > 1)'), + ('(a > 3)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_while_nested_break(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 2: + break + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')), + ('(a > 1)', '(a > 2)', ('break', 'a = 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'a = 1', '(a > 1)'), + (('(a > 1)', 'break'), 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_for_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', None), + ), + ) + + def test_for_else_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_else_continue(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', 'range(0, a)'), + ('(a > 1)', 'a = 0', 'a = 1'), + ('a = 0', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_else_break(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + + def test_for_else_return(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_nested_straightline(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')), + ('range(1, a)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_for_nested_continue(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 3: + continue + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)', + ('(a > 3)', 'a = 2')), + ('range(1, a)', '(a > 3)', ('continue', 'b += 1')), + ('(a > 3)', 'continue', 'range(1, a)'), + ('(a > 3)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_for_nested_break(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 2: + break + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')), + ('range(1, a)', '(a > 2)', ('break', 'b += 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'b += 1', 'range(1, a)'), + (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_complex(self): + + def test_fn(a): + b = 0 + while a > 0: + for b in range(0, a): + if a > 2: + break + if a > 3: + if a > 4: + continue + else: + max(a) + break + b += 1 + else: # for b in range(0, a): + return a + a = 2 + for a in range(1, a): + return b + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')), + ( + ('(a > 0)', 'continue', 'b += 1'), + 'range(0, a)', + ('(a > 2)', 'return a'), + ), + ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')), + ('(a > 3)', '(a > 4)', ('continue', 'max(a)')), + ('(a > 4)', 'max(a)', 'break'), + ('max(a)', 'break', 'a = 2'), + ('(a > 4)', 'continue', 'range(0, a)'), + ('(a > 3)', 'b += 1', 'range(0, a)'), + ('range(0, a)', 'return a', None), + ('break', 'a = 2', '(a > 0)'), + ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')), + ('range(1, a)', 'return b', None), + ('range(1, a)', 'a = 3', None), + ), + ) + + def test_finally_straightline(self): + + def test_fn(a): + try: + a += 1 + finally: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'a += 1', 'a = 2'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_return_finally(self): + + def test_fn(a): + try: + return a + finally: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'return a', 'a = 1'), + ('return a', 'a = 1', None), + (None, 'a = 2', None), + ), + ) + + def test_break_finally(self): + + def test_fn(a): + while a > 0: + try: + break + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'break'), + ('(a > 0)', 'break', 'a = 1'), + ('break', 'a = 1', None), + ), + ) + + def test_continue_finally(self): + + def test_fn(a): + while a > 0: + try: + continue + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', 'continue'), + ('(a > 0)', 'continue', 'a = 1'), + ('continue', 'a = 1', '(a > 0)'), + ), + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ca1441cf6f8bb034c95b37fcdd9e8158d1db2e39 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -0,0 +1,38 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "common_transformers", + srcs = [ + "anf.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "@gast_archive//:gast", + ], +) + +py_test( + name = "anf_test", + srcs = ["anf_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_transformers", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py new file mode 100644 index 0000000000000000000000000000000000000000..cc039986c219db1febfe610a5078e26eeb2d5a83 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion to A-normal form.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import transformer + + +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + +class AnfTransformer(transformer.Base): + """Performs the actual conversion.""" + + # TODO(mdan): Link to a reference. + # TODO(mdan): Implement. + + def __init__(self, entity_info): + """Creates a transformer. + + Args: + entity_info: transformer.EntityInfo + """ + super(AnfTransformer, self).__init__(entity_info) + self._gensym = DummyGensym(entity_info) + + +def transform(node, entity_info): + return AnfTransformer(entity_info).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py new file mode 100644 index 0000000000000000000000000000000000000000..81983a5ecb7b8c6216285409f854e27b7154a08b --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for anf module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.common_transformers import anf +from tensorflow.python.platform import test + + +class AnfTransformerTest(test.TestCase): + + def _simple_source_info(self): + return transformer.EntityInfo( + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None) + + def test_basic(self): + + def test_function(): + a = 0 + return a + + node, _ = parser.parse_entity(test_function) + node = anf.transform(node, self._simple_source_info()) + result, _ = compiler.ast_to_object(node) + + self.assertEqual(test_function(), result.test_function()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py deleted file mode 100644 index b34015cfd2888f0dbeb6492b9e7335d561bf4763..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/context.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Conversion context containers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -class EntityContext(object): - """Contains information about an entity, like source code. - - In general, objects of this class should be considered immutable. - - Attributes: - namer: Namer that matches the contract of all converters. - source_code: The entity's source code. - source_file: The entity's source file. - namespace: Dict[str->*], containing symbols visible to the entity - (excluding parameters). - arg_values: Dict[str->*], containing parameter values, if known. - arg_types: Dict[str->*], containing parameter types, if known. - owner_type: The surrounding class type of the function, if present. - """ - - # TODO(mdan): Remove the default and update tests. - def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, owner_type, recursive, type_annotation_func=None): - self.namer = namer - self.source_code = source_code - self.source_file = source_file - self.namespace = namespace - self.arg_values = {} if arg_values is None else arg_values - self.arg_types = {} if arg_types is None else arg_types - self.owner_type = owner_type - self.recursive = recursive - self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index 583cf7ecd7bce31c55de58361ab5295abb5d6707..da07013cf4f4309c0e24adda3017575d942861b7 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -205,6 +205,7 @@ class QnResolver(gast.NodeTransformer): return node def visit_Subscript(self, node): + # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): @@ -216,7 +217,11 @@ class QnResolver(gast.NodeTransformer): elif isinstance(s.value, gast.Str): subscript = QN(StringLiteral(s.value.s)) else: - subscript = anno.getanno(node.slice.value, anno.Basic.QN) + # The index may be an expression, case in which a name doesn't make sense. + if anno.hasanno(node.slice.value, anno.Basic.QN): + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + else: + return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 8064a967cd389e88d3febbeb21cac87b0fef9e18..bcf2dacec2062704805f1d72ec27a243159d13c1 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -27,6 +27,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", ], ) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index fdbd349af9d3325af114a7206d89617134278f14..bc22be0a270bbc9c361aea6d6d9c255ea51796e8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -112,18 +112,16 @@ class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - return node, ctx + node = activity.resolve(node, entity_info) + return node, entity_info def test_local_markers(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py index ad97fdfa8e78d1fd4c38724612d83519c6609cce..4acc4ed66a62b0ccd407d39b1abda00c4c88a9a1 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -276,9 +276,9 @@ class Forward(object): taken). """ - def __init__(self, label, context, transfer_fn=operator.or_): + def __init__(self, label, source_info, transfer_fn=operator.or_): self.transfer_fn = transfer_fn - self.context = context + self.source_info = source_info self.out_label = label + '_out' self.in_label = label + '_in' self.gen_label = label + '_gen' @@ -286,7 +286,7 @@ class Forward(object): # TODO(alexbw): see if we can simplify by visiting breadth-first def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagtion.""" + """Depth-first walking the CFG, applying dataflow info propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return @@ -399,18 +399,18 @@ class Liveness(Backward): later in the program. """ - def __init__(self, context): - super(Liveness, self).__init__('live', context) + def __init__(self, source_info): + super(Liveness, self).__init__('live', source_info) def get_gen_kill(self, node, _): # A variable's parents are live if it is live # e.g. x is live if x.y is live. This means gen needs to return # all parents of a variable (if it's an Attribute or Subscript). # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) - gen = activity.get_read(node.value, self.context) + gen = activity.get_read(node.value, self.source_info) gen = functools.reduce(lambda left, right: left | right.support_set, gen, gen) - kill = activity.get_updated(node.value, self.context) + kill = activity.get_updated(node.value, self.source_info) return gen, kill @@ -420,11 +420,11 @@ class ReachingDefinitions(Forward): Each statement is annotated with a set of (variable, definition) pairs. """ - def __init__(self, context): - super(ReachingDefinitions, self).__init__('definitions', context) + def __init__(self, source_info): + super(ReachingDefinitions, self).__init__('definitions', source_info) def get_gen_kill(self, node, incoming): - definitions = activity.get_updated(node.value, self.context) + definitions = activity.get_updated(node.value, self.source_info) gen = frozenset((id_, node.value) for id_ in definitions) kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) return gen, kill @@ -437,9 +437,10 @@ class Defined(Forward): be defined at that point. """ - def __init__(self, context): - super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + def __init__(self, source_info): + super(Defined, self).__init__( + 'defined', source_info, transfer_fn=operator.and_) def get_gen_kill(self, node, _): - gen = activity.get_updated(node.value, self.context) + gen = activity.get_updated(node.value, self.source_info) return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py index 8d723ce09d689cce0bf9e907633fe004dc0b92b0..428ebbedca85f9b94b4b1db0f3b36a334126196b 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -23,29 +23,26 @@ import functools import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.python.platform import test class CFGTest(test.TestCase): - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): - arg_types = arg_types or {} + def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, - namespace=namespace, + namespace={}, arg_values=None, - arg_types=arg_types, - owner_type=None, - recursive=True) + arg_types=None, + owner_type=None) node = qual_names.resolve(node) - return node, ctx + return node, entity_info def _check_anno_matches(self, node, anno_name, var_names): if isinstance(var_names, str): @@ -73,7 +70,7 @@ class CFGTest(test.TestCase): x = x return x - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body # Only the argument reaches the expression @@ -106,7 +103,7 @@ class CFGTest(test.TestCase): y = 2 # pylint: disable=unused-variable return x - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body # only x is for sure defined at the end @@ -115,20 +112,27 @@ class CFGTest(test.TestCase): if_body = body[0].body self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) - # TODO(alexbw): b/73926938 split this test up - def test_live(self): + def _get_live_annotated_fnbody(self, f): + node, ctx = self._parse_and_analyze(f) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body - def get_live_annotated_fnbody(f): - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.Liveness(ctx)) - body = node.body[0].body - return body + def test_live_straightline(self): def f1(x): a = g(x) # pylint: disable=undefined-variable b = h(a) # pylint: disable=undefined-variable, unused-variable return x + body = self._get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + def test_live_stacked_conds_with_else(self): + def f2(x, a): # pylint: disable=unused-argument if a > 0: # x should not be live x = 0 @@ -137,6 +141,12 @@ class CFGTest(test.TestCase): else: x = 2 + body = self._get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + def test_live_stacked_conds(self): + def f3(x, a): if a > 0: # x and a should be live x = 0 @@ -144,58 +154,58 @@ class CFGTest(test.TestCase): x = 1 return x # x should be live + body = self._get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + def test_live_possibly_unused_cond(self): + def f4(x, a): if a > 0: # x should be live x = 0 x += 1 + body = self._get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_attribute_in_cond(self): + def f5(x, a): if a > 0: # x.y should be live x.y = 0 return x.y + body = self._get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + def test_live_noop(self): + def f6(x): return x # should this cause x.* to be live? + body = self._get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + def test_live_loop(self): + def f7(x, n): for i in range(n): x += i return x - def f8(x, f): - with f: - x += 1 - - body = get_live_annotated_fnbody(f1) - self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) - self._check_anno_matches(body[2], 'live_out', ()) - - body = get_live_annotated_fnbody(f2) - self._check_anno_matches(body[0], 'live_in', ('a')) - self._check_anno_matches(body[1], 'live_in', ('a')) - - body = get_live_annotated_fnbody(f3) - self._check_anno_matches(body[0], 'live_in', ('a', 'x')) - self._check_anno_matches(body[1], 'live_in', ('a', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - - body = get_live_annotated_fnbody(f4) - self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + body = self._get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) self._check_anno_matches(body[1], 'live_in', ('x')) - body = get_live_annotated_fnbody(f5) - self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + def test_live_context_manager(self): - body = get_live_annotated_fnbody(f6) - self._check_anno_matches(body[0], 'live_in', ('x')) - - body = get_live_annotated_fnbody(f7) - self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) - self._check_anno_matches(body[1], 'live_in', ('x')) + def f8(x, f): + with f: + x += 1 - body = get_live_annotated_fnbody(f8) + body = self._get_live_annotated_fnbody(f8) self._check_anno_matches(body[0], 'live_in', ('f', 'x')) def test_node_equality(self): @@ -213,7 +223,7 @@ class CFGTest(test.TestCase): return g(x) - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body @@ -240,7 +250,7 @@ class CFGTest(test.TestCase): return g() # y is not defined here - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body self.assertEqual( @@ -269,7 +279,7 @@ class CFGTest(test.TestCase): return x, y for f in (for_orelse, while_orelse): - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body return_node = body[-1] diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 53ae15459097baff918432a493edd7360ebf209d..9ccb98f79adbe5410a7554548ee75ab95345962d 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -39,7 +39,7 @@ class LiveValueResolver(transformer.Base): def visit_ClassDef(self, node): self.generic_visit(node) - anno.setanno(node, 'live_val', self.context.namespace[node.name]) + anno.setanno(node, 'live_val', self.entity_info.namespace[node.name]) return node def visit_Name(self, node): @@ -55,8 +55,8 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - elif node.id in self.context.namespace: - obj = self.context.namespace[node.id] + elif node.id in self.entity_info.namespace: + obj = self.entity_info.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): anno.setanno(node, 'fqn', (obj.__name__,)) @@ -80,8 +80,8 @@ class LiveValueResolver(transformer.Base): # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: - if node.id in self.context.arg_values: - obj = self.context.arg_values[node.id] + if node.id in self.entity_info.arg_values: + obj = self.entity_info.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 69e428bde109ed43c3cdda1a94970a832dc47852..38af79277779f77ffe31c2f6e26ae88f3e1a7ae9 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import six from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info @@ -39,22 +39,19 @@ class LiveValuesResolverTest(test.TestCase): literals=None, arg_types=None): literals = literals or {} - arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) return node def test_literals(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index c00946f9c41bc68d5c638d71f356b484db1286d1..a229c288a83e516fc02f3af8df2046c5365e569c 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -17,8 +17,8 @@ This analyzer uses known live values to further infer object types. This may include for instance constructed objects and object member functions. -In addition, the analyzer will also process annotations for TF (staged) type -annotations. +In addition, the analyzer also handles user annotations made in the code (for +example, the autograph.set_element_type function). Requires annotations generated by LiveValuesResolver. """ @@ -43,7 +43,9 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -51,6 +53,7 @@ from tensorflow.python.util import tf_inspect # TODO(mdan): Remove the duplication between this and activity.py. # In particular, the symbol definitions we track here could as well be tracked # there because they follow the same rules for visibility. +# TODO(mdan): Use a CFG based Defined analysis instead. class Scope(object): """Tracks symbol value references. @@ -134,109 +137,100 @@ class TypeInfoResolver(transformer.Base): node.orelse = self._visit_block(node.orelse) return node - def _process_function_arg(self, arg_name): - str_name = str(arg_name) - if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: + def _process_function_arg(self, arg_node): + qn = anno.getanno(arg_node, anno.Basic.QN) + arg_name = str(qn) + self.scope.setval(qn, arg_node) + if (len(self.enclosing_entities) == 1 and + arg_name in self.entity_info.arg_types): # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_holder = arg_name.ast() - type_string, type_obj = self.context.arg_types[str_name] - anno.setanno(type_holder, 'type', type_obj) - anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(arg_name, type_holder) + type_string, type_obj = self.entity_info.arg_types[arg_name] + anno.setanno(arg_node, 'type', type_obj) + anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.'))) def visit_arg(self, node): - self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + self._process_function_arg(node.arg) return node def visit_Name(self, node): self.generic_visit(node) - qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): - self._process_function_arg(qn) - elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): - # E.g. if we had - # a = b - # then for future references to `a` we should have definition = `b` - definition = self.scope.getval(qn) - if anno.hasanno(definition, 'type'): - anno.setanno(node, 'type', anno.getanno(definition, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) - if anno.hasanno(definition, 'element_type'): - anno.setanno(node, 'element_type', - anno.getanno(definition, 'element_type')) + self._process_function_arg(node) + elif isinstance(node.ctx, gast.Load): + qn = anno.getanno(node, anno.Basic.QN) + if self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.setanno(node, 'definition', definition) + + # TODO(mdan): Remove this when the directives module is in. + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node - def _process_variable_assignment(self, source, targets): - # Special case: constructors. - if isinstance(source, gast.Call): - func = source.func + def _process_variable_assignment(self, target, value): + # Constructors + if isinstance(value, gast.Call): + func = value.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): - anno.setanno(source, 'is_constructor', True) - anno.setanno(source, 'type', func_obj) - anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) + anno.setanno(value, 'is_constructor', True) + anno.setanno(value, 'type', func_obj) + anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - # Multiple targets mean multiple assignment. - for target in targets: - # Tuple target means unpacking. - if isinstance(target, (gast.Tuple, gast.List)): - for i, target_item in enumerate(target.elts): - # Two cases here: - # 1. Static unpacking, e.g. a, b = c, d - # 2. Dynamic unpacking, e.g. a, b = c - # The former case is optimized away. - if isinstance(source, (gast.Tuple, gast.List)): - source_item = source.elts[i] - else: - source_item = gast.Subscript(source, gast.Index(i), ctx=None) - self._process_variable_assignment(source_item, (target_item,)) - elif isinstance(target, (gast.Name, gast.Attribute)): - target_symbol = anno.getanno(target, anno.Basic.QN) - self.scope.setval(target_symbol, source) - else: - raise ValueError('assignment target has unknown type: %s' % target) + if isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, value) + elif isinstance(target, gast.Subscript): + pass + else: + raise ValueError('assignment target has unknown type: %s' % target) def visit_With(self, node): - for wi in node.items: - if wi.optional_vars is not None: - self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) + for item in node.items: + if item.optional_vars is not None: + self.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self._process_variable_assignment(node.value, node.targets) + self.apply_to_single_assignments( + node.targets, node.value, self._process_variable_assignment) return node + # TODO(mdan): Remove as soon as the new directives module is ready. def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. - if (anno.getanno(node.func, 'live_val') is - self.context.type_annotation_func): + if anno.getanno(node.func, 'live_val') is utils.set_element_type: - if len(node.args) != 2: - raise ValueError('"%s" must have exactly two parameters' + if len(node.args) < 2 or len(node.args) > 3: + raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) - target_arg, type_arg = node.args - if not anno.hasanno(target_arg, anno.Basic.QN): - raise ValueError('the first argument of "%s" must by a symbol' - % self.context.type_annotation_func) - if isinstance(type_arg, gast.Str): - element_type = type_arg.s - elif isinstance(type_arg, gast.Num): - element_type = type_arg.n + if len(node.args) == 2: + target_arg, type_arg = node.args + shape_arg = parser.parse_expression('None') else: - if not anno.hasanno(type_arg, 'live_val'): - raise ValueError( - 'the second argument of "%s" must be statically resolvable' % - self.context.type_annotation_func) - element_type = anno.getanno(type_arg, 'live_val') + target_arg, type_arg, shape_arg = node.args + if not anno.hasanno(target_arg, anno.Basic.QN): + raise ValueError('the first argument of "%s" must by a symbol' % + utils.set_element_type) + # TODO(mdan): This is vulnerable to symbol renaming. + element_type = type_arg + element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given @@ -244,7 +238,9 @@ class TypeInfoResolver(transformer.Base): # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) + anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) + anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 46b7701624a43073fb7cc612d2678ab851513d91..32b1148ab21809514bc09a31e26f0219017bd088 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info @@ -62,21 +61,18 @@ class TypeInfoResolverTest(test.TestCase): namespace, arg_types=None): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True, - type_annotation_func=utils.set_element_type) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) return node def test_constructor_detection(self): @@ -147,7 +143,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze( - test_fn, {'training': training}, + test_fn, {}, arg_types={ 'opt': (training.GradientDescentOptimizer.__name__, training.GradientDescentOptimizer) @@ -180,22 +176,6 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) - def test_type_annotation(self): - - class Foo(object): - pass - - def test_fn(): - f = [] - f = utils.set_element_type(f, Foo) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) - def test_nested_unpacking(self): class Foo(object): @@ -210,32 +190,13 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) a, b, c = node.body[0].body[1].value.elts - self.assertEquals(Foo, anno.getanno(a, 'type')) - self.assertEquals(Bar, anno.getanno(b, 'type')) - self.assertEquals(Foo, anno.getanno(c, 'type')) + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) - def test_inner_scope(self): - - def test_fn(): - a = [] - utils.set_element_type(a, 1) - for _ in a: - b = [] - utils.set_element_type(b, 2) - return a, b - - node = self._parse_and_analyze(test_fn, {'utils': utils}) - a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(1, anno.getanno(a, 'element_type')) - self.assertEquals(2, anno.getanno(b, 'element_type')) - self.assertFalse(anno.hasanno(a, 'type')) - self.assertFalse(anno.hasanno(b, 'type')) - self.assertFalse(anno.hasanno(a, 'live_val')) - self.assertFalse(anno.hasanno(b, 'live_val')) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index baf7923fff7c786c1abd05e11fa6ffdb8c8f0912..9c479ebc2fa83d27dc363ae306daedb556734a1f 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -239,8 +239,13 @@ def replace_as_expression(template, **replacements): raise ValueError( 'single expression expected; for more general templates use replace') node = replacement[0] - if not isinstance(node, gast.Expr): - raise ValueError( - 'the template is expected to generate an expression node; instead ' - 'found %s' % node) - return node.value + node = qual_names.resolve(node) + + if isinstance(node, gast.Expr): + return node.value + elif isinstance(node, gast.Name): + return node + + raise ValueError( + 'the template is expected to generate an expression or a name node;' + ' instead found %s' % node) diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 4db6cc0adfad90ffc1a6bbcadfc80215688d271e..76558118308c31a2c1a770cad814e96abd6a6063 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -32,15 +32,40 @@ class AutographParseError(SyntaxError): pass -def try_ast_to_source(node): - try: - return compiler.ast_to_source(node) - except AssertionError: - return '' +# TODO(mdan): Use namedtuple. +class EntityInfo(object): + """Contains information about a Python entity. Immutable. + + Examples of entities include functions and classes. + + Attributes: + source_code: The entity's source code. + source_file: The entity's source file. + namespace: Dict[str, ], containing symbols visible to the entity + (excluding parameters). + arg_values: dict[str->*], containing parameter values, if known. + arg_types: dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. + """ + + # TODO(mdan): Remove the default and update tests. + def __init__(self, source_code, source_file, namespace, arg_values, arg_types, + owner_type): + self.source_code = source_code + self.source_file = source_file + self.namespace = namespace + self.arg_values = {} if arg_values is None else arg_values + self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type class Base(gast.NodeTransformer): - """Base class for specialized transformers. + """Base class for general-purpose code transformers transformers. + + This is an extension of ast.NodeTransformer that provides a few additional + functions, like state tracking within the scope of arbitrary node, helpers + for processing code blocks, debugging, mapping of transformed code to + original code, and others. Scope-local state tracking: to keep state across nodes, at the level of (possibly nested) scopes, use enter/exit_local_scope and set/get_local. @@ -48,15 +73,17 @@ class Base(gast.NodeTransformer): when they are not properly paired. """ - def __init__(self, context): + # TODO(mdan): Document all extra features. + + def __init__(self, entity_info): """Initialize the transformer. Subclasses should call this. Args: - context: An EntityContext. + entity_info: An EntityInfo object. """ self._lineno = 0 self._col_offset = 0 - self.context = context + self.entity_info = entity_info self._enclosing_entities = [] # A stack that allows keeping mutable, scope-local state where scopes may be @@ -70,14 +97,40 @@ class Base(gast.NodeTransformer): return tuple(self._enclosing_entities) @property - def locel_scope_level(self): + def local_scope_level(self): return len(self._local_scope_state) - def enter_local_scope(self): - self._local_scope_state.append({}) + def enter_local_scope(self, inherit=None): + """Marks entry into a new local scope. + + Args: + inherit: Optional enumerable of variable names to copy from the + parent scope. + """ + scope_entered = {} + if inherit: + this_scope = self._local_scope_state[-1] + for name in inherit: + if name in this_scope: + scope_entered[name] = this_scope[name] + self._local_scope_state.append(scope_entered) + + def exit_local_scope(self, keep=None): + """Marks exit from the current local scope. - def exit_local_scope(self): - return self._local_scope_state.pop() + Args: + keep: Optional enumerable of variable names to copy into the + parent scope. + Returns: + A dict containing the scope that has just been exited. + """ + scope_left = self._local_scope_state.pop() + if keep: + this_scope = self._local_scope_state[-1] + for name in keep: + if name in scope_left: + this_scope[name] = scope_left[name] + return scope_left def set_local(self, name, value): self._local_scope_state[-1][name] = value @@ -91,57 +144,181 @@ class Base(gast.NodeTransformer): print(pretty_printer.fmt(node)) return node - def visit_block(self, nodes): - """Helper equivalent to generic_visit, but for node lists.""" + def visit_block(self, nodes, before_visit=None, after_visit=None): + """A more powerful version of generic_visit for statement blocks. + + An example of a block is the body of an if statement. + + This function allows specifying a postprocessing callback (the + after_visit argument) argument which can be used to move nodes to a new + destination. This is done by after_visit by returning a non-null + second return value, e.g. return new_node, new_destination. + + For example, a transformer could perform the following move: + + foo() + bar() + baz() + + foo() + if cond: + bar() + baz() + + The above could be done with a postprocessor of this kind: + + def after_visit(node): + if node_is_function_call(bar): + new_container_node = build_cond() + new_container_node.body.append(node) + return new_container_node, new_container_node.body + else: + # Once we set a new destination, all subsequent items will be + # moved to it, so we don't need to explicitly handle baz. + return node, None + + Args: + nodes: enumerable of AST node objects + before_visit: optional callable that is called before visiting each item + in nodes + after_visit: optional callable that takes in an AST node and + returns a tuple (new_node, new_destination). It is called after + visiting each item in nodes. Is used in the same was as the + visit_* methods: new_node will replace the node; if not None, + new_destination must be a list, and subsequent nodes will be placed + in this list instead of the list returned by visit_block. + Returns: + A list of AST node objects containing the transformed items fron nodes, + except those nodes that have been relocated using after_visit. + """ results = [] + node_destination = results for node in nodes: + if before_visit: + # TODO(mdan): We can modify node here too, if ever needed. + before_visit() + replacement = self.visit(node) + + if after_visit and replacement: + replacement, new_destination = after_visit(replacement) + else: + new_destination = None + if replacement: if isinstance(replacement, (list, tuple)): - results.extend(replacement) + node_destination.extend(replacement) else: - results.append(replacement) + node_destination.append(replacement) + + # Allow the postprocessor to reroute the remaining nodes to a new list. + if new_destination is not None: + node_destination = new_destination return results + # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + def apply_to_single_assignments(self, targets, values, apply_fn): + """Applies a function to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: list, tuple of or individual AST node. Should be used with the + targets field of an ast.Assign node. + values: an AST node. + apply_fn: a function of a single argument, which will be called with the + respective nodes of each single assignment. The signature is + apply_fn(target, value), no return value. + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) + self.apply_to_single_assignments(target_el, value_el, apply_fn) + else: + # TODO(mdan): Look into allowing to rewrite the AST here. + apply_fn(target, values) + + def _get_source(self, node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + def visit(self, node): - source_code = self.context.source_code - source_file = self.context.source_file + source_code = self.entity_info.source_code + source_file = self.entity_info.source_file did_enter_function = False - local_scope_state_size = len(self._local_scope_state) + local_scope_size_at_entry = len(self._local_scope_state) try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): - self._enclosing_entities.append(node) did_enter_function = True + if did_enter_function: + self._enclosing_entities.append(node) + if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - return node - return super(Base, self).visit(node) - except (ValueError, AttributeError, KeyError, NotImplementedError, - AssertionError) as e: + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + result = super(Base, self).visit(node) + + # On exception, the local scope integrity is not guaranteed. + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_size_at_entry != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.' % ( + node, + local_scope_size_at_entry, + len(self._local_scope_state) + )) + return result + + except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( - e.__class__.__name__, str(e), try_ast_to_source(node), + e.__class__.__name__, str(e), self._get_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] else: line = '' + # TODO(mdan): Avoid the printing of the original exception. + # In other words, we need to find how to suppress the "During handling + # of the above exception, another exception occurred" message. six.reraise(AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) - finally: - if did_enter_function: - self._enclosing_entities.pop() - - if local_scope_state_size != len(self._local_scope_state): - raise AssertionError( - 'Inconsistent local scope stack. Before entering node %s, the' - ' stack had length %d, after exit it has length %d. This' - ' indicates enter_local_scope and exit_local_scope are not' - ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f96b0dc377521a482d347436caa98633a0a32c8a..baf04653ae862b0159fb50a1c67fa675ceb74b9a 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -18,8 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -27,16 +28,14 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_nodetesting(self): - return context.EntityContext( - namer=None, + def _simple_source_info(self): + return transformer.EntityInfo( source_code=None, source_file=None, namespace=None, arg_values=None, arg_types=None, - owner_type=None, - recursive=False) + owner_type=None) def test_entity_scope_tracking(self): @@ -53,7 +52,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 @@ -94,7 +93,7 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) - def test_statement_info_stack(self): + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -116,7 +115,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._simple_source_info()) def test_function(a): """Docstring.""" @@ -142,7 +141,7 @@ class TransformerTest(test.TestCase): self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test')) - def test_statement_info_stack_checks_integrity(self): + def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): @@ -155,7 +154,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._simple_source_info()) def no_exit(a): if a > 0: @@ -174,6 +173,38 @@ class TransformerTest(test.TestCase): with self.assertRaises(AssertionError): tr.visit(node) + def test_visit_block_postprocessing(self): + + class TestTransformer(transformer.Base): + + def _process_body_item(self, node): + if isinstance(node, gast.Assign) and (node.value.id == 'y'): + if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) + return if_node, if_node.body + return node, None + + def visit_FunctionDef(self, node): + node.body = self.visit_block( + node.body, after_visit=self._process_body_item) + return node + + def test_function(x, y): + z = x + z = y + return z + + tr = TestTransformer(self._simple_source_info()) + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + node = node.body[0] + + self.assertEqual(len(node.body), 2) + self.assertTrue(isinstance(node.body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1], gast.If)) + self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d3a1b9468892531cbc51bc13de66ef595f1a95f8..d82c17bf2afd01aedf4344f983b02c09abcb9bad 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -33,6 +33,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:dtypes", "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 211e8eaee9082dd3e4f035e4379871cd2e154a39..998087e056c2cd264399982220d6e0528aab9edb 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import py_func from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -38,7 +39,13 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_range(*args, **kwargs) if f is range: return dynamic_range(*args, **kwargs) - raise ValueError('%s is not supported' % f) + if f is int: + return dynamic_int(*args, **kwargs) + if f is float: + return dynamic_float(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) def dynamic_len(list_or_tensor): @@ -52,6 +59,20 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def dynamic_int(num_or_tensor, **kwargs): + """Implementation of int() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) + return int(num_or_tensor) + + +def dynamic_float(num_or_tensor, **kwargs): + """Implementation of float() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) + return float(num_or_tensor) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 163e6984079fea5c3b3d9aeda0ec8048d651686f..0c2312178a921037fa419818bf309d671c33914d 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import builtins from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.platform import test @@ -77,7 +78,7 @@ class BuiltinsTest(test.TestCase): return x # Functions that just have the names of builtins are rejected. - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( @@ -87,6 +88,20 @@ class BuiltinsTest(test.TestCase): self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + def test_casts(self): + i = constant_op.constant(2, dtype=dtypes.int32) + f = constant_op.constant(1.0, dtype=dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) + self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, True), 1) + self.assertEqual(builtins.dynamic_builtin(int, False), 0) + self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) + self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) + def test_dynamic_print_tf(self): try: out_capturer = six.StringIO() diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index d65c990c87cbc316472237d183c03765416501e7..b27a19b16c08cb588b45949105a6399623e766e1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -49,6 +49,14 @@ cc_library( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], @@ -96,6 +104,7 @@ py_test( name = "batch_ops_test", size = "small", srcs = ["python/ops/batch_ops_test.py"], + shard_count = 5, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a73bfb1bf008f6f4eafd14913c88dcfa..1e503a097a7b72d9244b0a1cf57747c4b4122c81 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4e478c3e60771fdc3ae99febc33d2e3..47b80bdf4ad88ebce3603a14ea2aa3cbe5bd345f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -83,6 +84,70 @@ def batch_function(num_batch_threads, SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors. + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + Args: num_batch_threads: Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index e22f978dde6f1b7febc771d526201579c20292c7..78468145469df216344bc00f116add250dc51dd3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,7 +23,10 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -185,12 +188,62 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] @@ -205,6 +258,114 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOp(self): + """Tests that the batch_function op works.""" + with self.test_session() as sess: + + @function.Defun(dtypes.int32) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = gen_batch_ops.batch_function( + [inp], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, + Tout=[dtypes.int32], + f=computation, + captured_tensors=computation.captured_inputs) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithCapturedInput(self): + """Tests that batch_function op works with captured input.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32) + def computation(inp): + return inp + captured_inp0 - captured_inp1 + + result = gen_batch_ops.batch_function( + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="", + f=computation, + in_tensors=[inp], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + + def testBasicUnbatchDecoratedWithReshape(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return array_ops.reshape(in_t, [-1]) + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [[2]]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" with self.test_session() as sess: diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/contrib/batching/serial_device_batch_scheduler.h similarity index 59% rename from tensorflow/compiler/xla/service/versioned_computation_handle.cc rename to tensorflow/contrib/batching/serial_device_batch_scheduler.h index a693c4695f0e776cf297d0ecd28d6de53bd5c0c6..bf6b7083612018eecf0d1784e60cbbf0c5796fef 100644 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc +++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h @@ -13,20 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" -namespace xla { - -string VersionedComputationHandle::ToString() const { - return tensorflow::strings::StrCat(handle.handle(), ":v", version); -} - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle) { - out << versioned_handle.ToString(); - return out; -} - -} // namespace xla +#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index d9e23646d8334014f1bef0d0744df9310b59909f..9e6a146f67796466202cc5074ddd25e4c2b083a6 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import gamma as gamma_lib from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test @@ -256,50 +255,6 @@ class ExpectationTest(test.TestCase): gradq_approx_kl_normal_normal_, rtol=0.01, atol=0.) - def test_docstring_example_gamma(self): - with self.test_session() as sess: - num_draws = int(1e5) - concentration_p = constant_op.constant(1.) - concentration_q = constant_op.constant(2.) - p = gamma_lib.Gamma(concentration=concentration_p, rate=1.) - q = gamma_lib.Gamma(concentration=concentration_q, rate=3.) - approx_kl_gamma_gamma = monte_carlo_lib.expectation( - f=lambda x: p.log_prob(x) - q.log_prob(x), - samples=p.sample(num_draws, seed=42), - log_prob=p.log_prob, - use_reparametrization=(p.reparameterization_type - == distribution_lib.FULLY_REPARAMETERIZED)) - exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q) - [exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([ - exact_kl_gamma_gamma, approx_kl_gamma_gamma]) - self.assertEqual( - False, - p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) - self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_, - rtol=0.01, atol=0.) - - # Compare gradients. (Not present in `docstring`.) - gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0] - gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0] - [ - gradp_exact_kl_gamma_gamma_, - gradq_exact_kl_gamma_gamma_, - gradp_approx_kl_gamma_gamma_, - gradq_approx_kl_gamma_gamma_, - ] = sess.run([ - gradp(exact_kl_gamma_gamma), - gradq(exact_kl_gamma_gamma), - gradp(approx_kl_gamma_gamma), - gradq(approx_kl_gamma_gamma), - ]) - # Notice that variance (i.e., `rtol`) is higher when using score-trick. - self.assertAllClose(gradp_exact_kl_gamma_gamma_, - gradp_approx_kl_gamma_gamma_, - rtol=0.05, atol=0.) - self.assertAllClose(gradq_exact_kl_gamma_gamma_, - gradq_approx_kl_gamma_gamma_, - rtol=0.03, atol=0.) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 5770bcdd706723394bb06196d24aeb32b8b8491a..68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers. - -See the @{$python/contrib.bayesflow.monte_carlo} guide. -""" +"""Monte Carlo integration and helpers.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 032b859d469ee5039e08e4af4c2f4ebf35c2ff19..68ead2f7609ca987180fe8973cf902f1e56b8388 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -192,7 +192,7 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). + r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..5c15d21e35557ba5ff25d9d943aae2809eddba4a --- /dev/null +++ b/tensorflow/contrib/bigtable/BUILD @@ -0,0 +1,196 @@ +# Cloud Bigtable client for TensorFlow + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_cc_test", + "tf_py_test", +) + +tf_custom_op_py_library( + name = "bigtable", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [ + ":python/ops/_bigtable.so", + ], + kernels = [ + ":bigtable_kernels", + ":bigtable_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +tf_custom_op_library( + name = "python/ops/_bigtable.so", + srcs = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", + "ops/bigtable_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_ops", + deps = [":bigtable_ops_op_lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "bigtable_ops", + "bigtable_test_ops", + ], +) + +tf_kernel_library( + name = "bigtable_kernels", + srcs = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", + ], + deps = [ + ":bigtable_lib_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +# A library for use in the bigtable kernels. +cc_library( + name = "bigtable_lib_cc", + srcs = ["kernels/bigtable_lib.cc"], + hdrs = ["kernels/bigtable_lib.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +cc_library( + name = "bigtable_test_client", + srcs = ["kernels/test_kernels/bigtable_test_client.cc"], + hdrs = ["kernels/test_kernels/bigtable_test_client.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "@com_github_googleapis_googleapis//:bigtable_protos", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + "@com_googlesource_code_re2//:re2", + ], +) + +tf_cc_test( + name = "bigtable_test_client_test", + srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"], + tags = ["manual"], + deps = [ + ":bigtable_test_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_test_ops", + deps = [":bigtable_test_ops_op_lib"], +) + +tf_custom_op_library( + name = "python/kernel_tests/_bigtable_test.so", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + "ops/bigtable_test_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "@com_googlesource_code_re2//:re2", + ], +) + +# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h +cc_library( + name = "bigtable_test_kernels", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_googlesource_code_re2//:re2", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "bigtable_test_py", + dso = [ + ":python/kernel_tests/_bigtable_test.so", + ], + kernels = [ + ":bigtable_test_kernels", + ":bigtable_test_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_test_ops", + # "//tensorflow/contrib/util:util_py", + # "//tensorflow/python:framework_for_generated_wrappers", + # "//tensorflow/python:platform", + # "//tensorflow/python:util", + # "//tensorflow/python/data", + ], +) + +tf_py_test( + name = "bigtable_ops_test", + size = "small", + srcs = ["python/kernel_tests/bigtable_ops_test.py"], + additional_deps = [ + ":bigtable", + ":bigtable_test_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ef3c60069e8a97f7a13457156d20f3f7a4f7eccb --- /dev/null +++ b/tensorflow/contrib/bigtable/README.md @@ -0,0 +1,10 @@ +# Bigtable # + +[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + + diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/contrib/bigtable/__init__.py similarity index 60% rename from tensorflow/python/keras/applications/densenet/__init__.py rename to tensorflow/contrib/bigtable/__init__.py index 6b8ea83920733a3a442171616ab460ffaf831521..7df054637cdab32f2dd6201dd3488a90495e1cf5 100644 --- a/tensorflow/python/keras/applications/densenet/__init__.py +++ b/tensorflow/contrib/bigtable/__init__.py @@ -12,18 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DenseNet Keras applications.""" +"""Cloud Bigtable Client for TensorFlow. + +This contrib package allows TensorFlow to interface directly with Cloud Bigtable +for high-speed data loading. + +@@BigtableClient +@@BigTable + +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'BigTable', + 'BigtableClient', +] -del absolute_import -del division -del print_function +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..f43b44f2cb412244c47d7feea388b6c1eea417f9 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -0,0 +1,331 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +namespace { + +class BigtableClientOp : public OpKernel { + public: + explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); + OP_REQUIRES(ctx, !project_id_.empty(), + errors::InvalidArgument("project_id must be non-empty")); + OP_REQUIRES(ctx, !instance_id_.empty(), + errors::InvalidArgument("instance_id must be non-empty")); + + OP_REQUIRES_OK( + ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_)); + // If left unset by the client code, set it to a default of 100. Note: the + // cloud-cpp default of 4 concurrent connections is far too low for high + // performance streaming. + if (connection_pool_size_ == -1) { + connection_pool_size_ = 100; + } + OP_REQUIRES(ctx, connection_pool_size_ > 0, + errors::InvalidArgument("connection_pool_size must be > 0")); + } + + ~BigtableClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx]( + BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto client_options = google::cloud::bigtable::ClientOptions(); + std::shared_ptr client = + google::cloud::bigtable::CreateDefaultDataClient( + project_id_, instance_id_, std::move(client_options)); + *ret = new BigtableClientResource(project_id_, instance_id_, + std::move(client)); + return Status::OK(); + })); + core::ScopedUnref resource_cleanup(resource); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string project_id_; + string instance_id_; + int64 connection_pool_size_; + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), + BigtableClientOp); + +class BigtableTableOp : public OpKernel { + public: + explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_)); + OP_REQUIRES(ctx, !table_.empty(), + errors::InvalidArgument("table_name must be non-empty")); + } + + ~BigtableTableOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + + BigtableClientResource* client_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); + core::ScopedUnref unref_client(client_resource); + + BigtableTableResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, client_resource](BigtableTableResource** ret) { + *ret = new BigtableTableResource(client_resource, table_); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string table_; // Note: this is const after construction. + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), + BigtableTableOp); + +class ToBigtableOp : public AsyncOpKernel { + public: + explicit ToBigtableOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())), + /* num_threads = */ 1, /* low_latency_hint = */ false)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, done]() { + const Tensor* column_families_tensor; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("column_families", &column_families_tensor), done); + OP_REQUIRES_ASYNC( + ctx, column_families_tensor->dims() == 1, + errors::InvalidArgument("`column_families` must be a vector."), done); + + const Tensor* columns_tensor; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done); + OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1, + errors::InvalidArgument("`columns` must be a vector."), + done); + OP_REQUIRES_ASYNC( + ctx, + columns_tensor->NumElements() == + column_families_tensor->NumElements(), + errors::InvalidArgument("len(column_families) != len(columns)"), + done); + + std::vector column_families; + column_families.reserve(column_families_tensor->NumElements()); + std::vector columns; + columns.reserve(column_families_tensor->NumElements()); + for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { + column_families.push_back(column_families_tensor->flat()(i)); + columns.push_back(columns_tensor->flat()(i)); + } + + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator), + done); + + int64 timestamp_int; + OP_REQUIRES_OK_ASYNC( + ctx, ParseScalarArgument(ctx, "timestamp", ×tamp_int), + done); + OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1, + errors::InvalidArgument("timestamp must be >= -1"), + done); + + BigtableTableResource* resource; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); + core::ScopedUnref resource_cleanup(resource); + + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence = false; + do { + ::google::cloud::bigtable::BulkMutation mutation; + // TODO(saeta): Make # of mutations configurable. + for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + if (!end_of_sequence) { + OP_REQUIRES_OK_ASYNC( + ctx, + CreateMutation(std::move(components), column_families, columns, + timestamp_int, &mutation), + done); + } + components.clear(); + } + grpc::Status mutation_status; + std::vector<::google::cloud::bigtable::FailedMutation> failures = + resource->table().BulkApply(std::move(mutation), mutation_status); + if (!failures.empty()) { + for (const auto& failure : failures) { + LOG(ERROR) << "Failure applying mutation on row (" + << failure.original_index() + << "): " << failure.mutation().row_key() + << " - error: " << failure.status().error_message() + << " (Details: " << failure.status().error_details() + << ")."; + } + } + OP_REQUIRES_ASYNC( + ctx, failures.empty() && mutation_status.ok(), + errors::Unknown("Failure while writing to BigTable: ", + mutation_status.error_code(), " - ", + mutation_status.error_message(), " (", + mutation_status.error_details(), + "), # of mutation failures: ", failures.size(), + ". See the log for the specific error details."), + done); + } while (!end_of_sequence); + done(); + }); + } + + private: + static string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; + } + + Status CreateMutation( + std::vector tensors, const std::vector& column_families, + const std::vector& columns, int64 timestamp_int, + ::google::cloud::bigtable::BulkMutation* bulk_mutation) { + if (tensors.size() != column_families.size() + 1) { + return errors::InvalidArgument( + "Iterator produced a set of Tensors shorter than expected"); + } + ::google::cloud::bigtable::SingleRowMutation mutation( + std::move(tensors[0].scalar()())); + std::chrono::milliseconds timestamp(timestamp_int); + for (size_t i = 1; i < tensors.size(); ++i) { + if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { + return errors::Internal("Output tensor ", i, " was not a scalar"); + } + if (timestamp_int == -1) { + mutation.emplace_back(::google::cloud::bigtable::SetCell( + column_families[i - 1], columns[i - 1], + std::move(tensors[i].scalar()()))); + } else { + mutation.emplace_back(::google::cloud::bigtable::SetCell( + column_families[i - 1], columns[i - 1], timestamp, + std::move(tensors[i].scalar()()))); + } + } + bulk_mutation->emplace_back(std::move(mutation)); + return Status::OK(); + } + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } + + std::unique_ptr thread_pool_; +}; + +REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), + ToBigtableOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..2514575f30831bdcfab87eba07511fd309e8b1c2 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status) { + if (status.ok()) { + return Status::OK(); + } + auto grpc_code = status.error_code(); + if (status.error_code() == ::grpc::StatusCode::ABORTED || + status.error_code() == ::grpc::StatusCode::UNAVAILABLE || + status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { + grpc_code = ::grpc::StatusCode::INTERNAL; + } + return Status( + static_cast<::tensorflow::error::Code>(status.error_code()), + strings::StrCat("Error reading from BigTable: ", status.error_message(), + " (Details: ", status.error_details(), ")")); +} + +string RegexFromStringSet(const std::vector& strs) { + CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; + std::unordered_set uniq(strs.begin(), strs.end()); + if (uniq.size() == 1) { + return *uniq.begin(); + } + return str_util::Join(uniq, "|"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..12d8256dea72e443826675765369ac6daa99a0ca --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ + +// Note: we use bigtable/client/internal/table.h as this is the no-exception API + +#include "google/cloud/bigtable/data_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status); + +string RegexFromStringSet(const std::vector& strs); + +class BigtableClientResource : public ResourceBase { + public: + BigtableClientResource( + string project_id, string instance_id, + std::shared_ptr client) + : project_id_(std::move(project_id)), + instance_id_(std::move(instance_id)), + client_(std::move(client)) {} + + std::shared_ptr get_client() { + return client_; + } + + string DebugString() override { + return strings::StrCat("BigtableClientResource(project_id: ", project_id_, + ", instance_id: ", instance_id_, ")"); + } + + private: + const string project_id_; + const string instance_id_; + std::shared_ptr client_; +}; + +class BigtableTableResource : public ResourceBase { + public: + BigtableTableResource(BigtableClientResource* client, string table_name) + : client_(client), + table_name_(std::move(table_name)), + table_(client->get_client(), table_name_) { + client_->Ref(); + } + + ~BigtableTableResource() override { client_->Unref(); } + + ::google::cloud::bigtable::noex::Table& table() { return table_; } + + string DebugString() override { + return strings::StrCat( + "BigtableTableResource(client: ", client_->DebugString(), + ", table: ", table_name_, ")"); + } + + private: + BigtableClientResource* client_; // Ownes one ref. + const string table_name_; + ::google::cloud::bigtable::noex::Table table_; +}; + +// BigtableReaderDatasetIterator is an abstract class for iterators from +// datasets that are "readers" (source datasets, not transformation datasets) +// that read from Bigtable. +template +class BigtableReaderDatasetIterator : public DatasetIterator { + public: + explicit BigtableReaderDatasetIterator( + const typename DatasetIterator::Params& params) + : DatasetIterator(params), iterator_(nullptr, false) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); + if (iterator_ == reader_->end()) { + grpc::Status status = reader_->Finish(); + if (status.ok()) { + *end_of_sequence = true; + return Status::OK(); + } + return GrpcStatusToTfStatus(status); + } + *end_of_sequence = false; + google::cloud::bigtable::Row& row = *iterator_; + Status s = ParseRow(ctx, row, out_tensors); + // Ensure we always advance. + ++iterator_; + return s; + } + + protected: + virtual ::google::cloud::bigtable::RowRange MakeRowRange() = 0; + virtual ::google::cloud::bigtable::Filter MakeFilter() = 0; + virtual Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) = 0; + + private: + Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (reader_) { + return Status::OK(); + } + + auto rows = MakeRowRange(); + auto filter = MakeFilter(); + + // Note: the this in `this->dataset()` below is necessary due to namespace + // name conflicts. + reader_.reset(new ::google::cloud::bigtable::RowReader( + this->dataset()->table()->table().ReadRows(rows, filter))); + iterator_ = reader_->begin(); + return Status::OK(); + } + + mutex mu_; + std::unique_ptr<::google::cloud::bigtable::RowReader> reader_ GUARDED_BY(mu_); + ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9e49fa35db4b2cd2c8991100a28a5b9c55f01ffe --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -0,0 +1,221 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { + public: + using UnaryDatasetOpKernel::UnaryDatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + BigtableTableResource* table; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = + new Dataset(ctx, input, table, std::move(column_families), + std::move(columns), output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + BigtableTableResource* table, + std::vector column_families, + std::vector columns, + const DataTypeVector& output_types, + std::vector output_shapes) + : GraphDatasetBase(ctx), + input_(input), + table_(table), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + output_types_(output_types), + output_shapes_(std::move(output_shapes)), + filter_(MakeFilter(column_families_, columns_)) { + table_->Ref(); + input_->Ref(); + } + + ~Dataset() override { + table_->Unref(); + input_->Unref(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableLookupDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableLookupDatasetOp::Dataset"; + } + + private: + static ::google::cloud::bigtable::Filter MakeFilter( + const std::vector& column_families, + const std::vector& columns) { + string column_family_regex = RegexFromStringSet(column_families); + string column_regex = RegexFromStringSet(columns); + + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex), + ::google::cloud::bigtable::Filter::ColumnRegex(column_regex)); + } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // Sequence requests. + std::vector input_tensors; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + if (input_tensors.size() != 1) { + return errors::InvalidArgument( + "Upstream iterator (", dataset()->input_->DebugString(), + ") did not produce a single `tf.string` `tf.Tensor`. It " + "produced ", + input_tensors.size(), " tensors."); + } + if (input_tensors[0].NumElements() == 0) { + return errors::InvalidArgument("Upstream iterator (", + dataset()->input_->DebugString(), + ") return an empty set of keys."); + } + if (input_tensors[0].NumElements() == 1) { + // Single key lookup. + ::grpc::Status status; + auto pair = dataset()->table_->table().ReadRow( + input_tensors[0].scalar()(), dataset()->filter_, status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + if (!pair.first) { + return errors::DataLoss("Row key '", + input_tensors[0].scalar()(), + "' not found."); + } + TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); + } else { + // Batched get. + return errors::Unimplemented( + "BigtableLookupDataset doesn't yet support batched retrieval."); + } + return Status::OK(); + } + + private: + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::DataLoss("Column ", dataset()->column_families_[i], + ":", dataset()->columns_[i], + " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + BigtableTableResource* table_; + const std::vector column_families_; + const std::vector columns_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const ::google::cloud::bigtable::Filter filter_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), + BigtableLookupDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e960719614a1c7c6c4af53ea924aef214a09b24d --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = new Dataset(ctx, resource, std::move(prefix)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix) + : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtablePrefixKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string prefix_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), + BigtablePrefixKeyDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..96d3565d9b90e72f9e25e69e91f1931c982714cd --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableRangeKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = + new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string start_key, string end_key) + : GraphDatasetBase(ctx), + table_(table), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string start_key_; + const string end_key_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), + BigtableRangeKeyDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..13cb8681679ec1541b74a20474665f770790201f --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -0,0 +1,219 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableScanDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), + errors::InvalidArgument( + "Either prefix or start_key must be specified")); + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + OP_REQUIRES(ctx, !column_families.empty(), + errors::InvalidArgument("`column_families` is empty")); + + float probability = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "probability", &probability)); + OP_REQUIRES( + ctx, probability > 0 && probability <= 1, + errors::InvalidArgument( + "Probability outside the range of (0, 1]. Got: ", probability)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key), + std::move(column_families), std::move(columns), + probability, output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key, + std::vector column_families, + std::vector columns, float probability, + const DataTypeVector& output_types, + std::vector output_shapes) + : GraphDatasetBase(ctx), + table_(table), + prefix_(std::move(prefix)), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + column_family_regex_(RegexFromStringSet(column_families_)), + column_regex_(RegexFromStringSet(columns_)), + probability_(probability), + output_types_(output_types), + output_shapes_(std::move(output_shapes)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableScanDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableScanDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + if (!dataset()->prefix_.empty()) { + DCHECK(dataset()->start_key_.empty()); + return ::google::cloud::bigtable::RowRange::Prefix( + dataset()->prefix_); + } else { + DCHECK(!dataset()->start_key_.empty()) + << "Both prefix and start_key were empty!"; + return ::google::cloud::bigtable::RowRange::Range( + dataset()->start_key_, dataset()->end_key_); + } + } + ::google::cloud::bigtable::Filter MakeFilter() override { + // TODO(saeta): Investigate optimal ordering here. + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex( + dataset()->column_family_regex_), + ::google::cloud::bigtable::Filter::ColumnRegex( + dataset()->column_regex_), + dataset()->probability_ != 1.0 + ? ::google::cloud::bigtable::Filter::RowSample( + dataset()->probability_) + : ::google::cloud::bigtable::Filter::PassAllFilter()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::InvalidArgument( + "Column ", dataset()->column_families_[i], ":", + dataset()->columns_[i], " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + }; + + BigtableTableResource* table_; + const string prefix_; + const string start_key_; + const string end_key_; + const std::vector column_families_; + const std::vector columns_; + const string column_family_regex_; + const string column_regex_; + const float probability_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), + BigtableScanDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..c164682508cd1ef6ec04162b5206a88628fa5221 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -0,0 +1,369 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" + +#include "google/bigtable/v2/data.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "re2/re2.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/ptr_util.h" +// #include "util/task/codes.pb.h" + +namespace tensorflow { +namespace { + +void UpdateRow(const ::google::bigtable::v2::Mutation& mut, + std::map* row) { + if (mut.has_set_cell()) { + CHECK(mut.set_cell().timestamp_micros() >= -1) + << "Timestamp_micros: " << mut.set_cell().timestamp_micros(); + auto col = + strings::Printf("%s:%s", mut.set_cell().family_name().c_str(), + string(mut.set_cell().column_qualifier()).c_str()); + (*row)[col] = string(mut.set_cell().value()); + } else if (mut.has_delete_from_column()) { + auto col = strings::Printf( + "%s:%s", mut.delete_from_column().family_name().c_str(), + string(mut.delete_from_column().column_qualifier()).c_str()); + row->erase(col); + } else if (mut.has_delete_from_family()) { + auto itr = row->lower_bound(mut.delete_from_family().family_name()); + auto prefix = + strings::Printf("%s:", mut.delete_from_family().family_name().c_str()); + while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) { + row->erase(itr); + } + } else if (mut.has_delete_from_row()) { + row->clear(); + } else { + LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString(); + } +} + +} // namespace + +class SampleRowKeysResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::SampleRowKeysResponse> { + public: + explicit SampleRowKeysResponse(BigtableTestClient* client) + : client_(client) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::SampleRowKeysResponse(); + resp->set_row_key(client_->table_.rows.begin()->first); + resp->set_offset_bytes(0); + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. +}; + +class ReadRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::ReadRowsResponse> { + public: + ReadRowsResponse(BigtableTestClient* client, + google::bigtable::v2::ReadRowsRequest const& request) + : client_(client), request_(request) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::ReadRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + RowFilter filter = MakeRowFilter(); + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::ReadRowsResponse(); + // Send all contents in first response. + for (auto itr = client_->table_.rows.begin(); + itr != client_->table_.rows.end(); ++itr) { + if (filter.AllowRow(itr->first)) { + ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr; + bool sent_first = false; + for (auto col_itr = itr->second.columns.begin(); + col_itr != itr->second.columns.end(); ++col_itr) { + if (filter.AllowColumn(col_itr->first)) { + chunk = resp->add_chunks(); + if (!sent_first) { + sent_first = true; + chunk->set_row_key(itr->first); + } + auto colon_idx = col_itr->first.find(":"); + CHECK(colon_idx != string::npos) + << "No ':' found in: " << col_itr->first; + chunk->mutable_family_name()->set_value( + string(col_itr->first, 0, colon_idx)); + chunk->mutable_qualifier()->set_value( + string(col_itr->first, ++colon_idx)); + if (!filter.strip_values) { + chunk->set_value(col_itr->second); + } + if (filter.only_one_column) { + break; + } + } + } + if (sent_first) { + // We are sending this row, so set the commit flag on the last chunk. + chunk->set_commit_row(true); + } + } + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + struct RowFilter { + std::set row_set; + std::vector> row_ranges; + double row_sample = 0.0; // Note: currently ignored. + std::unique_ptr col_filter; + bool strip_values = false; + bool only_one_column = false; + + bool AllowRow(const string& row) { + if (row_set.find(row) != row_set.end()) { + return true; + } + for (const auto& range : row_ranges) { + if (range.first <= row && range.second > row) { + return true; + } + } + return false; + } + + bool AllowColumn(const string& col) { + if (col_filter) { + return RE2::FullMatch(col, *col_filter); + } else { + return true; + } + } + }; + + RowFilter MakeRowFilter() { + RowFilter filter; + for (auto i = request_.rows().row_keys().begin(); + i != request_.rows().row_keys().end(); ++i) { + filter.row_set.insert(string(*i)); + } + for (auto i = request_.rows().row_ranges().begin(); + i != request_.rows().row_ranges().end(); ++i) { + if (i->start_key_case() != + google::bigtable::v2::RowRange::kStartKeyClosed || + i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) { + LOG(WARNING) << "Skipping row range that cannot be processed: " + << i->ShortDebugString(); + continue; + } + filter.row_ranges.emplace_back(std::make_pair( + string(i->start_key_closed()), string(i->end_key_open()))); + } + if (request_.filter().has_chain()) { + string family_filter; + string qualifier_filter; + for (auto i = request_.filter().chain().filters().begin(); + i != request_.filter().chain().filters().end(); ++i) { + switch (i->filter_case()) { + case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter: + family_filter = i->family_name_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter: + qualifier_filter = i->column_qualifier_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter: + if (i->cells_per_column_limit_filter() != 1) { + LOG(ERROR) << "Unexpected cells_per_column_limit_filter: " + << i->cells_per_column_limit_filter(); + } + break; + case google::bigtable::v2::RowFilter::kStripValueTransformer: + filter.strip_values = i->strip_value_transformer(); + break; + case google::bigtable::v2::RowFilter::kRowSampleFilter: + LOG(INFO) << "Ignoring row sample directive."; + break; + case google::bigtable::v2::RowFilter::kPassAllFilter: + break; + case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter: + filter.only_one_column = true; + break; + default: + LOG(WARNING) << "Ignoring unknown filter type: " + << i->ShortDebugString(); + } + } + if (family_filter.empty() || qualifier_filter.empty()) { + LOG(WARNING) << "Missing regex!"; + } else { + string regex = strings::Printf("%s:%s", family_filter.c_str(), + qualifier_filter.c_str()); + filter.col_filter.reset(new RE2(regex)); + } + } else { + LOG(WARNING) << "Read request did not have a filter chain specified: " + << request_.filter().DebugString(); + } + return filter; + } + + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. + const google::bigtable::v2::ReadRowsRequest request_; +}; + +class MutateRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::MutateRowsResponse> { + public: + explicit MutateRowsResponse(size_t num_successes) + : num_successes_(num_successes) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::MutateRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + *resp = google::bigtable::v2::MutateRowsResponse(); + for (size_t i = 0; i < num_successes_; ++i) { + auto entry = resp->add_entries(); + entry->set_index(i); + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + const size_t num_successes_; + + mutex mu_; + bool sent_first_message_ = false; +}; + +grpc::Status BigtableTestClient::MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) { + mutex_lock l(mu_); + auto* row = &table_.rows[string(request.row_key())]; + for (int i = 0; i < request.mutations_size(); ++i) { + UpdateRow(request.mutations(i), &row->columns); + } + *response = google::bigtable::v2::MutateRowResponse(); + return grpc::Status::OK; +} +grpc::Status BigtableTestClient::CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "CheckAndMutateRow not implemented."); +} +grpc::Status BigtableTestClient::ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "ReadModifyWriteRow not implemented."); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::ReadRows( + grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) { + return MakeUnique(this, request); +} + +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) { + return MakeUnique(this); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::MutateRows( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) { + mutex_lock l(mu_); + for (auto i = request.entries().begin(); i != request.entries().end(); ++i) { + auto* row = &table_.rows[string(i->row_key())]; + for (auto mut = i->mutations().begin(); mut != i->mutations().end(); + ++mut) { + UpdateRow(*mut, &row->columns); + } + } + return MakeUnique(request.entries_size()); +} + +std::shared_ptr BigtableTestClient::Channel() { + LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " + "cause a crash!"; + return nullptr; +} +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h new file mode 100644 index 0000000000000000000000000000000000000000..dac2b16a216d26f02684c7401ed2ddaa4b7baddb --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ + +#include "google/cloud/bigtable/data_client.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class BigtableTestClient : public ::google::cloud::bigtable::DataClient { + public: + std::string const& project_id() const override { return project_id_; } + std::string const& instance_id() const override { return instance_id_; } + void reset() override { + mutex_lock l(mu_); + table_ = Table(); + } + + grpc::Status MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) override; + + grpc::Status CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) override; + + grpc::Status ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + ReadRows(grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) override; + std::unique_ptr< + grpc::ClientReaderInterface> + SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + MutateRows(grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) override; + + std::shared_ptr Channel() override; + + private: + friend class SampleRowKeysResponse; + friend class ReadRowsResponse; + friend class MutateRowsResponse; + + struct Row { + string row_key; + std::map columns; + }; + struct Table { + std::map rows; + }; + + mutex mu_; + const std::string project_id_ = "testproject"; + const std::string instance_id_ = "testinstance"; + Table table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa3e587b90147bd519586eef0cfb5e048b1b75be --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace { + +class BigtableTestClientOp : public OpKernel { + public: + explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~BigtableTestClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr client( + new BigtableTestClient()); + // Note: must make explicit copies to sequence + // them before the move of client. + string project_id = client->project_id(); + string instance_id = client->instance_id(); + *ret = new BigtableClientResource(std::move(project_id), + std::move(instance_id), + std::move(client)); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU), + BigtableTestClientOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6b396471941eaa0ca1c13a7386503ed3861e087 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -0,0 +1,290 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void WriteCell(const string& row, const string& family, const string& column, + const string& value, + ::google::cloud::bigtable::noex::Table* table) { + ::google::cloud::bigtable::SingleRowMutation mut(row); + mut.emplace_back(::google::cloud::bigtable::SetCell(family, column, value)); + table->Apply(std::move(mut)); +} + +TEST(BigtableTestClientTest, EmptyRowRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + ::google::cloud::bigtable::RowSet rowset; + rowset.Append("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; + EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; +} + +TEST(BigtableTestClientTest, SingleRowWriteAndRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + ::google::cloud::bigtable::RowSet rowset("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()); + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::google::cloud::bigtable::RowSet rowset("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::google::cloud::bigtable::RowSet rowset("r1", "r2", "r3"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, ColumnFiltering) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex("f1"), + ::google::cloud::bigtable::Filter::ColumnRegex("c1")); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, RowKeys) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7ff012ec89db74848b513d614de49664b5724d8 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(saeta): Add support for setting ClientOptions values. +REGISTER_OP("BigtableClient") + .Attr("project_id: string") + .Attr("instance_id: string") + .Attr("connection_pool_size: int") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Add support for Application Profiles. +// See https://cloud.google.com/bigtable/docs/app-profiles for more info. +REGISTER_OP("BigtableTable") + .Input("client: resource") + .Attr("table_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("table: resource") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("DatasetToBigtable") + .Input("table: resource") + .Input("input_dataset: variant") + .Input("column_families: string") + .Input("columns: string") + .Input("timestamp: int64") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("BigtableLookupDataset") + .Input("keys_dataset: variant") + .Input("table: resource") + .Input("column_families: string") + .Input("columns: string") + .Output("handle: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtablePrefixKeyDataset") + .Input("table: resource") + .Input("prefix: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableRangeKeyDataset") + .Input("table: resource") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Support continuing despite bad data (e.g. empty string, or +// skip incomplete row.) +REGISTER_OP("BigtableScanDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Input("column_families: string") + .Input("columns: string") + .Input("probability: float") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7d02458f63d547000f00b184b3d5e3c5007fb72 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("BigtableTestClient") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/python/keras/_impl/keras/wrappers/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py similarity index 82% rename from tensorflow/python/keras/_impl/keras/wrappers/__init__.py rename to tensorflow/contrib/bigtable/python/kernel_tests/__init__.py index 20c95929e3d2e1f66e66efe43b9685c5d6ed1c10..292d8f4e51abbbd89d68b47febd86b7297bb8ed2 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/__init__.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras API wrappers. -""" + +"""This module contains tests for the bigtable integration.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.python.keras._impl.keras.wrappers import scikit_learn - diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d33a66f2dfbecd0dc1082fd98973660ce9a93931 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bigtable Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import bigtable +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable_test.so")) + + +class BigtableOpsTest(test.TestCase): + COMMON_ROW_KEYS = ["r1", "r2", "r3"] + COMMON_VALUES = ["v1", "v2", "v3"] + + def setUp(self): + self._client = gen_bigtable_test_ops.bigtable_test_client() + table = gen_bigtable_ops.bigtable_table(self._client, "testtable") + self._table = bigtable.BigTable("testtable", None, table) + + def _makeSimpleDataset(self): + output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS) + output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES) + return dataset_ops.Dataset.zip((output_rows, output_values)) + + def _writeCommonValues(self, sess): + output_ds = self._makeSimpleDataset() + write_op = self._table.write(output_ds, ["cf1"], ["c1"]) + sess.run(write_op) + + def runReadKeyTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected = list(self.COMMON_ROW_KEYS) + expected.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output), + "Unequal at step %d: want: %s, got: %s" % (i, want, output)) + + def testReadPrefixKeys(self): + self.runReadKeyTest(self._table.keys_by_prefix_dataset("r")) + + def testReadRangeKeys(self): + self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) + + def runScanTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_keys.reverse() + expected_values = list(self.COMMON_VALUES) + expected_values.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected_keys.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0])) + want = expected_values.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[1]), + "Unequal values at step: %d: want: %s, got: %s" % (i, want, + output[1])) + + def testScanPrefixStringCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1="c1")) + + def testScanPrefixListCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1=["c1"])) + + def testScanRangeStringCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1")) + + def testScanRangeListCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"])) + + def testLookup(self): + ds = self._table.keys_by_prefix_dataset("r") + ds = ds.apply(self._table.lookup_columns(cf1="c1")) + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_values = list(self.COMMON_VALUES) + expected_tuples = zip(expected_keys, expected_values) + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elem in enumerate(expected_tuples): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elem[0]), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elem[1]), compat.as_bytes(output[1]), + "Unequal values at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/bigtable/python/ops/__init__.py similarity index 75% rename from tensorflow/python/keras/datasets/cifar10/__init__.py rename to tensorflow/contrib/bigtable/python/ops/__init__.py index 68d3eb789ea2c410095c0c75e0b79a9b07d209a3..36d75b0d7068a650347a5e17f4727a5432d8752f 100644 --- a/tensorflow/python/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/bigtable/python/ops/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""CIFAR10 small image classification dataset.""" + +"""This module contains the Python API for the Cloud Bigtable integration.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data - -del absolute_import -del division -del print_function diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py new file mode 100644 index 0000000000000000000000000000000000000000..39c58ba6659e5e637c31dce419c34bcce9c09838 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -0,0 +1,499 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python API for TensorFlow's Bigtable integration. + +TensorFlow has support for reading from and writing to Cloud Bigtable. To use +the Bigtable TensorFlow integration, first create a BigtableClient (which +configures your connection to Cloud Bigtable), and then open a Table. The Table +object then allows you to create numerous @{tf.data.Dataset}s to read data, or +write a @{tf.data.Dataset} object to the underlying Bigtable Table. + +For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import iteritems + +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable.so")) + + +class BigtableClient(object): + """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. + + BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the + `table` method to open a Bigtable Table. + """ + + def __init__(self, project_id, instance_id, connection_pool_size=None): + """Creates a BigtableClient that can be used to open connections to tables. + + Args: + project_id: A string representing the GCP project id to connect to. + instance_id: A string representing the Bigtable instance to connect to. + connection_pool_size: (Optional.) A number representing the number of + concurrent connections to the Cloud Bigtable service to make. + + Raises: + ValueError: if the arguments are invalid (e.g. wrong type, or out of + expected ranges (e.g. negative).) + """ + if not isinstance(project_id, str): + raise ValueError("`project_id` must be a string") + self._project_id = project_id + + if not isinstance(instance_id, str): + raise ValueError("`instance_id` must be a string") + self._instance_id = instance_id + + if connection_pool_size is None: + connection_pool_size = -1 + elif connection_pool_size < 1: + raise ValueError("`connection_pool_size` must be positive") + self._connection_pool_size = connection_pool_size + + self._resource = gen_bigtable_ops.bigtable_client(project_id, instance_id, + connection_pool_size) + + def table(self, name, snapshot=None): + """Opens a table and returns a `BigTable` object. + + Args: + name: A `tf.string` `tf.Tensor` name of the table to open. + snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to + request the creation of a snapshot. (Note: currently unimplemented.) + + Returns: + A `BigTable` python object representing the operations available on the + table. + """ + # TODO(saeta): Implement snapshot functionality. + table = gen_bigtable_ops.bigtable_table(self._resource, name) + return BigTable(name, snapshot, table) + + +class BigTable(object): + """BigTable is the entrypoint for reading and writing data in Cloud Bigtable. + + This BigTable class is the python representation of the Cloud Bigtable table + within TensorFlow. Methods on this class allow data to be read from and + written to the Cloud Bigtable service in flexible and high performance + manners. + """ + + # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface. + # TODO(saeta): Consider variant tensors instead of resources (while supporting + # connection pooling). + + def __init__(self, name, snapshot, resource): + self._name = name + self._snapshot = snapshot + self._resource = resource + + def lookup_columns(self, *args, **kwargs): + """Retrieves the values of columns for a dataset of keys. + + Example usage: + ``` + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns(("cf1", "image"), + ("cf2", "label"), + ("cf2", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Alternatively, you can use keyword arguments to specify the columns to + capture. Example (same as above, rewritten): + ``` + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns( + cf1="image", cf2=("label", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Note: certain kwargs keys are reserved, and thus some column families cannot + be identified using the kwargs syntax. Instead, please use the args syntax. + This list includes: + - 'name' + This list can change at any time. + + Args: + *args: A list of tuples containing (column family, column name) pairs. + **kwargs: Column families and + + Returns: + A function that can be passed to `tf.data.Dataset.apply` to retrieve the + values of columns for the rows. + """ + table = self # Capture self + normalized = args + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + def _apply_fn(dataset): + # TODO(saeta): Verify dataset's types are correct! + return _BigtableLookupDataset(dataset, table, normalized) + + return _apply_fn + + def keys_by_range_dataset(self, start, end): + """Retrieves all row keys between start and end. + + Note: it does NOT retrieve the values of columns. + + Args: + start: The start row key. The row keys for rows after start (inclusive) + will be retrieved. + end: (Optional.) The end row key. Rows up to (but not including) end will + be retrieved. If end is None, all subsequent row keys will be retrieved. + + Returns: + A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all + of the row keys between `start` and `end`. + """ + # TODO(saeta): Make inclusive / exclusive configurable? + if end is None: + end = "" + return _BigtableRangeKeyDataset(self, start, end) + + def keys_by_prefix_dataset(self, prefix): + """Retrieves the row keys matching a given prefix. + + Args: + prefix: All row keys that begin with `prefix` in the table will be + retrieved. + + Returns: + A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all + of the row keys matching that prefix. + """ + return _BigtablePrefixKeyDataset(self, prefix) + + def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): + """Retrieves row (including values) from the Bigtable service. + + Rows with row-key prefixed by `prefix` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys muat match to be retrieved for prefix- + based scans. + probability: Probabilistically sample rows. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + + def scan_range(self, start, end, probability=None, columns=None, **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + probability: Probabilistically sample rows. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + return _BigtableScanDataset(self, "", start, end, normalized, probability) + + def write(self, dataset, column_families, columns, timestamp=None): + """Writes a dataset to the table. + + Args: + dataset: A @{tf.data.Dataset} to be written to this table. It must produce + a list of number-of-columns+1 elements, all of which must be strings. + The first value will be used as the row key, and subsequent values will + be used as cell values for the corresponding columns from the + corresponding column_families and columns entries. + column_families: A @{tf.Tensor} of `tf.string`s corresponding to the + column names to store the dataset's elements into. + columns: A `tf.Tensor` of `tf.string`s corresponding to the column names + to store the dataset's elements into. + timestamp: (Optional.) An int64 timestamp to write all the values at. + Leave as None to use server-provided timestamps. + + Returns: + A @{tf.Operation} that can be run to perform the write. + + Raises: + ValueError: If there are unexpected or incompatible types, or if the + number of columns and column_families does not match the output of + `dataset`. + """ + if timestamp is None: + timestamp = -1 # Bigtable server provided timestamp. + for tensor_type in nest.flatten(dataset.output_types): + if tensor_type != dtypes.string: + raise ValueError("Not all elements of the dataset were `tf.string`") + for shape in nest.flatten(dataset.output_shapes): + if not shape.is_compatible_with(tensor_shape.scalar()): + raise ValueError("Not all elements of the dataset were scalars") + if len(column_families) != len(columns): + raise ValueError("len(column_families) != len(columns)") + if len(nest.flatten(dataset.output_types)) != len(columns) + 1: + raise ValueError("A column name must be specified for every component of " + "the dataset elements. (e.g.: len(columns) != " + "len(dataset.output_types))") + return gen_bigtable_ops.dataset_to_bigtable( + self._resource, + dataset._as_variant_tensor(), # pylint: disable=protected-access + column_families, + columns, + timestamp) + + +class _BigtableKeyDataset(dataset_ops.Dataset): + """_BigtableKeyDataset is an abstract class representing the keys of a table. + """ + + def __init__(self, table): + """Constructs a _BigtableKeyDataset. + + Args: + table: a Bigtable class. + """ + super(_BigtableKeyDataset, self).__init__() + self._table = table + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.TensorShape([]) + + @property + def output_types(self): + return dtypes.string + + +class _BigtablePrefixKeyDataset(_BigtableKeyDataset): + """_BigtablePrefixKeyDataset represents looking up keys by prefix. + """ + + def __init__(self, table, prefix): + super(_BigtablePrefixKeyDataset, self).__init__(table) + self._prefix = prefix + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_prefix_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix) + + +class _BigtableRangeKeyDataset(_BigtableKeyDataset): + """_BigtableRangeKeyDataset represents looking up keys by range. + """ + + def __init__(self, table, start, end): + super(_BigtableRangeKeyDataset, self).__init__(table) + self._start = start + self._end = end + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_range_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + start_key=self._start, + end_key=self._end) + + +class _BigtableLookupDataset(dataset_ops.Dataset): + """_BigtableLookupDataset represents a dataset that retrieves values for keys. + """ + + def __init__(self, dataset, table, normalized): + self._num_outputs = len(normalized) + 1 # 1 for row key + self._dataset = dataset + self._table = table + self._normalized = normalized + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._as_variant_tensor(), + table=self._table._resource, + column_families=self._column_families, + columns=self._columns) + + +class _BigtableScanDataset(dataset_ops.Dataset): + """_BigtableScanDataset represents a dataset that retrieves keys and values. + """ + + def __init__(self, table, prefix, start, end, normalized, probability): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + self._probability = probability + self._num_outputs = len(normalized) + 1 # 1 for row key + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_scan_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end, + column_families=self._column_families, + columns=self._columns, + probability=self._probability) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 8cff1a3bb1d11aff6a264636291a7149b40de516..ef0e80cd0997bc0e95cd0d150e87db144a2dde44 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -15,8 +15,9 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "custom_export_strategy", + ":custom_export_strategy", ":custom_loss_head", + ":distillation_loss", ":estimator", ":model", ":trainer_hooks", @@ -144,6 +145,7 @@ py_library( srcs = ["dnn_tree_combined_estimator.py"], srcs_version = "PY2AND3", deps = [ + ":distillation_loss", ":estimator_utils", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", @@ -156,6 +158,17 @@ py_library( ], ) +py_library( + name = "distillation_loss", + srcs = ["distillation_loss.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + ], +) + py_test( name = "dnn_tree_combined_estimator_test", size = "medium", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9aacc5534329d1302b25dcfab678f9adb8f773f6 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utill functions for distillation loss. + +The distillation loss_fn will be called with the following: + +Args: + dnn_logits: Tensor of logits from the dnn, treated as the "target". This will + be the output of a call to tf.stop_gradient(). + tree_logits: Tensor of logits from the tree, treated as the "predictions". + example_weights: Tensor of example weights, or a single scalar. + +Returns: + A scalar indicating the reduced loss for that batch of examples. + +Note: we calls the loss_fn defined in contrib head, which is computing two +losses, first one for training and second one for reporting. We only take the +first one here. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + + +def _logits_to_label_for_tree(logits, n_classes): + if n_classes == 2: + return math_ops.sigmoid(logits) + else: + return nn.softmax(logits) + + +def create_dnn_to_tree_squared_loss_fn(n_classes): + """Returns a squared loss function for dnn to tree distillation.""" + + def _dnn_to_tree_squared_loss(dnn_logits, tree_logits, example_weights): + return head_lib._mean_squared_loss( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=_logits_to_label_for_tree(tree_logits, n_classes), + weights=example_weights)[0] + + return _dnn_to_tree_squared_loss + + +def create_dnn_to_tree_cross_entropy_loss_fn(n_classes): + """Returns a cross entropy loss function for dnn to tree distillation.""" + + def _dnn_to_tree_cross_entropy_loss(dnn_logits, tree_logits, example_weights): + if n_classes == 2: + return head_lib._log_loss_with_two_classes( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=tree_logits, + weights=example_weights)[0] + else: + return head_lib._softmax_cross_entropy_loss( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=tree_logits, + weights=example_weights)[0] + + return _dnn_to_tree_cross_entropy_loss diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 9994c84ebdb930eea0818188225488eb5eca84eb..7eb429b636a5193a124dd9b0c020dae6cac910cb 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -24,7 +24,9 @@ from __future__ import division from __future__ import print_function import six + from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -35,16 +37,19 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import training_util _DNN_LEARNING_RATE = 0.001 + def _get_optimizer(optimizer): if callable(optimizer): return optimizer() @@ -73,8 +78,10 @@ def _dnn_tree_combined_model_fn(features, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """DNN and GBDT combined model_fn. @@ -108,11 +115,20 @@ def _dnn_tree_combined_model_fn(features, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. @@ -128,12 +144,17 @@ def _dnn_tree_combined_model_fn(features, if not dnn_feature_columns: raise ValueError("dnn_feature_columns must be specified") + if dnn_to_tree_distillation_param: + if not predict_with_tree_only: + logging.warning("update predict_with_tree_only to True since distillation" + "is specified.") + predict_with_tree_only = True + # Build DNN Logits. dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( partitioned_variables.min_max_variable_partitioner( - max_partitions=config.num_ps_replicas, - min_slice_size=64 << 20)) + max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) with variable_scope.variable_scope( dnn_parent_scope, @@ -171,8 +192,7 @@ def _dnn_tree_combined_model_fn(features, _add_hidden_layer_summary(net, hidden_layer_scope.name) previous_layer = net with variable_scope.variable_scope( - "logits", - values=(previous_layer,)) as logits_scope: + "logits", values=(previous_layer,)) as logits_scope: dnn_logits = layers.fully_connected( previous_layer, head.logits_dimension, @@ -190,8 +210,7 @@ def _dnn_tree_combined_model_fn(features, optimizer=_get_optimizer(dnn_optimizer), name=dnn_parent_scope, variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=dnn_parent_scope), + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), # Empty summaries to prevent optimizers from logging training_loss. summaries=[]) @@ -224,13 +243,41 @@ def _dnn_tree_combined_model_fn(features, def _tree_train_op_fn(loss): """Returns the op to optimize the loss.""" + if dnn_to_tree_distillation_param: + loss_weight, loss_fn = dnn_to_tree_distillation_param + weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access + features, head.weight_column_name) + dnn_logits_fixed = array_ops.stop_gradient(dnn_logits) + + if loss_fn is None: + # we create the loss_fn similar to the head loss_fn for + # multi_class_head used previously as the default one. + n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension + loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn( + n_classes) + + dnn_to_tree_distillation_loss = loss_weight * loss_fn( + dnn_logits_fixed, tree_logits, weight_tensor) + summary.scalar("dnn_to_tree_distillation_loss", + dnn_to_tree_distillation_loss) + loss += dnn_to_tree_distillation_loss + update_op = gbdt_model.train(loss, predictions_dict, labels) with ops.control_dependencies( [update_op]), (ops.colocate_with(global_step)): update_op = state_ops.assign_add(global_step, 1).op return update_op - tree_train_logits = dnn_logits + tree_logits + if predict_with_tree_only: + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER: + tree_train_logits = tree_logits + else: + tree_train_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: tree_logits, + lambda: dnn_logits) + else: + tree_train_logits = dnn_logits + tree_logits def _no_train_op_fn(loss): """Returns a no-op.""" @@ -288,10 +335,10 @@ def _dnn_tree_combined_model_fn(features, finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp( - dnn_train_op, dnn_steps_to_train, tree_train_op), - trainer_hooks.StopAfterNTrees( - num_trees, attempted_trees, finalized_trees)]) + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) + ]) return model_fn_ops @@ -318,8 +365,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """Initializes a DNNBoostedTreeCombinedClassifier instance. @@ -360,11 +409,20 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ @@ -377,16 +435,33 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedRegressor(estimator.Estimator): @@ -410,8 +485,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """Initializes a DNNBoostedTreeCombinedRegressor instance. @@ -452,11 +529,20 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ @@ -474,16 +560,33 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedEstimator(estimator.Estimator): @@ -508,8 +611,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """Initializes a DNNBoostedTreeCombinedEstimator instance. @@ -545,23 +650,50 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ + def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index f495edc62f0909880c170ccb4cf5d11e3f20f55c..9b7acfa664b0398216b5a7fb904960d8363929d6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -131,6 +131,30 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateWithDistillation(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[feature_column.real_valued_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[feature_column.real_valued_column("x")], + dnn_to_tree_distillation_param=(1, None)) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 89d0d611d2905492cec09e033b8cbc238ec7fac6..9c36c302210185bc390751a0229a61f2f8cd91b8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -41,7 +41,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -66,6 +67,16 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + Raises: ValueError: If learner_config is not valid. """ @@ -74,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): # supports second order derivative. def loss_fn(labels, logits, weights=None): result = losses.per_example_maxent_loss( - labels=labels, logits=logits, weights=weights, + labels=labels, + logits=logits, + weights=weights, num_classes=n_classes) return math_ops.reduce_mean(result[0]) else: @@ -102,6 +115,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'center_bias': center_bias, 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, }, model_dir=model_dir, config=config, @@ -124,7 +138,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -151,6 +166,13 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ head = head_lib.regression_head( label_name=label_name, @@ -173,6 +195,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, @@ -197,7 +220,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -220,6 +244,13 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -233,6 +264,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 0d58317bd59331cfcde0e12aeb3a3a03fc45d89b..75ef1b050028b6462b255827c06e836e5c481844 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -68,6 +68,28 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) classifier.export(self._export_dir_base) + def testThatLeafIndexIsInPredictions(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("leaf_index" in prediction_dict) + self.assertTrue("logits" in prediction_dict) + def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 15ab6d814522ab1dee58dcd71246354fc4d8a483..1ee891198939e53fc5913104b2c2e65dc977823f 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -63,6 +63,8 @@ def model_builder(features, labels, mode, params, config): num_trees = params["num_trees"] use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + if features is None: raise ValueError("At least one feature must be specified.") @@ -96,7 +98,8 @@ def model_builder(features, labels, mode, params, config): feature_columns=feature_columns, logits_dimension=head.logits_dimension, features=training_features, - use_core_columns=use_core_libs) + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -127,6 +130,9 @@ def model_builder(features, labels, mode, params, config): labels=labels, train_op_fn=_train_op_fn, logits=logits) + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] if num_trees: if center_bias: num_trees += 1 diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index b3fe38614e05801b223f0c96f7a70ce7e432a70b..9493c1a1394040db3b744f1b382b20bd5bd1988d 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; +const char* kLeafIndexTensorName = "leaf_index"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, @@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel { core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } else { - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } } - private: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource) { + protected: + // return_output_leaf_index is a boolean variable indicating whether to output + // leaf index in prediction. Though this class invokes only with this param + // value as false, the subclass GradientTreesPredictionVerboseOp will invoke + // with the true value. + virtual void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + const bool return_output_leaf_index) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); + // Allocate output leaf index matrix. + Tensor* output_leaf_index_t = nullptr; + if (return_output_leaf_index) { + OP_REQUIRES_OK(context, context->allocate_output( + kLeafIndexTensorName, + {batch_size, ensemble_resource->num_trees()}, + &output_leaf_index_t)); + } // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; @@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel { i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, - worker_threads, output_predictions); + worker_threads, output_predictions, + output_leaf_index_t); } else { MultipleAdditiveTrees::Predict( ensemble_resource->decision_tree_ensemble(), trees_to_include, - batch_features, worker_threads, output_predictions); + batch_features, worker_threads, output_predictions, + output_leaf_index_t); } // Output dropped trees and original weights. @@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel { {2, static_cast(dropped_trees.size())}, &output_dropout_info_t)); auto output_dropout_info = output_dropout_info_t->matrix(); - for (int32 i = 0; i < dropped_trees.size(); ++i) { output_dropout_info(0, i) = dropped_trees[i]; output_dropout_info(1, i) = original_weights[i]; @@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU), GradientTreesPredictionOp); +// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp +// and have an additional output of tensor of rank 2 containing leaf ids for +// each tree where an instance ended up with. +class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp { + public: + explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context) + : GradientTreesPredictionOp(context) {} + + protected: + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + bool return_output_leaf_index) override { + GradientTreesPredictionOp::DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/true); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU), + GradientTreesPredictionVerboseOp); + class GradientTreesPartitionExamplesOp : public OpKernel { public: explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 04e32267cc4a00b3169c3abbcbf549805a0fb462..401bec84a20a0fefcddbfa1039a117e65f853633 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -43,47 +43,60 @@ namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; } // namespace -class BaseBuildSplitOp : public OpKernel { +class SplitBuilderState { public: - explicit BaseBuildSplitOp(OpKernelConstruction* const context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", - &feature_column_group_id_)); + explicit SplitBuilderState(OpKernelContext* const context) { + const Tensor* l1_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l1_regularization", &l1_regularization_)); + context->input("l1_regularization", &l1_regularization_t)); + const Tensor* l2_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l2_regularization", &l2_regularization_)); - OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization", - &tree_complexity_regularization_)); + context->input("l2_regularization", &l2_regularization_t)); + const Tensor* tree_complexity_regularization_t; + OP_REQUIRES_OK(context, context->input("tree_complexity_regularization", + &tree_complexity_regularization_t)); + const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, - context->GetAttr("min_node_weight", &min_node_weight_)); + context->input("min_node_weight", &min_node_weight_t)); - int strategy; - OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy)); + const Tensor* feature_column_group_id_t; + OP_REQUIRES_OK(context, context->input("feature_column_group_id", + &feature_column_group_id_t)); + + const Tensor* multiclass_strategy_t; + OP_REQUIRES_OK( + context, context->input("multiclass_strategy", &multiclass_strategy_t)); + int strategy = multiclass_strategy_t->scalar()(); OP_REQUIRES( context, boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid( strategy), errors::InvalidArgument("Wrong multiclass strategy passed.")); - multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - } - NodeStats ComputeNodeStats(const GradientStats& grad_stats) { - return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, - multiclass_strategy_, grad_stats); - } + multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - void ReadClassId(OpKernelContext* const context, int32* class_id) { const Tensor* class_id_t; OP_REQUIRES_OK(context, context->input("class_id", &class_id_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()), errors::InvalidArgument("class_id must be a scalar.")); - *class_id = class_id_t->scalar()(); + class_id_ = class_id_t->scalar()(); + + l1_regularization_ = l1_regularization_t->scalar()(); + l2_regularization_ = l2_regularization_t->scalar()(); + tree_complexity_regularization_ = + tree_complexity_regularization_t->scalar()(); + min_node_weight_ = min_node_weight_t->scalar()(); + feature_column_group_id_ = feature_column_group_id_t->scalar()(); + } + + NodeStats ComputeNodeStats(const GradientStats& grad_stats) { + return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, + multiclass_strategy_, grad_stats); } - void FillLeaf(const int class_id, const NodeStats& best_node_stats, + void FillLeaf(const NodeStats& best_node_stats, boosted_trees::trees::Leaf* leaf) const { - if (class_id == -1) { + if (class_id_ == -1) { // This would be the case either for TREE_PER_CLASS with only 2 classes, // or for other multiclass strategies. for (float f : best_node_stats.weight_contribution) { @@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel { CHECK(best_node_stats.weight_contribution.size() == 1) << "Weight contribution size = " << best_node_stats.weight_contribution.size(); - leaf->mutable_sparse_vector()->add_index(class_id); + leaf->mutable_sparse_vector()->add_index(class_id_); leaf->mutable_sparse_vector()->add_value( best_node_stats.weight_contribution[0]); } } - protected: + int32 feature_column_group_id() { return feature_column_group_id_; } + float tree_complexity_regularization() { + return tree_complexity_regularization_; + } + + private: LearnerConfig_MultiClassStrategy multiclass_strategy_; - int32 feature_column_group_id_; float l1_regularization_; float l2_regularization_; - float min_node_weight_; float tree_complexity_regularization_; + float min_node_weight_; + int32 class_id_; + int32 feature_column_group_id_; }; -class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildDenseInequalitySplitsOp : public OpKernel { public: explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) {} + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } @@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); -class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildSparseInequalitySplitsOp : public OpKernel { public: explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // For each partition (tree node), store starting index for each dimension. PartitionAndDimensionBoundaries partition_boundaries; @@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { const auto& dimension_boundaries = @@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { OP_REQUIRES( context, - bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); // Dimension for bias feature is always 0 @@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats root_gradient_stats(*gradients_t, *hessians_t, bias_start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { @@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { << bucket_ids_and_dimensions(start_index, 1) << " and for " << bucket_ids_and_dimensions(end_index - 1, 0) << " " << bucket_ids_and_dimensions(end_index - 1, 1); - if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) << "Dimension of bias feature should be 0"; @@ -447,10 +464,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { present_gradient_stats - left_gradient_stats; { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats left_stats_default_left = state.ComputeNodeStats( + root_gradient_stats - right_gradient_stats); NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); + state.ComputeNodeStats(right_gradient_stats); if (left_stats_default_left.gain + right_stats_default_left.gain > best_gain) { best_gain = @@ -466,9 +483,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // enough missing examples. if (!fixed_default_direction) { NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); + state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = state.ComputeNodeStats( + root_gradient_stats - left_gradient_stats); if (left_stats_default_right.gain + right_stats_default_right.gain > best_gain) { best_gain = left_stats_default_right.gain + @@ -494,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_sparse_float_binary_split_default_left() ->mutable_split(); } - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); // Set the feature index for the best feature column. const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); @@ -505,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); } } @@ -526,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // For each partition, store start indices of feature column dimensions. typedef std::vector> PartitionAndDimensionBoundaries; - - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), BuildSparseInequalitySplitsOp); -class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { +class BuildCategoricalEqualitySplitsOp : public OpKernel { public: explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -561,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; @@ -605,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -625,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -637,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(feature_column_group_id_); + equality_split->set_feature_column(state.feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } - - private: - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 56ff00b39062d57c813633c98c765e077dd4c262..1b7f59ea4218355a13f1df7264352bd68503bd19 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -37,6 +37,7 @@ class BaseSplitHandler(object): gradient_shape, hessian_shape, multiclass_strategy, + loss_uses_sum_reduction=False, name=None): """Constructor for BaseSplitHandler. @@ -51,6 +52,8 @@ class BaseSplitHandler(object): gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. multiclass_strategy: Strategy describing how to treat multiclass problems. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ self._l1_regularization = l1_regularization @@ -62,6 +65,7 @@ class BaseSplitHandler(object): self._multiclass_strategy = multiclass_strategy self._hessian_shape = hessian_shape self._gradient_shape = gradient_shape + self._loss_uses_sum_reduction = loss_uses_sum_reduction def scheduled_reads(self): """Returns the list of `ScheduledOp`s required for update_stats.""" diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 9f78ab20242800fd8af7ad049d5970fbe26ec0ea..bf686237ff696dadad9713d26bf784d7442b80d0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -23,6 +23,7 @@ from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -44,6 +45,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -62,6 +64,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(EqualitySplitHandler, self).__init__( @@ -73,6 +77,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): gradient_shape=gradient_shape, hessian_shape=hessian_shape, multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction, name=name) self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( init_stamp_token, @@ -173,6 +178,11 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # pair. num_minibatches, partition_ids, feature_ids, gradients, hessians = ( self._stats_accumulator.flush(stamp_token, next_stamp_token)) + # For sum_reduction, we don't need to divide by number of minibatches. + + num_minibatches = control_flow_ops.cond( + ops.convert_to_tensor(self._loss_uses_sum_reduction), + lambda: math_ops.to_int64(1), lambda: num_minibatches) partition_ids, gains, split_infos = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=num_minibatches, @@ -187,7 +197,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): tree_complexity_regularization=self._tree_complexity_regularization, min_node_weight=self._min_node_weight, bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy,)) + multiclass_strategy=self._multiclass_strategy)) # There are no warm-up rounds needed in the equality column handler. So we # always return ready. are_splits_ready = constant_op.constant(True) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index 0b65eba2a76273a81f1464ed7639f0c0760e0050..ef253e7cec4e8a96b360ced32b59398c2e2c9680 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -90,7 +90,17 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): empty_hessians, example_weights, is_active=array_ops.constant([True, True])) - with ops.control_dependencies([update_1]): + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + + with ops.control_dependencies([update_1, update_2]): are_splits_ready, partitions, gains, splits = ( split_handler.make_splits(0, 1, class_id)) are_splits_ready, partitions, gains, splits = (sess.run( @@ -159,6 +169,129 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) + def testGenerateFeatureSplitCandidatesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Feature ID | + # i0 | (0.2, 0.12) | 0 | 1,2 | + # i1 | (-0.5, 0.07) | 0 | | + # i2 | (1.2, 0.2) | 0 | 2 | + # i3 | (4.0, 0.13) | 1 | 1 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = [0, 0, 0, 1] + indices = [[0, 0], [0, 1], [2, 0], [3, 0]] + values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = categorical_split_handler.EqualitySplitHandler( + l1_regularization=0.1, + l2_regularization=1, + tree_complexity_regularization=0, + min_node_weight=0, + sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]), + feature_column_group_id=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + init_stamp_token=0, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1, update_2]): + are_splits_ready, partitions, gains, splits = ( + split_handler.make_splits(0, 1, class_id)) + are_splits_ready, partitions, gains, splits = ( + sess.run([are_splits_ready, partitions, gains, splits])) + self.assertTrue(are_splits_ready) + self.assertAllEqual([0, 1], partitions) + + # Check the split on partition 0. + # -(0.4 + 2.4 - 0.1) / (0.24 + 0.4 + 1) + expected_left_weight = -1.6463414634146338 + + # (0.4 + 2.4 - 0.1) ** 2 / (0.24 + 0.4 + 1) + expected_left_gain = 4.445121951219511 + + # -(-1 + 0.1) / (0.14 + 1) + expected_right_weight = 0.789473684211 + + # (-1 + 0.1) ** 2 / (0.14 + 1) + expected_right_gain = 0.710526315789 + + # (0.4 + -1 + 2.4 - 0.1) ** 2 / (0.24 + 0.14 + 0.4 + 1) + expected_bias_gain = 1.6235955056179772 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(2, split_node.feature_id) + + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + # Check the split on partition 1. + # (-8 + 0.1) / (0.26 + 1) + expected_left_weight = -6.26984126984 + # (-8 + 0.1) ** 2 / (0.26 + 1) + expected_left_gain = 49.5317460317 + expected_right_weight = 0 + expected_right_gain = 0 + # (-8 + 0.1) ** 2 / (0.26 + 1) + expected_bias_gain = 49.5317460317 + + # Verify candidate for partition 1, there's only one active feature here + # so zero gain is expected. + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + self.assertAllClose(0.0, gains[1], 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(1, split_node.feature_id) + def testGenerateFeatureSplitCandidatesMulticlass(self): with self.test_session() as sess: # Batch size is 4, 2 gradients per each instance. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f06b73c00d0bebb2717a79b7894e2addf914daba..df0bec1fe363e07bbff6b059e86076239bd605e9 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,8 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops +from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops @@ -72,9 +74,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -95,6 +99,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -113,6 +118,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(InequalitySplitHandler, self).__init__( @@ -124,17 +131,21 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): feature_column_group_id=feature_column_group_id, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=multiclass_strategy) + multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction) self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( init_stamp_token, gradient_shape, hessian_shape, name="StatsAccumulator/{}".format(self._name)) - self._quantile_accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token, - epsilon=epsilon, - num_quantiles=num_quantiles, - name="QuantileAccumulator/{}".format(self._name)) + # Allocate both stats accumulator and quantile accumulator on the same + # device so that we can build splits with fewer RPCs. + with ops.colocate_with(self._stats_accumulator.resource()): + self._quantile_accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token, + epsilon=epsilon, + num_quantiles=num_quantiles, + name="QuantileAccumulator/{}".format(self._name)) class DenseSplitHandler(InequalitySplitHandler): @@ -153,6 +164,7 @@ class DenseSplitHandler(InequalitySplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -172,6 +184,8 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(DenseSplitHandler, self).__init__( @@ -186,7 +200,8 @@ class DenseSplitHandler(InequalitySplitHandler): name=name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=multiclass_strategy) + multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction) self._dense_float_column = dense_float_column # Register dense_make_stats_update function as an Op to the graph. g = ops.get_default_graph() @@ -236,45 +251,77 @@ class DenseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - # Get the aggregated gradients and hessians per - # pair. - # In order to distribute the computation on all the PSs we use the PS that - # had the stats accumulator on. - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - - partition_ids, gains, split_infos = ( - split_handler_ops.build_dense_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_dense_split_scalar + else: + handler = make_dense_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight, self._loss_uses_sum_reduction)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_dense_split( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): + """Function that builds splits for a dense feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + # For sum_reduction, we don't need to divide by number of minibatches. + num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, + lambda: math_ops.to_int64(1), + lambda: num_minibatches) + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_dense_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos class SparseSplitHandler(InequalitySplitHandler): @@ -293,6 +340,7 @@ class SparseSplitHandler(InequalitySplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -312,6 +360,8 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(SparseSplitHandler, self).__init__( @@ -326,10 +376,8 @@ class SparseSplitHandler(InequalitySplitHandler): hessian_shape=hessian_shape, multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction, name=name) - # Register sparse_make_stats_update function as an Op to the graph. - g = ops.get_default_graph() - sparse_make_stats_update.add_to_graph(g) self._sparse_float_column = sparse_float_column def scheduled_reads(self): @@ -361,8 +409,8 @@ class SparseSplitHandler(InequalitySplitHandler): are_buckets_ready, buckets = scheduled_reads[0] with ops.name_scope(self._name, "SparseSplitHandler"): (quantile_indices, quantile_values, quantile_shapes, quantile_weights, - example_partition_ids, - feature_ids, gradients, hessians) = sparse_make_stats_update( + example_partition_ids, feature_ids, gradients, + hessians) = sparse_make_stats_update( is_active, are_buckets_ready, self._sparse_float_column.indices, self._sparse_float_column.values, self._sparse_float_column.dense_shape, buckets, @@ -379,42 +427,118 @@ class SparseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_sparse_split_scalar + else: + handler = make_sparse_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight, self._loss_uses_sum_reduction)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_sparse_split( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): + """Function that builds splits for a sparse feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, + lambda: math_ops.to_int64(1), + lambda: num_minibatches) + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_make_split(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.bool, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, loss_uses_sum_reduction): + """Function that builds splits for a sparse feature column.""" + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction) + + return f + +make_dense_split_scalar = _specialize_make_split(_make_dense_split, + is_multi_dimentional=False) +make_dense_split_tensor = _specialize_make_split(_make_dense_split, + is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=True) @function.Defun( @@ -540,8 +664,9 @@ def sparse_make_stats_update( empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( - [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( - [0, 1], dtype=dtypes.int64), empty_float) + [], dtype=dtypes.int64, shape=[0, 2]), empty_float, + constant_op.constant([0, 1], dtype=dtypes.int64), + empty_float) handler_active = (sparse_column_indices, sparse_column_values, sparse_column_shape, weights) quantile_indices, quantile_values, quantile_shape, quantile_weights = ( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 54d03018d9e266beabbbabd78ebbb80cfe689c04..d59732cf92eb85e88732ac5a17dccf475ae5342f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 @@ -65,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.scalar() split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -178,6 +182,144 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 1 | + # i1 | (-0.5, 0.07) | 0 | 1 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.2, + l2_regularization=2., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_3 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2, update_3]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + # Check the split on partition 0. + # -(2.4 - 0.2) / (0.4 + 2) + expected_left_weight = -0.91666 + + # expected_left_weight * -(2.4 - 0.2) + expected_left_gain = 2.016666666666666 + + # -(-1 + 0.4 + 0.2) / (0.38 + 2) + expected_right_weight = 0.1680672 + + # expected_right_weight * -(-1 + 0.4 + 0.2) + expected_right_gain = 0.0672268907563025 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain = 0.9208633093525178 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.dense_float_binary_split + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + + # Check the split on partition 1. + # (-8 + 0.2) / (0.26 + 2) + expected_left_weight = -3.4513274336283186 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.dense_float_binary_split + self.assertAllClose(0.0, gains[1], 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) @@ -199,10 +341,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2, 2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -227,7 +369,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -240,7 +384,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -285,10 +429,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -313,7 +457,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -326,7 +471,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -369,9 +514,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -396,7 +541,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -409,7 +555,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -443,9 +589,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, - min_node_weight=0, + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -470,7 +616,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -483,7 +630,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -576,7 +723,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, min_node_weight=1.5, epsilon=0.001, @@ -603,7 +750,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -616,7 +764,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -685,10 +833,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -713,8 +861,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] - + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -727,7 +875,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -788,6 +936,139 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Sparse Quantile | + # i0 | (0.2, 0.12) | 0 | 1 | + # i1 | (-0.5, 0.07) | 0 | N/A | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + example_partitions = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.52, 0.3, 0.52]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=4.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_3 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2, update_3]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + # Check the split on partition 0. + # -(0.4 + 2.4) / (0.24 + 0.4 + 4) + expected_left_weight = -0.603448275862069 + # (0.4 + 2.4) ** 2 / (0.24 + 0.4 + 4) + expected_left_gain = 1.689655172413793 + # 1 / (0.14 + 4) + expected_right_weight = 0.24154589371980678 + # 1 ** 2 / (0.14 + 4) + expected_right_gain = 0.24154589371980678 + # (0.4 + 2.4 - 1) ** 2 / (0.24 + 0.4 + 0.14 + 4) + expected_bias_gain = 0.6778242677824265 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_right + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.52, split_node.split.threshold) + + # Check the split on partition 1. + expected_left_weight = -1.8779342723004695 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertAllClose(0.0, gains[1]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.52, split_node.split.threshold) + def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes @@ -811,10 +1092,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -839,7 +1120,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -853,7 +1135,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -905,10 +1187,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -933,7 +1215,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -947,7 +1230,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -996,10 +1279,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1024,7 +1307,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1038,7 +1322,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1065,10 +1349,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1096,7 +1380,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1110,7 +1395,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1138,10 +1423,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1166,7 +1451,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1180,7 +1466,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 43b00d4c6dc2e0066810012292874314215c41be..c9223afeab233497bce9f680bd44bd10ccfc6491 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict( const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions) { + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index) { // Zero out predictions as the model is additive. output_predictions.setZero(); @@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict( // Lambda for doing a block of work. auto update_predictions = [&config, &features, &trees_to_include, - &output_predictions](int64 start, int64 end) { + &output_predictions, + &output_leaf_index](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); + Tensor dummy_tensor(DT_INT32, TensorShape({1, 1})); + tensorflow::TTypes::Matrix output_leaf_index_mat = + output_leaf_index != nullptr ? output_leaf_index->matrix() + : dummy_tensor.matrix(); for (const auto& example : examples_iterable) { for (const int32 tree_idx : trees_to_include) { const boosted_trees::trees::DecisionTreeConfig& tree = @@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict( const float tree_weight = config.tree_weights(tree_idx); const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + // Checks if output leaf tree index is required. + if (output_leaf_index != nullptr) { + output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx; + } const auto& leaf_node = tree.nodes(leaf_idx); QCHECK(leaf_node.has_leaf()) << "Invalid leaf node: " << leaf_node.DebugString(); diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f..940531c4ba4bcac19fa980deb091e55b48e0693b 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -33,12 +33,17 @@ class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates // output predictions accordingly, for the given list of trees. + // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not + // nullptr, this method fills output_leaf_indices with a per-tree leaf id + // where each of the instances from 'features' ended up in. Its shape is num + // examples X num of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions); + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 4ca18bedb1054ef64c6d4b25bbad04842bab1a6a..462a9ac86fe51d07cfb958d9be49bef84811a52e 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); } @@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } + // Normal case with leaf node. + { + // Initialize output leaf index tensor, since leaf index is positive in this + // case, initialize with the value of -1. Since there are 2 examples and + // there are 2 trees, initialize leaf output index by 2 * 2. + Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2})); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix, + &output_leaf_index_tensor); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 0, 0)); // 1st leaf for the first example + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 1, 0)); // 1st leaf for the second example + EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix()( + 0, 1)); // 2nd leaf for the first example + EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix()( + 1, 1)); // 2nd leaf for the second example + } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). @@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Drop first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). } // Drop second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). } // Drop all trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); } @@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) @@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) @@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) @@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Drop both trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); @@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 8ad97fedc923ac50bcaad86e0ba2c2e46df6821b..c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -295,7 +295,7 @@ WeightedQuantilesStream::GetQuantileSpecs( if (eps <= std::numeric_limits::epsilon()) { // Exact quantile computation at the expense of RAM. max_level = 1; - block_size = std::max(max_elements, 2LL); + block_size = std::max(max_elements, int64{2}); } else { // The bottom-most level will become full at most // (max_elements / block_size) times, the level above will become full @@ -315,7 +315,7 @@ WeightedQuantilesStream::GetQuantileSpecs( block_size = static_cast(ceil(max_level / eps)) + 1; } } - return std::make_tuple(max_level, std::max(block_size, 2LL)); + return std::make_tuple(max_level, std::max(block_size, int64{2})); } } // namespace quantiles diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 7576856dc3a6d0b6681ee9745c875cf46d1e2960..a7e7bfc13cadcea4d29d33e0dbd955bdad6ffcb9 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -195,7 +195,7 @@ class WeightedQuantilesSummary { // designed to be cache-friendly. void Compress(int64 size_hint, double min_eps = 0) { // No-op if we're already within the size requirement. - size_hint = std::max(size_hint, 2LL); + size_hint = std::max(size_hint, int64{2}); if (entries_.size() <= size_hint) { return; } @@ -267,7 +267,7 @@ class WeightedQuantilesSummary { if (entries_.empty()) { return output; } - num_quantiles = std::max(num_quantiles, 2LL); + num_quantiles = std::max(num_quantiles, int64{2}); output.reserve(num_quantiles + 1); // Make successive rank queries to get boundaries. diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index d66f645f62aba84261337eb37d6e3204930f8f15..6491d58794332e9417951753532e018aafb652b1 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { return Status::OK(); } +static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) { + string learner_config_str; + c->GetAttr("learner_config", &learner_config_str).IgnoreError(); + LearnerConfig learner_config; + ParseProtoUnlimited(&learner_config, learner_config_str); + + bool reduce_dim; + c->GetAttr("reduce_dim", &reduce_dim).IgnoreError(); + // Sets the shape of the output as a matrix. + c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, + reduce_dim ? learner_config.num_classes() - 1 + : learner_config.num_classes())}); + c->set_output(1, {c->UnknownShape()}); + c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)}); + return Status::OK(); +} + REGISTER_OP("GradientTreesPrediction") .Attr("learner_config: string") .Attr("num_dense_float_features: int >= 0") @@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); +REGISTER_OP("GradientTreesPredictionVerbose") + .Attr("learner_config: string") + .Attr("num_dense_float_features: int >= 0") + .Attr("num_sparse_float_features: int >= 0") + .Attr("num_sparse_int_features: int >= 0") + .Attr("use_locking: bool = false") + .Attr("apply_dropout: bool") + .Attr("apply_averaging: bool") + .Attr("center_bias: bool") + .Attr("reduce_dim: bool") + .Input("tree_ensemble_handle: resource") + .Input("seed: int64") + .Input("dense_float_features: num_dense_float_features * float") + .Input("sparse_float_feature_indices: num_sparse_float_features * int64") + .Input("sparse_float_feature_values: num_sparse_float_features * float") + .Input("sparse_float_feature_shapes: num_sparse_float_features * int64") + .Input("sparse_int_feature_indices: num_sparse_int_features * int64") + .Input("sparse_int_feature_values: num_sparse_int_features * int64") + .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") + .Output("predictions: float") + .Output("drop_out_tree_indices_weights: float") + .Output("leaf_index: int32") + .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn) + .Doc(R"doc( +Runs multiple additive regression forests predictors on input instances +and computes the final prediction for each class, and outputs a matrix of +leaf ids per each tree in an ensemble. + +learner_config: Config for the learner of type LearnerConfig proto. Prediction +ops for now uses only LearningRateDropoutDrivenConfig config from the learner. +num_dense_float_features: Number of dense float features. +num_sparse_float_features: Number of sparse float features. +num_sparse_int_features: Number of sparse int features. +use_locking: Whether to use locking. +seed: random seed to be used for dropout. +reduce_dim: whether to reduce the dimension (legacy impl) or not. +apply_dropout: whether to apply dropout during prediction. +apply_averaging: whether averaging of tree ensembles should take place. If set +to true, will be based on AveragingConfig from learner_config. +tree_ensemble_handle: The handle to the tree ensemble. +dense_float_features: Rank 2 Tensors containing dense float feature values. +sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices. +sparse_float_feature_values: Rank 1 Tensors containing sparse float values. +sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes. +sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. +sparse_int_feature_values: Rank 1 Tensors containing sparse int values. +sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. +predictions: Rank 2 Tensor containing predictions per example per class. +drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices +leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up. +)doc"); + REGISTER_OP("GradientTreesPartitionExamples") .Attr("num_dense_float_features: int >= 0") .Attr("num_sparse_float_features: int >= 0") diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 5d0ebbf73ce1272b51a475f67984db3a181b7130..ca5c7f3d8c78a543c63fbfa9f7eb7c3d348f11b8 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -23,12 +23,6 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("BuildDenseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildSparseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildCategoricalEqualitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("feature_ids: int64") .Input("gradients: float32") .Input("hessians: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs. feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); } // namespace tensorflow - // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 7a5f329b7ab3216972180ccbb4c85f2537175422..843420968ac6a6716fdf6b4967146e131139f67c 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc import collections +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -60,6 +62,7 @@ def _move_tensors(tensors, device): """Moves a list of tensors to a device by concatenating/splitting them.""" # Reset the device setting to avoid weird interactions with device merging # logic. + zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): with ops.device(tensors[0].device): @@ -68,12 +71,11 @@ def _move_tensors(tensors, device): return array_ops.unstack(values) else: with ops.device(tensors[0].device): - sizes = array_ops.stack( - [array_ops.shape(tensor)[0] for tensor in tensors]) - values = array_ops.concat(tensors, axis=0) + sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0] + values = array_ops.concat(tensors, axis=zero) with ops.device(device): sizes = array_ops.unstack(sizes) - return list(array_ops.split(values, sizes, axis=0)) + return list(array_ops.split(values, sizes, axis=zero)) def _scheduled_stamp_resource_op_runner(batch, stamp): diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index 58f0d36b0f78eeed6abcec1c4fa696f4ccffa615..7f6e55ae5888fc4ef50e34690d61c3ed303e971a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -21,4 +21,5 @@ from __future__ import print_function from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose # pylint: enable=unused-import diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 50cc00afdcc77fedc9bf8c94a9a6fcf2a28ebde9..19b6b3296db394b07f57a25dbde187eb9195af38 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result + + def resource(self): + return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e53d86ec612f299c800753d67ceee79acb5db497..1ee7f2395ea2ad71a7d380a1cc8f9a77bd4782b3 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import device_setter @@ -58,8 +59,16 @@ NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" NUM_USED_HANDLERS = "num_used_handlers" USED_HANDLERS_MASK = "used_handlers_mask" +LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" +# Keys in Training state. +GBDTTrainingState = collections.namedtuple("GBDTTrainingState", [ + "num_layer_examples", "num_layer_steps", "num_layers", "active_tree", + "active_layer", "continue_centering", "bias_stats_accumulator", + "steps_accumulator", "handlers" +]) + def _get_column_by_index(tensor, indices): """Returns columns from a 2-D tensor by index.""" @@ -71,18 +80,24 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, - used_handlers): +def _make_predictions_dict(stamp, + logits, + partition_ids, + ensemble_stats, + used_handlers, + leaf_index=None): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. - logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - that contains predictions when no dropout was applied. + logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that + contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a - boolean mask.. + boolean mask. + leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that + contains leaf id for each example prediction. Returns: A dict of predictions. @@ -95,6 +110,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask + if leaf_index is not None: + result[LEAF_INDEX] = leaf_index return result @@ -180,8 +197,7 @@ def extract_features(features, feature_columns, use_core_columns): elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( - features, [fc], - weight_collections=[scope]) + features, [fc], weight_collections=[scope]) else: result = feature_column_ops.transform_features(features, [fc]) if len(result) > 1: @@ -268,8 +284,10 @@ class GradientBoostedDecisionTreeModel(object): learner_config, features, logits_dimension, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS, feature_columns=None, - use_core_columns=False): + use_core_columns=False, + output_leaf_index=False): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -277,13 +295,18 @@ class GradientBoostedDecisionTreeModel(object): num_ps_replicas: Number of parameter server replicas, can be 0. ensemble_handle: A handle to the ensemble variable. center_bias: Whether to center the bias before growing trees. - examples_per_layer: Number of examples to accumulate before growing - a tree layer. It can also be a function that computes the number of - examples based on the depth of the layer that's being built. + examples_per_layer: Number of examples to accumulate before growing a tree + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. + loss_reduction: Either `SUM_OVER_NONZERO_WEIGHTS` (mean) or `SUM`. feature_columns: A list of feature columns. + use_core_columns: A boolean specifying whether core feature columns are + used. + output_leaf_index: A boolean variable indicating whether to output leaf + index into predictions dictionary. Raises: ValueError: if inputs are not valid. @@ -304,6 +327,13 @@ class GradientBoostedDecisionTreeModel(object): self._center_bias = center_bias self._examples_per_layer = examples_per_layer + # Check loss reduction value. + if (loss_reduction != losses.Reduction.SUM and + loss_reduction != losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + raise ValueError( + "Invalid loss reduction is provided: %s." % loss_reduction) + self._loss_reduction = loss_reduction + # Fill in the defaults. if (learner_config.multi_class_strategy == learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED): @@ -314,6 +344,19 @@ class GradientBoostedDecisionTreeModel(object): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + if logits_dimension == 1 or learner_config.multi_class_strategy == ( + learner_pb2.LearnerConfig.TREE_PER_CLASS): + self._gradient_shape = tensor_shape.scalar() + self._hessian_shape = tensor_shape.scalar() + else: + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) + if (learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.FULL_HESSIAN): + self._hessian_shape = tensor_shape.TensorShape( + ([logits_dimension, logits_dimension])) + else: + # Diagonal hessian strategy. + self._hessian_shape = tensor_shape.TensorShape(([logits_dimension])) if (learner_config.growing_mode == learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER @@ -334,10 +377,12 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") @@ -346,6 +391,7 @@ class GradientBoostedDecisionTreeModel(object): sparse_int_values, sparse_int_shapes) = extract_features( features, self._feature_columns, use_core_columns) logging.info("Active Feature Columns: " + str(fc_names)) + logging.info("Learner config: " + str(learner_config)) self._fc_names = fc_names self._dense_floats = dense_floats self._sparse_float_indices = sparse_float_indices @@ -354,9 +400,11 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_int_indices = sparse_int_indices self._sparse_int_values = sparse_int_values self._sparse_int_shapes = sparse_int_shapes - self._reduce_dim = (self._learner_config.multi_class_strategy == - learner_pb2.LearnerConfig.TREE_PER_CLASS and - learner_config.num_classes == 2) + self._reduce_dim = ( + self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + learner_config.num_classes == 2) + self._output_leaf_index = output_leaf_index def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -374,8 +422,8 @@ class GradientBoostedDecisionTreeModel(object): ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) num_handlers = ( - len(self._dense_floats) + len(self._sparse_float_shapes) + - len(self._sparse_int_shapes)) + len(self._dense_floats) + len(self._sparse_float_shapes) + len( + self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) @@ -386,22 +434,44 @@ class GradientBoostedDecisionTreeModel(object): # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): - predictions, _ = prediction_ops.gradient_trees_prediction( - ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=mode != learn.ModeKeys.TRAIN, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim) + leaf_index = None + # Only used in infer (predict), not used in train and eval. + if self._output_leaf_index and mode == learn.ModeKeys.INFER: + predictions, _, leaf_index = ( + prediction_ops).gradient_trees_prediction_verbose( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + else: + leaf_index = None + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, @@ -414,7 +484,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats, used_handlers) + ensemble_stats, used_handlers, leaf_index) def predict(self, mode): """Returns predictions given the features and mode. @@ -432,8 +502,9 @@ class GradientBoostedDecisionTreeModel(object): # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") @@ -457,8 +528,8 @@ class GradientBoostedDecisionTreeModel(object): # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): - # Serialize the model from parameter server after reading all inputs. - with ops.control_dependencies(input_deps): + # Serialize the model from parameter server after reading the inputs. + with ops.control_dependencies([input_deps[0]]): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) @@ -484,24 +555,38 @@ class GradientBoostedDecisionTreeModel(object): return self._predict_and_return_dict(self._ensemble_handle, ensemble_stamp, mode) - def train(self, loss, predictions_dict, labels): - """Grows a new tree and adds it to the ensemble. + def _get_class_id(self, predictions_dict): + # Handle different multiclass strategies. + if (self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + self._logits_dimension != 1): + # Choose the class for which the tree is built (one vs rest). + return math_ops.to_int32( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) + return constant_op.constant(-1, dtype=dtypes.int32) + + def update_stats(self, loss, predictions_dict): + """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. - labels: Rank 2 `Tensor` representing labels per example. Returns: - An op that adds a new tree to the ensemble. + Three values: + - An op that adds a new tree to the ensemble, and + - An op that increments the stamp but removes all the trees and resets + the handlers. This can be used to reset the state of the ensemble. + - A dict containing the training state. Raises: ValueError: if inputs are not valid. """ # Get the worker device from input dependencies. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) worker_device = input_deps[0].device # Get tensors relevant for training and form the loss. @@ -517,13 +602,10 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = -1 + class_id = self._get_class_id(predictions_dict) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() - if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( @@ -540,11 +622,6 @@ class GradientBoostedDecisionTreeModel(object): hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) - - # Choose the class for which the tree is built (one vs rest). - class_id = math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) - # Use class id tensor to get the column with that index from gradients # and hessians. squeezed_gradients = array_ops.squeeze( @@ -553,15 +630,10 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) - if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape( - ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -569,34 +641,47 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = hessians # Get the weights for each example for quantiles calculation, - weights = self._get_weights(hessian_shape, squeezed_hessians) + weights = self._get_weights(self._hessian_shape, squeezed_hessians) - regularization_config = self._learner_config.regularization - min_node_weight = self._learner_config.constraints.min_node_weight # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) + l1_regularization = constant_op.constant( + self._learner_config.regularization.l1, dtypes.float32) + l2_regularization = constant_op.constant( + self._learner_config.regularization.l2, dtypes.float32) + tree_complexity_regularization = constant_op.constant( + self._learner_config.regularization.tree_complexity, dtypes.float32) + min_node_weight = constant_op.constant( + self._learner_config.constraints.min_node_weight, dtypes.float32) + loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM + loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) + epsilon = 0.01 + num_quantiles = 100 + strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.DenseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=dense_float_column_idx, - epsilon=0.01, - num_quantiles=100, + feature_column_group_id=constant_op.constant( + dense_float_column_idx), + epsilon=epsilon, + num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - multiclass_strategy=strategy, - init_stamp_token=init_stamp_token)) + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + multiclass_strategy=strategy_tensor, + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction, + )) fc_name_idx += 1 # Create handlers for sparse float columns. @@ -604,23 +689,24 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.SparseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=sparse_float_column_idx, - epsilon=0.01, - num_quantiles=100, + feature_column_group_id=constant_op.constant( + sparse_float_column_idx), + epsilon=epsilon, + num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( self._sparse_float_indices[sparse_float_column_idx], self._sparse_float_values[sparse_float_column_idx], self._sparse_float_shapes[sparse_float_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - multiclass_strategy=strategy, - init_stamp_token=init_stamp_token)) + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + multiclass_strategy=strategy_tensor, + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction)) fc_name_idx += 1 # Create handlers for sparse int columns. @@ -628,37 +714,24 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( categorical_split_handler.EqualitySplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=sparse_int_column_idx, + feature_column_group_id=constant_op.constant( + sparse_int_column_idx), sparse_int_column=sparse_tensor.SparseTensor( self._sparse_int_indices[sparse_int_column_idx], self._sparse_int_values[sparse_int_column_idx], self._sparse_int_shapes[sparse_int_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - multiclass_strategy=strategy, - init_stamp_token=init_stamp_token)) + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + multiclass_strategy=strategy_tensor, + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction)) fc_name_idx += 1 - # Create steps accumulator. - steps_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar(), - name="StepsAccumulator") - - # Create bias stats accumulator. - bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - name="BiasAccumulator") - # Create ensemble stats variables. num_layer_examples = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), @@ -680,7 +753,23 @@ class GradientBoostedDecisionTreeModel(object): initial_value=array_ops.zeros([], dtypes.int64), name="active_layer", trainable=False) - + # Variable that becomes false once bias centering is done. + continue_centering = variables.Variable( + initial_value=self._center_bias, + name="continue_centering", + trainable=False) + # Create bias stats accumulator. + bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + name="BiasAccumulator") + # Create steps accumulator. + steps_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar(), + name="StepsAccumulator") # Create ensemble stats summaries. summary.scalar("layer_stats/num_examples", num_layer_examples) summary.scalar("layer_stats/num_steps", num_layer_steps) @@ -689,16 +778,13 @@ class GradientBoostedDecisionTreeModel(object): # Update bias stats. stats_update_ops = [] - continue_centering = variables.Variable( - initial_value=self._center_bias, - name="continue_centering", - trainable=False) + stats_update_ops.append( - control_flow_ops.cond(continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), - control_flow_ops.no_op)) + control_flow_ops.cond( + continue_centering, + self._make_update_bias_stats_fn( + ensemble_stamp, predictions, gradients, + bias_stats_accumulator), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -720,8 +806,8 @@ class GradientBoostedDecisionTreeModel(object): shape=[len(handlers)], seed=[seed + 1, 1]) active_handlers = array_ops.stack( [active_handlers_current_layer, active_handlers_next_layer], axis=1) - active_handlers = (active_handlers < - self._learner_config.feature_fraction_per_level) + active_handlers = ( + active_handlers < self._learner_config.feature_fraction_per_level) elif subsampling_type == "feature_fraction_per_tree": seed = predictions_dict[NUM_TREES_ATTEMPTED] active_handlers_current_layer = stateless.stateless_random_uniform( @@ -729,9 +815,12 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack([ - active_handlers_current_layer, - array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) + active_handlers = array_ops.stack( + [ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool) + ], + axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) @@ -752,14 +841,15 @@ class GradientBoostedDecisionTreeModel(object): lambda: active_handlers)) # Prepare empty gradients and hessians when handlers are not ready. - empty_hess_shape = [1] + hessian_shape.as_list() - empty_grad_shape = [1] + gradient_shape.as_list() + empty_hess_shape = [1] + self._hessian_shape.as_list() + empty_grad_shape = [1] + self._gradient_shape.as_list() empty_gradients = constant_op.constant( [], dtype=dtypes.float32, shape=empty_grad_shape) empty_hessians = constant_op.constant( [], dtype=dtypes.float32, shape=empty_hess_shape) + active_handlers = array_ops.unstack(active_handlers, axis=0) for handler_idx in range(len(handlers)): handler = handlers[handler_idx] is_active = active_handlers[handler_idx] @@ -774,34 +864,86 @@ class GradientBoostedDecisionTreeModel(object): per_handler_updates, ensemble_stamp, worker_device) for update in update_results.values(): stats_update_ops += update + + training_state = GBDTTrainingState( + num_layer_examples=num_layer_examples, + num_layer_steps=num_layer_steps, + num_layers=num_layers, + active_tree=active_tree, + active_layer=active_layer, + continue_centering=continue_centering, + bias_stats_accumulator=bias_stats_accumulator, + steps_accumulator=steps_accumulator, + handlers=handlers) + + reset_op = control_flow_ops.no_op() + if self._is_chief: + # Advance the ensemble stamp to throw away staggered workers. + stamp_token, _ = model_ops.tree_ensemble_serialize(self._ensemble_handle) + next_stamp_token = stamp_token + 1 + + reset_ops = [] + for handler in handlers: + reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0)) + if self._center_bias: + reset_ops.append( + bias_stats_accumulator.flush(stamp_token, next_stamp_token)) + reset_ops.append(steps_accumulator.flush(stamp_token, next_stamp_token)) + reset_ops.append(self._finalized_trees.assign(0).op) + reset_ops.append(self._attempted_trees.assign(0).op) + reset_ops.append( + model_ops.tree_ensemble_deserialize( + self._ensemble_handle, + stamp_token=next_stamp_token, + tree_ensemble_config="", + name="reset_gbdt")) + + reset_op = control_flow_ops.group([reset_ops]) + + return stats_update_ops, reset_op, training_state + + def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict, + training_state): + """Increments number of visited examples and grows the ensemble. + + If the number of visited examples reaches the target examples_per_layer, + ensemble is updated. + + Args: + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + training_state: `dict` returned by update_stats. + + Returns: + An op that updates the counters and potientially grows the ensemble. + """ + batch_size = math_ops.cast( + array_ops.shape(predictions_dict[PREDICTIONS])[0], dtypes.float32) + ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] # Accumulate a step after updating stats. - batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) - with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], - [batch_size], [1.0]) - # Determine learning rate. - learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( - "tuner") - if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": - tuner = getattr(self._learner_config.learning_rate_tuner, - learning_rate_tuner) - learning_rate = tuner.learning_rate - else: - # TODO(nponomareva, soroush) do the line search. - raise ValueError("Line search learning rate is not yet supported.") + steps_accumulator = training_state.steps_accumulator + num_layer_examples = training_state.num_layer_examples + num_layer_steps = training_state.num_layer_steps + active_layer = training_state.active_layer + add_step_op = steps_accumulator.add( + ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # After adding the step, decide if further processing is needed. ensemble_update_ops = [add_step_op] + class_id = self._get_class_id(predictions_dict) + with ops.control_dependencies([add_step_op]): if self._is_chief: dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED] # Get accumulated steps and examples for the current layer. - _, _, _, _, acc_examples, acc_steps = steps_accumulator.serialize() + _, _, _, _, acc_examples, acc_steps = ( + steps_accumulator.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) - ensemble_update_ops.append(num_layer_examples.assign(acc_examples)) + ensemble_update_ops.append( + num_layer_examples.assign(acc_examples)) ensemble_update_ops.append(num_layer_steps.assign(acc_steps)) # Determine whether we need to update tree ensemble. examples_per_layer = self._examples_per_layer @@ -810,18 +952,172 @@ class GradientBoostedDecisionTreeModel(object): ensemble_update_ops.append( control_flow_ops.cond( acc_examples >= examples_per_layer, - self._make_update_ensemble_fn( - ensemble_stamp, steps_accumulator, bias_stats_accumulator, - continue_centering, learning_rate, handlers, num_layers, - active_tree, active_layer, dropout_seed, class_id), + self.make_update_ensemble_fn(ensemble_stamp, training_state, + dropout_seed, class_id), control_flow_ops.no_op)) - # Calculate the loss to be reported. # Note, the loss is calculated from the prediction considering dropouts, so # that the value might look staggering over steps when the dropout ratio is # high. eval_loss might be referred instead in the aspect of convergence. return control_flow_ops.group(*ensemble_update_ops) + def make_update_ensemble_fn(self, ensemble_stamp, training_state, + dropout_seed, class_id): + """A method to create the function which updates the tree ensemble.""" + # Determine learning rate. + learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( + "tuner") + if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": + tuner = getattr(self._learner_config.learning_rate_tuner, + learning_rate_tuner) + learning_rate = tuner.learning_rate + else: + # TODO(nponomareva, soroush) do the line search. + raise ValueError("Line search learning rate is not yet supported.") + + def _update_ensemble(): + """A method to update the tree ensemble.""" + # Get next stamp token. + next_ensemble_stamp = ensemble_stamp + 1 + # Finalize bias stats. + _, _, _, bias_grads, bias_hess = ( + training_state.bias_stats_accumulator.flush(ensemble_stamp, + next_ensemble_stamp)) + + # Finalize handler splits. + are_splits_ready_list = [] + partition_ids_list = [] + gains_list = [] + split_info_list = [] + + for handler in training_state.handlers: + (are_splits_ready, + partition_ids, gains, split_info) = handler.make_splits( + ensemble_stamp, next_ensemble_stamp, class_id) + are_splits_ready_list.append(are_splits_ready) + partition_ids_list.append(partition_ids) + gains_list.append(gains) + split_info_list.append(split_info) + # Stack all the inputs to one tensor per type. + # This is a workaround for the slowness of graph building in tf.cond. + # See (b/36554864). + split_sizes = array_ops.reshape( + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) + partition_ids = array_ops.concat(partition_ids_list, axis=0) + gains = array_ops.concat(gains_list, axis=0) + split_infos = array_ops.concat(split_info_list, axis=0) + + # Determine if all splits are ready. + are_all_splits_ready = math_ops.reduce_all( + array_ops.stack( + are_splits_ready_list, axis=0, name="stack_handler_readiness")) + + # Define bias centering update operation. + def _center_bias_fn(): + # Center tree ensemble bias. + delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, + array_ops.zeros_like(bias_grads)) + center_bias = training_ops.center_tree_ensemble_bias( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + delta_updates=delta_updates, + learner_config=self._learner_config_serialized) + return training_state.continue_centering.assign(center_bias) + + # Define ensemble growing operations. + def _grow_ensemble_ready_fn(): + # Grow the ensemble given the current candidates. + sizes = array_ops.unstack(split_sizes) + partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) + gains_list = list(array_ops.split(gains, sizes, axis=0)) + split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) + return training_ops.grow_tree_ensemble( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + learning_rate=learning_rate, + partition_ids=partition_ids_list, + gains=gains_list, + splits=split_info_list, + learner_config=self._learner_config_serialized, + dropout_seed=dropout_seed, + center_bias=self._center_bias) + + def _grow_ensemble_not_ready_fn(): + # Don't grow the ensemble, just update the stamp. + return training_ops.grow_tree_ensemble( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + learning_rate=0, + partition_ids=[], + gains=[], + splits=[], + learner_config=self._learner_config_serialized, + dropout_seed=dropout_seed, + center_bias=self._center_bias) + + def _grow_ensemble_fn(): + # Conditionally grow an ensemble depending on whether the splits + # from all the handlers are ready. + return control_flow_ops.cond(are_all_splits_ready, + _grow_ensemble_ready_fn, + _grow_ensemble_not_ready_fn) + + # Update ensemble. + update_ops = [are_all_splits_ready] + if self._center_bias: + update_model = control_flow_ops.cond(training_state.continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() + update_ops.append(update_model) + + # Update ensemble stats. + with ops.control_dependencies([update_model]): + stats = training_ops.tree_ensemble_stats( + self._ensemble_handle, stamp_token=next_ensemble_stamp) + update_ops.append(self._finalized_trees.assign(stats.num_trees)) + update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) + update_ops.append(training_state.num_layers.assign(stats.num_layers)) + update_ops.append(training_state.active_tree.assign(stats.active_tree)) + update_ops.append( + training_state.active_layer.assign(stats.active_layer)) + + # Flush step stats. + update_ops.extend( + training_state.steps_accumulator.flush(ensemble_stamp, + next_ensemble_stamp)) + return control_flow_ops.group(*update_ops, name="update_ensemble") + + return _update_ensemble + + def get_number_of_trees_tensor(self): + return self._finalized_trees, self._attempted_trees + + def train(self, loss, predictions_dict, labels): + """Updates the accumalator stats and grows the ensemble. + + Args: + loss: A scalar tensor representing average loss of examples. + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + labels: Rank 2 `Tensor` representing labels per example. Has no effect + on the training and is only kept for backward compatibility. + + Returns: + An op that adds a new tree to the ensemble. + + Raises: + ValueError: if inputs are not valid. + """ + del labels # unused; kept for backward compatibility. + update_op, _, training_state = self.update_stats(loss, predictions_dict) + with ops.control_dependencies(update_op): + return self.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + def _get_weights(self, hessian_shape, hessians): """Derives weights to be used based on hessians and multiclass strategy.""" if hessian_shape == tensor_shape.scalar(): @@ -901,7 +1197,6 @@ class GradientBoostedDecisionTreeModel(object): "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - "QuantileStreamResourceHandleOp", ] ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( @@ -938,124 +1233,3 @@ class GradientBoostedDecisionTreeModel(object): return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") return _update_bias_stats - - def _make_update_ensemble_fn(self, ensemble_stamp, steps_accumulator, - bias_stats_accumulator, continue_centering, - learning_rate, handlers, num_layers, active_tree, - active_layer, dropout_seed, class_id): - """A method to create the function which updates the tree ensemble.""" - - def _update_ensemble(): - """A method to update the tree ensemble.""" - # Get next stamp token. - next_ensemble_stamp = ensemble_stamp + 1 - # Finalize bias stats. - _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush( - ensemble_stamp, next_ensemble_stamp) - - # Finalize handler splits. - are_splits_ready_list = [] - partition_ids_list = [] - gains_list = [] - split_info_list = [] - - for handler in handlers: - (are_splits_ready, - partition_ids, gains, split_info) = handler.make_splits( - ensemble_stamp, next_ensemble_stamp, class_id) - are_splits_ready_list.append(are_splits_ready) - partition_ids_list.append(partition_ids) - gains_list.append(gains) - split_info_list.append(split_info) - # Stack all the inputs to one tensor per type. - # This is a workaround for the slowness of graph building in tf.cond. - # See (b/36554864). - split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [-1]) - partition_ids = array_ops.concat(partition_ids_list, axis=0) - gains = array_ops.concat(gains_list, axis=0) - split_infos = array_ops.concat(split_info_list, axis=0) - - # Determine if all splits are ready. - are_all_splits_ready = math_ops.reduce_all( - array_ops.stack( - are_splits_ready_list, axis=0, name="stack_handler_readiness")) - - # Define bias centering update operation. - def _center_bias_fn(): - # Center tree ensemble bias. - delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, - array_ops.zeros_like(bias_grads)) - center_bias = training_ops.center_tree_ensemble_bias( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - delta_updates=delta_updates, - learner_config=self._learner_config_serialized) - return continue_centering.assign(center_bias) - - # Define ensemble growing operations. - def _grow_ensemble_ready_fn(): - # Grow the ensemble given the current candidates. - sizes = array_ops.unstack(split_sizes) - partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) - gains_list = list(array_ops.split(gains, sizes, axis=0)) - split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) - return training_ops.grow_tree_ensemble( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - learning_rate=learning_rate, - partition_ids=partition_ids_list, - gains=gains_list, - splits=split_info_list, - learner_config=self._learner_config_serialized, - dropout_seed=dropout_seed, - center_bias=self._center_bias) - - def _grow_ensemble_not_ready_fn(): - # Don't grow the ensemble, just update the stamp. - return training_ops.grow_tree_ensemble( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - learning_rate=0, - partition_ids=[], - gains=[], - splits=[], - learner_config=self._learner_config_serialized, - dropout_seed=dropout_seed, - center_bias=self._center_bias) - - def _grow_ensemble_fn(): - # Conditionally grow an ensemble depending on whether the splits - # from all the handlers are ready. - return control_flow_ops.cond(are_all_splits_ready, - _grow_ensemble_ready_fn, - _grow_ensemble_not_ready_fn) - - # Update ensemble. - update_ops = [are_all_splits_ready] - update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, - _grow_ensemble_fn) - update_ops.append(update_model) - - # Update ensemble stats. - with ops.control_dependencies([update_model]): - stats = training_ops.tree_ensemble_stats( - self._ensemble_handle, stamp_token=next_ensemble_stamp) - update_ops.append(self._finalized_trees.assign(stats.num_trees)) - update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) - update_ops.append(num_layers.assign(stats.num_layers)) - update_ops.append(active_tree.assign(stats.active_tree)) - update_ops.append(active_layer.assign(stats.active_layer)) - - # Flush step stats. - update_ops.extend( - steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp)) - return control_flow_ops.group(*update_ops, name="update_ensemble") - - return _update_ensemble - - def get_number_of_trees_tensor(self): - return self._finalized_trees, self._attempted_trees diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index f9c22283b7f5136777bfa60a12c94974adfbd245..f7867d882d6813a8701065ad0ce8d27f8bb9c301 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -19,20 +19,17 @@ from __future__ import division from __future__ import print_function from google.protobuf import text_format - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.boosted_trees.python.utils import losses - -from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn - - +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -97,8 +94,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_int"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros([2], dtypes.int64), - array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.int64), array_ops.zeros([2], + dtypes.int64)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, sparse_int_shapes) = ( @@ -139,8 +136,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_categorical"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros( - [2], dtypes.string), array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) feature_columns = set() feature_columns.add(layers.real_valued_column("dense_float")) feature_columns.add( @@ -235,7 +232,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -316,6 +314,113 @@ class GbdtTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): + """Tests the train function with sparse and dense features.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + features["sparse_float"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.float32), + array_ops.constant([4, 1], dtypes.int64)) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + sparse_float_binary_split_default_right { + split{ + left_id: 1 + right_id: 2 + } + } + node_metadata { + gain: 1.125 + } + } + nodes { + leaf { + vector { + value: 1.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -339,7 +444,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -442,7 +548,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -513,7 +620,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -576,7 +684,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -622,7 +731,8 @@ class GbdtTest(test_util.TensorFlowTestCase): with self.test_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" + text_format.Merge( + """ trees { nodes { leaf { @@ -659,15 +769,128 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL predictions_dict = sess.run(gbdt_model.predict(mode)) self.assertEquals(predictions_dict["ensemble_stamp"], 3) - self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25], - [0.25], [0.25]]) + self.assertAllClose(predictions_dict["predictions"], + [[0.25], [0.25], [0.25], [0.25]]) + self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + + def testPredictFnWithLeafIndexAdvancedLeft(self): + """Tests the predict function with output leaf ids.""" + with self.test_session() as sess: + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.15 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: 0.99 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 00 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.23 + } + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=3, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.constant( + [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features, + output_leaf_index=True) + + # Create predict op. + mode = model_fn.ModeKeys.INFER + predictions_dict = sess.run(gbdt_model.predict(mode)) + self.assertEquals(predictions_dict["ensemble_stamp"], 3) + # here are how the numbers in expected results are calculated, + # 0.5 = 0.25 + 0.25 + # 0.48 = 0.25 + 0.23 + # 0.38 = 0.15 + 0.23 + # 0.38 = 0.15 + 0.23 + self.assertAllClose(predictions_dict["predictions"], + [[0.5], [0.48], [0.38], [0.38]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + self.assertAllClose(predictions_dict["leaf_index"], + [[1, 1], [1, 2], [2, 2], [2, 2]]) def testTrainFnMulticlassFullHessian(self): """Tests the GBDT train for multiclass full hessian.""" @@ -698,7 +921,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -801,7 +1025,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -893,8 +1118,8 @@ class GbdtTest(test_util.TensorFlowTestCase): learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = { - "dense_float": array_ops.constant( - [[1.0], [1.5], [2.0]], dtypes.float32), + "dense_float": + array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32), } gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( @@ -904,7 +1129,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) batch_size = 3 predictions = array_ops.constant( @@ -986,7 +1212,8 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllClose( 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0], - atol=1e-4, rtol=1e-4) + atol=1e-4, + rtol=1e-4) def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): """Tests the train function running on chief with feature selection.""" @@ -1230,9 +1457,9 @@ class GbdtTest(test_util.TensorFlowTestCase): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() - _set_float_split(tree.nodes.add() - .sparse_float_binary_split_default_right.split, 2, 4.0, - 1, 2) + _set_float_split( + tree.nodes.add().sparse_float_binary_split_default_right.split, 2, + 4.0, 1, 2) _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) tree_ensemble_config.tree_weights.append(1.0) @@ -1241,7 +1468,8 @@ class GbdtTest(test_util.TensorFlowTestCase): metadata.num_layers_grown = 1 tree_ensemble_config = tree_ensemble_config.SerializeToString() ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, tree_ensemble_config=tree_ensemble_config, + stamp_token=0, + tree_ensemble_config=tree_ensemble_config, name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 @@ -1333,5 +1561,301 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEquals(output.growing_metadata.num_layers_attempted, 2) + def testResetModelBeforeAndAfterSplit(self): + """Tests whether resetting works.""" + with self.test_session(): + # First build a small tree and train it to verify training works. + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + "max_tree_depth": 4, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create train op. + update_op, reset_op, training_state = gbdt_model.update_stats( + loss, predictions_dict) + with ops.control_dependencies(update_op): + train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + original_stamp = ensemble_stamp.eval() + expected_tree = """ + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + }""" + + def _train_once_and_check(expect_split): + stamp = ensemble_stamp.eval() + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(stamp_token.eval(), stamp + 1) + if expect_split: + # State of the ensemble after a split occurs. + self.assertEquals(len(output.trees), 1) + self.assertProtoEquals(expected_tree, output.trees[0]) + else: + # State of the ensemble after a single accumulation but before any + # splitting occurs + self.assertEquals(len(output.trees), 0) + self.assertProtoEquals(""" + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + }""", output) + + def _run_reset(): + stamp_before_reset = ensemble_stamp.eval() + reset_op.run() + stamp_after_reset = ensemble_stamp.eval() + self.assertNotEquals(stamp_after_reset, stamp_before_reset) + + _, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertProtoEquals("", output) + + return stamp_after_reset + + # Exit after one train_op, so no new layer are created but the handlers + # contain enough information to split on the next call to train. + _train_once_and_check(expect_split=False) + self.assertEquals(ensemble_stamp.eval(), original_stamp + 1) + + # Reset the handlers so it still requires two training calls to split. + stamp_after_reset = _run_reset() + + _train_once_and_check(expect_split=False) + _train_once_and_check(expect_split=True) + self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2) + + # This time, test that the reset_op works right after splitting. + stamp_after_reset = _run_reset() + + # Test that after resetting, the tree can be trained as normal. + _train_once_and_check(expect_split=False) + _train_once_and_check(expect_split=True) + self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2) + + def testResetModelNonChief(self): + """Tests the reset function on a non-chief worker.""" + with self.test_session(): + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + leaf { + vector { + value: 0.25 + } + } + } + } + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: false + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create reset op. + _, reset_op, _ = gbdt_model.update_stats( + loss, predictions_dict) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # Reset op doesn't do anything because this is a non-chief worker. + reset_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertEquals(len(output.tree_weights), 1) + self.assertEquals(stamp_token.eval(), 0) + + def testResetModelWithCenterBias(self): + """Tests the reset function running on chief with bias centering.""" + with self.test_session(): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create train op. + update_op, reset_op, training_state = gbdt_model.update_stats( + loss, predictions_dict) + with ops.control_dependencies(update_op): + train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect bias to be centered. + def train_and_check(): + train_op.run() + _, serialized = model_ops.tree_ensemble_serialize(ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + expected_tree = """ + nodes { + leaf { + vector { + value: 0.25 + } + } + }""" + self.assertEquals(len(output.trees), 1) + self.assertAllEqual(output.tree_weights, [1.0]) + self.assertProtoEquals(expected_tree, output.trees[0]) + + train_and_check() + self.assertEquals(ensemble_stamp.eval(), 1) + + reset_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 2) + + train_and_check() + self.assertEquals(ensemble_stamp.eval(), 3) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index af8df72618b7255e182e98e6e4b96a0333b3dce6..2fbaa31d5e19b58c335cd0a894e1db9af2c34d08 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -16,13 +16,20 @@ Visualization and inspection: @@dot_graph_from_checkpoint +@@list_objects @@object_metadata -Creating and managing dependencies: +Managing dependencies: +@@capture_dependencies @@Checkpointable +@@CheckpointableBase @@CheckpointableObjectGraph @@NoDependency @@split_dependency + +Checkpointable data structures: +@@List +@@Mapping @@UniqueNameTracker """ @@ -34,8 +41,13 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph -from tensorflow.python.training.checkpointable.base import Checkpointable -from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.base import CheckpointableBase +from tensorflow.python.training.checkpointable.data_structures import List +from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.data_structures import NoDependency +from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.util import capture_dependencies +from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 53f4e97f9932104933b3ecf80142e5af82cd487a..7b200a29bf60087d6da1010b0be05c04faec80cd 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -11,6 +11,7 @@ py_library( ":containers", ":split_dependency", ":visualize", + "//tensorflow/python/training/checkpointable:data_structures", ], ) @@ -19,7 +20,10 @@ py_library( srcs = ["containers.py"], srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], - deps = ["//tensorflow/python/training/checkpointable:base"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:data_structures", + ], ) py_test( @@ -30,8 +34,8 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", "@six_archive//:six", ], ) @@ -44,6 +48,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", ], ) @@ -55,8 +60,9 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -67,6 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -75,10 +83,13 @@ py_test( srcs = ["visualize_test.py"], deps = [ ":visualize", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 9807abae1f5106bb84f858c3725f096aaa4eaca9..4d3d5312993740636709cb732c0b8e3e2626262d 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -18,9 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures -class UniqueNameTracker(checkpointable_lib.CheckpointableBase): +class UniqueNameTracker(data_structures.CheckpointableDataStructure): """Adds dependencies on checkpointable objects with name hints. Useful for creating dependencies with locally unique names. @@ -41,6 +42,7 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase): """ def __init__(self): + super(UniqueNameTracker, self).__init__() self._maybe_initialize_checkpointable() self._name_counts = {} @@ -74,4 +76,5 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - return self._track_checkpointable(checkpointable, name=candidate) + self._track_value(checkpointable, name=candidate) + return checkpointable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 851a80058852bd917aec075b4bf63264318603a7..ac85c7be803cd4c2f8ba19d3ef887a3c65a15933 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -22,15 +22,18 @@ import six from tensorflow.contrib.checkpoint.python import containers from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util class UniqueNameTrackerTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNames(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -46,11 +49,11 @@ class UniqueNameTrackerTests(test.TestCase): slots.track(y, "y") self.evaluate((x1.initializer, x2.initializer, x3.initializer, y.initializer)) - save_root = checkpointable_utils.Checkpoint(slots=slots) + save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = checkpointable.Checkpointable() - restore_root = checkpointable_utils.Checkpoint( + restore_slots = tracking.Checkpointable() + restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) restore_slots.x = resource_variable_ops.ResourceVariable(0.) @@ -63,9 +66,9 @@ class UniqueNameTrackerTests(test.TestCase): self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) self.assertEqual(5., self.evaluate(restore_slots.y)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(checkpointable.Checkpointable): + class SlotManager(tracking.Checkpointable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() @@ -77,15 +80,15 @@ class UniqueNameTrackerTests(test.TestCase): resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) - self.slots = slots + self.slots = data_structures.NoDependency(slots) manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) - checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint = util.Checkpoint(slot_manager=manager) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpoint.save(checkpoint_prefix) - metadata = checkpointable_utils.object_metadata(save_path) + metadata = util.object_metadata(save_path) dependency_names = [] for node in metadata.nodes: for child in node.children: @@ -95,5 +98,12 @@ class UniqueNameTrackerTests(test.TestCase): dependency_names, ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + @test_util.run_in_graph_and_eager_modes + def testLayers(self): + tracker = containers.UniqueNameTracker() + tracker.track(layers.Dense(3), "dense") + tracker.layers[0](array_ops.zeros([1, 1])) + self.assertEqual(2, len(tracker.trainable_weights)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 69dc0b9be2d5548852c37552a64a0d31c9557b43..00a805af25d5d0ea723db5d015fb12bf45c53857 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,8 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util def _split_variable_closure(variable): @@ -43,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): +class SaveTensorSlicesAsDeps(base.CheckpointableBase): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -58,14 +59,14 @@ class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): self._track_checkpointable(dep, name=name) -class HasRegularDeps(checkpointable.Checkpointable): +class HasRegularDeps(tracking.Checkpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(checkpointable.Checkpointable): +class OnlyOneDep(tracking.Checkpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) @@ -73,9 +74,9 @@ class OnlyOneDep(checkpointable.Checkpointable): class SplitTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestoreSplitDep(self): - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = util.Checkpoint( dep=SaveTensorSlicesAsDeps()) self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) checkpoint_directory = self.get_temp_dir() @@ -83,7 +84,7 @@ class SplitTests(test.TestCase): save_path = save_checkpoint.save(checkpoint_prefix) regular_deps = HasRegularDeps() - regular_restore_checkpoint = checkpointable_utils.Checkpoint( + regular_restore_checkpoint = util.Checkpoint( dep=regular_deps) regular_restore_checkpoint.restore( save_path).assert_consumed().run_restore_ops() @@ -91,7 +92,7 @@ class SplitTests(test.TestCase): self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) one_dep = OnlyOneDep() - one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep) status = one_dep_restore_checkpoint.restore(save_path) with self.assertRaises(AssertionError): # Missing the second dependency. @@ -99,7 +100,7 @@ class SplitTests(test.TestCase): status.run_restore_ops() self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) - restore_checkpoint = checkpointable_utils.Checkpoint() + restore_checkpoint = util.Checkpoint() status = restore_checkpoint.restore(save_path) restore_checkpoint.dep = SaveTensorSlicesAsDeps() status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index a72a78b89f6875158460c6b68d541e3916f20910..583e3bc442893d825c337d73fb999d1e586738a1 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -24,8 +24,8 @@ from tensorflow.contrib.checkpoint.python import visualize from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam from tensorflow.python.training.checkpointable import util as checkpointable_utils diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index f3a75e8688ece19a6e6fd53ee9faf7f4144d76cf..523a9efcf05f5d32589f6e1734f866bf8b4b9cdc 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -15,7 +15,10 @@ load( ) tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -28,15 +31,26 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", + "//tensorflow/contrib/bigtable", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", @@ -61,3 +75,14 @@ tf_py_test( ], tags = ["manual"], ) + +tf_py_test( + name = "gcs_config_ops_test", + size = "small", + srcs = ["python/ops/gcs_config_ops_test.py"], + additional_deps = [ + ":cloud_py", + "//tensorflow/python:client_testlib", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md new file mode 100644 index 0000000000000000000000000000000000000000..134ce057f4334096b4fbbec29cc85f0ea42c9f86 --- /dev/null +++ b/tensorflow/contrib/cloud/README.md @@ -0,0 +1,18 @@ +# Cloud # + +## BigTable ## + +[Google Cloud BigTable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + + + +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. + + diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95dfd9f8c4b1655c475fe23e0639924f..af81106a6848bfd8c91108b56c8150d47c3eb501 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -18,11 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long,wildcard-import +import os + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * -# pylint: enable=line-too-long,wildcard-import +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * + +if os.name != 'nt': + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + +del os from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'BigTable', + 'BigtableClient', + 'BlockCacheParams', + 'configure_colab_session', + 'configure_gcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index ff46f0daa80a70badedf73e15bfaf4dca85fdd89..1311063ec023bdaa2588d6f1c826bf900f7dea09 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -73,3 +73,18 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..648a219fb87a6ebc64767a7da780013ef6b95443 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cf85f5f1811d873075b6d2e1931d8badfd6e32c --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..95e7e744d34391a511cdba7702aad369b8d9d9c0 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,193 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_op = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + else: + self._credentials_op = None + + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + else: + self._block_cache_op = None + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6c056d6c8adfa50b95aefb8e9740631327a572 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -0,0 +1,44 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the gcs_config_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cloud.python.ops import gcs_config_ops +from tensorflow.python.platform import test + + +class GcsConfigOpsTest(test.TestCase): + + def testSetBlockCache(self): + cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) + with self.test_session() as sess: + gcs_config_ops.configure_gcs(sess, block_cache=cfg) + + def testConfigureGcsHook(self): + creds = {'client_id': 'fake_client', + 'refresh_token': 'fake_token', + 'client_secret': 'fake_secret', + 'type': 'authorized_user'} + hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) + hook.begin() + with self.test_session() as sess: + sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None + hook.after_create_session(sess, None) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8ede28602fd6cf7a2239772f37b4ea6c0ffd7b4a..8f521ffee4d31e090c13bac98290656d6e1d330e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,7 +36,9 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' _DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' class TPUClusterResolver(ClusterResolver): @@ -68,8 +70,8 @@ class TPUClusterResolver(ClusterResolver): return _GKE_ENV_VARIABLE in os.environ @staticmethod - def _gkeMaster(): - return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] @staticmethod def _envVarFallback(): @@ -77,6 +79,10 @@ class TPUClusterResolver(ClusterResolver): return os.environ[_DEFAULT_ENV_VARIABLE] return None + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) + def __init__(self, tpu=None, zone=None, @@ -85,7 +91,8 @@ class TPUClusterResolver(ClusterResolver): coordinator_name=None, coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs @@ -115,6 +122,11 @@ class TPUClusterResolver(ClusterResolver): service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. @@ -132,7 +144,7 @@ class TPUClusterResolver(ClusterResolver): # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: - tpu = self._gkeMaster() + tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() @@ -159,14 +171,22 @@ class TPUClusterResolver(ClusterResolver): if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver. Execute: `pip install ' - '--upgrade google-api-python-client` to install with ' - 'pip.') - - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service @@ -195,7 +215,7 @@ class TPUClusterResolver(ClusterResolver): ValueError: If none of the TPUs specified exists. """ if not self._shouldResolve(): - return self._tpu + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: @@ -237,6 +257,10 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) @@ -257,8 +281,12 @@ class TPUClusterResolver(ClusterResolver): # Case 3. return None # Case 2. - cluster_spec = {self._job_name: [self._tpu[len( - compat.as_bytes('grpc://')):]]} + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } if self._coordinator_address: # {1, 2}.a diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5b3f9be5a11237f9dceebefa1db294efaf7e482d..ad4f6432630be44a7de6e778f55f1fb7fd66f307 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { @@ -358,15 +402,67 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) - def testGkeEnvironment(self): + def testGkeEnvironmentForDonut(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' - self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testGkeEnvironmentForPod(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470') + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(TPUClusterResolver._gkeMaster())) + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + tasks { key: 1 value: '10.120.27.6:8470' } + tasks { key: 2 value: '10.120.27.7:8470' } + tasks { key: 3 value: '10.120.27.8:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + def testDiscoveryUrl(self): + os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' + self.assertEqual('https://{api}.internal/{apiVersion}', + TPUClusterResolver._discoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 0708d6b7b9f0ba549aea091a265f42890e50d223..a0a5b0e00c1979ebf8850408785135b9ceac7d2a 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) + +if(WIN32) +# BoringSSL is disabled for windows as it currently doesn't build with +# MSBuild. (Ninja is required.) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) +else() +# BoringSSL is enabled for gRPC. +option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON) +endif() + option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) @@ -290,17 +299,20 @@ include_directories( ${double_conversion_INCLUDE_DIR} ) -if(tensorflow_ENABLE_SSL_SUPPORT) - include(boringssl) - list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) - include_directories(${boringssl_INCLUDE_DIR}) -endif() if(tensorflow_ENABLE_GRPC_SUPPORT) + if(tensorflow_ENABLE_SSL_SUPPORT) + include(boringssl) + include_directories(${boringssl_INCLUDE_DIR}) + endif() include(grpc) + include_directories(${GRPC_INCLUDE_DIRS}) + # Place boringssl after grpc as grpc depends on boringssl. list(APPEND tensorflow_EXTERNAL_LIBRARIES ${grpc_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES grpc) - include_directories(${GRPC_INCLUDE_DIRS}) + if(tensorflow_ENABLE_SSL_SUPPORT) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) + endif() endif() if(tensorflow_ENABLE_JEMALLOC_SUPPORT) include(jemalloc) @@ -327,40 +339,14 @@ endif() # MKL Support if (tensorflow_ENABLE_MKL_SUPPORT) add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) - if (WIN32) - find_path(MKL_HOME_PLATFORM mkl - PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES windows) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS - ${MKL_HOME_PLATFORM}/mkl/lib/intel64 - ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt - ${MKL_HOME_PLATFORM}/compiler/lib/intel64 - ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) - set(MKL_REDIST_DLL_DIRS - ${MKL_HOME_PLATFORM}/redist/intel64/mkl - ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt - ${MKL_HOME_PLATFORM}/redist/intel64/compiler) - list(APPEND tensorflow_EXTERNAL_LIBRARIES - mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) - endif() - if (UNIX) - # Fix me: complete the path on linux - find_path(MKL_HOME_PLATFORM mkl - HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES linux) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS) # incompleted - set(MKL_REDIST_SO_DIRS) # incompleted - endif() - include_directories(${MKL_INCLUDE_DIRS}) - link_directories(${MKL_LINK_DIRS}) + include(mkl) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination) + include_directories(${mkl_INCLUDE_DIRS}) if (tensorflow_ENABLE_MKLDNN_SUPPORT) include(mkldnn) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) else (tensorflow_ENABLE_MKLDNN_SUPPORT) add_definitions(-DINTEL_MKL_ML) diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index 3c4bb01e24fd121c9d0fc3594cc25de37af0e8a1..fbb14b2515a656f1dfc0e3f63ac367e9b7738a23 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include) #set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src) set(boringssl_URL https://boringssl.googlesource.com/boringssl) -set(boringssl_TAG ee7aa02) +set(boringssl_TAG 7f8c553d7f4db0a6ce727f2986d41bf8fe8ec4bf) set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build) #set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so) set(boringssl_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake index 527ccdc8d887cb4c2e7d2412c99a8bc682568472..5c5adaf5798289fba1c5d0b3f9e0489dc242043e 100644 --- a/tensorflow/contrib/cmake/external/double_conversion.cmake +++ b/tensorflow/contrib/cmake/external/double_conversion.cmake @@ -16,15 +16,15 @@ include (ExternalProject) set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) set(double_conversion_URL https://github.com/google/double-conversion.git) -set(double_conversion_TAG 5664746) +set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8) set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) set(double_conversion_INCLUDES ${double_conversion_BUILD}) if(WIN32) - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib) else() - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a) endif() set(double_conversion_HEADERS diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 693dc7cd673233b889b35a3f3170b57581da9a9f..b1e64aa55c80ad59cfdc0f4767c0282b4f73367f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) + # We use unsecure gRPC because boringssl does not build on windows + set(grpc_TARGET grpc++_unsecure) + set(grpc_DEPENDS protobuf zlib) + set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib @@ -32,9 +36,12 @@ if(WIN32) ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) endif() else() + set(grpc_TARGET grpc++) + set(grpc_DEPENDS boringssl protobuf zlib) + set(grpc_SSL_PROVIDER module) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) @@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0) ExternalProject_Add(grpc PREFIX grpc - DEPENDS protobuf zlib + DEPENDS ${grpc_DEPENDS} GIT_REPOSITORY ${GRPC_URL} GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET} COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" CMAKE_CACHE_ARGS @@ -59,7 +66,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=NONE + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/external/mkl.cmake b/tensorflow/contrib/cmake/external/mkl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a172e3a41a283359b9a8c823ddcb2b1973b5b3cc --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkl.cmake @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries +set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include) +set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin) +set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14 +set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz) +set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz) +set(mkl_TAG v0.14) +set(mkl_URL https://github.com/intel/mkl-dnn/releases) + +if (WIN32) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN}) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll) +elseif (UNIX) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX}) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so) +elseif (APPLE) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC}) + #TODO need more information +endif () + +ExternalProject_Add(mkl + PREFIX mkl + URL ${mkl_DOWNLOAD_URL} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "") + +# put mkl dynamic libraries in one bin directory +add_custom_target(mkl_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS} + DEPENDS mkl) + +add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir) + +foreach(dll_file ${mkl_SHARED_LIBRARIES}) + add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS}) +endforeach() diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake index a639fdee367f060d4c8a79267803da6ffe3dc503..8123ee1f393ab8e3a52f13915ea2a65decc188d9 100644 --- a/tensorflow/contrib/cmake/external/mkldnn.cmake +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll) + set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release) else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll) endif() else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) @@ -31,6 +34,7 @@ endif() ExternalProject_Add(mkldnn PREFIX mkldnn + DEPENDS mkl GIT_REPOSITORY ${mkldnn_URL} GIT_TAG ${mkldnn_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" @@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DMKLINC:STRING=${MKL_INCLUDE_DIRS} + -DMKLINC:STRING=${mkl_INCLUDE_DIRS} ) + +# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs +add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn) + +add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS}) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c..eba3bcfc79efe87d0a45c979c5accfa1b6511ed0 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) +set(nsync_TAG 1.20.0) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index ab464bc99a43138130bb2758ae28ecef29805c31..f56fb35a0f71250f00b84e5cf94a24682bda6c82 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG v3.6.0) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index a9fd298449bf014dd538ba596aadca01f0b80ed8..a5eba5a8c94d6ddfa820ae371841f764b628c4b5 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -32,51 +32,14 @@ tensorflow/python/feature_column tensorflow/python/framework tensorflow/python/grappler tensorflow/python/keras -tensorflow/python/keras/activations tensorflow/python/keras/applications -tensorflow/python/keras/applications/densenet -tensorflow/python/keras/applications/inception_resnet_v2 -tensorflow/python/keras/applications/inception_v3 -tensorflow/python/keras/applications/mobilenet -tensorflow/python/keras/applications/nasnet -tensorflow/python/keras/applications/resnet50 -tensorflow/python/keras/applications/vgg16 -tensorflow/python/keras/applications/vgg19 -tensorflow/python/keras/applications/xception -tensorflow/python/keras/backend -tensorflow/python/keras/callbacks -tensorflow/python/keras/constraints tensorflow/python/keras/datasets -tensorflow/python/keras/datasets/boston_housing -tensorflow/python/keras/datasets/cifar10 -tensorflow/python/keras/datasets/cifar100 -tensorflow/python/keras/datasets/fashion_mnist -tensorflow/python/keras/datasets/imdb -tensorflow/python/keras/datasets/mnist -tensorflow/python/keras/datasets/reuters -tensorflow/python/keras/initializers +tensorflow/python/keras/engine +tensorflow/python/keras/estimator tensorflow/python/keras/layers -tensorflow/python/keras/losses -tensorflow/python/keras/metrics -tensorflow/python/keras/models -tensorflow/python/keras/optimizers tensorflow/python/keras/preprocessing -tensorflow/python/keras/preprocessing/image -tensorflow/python/keras/preprocessing/sequence -tensorflow/python/keras/preprocessing/text -tensorflow/python/keras/regularizers tensorflow/python/keras/utils tensorflow/python/keras/wrappers -tensorflow/python/keras/wrappers/scikit_learn -tensorflow/python/keras/_impl -tensorflow/python/keras/_impl/keras -tensorflow/python/keras/_impl/keras/applications -tensorflow/python/keras/_impl/keras/datasets -tensorflow/python/keras/_impl/keras/engine -tensorflow/python/keras/_impl/keras/layers -tensorflow/python/keras/_impl/keras/preprocessing -tensorflow/python/keras/_impl/keras/utils -tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions @@ -123,6 +86,8 @@ tensorflow/contrib/batching/python/ops tensorflow/contrib/bayesflow tensorflow/contrib/bayesflow/python tensorflow/contrib/bayesflow/python/ops +# tensorflow/contrib/bigtable/python +# tensorflow/contrib/bigtable/python/ops tensorflow/contrib/boosted_trees tensorflow/contrib/boosted_trees/estimator_batch tensorflow/contrib/boosted_trees/kernels @@ -167,6 +132,7 @@ tensorflow/contrib/data tensorflow/contrib/data/kernels tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/kernel_tests/serialization tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto @@ -274,6 +240,8 @@ tensorflow/contrib/keras/api/keras/wrappers/scikit_learn tensorflow/contrib/kernel_methods tensorflow/contrib/kernel_methods/python tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kinesis/python +tensorflow/contrib/kinesis/python/ops tensorflow/contrib/kfac tensorflow/contrib/kfac/examples tensorflow/contrib/kfac/python diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index d63c41db844af243f0c6600b1565635ac9b91cac..cf1ee2ad76f2cc9f58dbe90182a3e17f1edc7ed3 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -11,7 +11,6 @@ tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto -tensorflow/contrib/tensorboard/graph_explorer/proto tensorflow/contrib/tensorboard/plugins/projector tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c6a15f2ca075c8de96786a580c7ddb89541df5bc..7a30eb94f54b18a2a517615a315e23e09e1170d0 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,9 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" @@ -37,14 +36,3 @@ add_dependencies( tf_cc_while_loop tf_core_lib tf_protos_cc) - -add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" -) -add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f73da0b8ab18af1eca4c2bd577604595f8b8ec6d..6c90cf398c69c8c1b22ea75e0c407f258e2535f9 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -155,7 +155,7 @@ if (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") endif() else (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) add_custom_target(tf_extension_ops) diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index b47c32f1c48b3d42fe5b4ba115cc2a511b7ee5f4..872b016d2b6c1b8fb5875c9568a1b7b6201507c0 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -49,43 +49,48 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -if(NOT WIN32) - function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR) - if(NOT ARGN) - message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") - return() +function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR) + if(NOT ARGN) + message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + foreach(FIL ${ARGN}) + set(ABS_FIL ${ROOT_DIR}/${FIL}) + get_filename_component(FIL_WE ${FIL} NAME_WE) + get_filename_component(FIL_DIR ${ABS_FIL} PATH) + file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h") + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") + + # We adust the path of the gRPC code generation accordingly. + if(WIN32) + set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/Release/grpc_cpp_plugin.exe) + else() + set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/grpc_cpp_plugin) endif() - set(${SRCS}) - set(${HDRS}) - foreach(FIL ${ARGN}) - set(ABS_FIL ${ROOT_DIR}/${FIL}) - get_filename_component(FIL_WE ${FIL} NAME_WE) - get_filename_component(FIL_DIR ${ABS_FIL} PATH) - file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) - - list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc") - list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h") - list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") - list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") - - add_custom_command( - OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" - COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} - ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin protoc-gen-grpc=${GRPC_BUILD}/grpc_cpp_plugin -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} - DEPENDS ${ABS_FIL} protobuf grpc - COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}" - VERBATIM ) - endforeach() - - set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) - set(${SRCS} ${${SRCS}} PARENT_SCOPE) - set(${HDRS} ${${HDRS}} PARENT_SCOPE) - endfunction() -endif() + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin=protoc-gen-grpc=${GRPC_PROTOC_PLUGIN_PATH} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} + DEPENDS ${ABS_FIL} protobuf grpc + COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() function(RELATIVE_PROTOBUF_TEXT_GENERATE_CPP SRCS HDRS ROOT_DIR) if(NOT ARGN) @@ -125,6 +130,7 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" + "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) @@ -174,17 +180,14 @@ RELATIVE_PROTOBUF_TEXT_GENERATE_CPP(PROTO_TEXT_SRCS PROTO_TEXT_HDRS ${tensorflow_source_dir} ${tf_proto_text_srcs} ) -if(WIN32) - add_library(tf_protos_cc ${PROTO_SRCS} ${PROTO_HDRS}) -else() - file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/debug/*.proto" - ) - RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS - ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs} - ) - add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS}) -endif() +file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/tensorflow/core/debug/*.proto" + "${tensorflow_source_dir}/tensorflow/core/protobuf/master_service.proto" +) +RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS + ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs} +) +add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS}) ######################################################## # tf_core_lib library @@ -213,10 +216,6 @@ else() list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() -file(GLOB tf_core_platform_exclude_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc") -list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs}) - list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs}) if(UNIX) @@ -237,15 +236,6 @@ if(WIN32) list(APPEND tf_core_lib_srcs ${tf_core_platform_windows_srcs}) endif(WIN32) -if(tensorflow_ENABLE_SSL_SUPPORT) - # Cloud libraries require boringssl. - file(GLOB tf_core_platform_cloud_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.h" - "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.cc" - ) - list(APPEND tf_core_lib_srcs ${tf_core_platform_cloud_srcs}) -endif() - if (tensorflow_ENABLE_HDFS_SUPPORT) list(APPEND tf_core_platform_hdfs_srcs "${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system.cc" @@ -286,8 +276,6 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 2d76bf530a2100b2afa80a16a5d64b6ec51ffc68..844f62649d970506f1b4b4c5718fab8d1f0856e1 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -134,14 +134,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs}) endif(tensorflow_BUILD_CONTRIB_KERNELS) -if(NOT tensorflow_ENABLE_SSL_SUPPORT) - # Cloud libraries require boringssl. - file(GLOB tf_core_kernels_cloud_srcs - "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" - "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" - ) +# Cloud libraries require curl and boringssl. +# Curl is not supported yet anyway so we remove for now. +file(GLOB tf_core_kernels_cloud_srcs + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" +) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_cloud_srcs}) -endif() file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*test*.h" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e558691de4b74988031f7b2204aad92e8c7af68b..bc753333dba4f67eee0114c4022743dd59a05982 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8d24a7ae38f5b0d1008038978b735c1a723c0d3e..e3b59001bcb4f081eb2db3443ee9ad714c822ac8 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -420,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" @@ -454,6 +456,18 @@ add_custom_command( COMMENT "Running SWIG to generate Python wrappers" VERBATIM ) +add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" +) +add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc + tf_python_protos_cc) + set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc" @@ -713,7 +727,7 @@ if(WIN32) endif() else() add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) endif() @@ -723,7 +737,7 @@ endif() ######################################################## # Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -734,40 +748,119 @@ foreach(api_init_file ${api_init_files_list}) string(STRIP "${api_init_file}" api_init_file) if(api_init_file) string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes - list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}") + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}") endif() endforeach(api_init_file) set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") file(WRITE "${api_init_list_file}" "${api_init_files}") +# Run create_python_api.py to generate __init__.py files. + +### TODO +# In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path +# to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where +# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to +# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. +# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, +# and should be removed if the path issue can be resolved. +### + +if (tensorflow_ENABLE_MKL_SUPPORT) + # add mkl dist dlls to system path for python + # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths, + # so we have to specify only one path in it to work around the issue. We need this if/else + # to protect overwriting CUDA environments + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS}) + add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM + ) +else (tensorflow_ENABLE_MKL_SUPPORT) + add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + ) +endif (tensorflow_ENABLE_MKL_SUPPORT) + +add_custom_target(tf_python_api SOURCES ${api_init_files}) +add_dependencies(tf_python_api tf_python_ops) + +# TODO(mikecase): This can be removed once tf.estimator is moved +# out of TensorFlow. +######################################################## +# Generate API __init__.py files for tf.estimator. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}") + endif() +endforeach(api_init_file) +set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt") +file(WRITE "${estimator_api_init_list_file}" "${api_init_files}") + # Run create_python_api.py to generate __init__.py files. add_custom_command( OUTPUT ${api_init_files} DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}" - - # Re-add tensorflow/__init__.py back. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" + "--package=tensorflow.python.estimator" + "--apiname=estimator" + "${estimator_api_init_list_file}" COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" ) -add_custom_target(tf_python_api SOURCES ${api_init_files}) -add_dependencies(tf_python_api tf_python_ops) - - +add_custom_target(estimator_python_api SOURCES ${api_init_files}) +add_dependencies(estimator_python_api tf_python_ops) ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -778,6 +871,7 @@ add_dependencies(tf_python_build_pip_package tf_python_touchup_modules tf_python_ops tf_python_api + estimator_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. @@ -789,7 +883,6 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 38f40452b533fdc0dba6ac686a0ff43a2ef13cb8..fdf522f1fd90ffc64acbe82381ef57a389645d61 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# mkl +if (tensorflow_ENABLE_MKL_SUPPORT) + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ + DESTINATION include/mkl) +endif (tensorflow_ENABLE_MKL_SUPPORT) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 9a37b681194d4ef82b27a0160dd969f733ecad67..6d634cb1709910f366c7ca538d28bd802b2a7c63 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -64,8 +64,6 @@ file(GLOB tf_stream_executor_srcs if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" ) if (NOT tensorflow_BUILD_CC_TESTS) file(GLOB tf_stream_executor_gpu_tests @@ -76,11 +74,11 @@ if (tensorflow_ENABLE_GPU) list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) endif() -#file(GLOB_RECURSE tf_stream_executor_test_srcs -# "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.cc" -# "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.h" -#) -#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +file(GLOB_RECURSE tf_stream_executor_test_srcs + "${tensorflow_source_dir}/tensorflow/stream_executor/*test.cc" + "${tensorflow_source_dir}/tensorflow/stream_executor/lib/*test.h" +) +list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) if (NOT WIN32) set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lgomp") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 5942ff3363a96de70df7e13d0857e4ad82e35fee..eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" # Disable following manual tag in BUILD. "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + # These tests depend on a .so file + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py ) if (WIN32) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index cffe069aa352f8a6f2c436bc70b62f54e2336ac6..4f957f1e0b46fde5daacbc59657af994e13c42d5 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -44,7 +44,8 @@ UNDNAME = "undname.exe" DUMPBIN = "dumpbin.exe" # Exclude if matched -EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") +EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::|Internal|" + r"python_op_gen_internal|grappler") # Include if matched before exclude INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" @@ -56,6 +57,10 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" r"tensorflow::ops::internal::Enter|" r"tensorflow::strings::internal::AppendPieces|" r"tensorflow::strings::internal::CatPieces|" + r"tensorflow::errors::Internal|" + r"tensorflow::Tensor::CopyFromInternal|" + r"tensorflow::kernel_factory::" + r"OpKernelRegistrar::InitInternal|" r"tensorflow::io::internal::JoinPathImpl") # Include if matched after exclude @@ -64,7 +69,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" r"tensorflow::|" r"functor::|" r"\?nsync_|" - r"perftools::gputools") + r"stream_executor::") # We want to identify data members explicitly in the DEF file, so that no one # can implicitly link against the DLL if they use one of the variables exported diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py index f039cb0f5265b920200f63c5bd5ebeb4e23826be..0c997bd4fdfa4233117c9fec2c4397301b1c8cb9 100644 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import engine +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import init_ops @@ -40,7 +40,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary -class EntropyBottleneck(engine.Layer): +class EntropyBottleneck(base_layer.Layer): """Entropy bottleneck layer. This layer can be used to model the entropy (the amount of information @@ -262,7 +262,7 @@ class EntropyBottleneck(engine.Layer): self._range_coder_precision = int(range_coder_precision) self._data_format = data_format self._channel_axis(2) # trigger ValueError early - self.input_spec = engine.InputSpec(min_ndim=2) + self.input_spec = base_layer.InputSpec(min_ndim=2) @property def init_scale(self): @@ -357,7 +357,7 @@ class EntropyBottleneck(engine.Layer): channels = input_shape[channel_axis].value if channels is None: raise ValueError("The channel dimension of the inputs must be defined.") - self.input_spec = engine.InputSpec( + self.input_spec = base_layer.InputSpec( ndim=input_shape.ndims, axes={channel_axis: channels}) filters = (1,) + self.filters + (1,) scale = self.init_scale ** (1 / (len(self.filters) + 1)) diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index c65a150464efc1e77419040f66f36fc6756325aa..cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -46,7 +46,7 @@ document. Imagine that we want to constrain the recall of a binary classifier to be at least 90%. Since the recall is proportional to the number of true positive classifications, which itself is a sum of indicator functions, this constraint -is non-differentible, and therefore cannot be used in a problem that will be +is non-differentiable, and therefore cannot be used in a problem that will be optimized using a (stochastic) gradient-based algorithm. For this and similar problems, TFCO supports so-called *proxy constraints*, diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 04014ab4aebd6d9cd70653c53f9361320e803329..3791dae8d7f6b03bc1115bca97811dfc4775c45b 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -169,8 +169,8 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): del old_inactive # Needed by the condition, but not the body. iteration += 1 scale = (1.0 - standard_ops.reduce_sum( - matrix, axis=0, keep_dims=True)) / standard_ops.maximum( - 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) + matrix, axis=0, keepdims=True)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True)) matrix += scale * inactive new_inactive = standard_ops.to_float(matrix > 0) matrix *= new_inactive @@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): # For numerical reasons, make sure that the largest matrix element is zero # before exponentiating. - log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) + log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True) log_matrix -= standard_ops.log( standard_ops.reduce_sum( - standard_ops.exp(log_matrix), axis=0, keep_dims=True)) + standard_ops.exp(log_matrix), axis=0, keepdims=True)) return log_matrix diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 102bc460fdadb0ad5dc9a2960b8655c55357108e..a0dd3881a86c19e47ccb65f84a2477a55626b81c 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 8285ea04926d3a24e9c22bd6d69eb7a48f5e3a85..252ea1560d7f5be3799686d6d91ae9a6d262ac0a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -768,7 +768,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLSTMCheckpointableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION @@ -781,7 +781,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGRUCheckpointableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION @@ -826,7 +826,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCudnnCompatibleLSTMCheckpointablMultiLayer(self): num_units = 2 num_layers = 3 diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 125da7df5de2ad7cc14ace59abf795c7564701e0..748d7cd011f32fdebd781176b560b9b7498f327e 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,11 +20,10 @@ from __future__ import print_function import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops @@ -34,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import tracking as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -1647,10 +1646,3 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): # 1 set of weight and bias parameters for the recurrent input, and 1 for the # previous layer input. _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER - - -ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index a25aa85251083c24ca6685c4ffef267955f66f63..156538b4e01bf1a1ccca0fca1e309b1d37b6dbc0 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -20,23 +20,31 @@ be used in conjunction with the @{tf.data.Dataset} API. Note that the guarantees as `tf.data`, but we will provide deprecation advice in advance of removing existing functionality. -See the @{$datasets$Importing Data} Programmer's Guide for an overview. +See @{$guide/datasets$Importing Data} for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@RandomDataset +@@Reducer @@SqlDataset +@@TFRecordWriter @@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length +@@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset + +@@get_single_element +@@group_by_reducer @@group_by_window @@ignore_errors @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator + @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -49,8 +57,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@sliding_window_batch @@sloppy_interleave @@unbatch - -@@get_single_element +@@unique """ from __future__ import absolute_import @@ -70,13 +77,17 @@ from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length +from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.grouping import Reducer +from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset @@ -86,6 +97,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.contrib.data.python.ops.unique import unique +from tensorflow.contrib.data.python.ops.writers import TFRecordWriter # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 76e54a284e07ec1bab9b0f364a44997a39bce78a..4657807785d58727d34f37172bd30c56a5b7cde6 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { @@ -103,12 +102,11 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES( ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - bool select_all_cols = select_cols.empty(); - *output = new Dataset( - ctx, std::move(filenames), header, buffer_size, output_types_, - output_shapes_, std::move(record_defaults), std::move(select_cols), - select_all_cols, use_quote_delim, delim[0], std::move(na_value)); + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); } private: @@ -118,8 +116,7 @@ class CSVDatasetOp : public DatasetOpKernel { int64 buffer_size, const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, - bool select_all_cols, bool use_quote_delim, char delim, - string na_value) + bool use_quote_delim, char delim, string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), @@ -128,12 +125,11 @@ class CSVDatasetOp : public DatasetOpKernel { output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), - select_all_cols_(select_all_cols), use_quote_delim_(use_quote_delim), delim_(delim), na_value_(std::move(na_value)) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::CSV")})); @@ -145,7 +141,7 @@ class CSVDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { return "CSVDatasetOp::Dataset"; } + string DebugString() const override { return "CSVDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, @@ -166,11 +162,24 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); do { // We are currently processing a file, so try to read the next record - if (buffered_input_stream_) { - Status s = ReadRecord(ctx, out_tensors); - if (s.ok() || !errors::IsOutOfRange(s)) { + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { // Not at the end of file, return OK or non-EOF errors to caller. *end_of_sequence = false; return s; @@ -203,145 +212,317 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - // Reads a record by parsing the input buffer, and converting extracted + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted // fields to output tensors as we go. - Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // Extracts fields from line(s) from the buffered input stream. - out_tensors->reserve(dataset()->record_defaults_.size()); - - string input; - TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); - - size_t current_idx = 0; - size_t num_fields_parsed = 0; - size_t selector_idx = 0; // Keep track of index into select_cols - - while (current_idx < input.size()) { - // In each iteration, parse one field - if (input[current_idx] == '\n' || input[current_idx] == '\r') { - // This should never happen, because buffered input reader splits - // input on newlines. - return errors::InvalidArgument("Parsing error."); - } + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result; - bool quoted = false; + while (!end_of_record) { // Read till we reach \n, \r or EOF bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); - if (dataset()->use_quote_delim_ && input[current_idx] == '"') { - quoted = true; - current_idx++; + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller } - // Parse the body of the field - string field; - if (!quoted) { - while (current_idx < input.size() && - input[current_idx] != dataset()->delim_) { - if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || - input[current_idx] == '\n' || input[current_idx] == '\r') { - return errors::InvalidArgument( - "Unquoted fields cannot have quotes/CRLFs inside"); + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; } - if (include) field += input[current_idx]; - current_idx++; - } // Exit condition: end of input, or current index at delim + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + return parse_result; + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + if (next == '\r') SkipNewLineIfNecessary(); + return parse_result; + } else if (next != '"') { + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); + } - // Go to next field or the end - current_idx++; } else { - // Quoted field needs to be ended with '"' and delim or end - while (true) { - if (current_idx >= input.size() - 1 || input.empty()) { - if (current_idx == input.size() - 1 && - input[current_idx] == '"') { - // We're at the end of the input, and the quote terminates the - // record. Go to end. - current_idx++; - break; - } - // If there's no terminating quote, it means our buffered record - // line reader split a record up. This can happen if there is a - // newline encased in quotes. The next line is also part of the - // record, so we read it and reset the index. - if (include && current_idx == input.size() - 1) { - // TODO(rachelim): Instead of building up a string, keep track - // of terminal indices (or starting char* and length) - // Also look into using /lib/strings/Scanner - field += input[current_idx]; - } - if (include) { - field += '\n'; - } - current_idx = 0; - Status s = buffered_input_stream_->ReadLine(&input); - if (!s.ok()) { - return errors::InvalidArgument( - "Quoted field has to end with quote followed by delim, " - "CRLF, or EOF"); - } - } else if (input[current_idx] == '"' && - input[current_idx + 1] == dataset()->delim_) { - // End of field, go to next field or end - current_idx += 2; - break; - } else if (input[current_idx] == '"') { - // Current char is a quote. Since we're not at end of field, - // the next character must also be a quote. - if (input[current_idx + 1] != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another " - "quote"); - } - if (include) field += '"'; - current_idx += 2; - } else { - if (include) field += input[current_idx]; - current_idx++; - } + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + Status parse_result; + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; // Surface all other errors to caller } } - num_fields_parsed++; + char ch = buffer_[pos_]; - if (include) { - // Add the tensor to the result - TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), - selector_idx, out_tensors)); - selector_idx++; - // Terminate early if we have all the fields we want - if (selector_idx == dataset()->select_cols_.size()) - return Status::OK(); + if (ch == dataset()->delim_) { + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + pos_++; + return parse_result; } - } // Exit condition: current_idx has reached the end of record - - // Check if the last field is empty, and include it if necessary - bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); - if (include && !input.empty() && - input[input.size() - 1] == dataset()->delim_) { - TF_RETURN_IF_ERROR( - FieldToOutput(ctx, string(), selector_idx, out_tensors)); + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return parse_result; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); + } + // Otherwise, go to next character + pos_++; } + } - // Check that number of fields matches - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), - " fields but have ", - out_tensors->size(), " in record"); + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); } - return Status::OK(); + return s; } - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status FieldToOutput(IteratorContext* ctx, string field, - size_t output_idx, + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of // fields exceeds the number of defaults provided @@ -397,7 +578,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { float value; - if (!strings::safe_strtof(field.c_str(), &value)) { + if (!strings::safe_strtof(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid float: ", field); @@ -412,7 +593,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { double value; - if (!strings::safe_strtod(field.c_str(), &value)) { + if (!strings::safe_strtod(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid double: ", field); @@ -426,7 +607,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar()() = dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = std::move(field); + component.scalar()() = field.ToString(); } break; } @@ -439,6 +620,50 @@ class CSVDatasetOp : public DatasetOpKernel { return Status::OK(); } + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + // Sets up reader streams to read from the file at `current_file_index_`. Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { @@ -452,16 +677,18 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->filenames_[current_file_index_], &file_)); input_stream_.reset( new io::RandomAccessInputStream(file_.get(), false)); - // TODO(rachelim): Maintain our own buffer so we don't read every record - // twice - buffered_input_stream_.reset(new io::BufferedInputStream( - input_stream_.get(), dataset()->buffer_size_, false)); + buffer_.clear(); + pos_ = 0; if (dataset()->header_) { - // Ignore header line - string str; - Status s = buffered_input_stream_->ReadLine(&str); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument("Can't read header of empty file"); + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); } } return Status::OK(); @@ -470,15 +697,15 @@ class CSVDatasetOp : public DatasetOpKernel { // Resets all reader streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { input_stream_.reset(); - buffered_input_stream_.reset(); file_.reset(); } mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters std::unique_ptr input_stream_ GUARDED_BY(mu_); - std::unique_ptr buffered_input_stream_ - GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -491,7 +718,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector output_shapes_; const std::vector record_defaults_; const std::vector select_cols_; - const bool select_all_cols_; const bool use_quote_delim_; const char delim_; const string na_value_; diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 48d3734162525ffc6ace076e4f0523c1d0cae511..6a12ca06f4d6cc2096aaf8191a01a899881b43db 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::DirectedInterleave")})); @@ -105,7 +105,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); } @@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - selector_input_impl_(params.dataset->selector_input_->MakeIterator( - params.prefix + ".selector")), - num_active_inputs_(params.dataset->data_inputs_.size()) { - data_input_impls_.reserve(params.dataset->data_inputs_.size()); - for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { - const DatasetBase* data_input = params.dataset->data_inputs_[i]; - data_input_impls_.push_back(data_input->MakeIterator( - strings::StrCat(params.prefix, "[", i, "]"))); + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bb29df60e8f114aaa50f578c43e73874f72ab0a3..bbec50681c6f5decec5a3b5fbf09cc3011a21199 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); @@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -72,8 +74,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index a2bfce03620a1482f5b21cbf23c66833bc5cd480..b3d464d7165d53cf198072e06214f7d5e982073d 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -40,7 +40,8 @@ class FunctionBufferingResource : public ResourceBase { const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, - const std::vector& func_args) + const std::vector& func_args, + const DataTypeVector& output_types) : lib_(lib), pflr_(std::move(pflr)), func_(func), @@ -48,6 +49,7 @@ class FunctionBufferingResource : public ResourceBase { source_device_(source_device), target_device_(target_device), func_args_(func_args), + output_types_(output_types), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), @@ -176,6 +178,13 @@ class FunctionBufferingResource : public ResourceBase { AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); opts.args_alloc_attrs.push_back(arg_alloc_attr); + for (const auto& dtype : output_types_) { + AllocatorAttributes ret_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + ret_alloc_attrs.set_on_host(true); + } + opts.rets_alloc_attrs.push_back(ret_alloc_attrs); + } if (opts.source_device != target_device_) { opts.remote_execution = true; } @@ -233,6 +242,7 @@ class FunctionBufferingResource : public ResourceBase { const string source_device_; const string target_device_; const std::vector func_args_; + const DataTypeVector output_types_; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); @@ -250,6 +260,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); } ~FunctionBufferResourceHandleOp() override { @@ -269,18 +280,20 @@ class FunctionBufferResourceHandleOp : public OpKernel { std::vector func_args; func_args.push_back(*string_arg); + const string& source_device = ctx->device()->name(); + // Obtain and canonicalize target_device. const Tensor* target_arg; OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); - const string& target_device = - DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar()()); + string target_device; + OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( + target_arg->scalar()(), source_device, + &target_device)); FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library is provided.")); - const string& source_device = ctx->device()->name(); - mutex_lock l(mu_); if (!initialized_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); @@ -297,7 +310,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { this](FunctionBufferingResource** ptr) { *ptr = new FunctionBufferingResource( clone_lib, std::move(pflr), func_, buffer_size_, - source_device, target_device, func_args); + source_device, target_device, func_args, output_types_); return Status::OK(); })); core::ScopedUnref s(buffer); @@ -319,6 +332,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { int64 buffer_size_; string container_; string name_; + DataTypeVector output_types_; }; REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 63e19ae3f837c9d3cfb1221df64360ee74117f13..141706f393b076d9f55898ca4bdbe7438f7c3625 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace { @@ -24,19 +25,32 @@ namespace { class ThreadPoolResource : public ResourceBase { public: ThreadPoolResource(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, bool low_latency_hint) - : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { - } + const string& name, int num_threads, bool low_latency_hint, + int max_intra_op_parallelism) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), + max_intra_op_parallelism_(max_intra_op_parallelism) {} // Schedules fn() for execution in the pool of threads. void Schedule(std::function fn) { - thread_pool_.Schedule(std::move(fn)); + if (max_intra_op_parallelism_ < 0) { + thread_pool_.Schedule(std::move(fn)); + } else { + thread_pool_.Schedule(std::bind( + [this](std::function bound_fn) { + // TODO(mrry): Consider moving this thread-local configuration to + // the threads themselves. + ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); + bound_fn(); + }, + std::move(fn))); + } } string DebugString() override { return "ThreadPoolResource"; } private: thread::ThreadPool thread_pool_; + const int max_intra_op_parallelism_; }; // Creates a handle to a ThreadPool resource. Note that we don't use @@ -48,6 +62,8 @@ class ThreadPoolHandleOp : public OpKernel { explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); OP_REQUIRES( ctx, num_threads_ > 0, errors::InvalidArgument("`num_threads` must be greater than zero.")); @@ -78,7 +94,7 @@ class ThreadPoolHandleOp : public OpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new ThreadPoolResource( ctx->env(), {}, display_name_, - num_threads_, + num_threads_, max_intra_op_parallelism_, false /* low_latency_hint */); return Status::OK(); })); @@ -95,6 +111,7 @@ class ThreadPoolHandleOp : public OpKernel { bool initialized_ GUARDED_BY(mu_) = false; string display_name_; int num_threads_; + int max_intra_op_parallelism_; }; class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { @@ -127,7 +144,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { threadpool_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); @@ -140,7 +157,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -154,8 +173,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 69fbb0fcdcce87951d2c9b84210fda378081b103..67c237799c10a2724f18bb0df99e4bf8f5cd2b8a 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unique")})); @@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("UniqueDatasetOp::Dataset"); } @@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index f271d269ab1b9339de4657e459dcbbd462890f0a..8413fcaf872f49f654c6a1327a14d5c44bdd815a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -104,6 +104,7 @@ REGISTER_OP("FunctionBufferingResource") .Attr("container: string") .Attr("f: func") .Attr("buffer_size: int") + .Attr("output_types: list(type)") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Creates a resource that fills up a buffer by making function calls. @@ -117,6 +118,7 @@ container: If non-empty, this resource is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this resource will be shared under the given name across multiple sessions. +output_types: The type list for the return values. )doc"); REGISTER_OP("FunctionBufferingResourceGetNext") @@ -158,6 +160,7 @@ REGISTER_OP("ThreadPoolHandle") .Output("handle: resource") .SetShapeFn(shape_inference::ScalarShape) .Attr("num_threads: int") + .Attr("max_intra_op_parallelism: int = 1") .Attr("display_name: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") @@ -166,6 +169,8 @@ Creates a custom thread pool with the given number of threads. handle: A resource that can be consumed by one or more ThreadPoolDataset ops. num_threads: The number of threads in the thread pool. +max_intra_op_parallelism: The maximum degree of parallelism to use within + operations that execute on this threadpool. display_name: A human-readable name for the threads that may be visible in some visualizations. )doc"); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2ebc80fa637476889c6ee183e73ebaf45e8d0410..079c8bbd8ee4360a847bda14d17a0b48a14c45a5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test") py_test( name = "batch_dataset_op_test", @@ -12,23 +12,27 @@ py_test( srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ + "no_oss", # (b/79552534) "no_pip", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -38,7 +42,6 @@ py_test( srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -47,24 +50,33 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "concatenate_dataset_op_test", + name = "csv_dataset_op_test", size = "small", - srcs = ["concatenate_dataset_op_test.py"], + srcs = ["csv_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -79,103 +91,44 @@ py_test( "nomac", # b/62040583 ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", ], ) -py_library( - name = "dataset_serialization_test", - srcs = [ - "dataset_serialization_test_base.py", - ], +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "csv_dataset_op_test", - size = "small", - srcs = ["csv_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:random_seed", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "filter_dataset_op_test", + name = "get_single_element_test", size = "small", - srcs = ["filter_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "optonly", - ], + srcs = ["get_single_element_test.py"], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "flat_map_dataset_op_test", - size = "medium", - srcs = ["flat_map_dataset_op_test.py"], - additional_deps = [ - ":dataset_serialization_test", - "//third_party/py/numpy", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:get_single_element", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:function", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", ], - grpc_enabled = True, - tags = ["no_pip"], ) py_test( @@ -190,10 +143,8 @@ py_test( "notap", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -201,26 +152,28 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", + "@six_archive//:six", ], ) -tf_py_test( - name = "get_single_element_test", +py_test( + name = "iterator_ops_test", size = "small", - srcs = ["get_single_element_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:get_single_element", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:array_ops", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -235,91 +188,112 @@ py_test( "optonly", ], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "prefetch_dataset_op_test", + name = "optimize_dataset_op_test", size = "small", - srcs = ["prefetch_dataset_op_test.py"], + srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:platform", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", ], ) +cuda_py_test( + name = "prefetching_ops_test", + size = "small", + srcs = ["prefetching_ops_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + py_test( name = "range_dataset_op_test", size = "small", srcs = ["range_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:counter", "//tensorflow/contrib/data/python/ops:enumerate_ops", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], - shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:string_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -346,6 +320,7 @@ py_test( "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", + "@six_archive//:six", ], ) @@ -356,13 +331,14 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -370,55 +346,55 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "shuffle_dataset_op_test", size = "medium", - srcs = ["sequence_dataset_op_test.py"], + srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "serialization_integration_test", + name = "slide_dataset_op_test", size = "small", - srcs = ["serialization_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], + srcs = ["slide_dataset_op_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:sliding", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) -py_test( - name = "shuffle_dataset_op_test", - size = "medium", - srcs = ["shuffle_dataset_op_test.py"], +py_library( + name = "sql_dataset_op_test_base", + srcs = ["sql_dataset_op_test_base.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", + "@org_sqlite//:python", ], ) @@ -427,14 +403,12 @@ py_test( size = "small", srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", + ":sql_dataset_op_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "@org_sqlite//:python", ], ) @@ -445,11 +419,15 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) @@ -463,8 +441,12 @@ py_test( "//tensorflow/contrib/data/python/ops:threadpool", "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -475,87 +457,49 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/contrib/stateless", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", ], ) py_test( - name = "zip_dataset_op_test", - size = "small", - srcs = ["zip_dataset_op_test.py"], + name = "window_dataset_op_test", + size = "medium", + srcs = ["window_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "prefetching_ops_test", - size = "small", - srcs = ["prefetching_ops_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", + tags = [ + "no_pip", ], -) - -tf_py_test( - name = "slide_dataset_op_test", - size = "small", - srcs = ["slide_dataset_op_test.py"], - additional_deps = [ - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:sliding", + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_py_test( +py_test( name = "writer_ops_test", size = "small", srcs = ["writer_ops_test.py"], - additional_deps = [ + deps = [ "//tensorflow/contrib/data/python/ops:writers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", - "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 2568b899d7ea1be685036ad8af93f584f861c951..42adfd17f07e508f25d8b351c791fa519eca8bd9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import math import time +from absl.testing import parameterized import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -40,7 +40,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase): +class BatchDatasetTest(test.TestCase, parameterized.TestCase): def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -293,7 +293,7 @@ class BatchDatasetTest(test.TestCase): ph2: np.arange(8).astype(np.int32) }) with self.assertRaises(errors.InvalidArgumentError): - print(sess.run(next_element)) + sess.run(next_element) # No 0th dimension (i.e. scalar value) for one component. sess.run( @@ -303,7 +303,7 @@ class BatchDatasetTest(test.TestCase): ph2: 7 }) with self.assertRaises(errors.InvalidArgumentError): - print(sess.run(next_element)) + sess.run(next_element) def testBatchAndDropRemainder(self): components = (np.arange(7), @@ -427,9 +427,13 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, - num_parallel_calls=None, - num_parallel_batches=None): + @parameterized.named_parameters( + ("default", None, None), + ("sequential_calls", 1, None), + ("parallel_calls", 2, None), + ("parallel_batches", None, 10), + ) + def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -500,19 +504,11 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatch(self): - return self._testMapAndBatchDatasetHelper() - - def testMapAndBatchWithParallelBatches(self): - return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) - - def testMapAndBatchWithSequentialCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) - - def testMapAndBatchWithParallelCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) - - def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + @parameterized.named_parameters( + ("even", False), + ("uneven", True), + ) + def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch( @@ -532,12 +528,6 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - def testMapAndBatchPartialBatch(self): - return self._testMapAndBatchPartialBatchHelper() - - def testMapAndBatchPartialBatchDropRemainder(self): - return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) - def testMapAndBatchYieldsPartialBatch(self): iterator = (dataset_ops.Dataset.range(10) .apply(batching.map_and_batch( @@ -552,6 +542,44 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(50000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(5): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(4): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + def testMapAndBatchSparse(self): def _sparse(i): @@ -576,7 +604,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testMapAndBatchDatasetFails(self): + def testMapAndBatchFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -590,7 +618,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testMapAndBatchDatasetShapeMismatch(self): + def testMapAndBatchShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -613,173 +641,79 @@ class BatchDatasetTest(test.TestCase): "number of elements does not match"): sess.run(get_next) + def testMapAndBatchImplicitDispose(self): + # Tests whether a map and batch dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> + # MapAndBatchDataset(f=square_3, batch_size=100). + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) -class BatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( + 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) + dataset = dataset.prefetch(5) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len // batch_size - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) + with self.test_session() as sess: + for _ in range(3): + sess.run(get_next) - def _build_dataset_dense_to_sparse(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) + @parameterized.parameters(0, 5, 10, 90, 95, 99) + def testMapAndBatchOutOfRangeError(self, threshold): - def testDenseToSparseBatchDatasetCore(self): - components = np.random.randint(5, size=(40,)).astype(np.int32) - diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) - - num_outputs = len(components) // 4 - self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), - lambda: self._build_dataset_dense_to_sparse(diff_comp), - num_outputs) - - def _sparse(self, i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) + def raising_py_fn(i): + if i >= threshold: + raise StopIteration() + else: + return i - def _build_dataset_sparse(self, batch_size=5): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) - - def testSparseCore(self): - self.run_core_tests(self._build_dataset_sparse, - lambda: self._build_dataset_sparse(2), 2) - - def _build_dataset_nested_sparse(self): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) - - def testNestedSparseCore(self): - self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + iterator = ( + dataset_ops.Dataset.range(100).apply( + batching.map_and_batch( + lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), + batch_size=10)).make_one_shot_iterator()) + get_next = iterator.get_next() + with self.test_session() as sess: + for i in range(threshold // 10): + self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) + if threshold % 10 != 0: + self.assertAllEqual( + [threshold // 10 * 10 + j for j in range(threshold % 10)], + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) -class UnbatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + @parameterized.parameters( + (False, dtypes.bool), + (-42, dtypes.int8), + (-42, dtypes.int16), + (-42, dtypes.int32), + (-42, dtypes.int64), + (42, dtypes.uint8), + (42, dtypes.uint16), + (42.0, dtypes.float16), + (42.0, dtypes.float32), + (42.0, dtypes.float64), + (b"hello", dtypes.string), + ) + def testMapAndBatchTypes(self, element, dtype): + def gen(): + yield element + + dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( + batching.map_and_batch(lambda x: x, batch_size=10)) + + get_next = dataset.make_one_shot_iterator().get_next() - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) - - return dataset_ops.Dataset.from_tensor_slices(components).batch( - batch_size).apply(batching.unbatch()) - - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) - - -class MapAndBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testNumParallelBatches(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_batches = 2 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_batches=num_parallel_batches, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - def testNumParallelCalls(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_calls = 7 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_calls=num_parallel_calls, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - -class PaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=[-1]) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) + with self.test_session() as sess: + for _ in range(10): + self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) class RestructuredDatasetTest(test.TestCase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index bd3e034211c4aa454e4f8f6b09f14935d7a3b35c..2022c1f2bdd09cdf43a993b3666335ce468a40ba 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -21,7 +21,6 @@ import random import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -68,7 +67,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda _: (0.0, 0.0), reduce_func=reduce_fn, - finalize_func=lambda x: x[0]) + finalize_func=lambda x, _: x) for i in range(1, 11): dataset = dataset_ops.Dataset.range(2 * i).apply( grouping.group_by_reducer( @@ -121,7 +120,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda x: ([0], 1), reduce_func=reduce_fn, - finalize_func=lambda x: x) + finalize_func=lambda x, y: (x, y)) for i in range(1, 11): dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( @@ -176,37 +175,27 @@ class GroupByReducerTest(test.TestCase): dataset.apply( grouping.group_by_reducer(lambda _: "wrong", reducer)) + def testTuple(self): + def init_fn(_): + return np.array([], dtype=np.int64), np.int64(0) -class GroupByReducerSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + def reduce_fn(state, value): + s1, s2 = state + v1, v2 = value + return array_ops.concat([s1, [v1]], 0), s2 + v2 - def _build_dataset(self, components): - reducer = grouping.Reducer( - init_func=lambda _: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) + def finalize_fn(s1, s2): + return s1, s2 - return dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_reducer(lambda x: x % 5, reducer)) - - def testCoreGroupByReducer(self): - components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 5, - verify_exhausted=True) + reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( + grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual(x, np.asarray([x for x in range(10)])) + self.assertEqual(y, 45) class GroupByWindowTest(test.TestCase): @@ -353,34 +342,6 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) -class GroupByWindowSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) - - def testCoreGroupByWindow(self): - components = np.array( - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 12, - verify_exhausted=False) - - # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. @@ -655,7 +616,44 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(boundaries), sorted(lengths_val)) + self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], + sorted(lengths_val)) + + def testPadToBoundaryNoExtraneousPadding(self): + + boundaries = [3, 7, 11] + batch_sizes = [2, 2, 2, 2] + lengths = range(1, 11) + + def element_gen(): + for length in lengths: + yield ([1] * length,) + + element_len = lambda element: array_ops.shape(element)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(5): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + + self.assertAllEqual(batches[0], [[1, 0], + [1, 1]]) + self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]) + self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) def testTupleElements(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 641a389c033687ebe081963182390b00230e4cb5..df115175f5046803ada036563be1ca802f7ad0cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import os +import string import tempfile import time import numpy as np +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers @@ -31,7 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -60,12 +62,12 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs): + def setup_files(self, inputs, linebreak='\n'): filenames = [] for i, ip in enumerate(inputs): - fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) - with open(fn, 'w') as f: - f.write('\n'.join(ip)) + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) filenames.append(fn) return filenames @@ -74,7 +76,7 @@ class CsvDatasetOpTest(test.TestCase): filenames = self.setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( - lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected) @@ -85,38 +87,47 @@ class CsvDatasetOpTest(test.TestCase): inputs, **kwargs) self._assert_datasets_equal(g, dataset_actual, dataset_expected) + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, + linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs) + filenames = self.setup_files(inputs, linebreak) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) - nxt = dataset.make_one_shot_iterator().get_next() - if expected_err_re is None: - # Verify that output is expected, without errors - expected_output = [[ - v.encode('utf-8') if isinstance(v, str) else v for v in op - ] for op in expected_output] - for value in expected_output: - op = sess.run(nxt) - self.assertAllEqual(op, value) - with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) - else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - while True: - try: - sess.run(nxt) - except errors.OutOfRangeError: - break - - def testCsvDataset_floatRequired(self): + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @@ -136,10 +147,55 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withQuoted(self): - record_defaults = [['']] * 4 - inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] - self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) def testCsvDataset_mixedTypes(self): record_defaults = [ @@ -163,11 +219,6 @@ class CsvDatasetOpTest(test.TestCase): self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') - def testCsvDataset_withEmptyValues(self): - record_defaults = [[0]] * 4 - inputs = [['1,,3,4', ',6,7,8']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] @@ -175,8 +226,8 @@ class CsvDatasetOpTest(test.TestCase): inputs, record_defaults=record_defaults, na_value='NA') def testCsvDataset_withSelectCols(self): - record_defaults = [[0]] * 2 - inputs = [['1,2,3,4', '5,6,7,8']] + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) @@ -189,27 +240,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, select_cols=[3, 4]) + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNewLine(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_withMultipleNewLines(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] @@ -265,9 +306,10 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] + expected_err_re = "Can't read header of file" self._test_dataset( inputs, - expected_err_re="Can't read header of empty file", + expected_err_re=expected_err_re, record_defaults=record_defaults, header=True, ) @@ -283,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, - expected_err_re='Expect 2 fields but have 0 in record', + expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults) def testCsvDataset_withChainedOps(self): @@ -300,7 +342,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields - record_defaults = [dtypes.float32, dtypes.float32] + record_defaults = [dtypes.float32, [0.0]] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, @@ -308,71 +350,270 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, ) + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self.setup_files(data) + + with ops.Graph().as_default() as g: + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + next_batch = ds.make_one_shot_iterator().get_next() + + with self.test_session(graph=g) as sess: + result = list(sess.run(next_batch).values()) + + self.assertEqual(result, sorted(result)) + +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 - def _setUp(self): + def _setUp(self, str_val): # Since this isn't test.TestCase, have to manually create a test dir gfile.MakeDirs(googletest.GetTempDir()) self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._num_cols = [4, 64, 256] - self._batch_size = 500 + self._num_per_iter = 5000 self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) - with open(fn, 'w') as f: - # Just write 10 rows and use `repeat`... - row = ','.join(['1.23456E12' for _ in range(n)]) - f.write('\n'.join([row for _ in range(10)])) + with open(fn, 'wb') as f: + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) self._filenames.append(fn) def _tearDown(self): gfile.DeleteRecursively(self._temp_dir) def _runBenchmark(self, dataset, num_cols, prefix): - next_element = dataset.make_one_shot_iterator().get_next() - with session.Session() as sess: - for _ in range(5): - sess.run(next_element) - deltas = [] - for _ in range(10): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. sess.run(next_element) end = time.time() - deltas.append(end - start) - median_wall_time = np.median(deltas) / 100 + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, median_wall_time)) self.report_benchmark( - iters=self._batch_size, + iters=self._num_per_iter, wall_time=median_wall_time, name='%s_with_cols_%d' % (prefix, num_cols)) - def benchmarkBatchThenMap(self): - self._setUp() + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop - dataset = dataset.batch(self._batch_size) - self._runBenchmark(dataset, num_cols, 'csv_map_then_batch') + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') self._tearDown() - def benchmarkCsvDataset(self): - self._setUp() + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop - dataset = dataset.batch(self._batch_size) - self._runBenchmark(dataset, num_cols, 'csv_fused_dataset') + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') self._tearDown() + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a842502cc6fe3605dde0be5f50cf46e3e37d7ed4..a2ab3de52e8e512e3cba399f7a1725e5570cfd01 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -70,63 +66,5 @@ class DatasetConstructorTest(test.TestCase): # pylint: enable=protected-access -class DatasetConstructorSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_tensor_dataset(self, variable_array): - components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - - return dataset_ops.Dataset.from_tensors(components) - - def testFromTensorsCore(self): - # Equal length components - arr = np.array(1) - num_outputs = 1 - diff_arr = np.array(2) - self.run_core_tests(lambda: self._build_tensor_dataset(arr), - lambda: self._build_tensor_dataset(diff_arr), - num_outputs) - - def _build_tensor_slices_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components) - - def testFromTensorSlicesCore(self): - # Equal length components - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0])) - - diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[5], [6], [7], [8]]), 22), - np.array([1.0, 2.0, 3.0, 4.0])) - - dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), - lambda: self._build_tensor_slices_dataset(diff_comp), 4) - self.run_core_tests( - lambda: self._build_tensor_slices_dataset(dict_components), None, 3) - - def _build_sparse_tensor_slice_dataset(self, slices): - indices = np.array( - [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], - dtype=np.int64) - values = np.array([val for s in slices for val in s], dtype=np.float64) - dense_shape = np.array( - [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) - sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) - return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) - - def testFromSparseTensorSlicesCore(self): - slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] - diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] - - self.run_core_tests( - lambda: self._build_sparse_tensor_slice_dataset(slices), - lambda: self._build_sparse_tensor_slice_dataset(diff_slices), - 9, - sparse_tensors=True) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1857de1a96c8f71788a1bf5085ef0605417fe7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -0,0 +1,147 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 5000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 43aa4b1bd02791ff304a990c0bbe8e45534c0c77..44c3325a3db84bb844b7f860a7c925982f1e3d6a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,15 +22,12 @@ import math import threading import time -import numpy as np from six.moves import zip_longest -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -39,132 +36,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, input_values, cycle_length, block_length): - repeat_count = 2 - return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( - repeat_count).interleave( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) - - def testSerializationCore(self): - input_values = np.array([4, 5, 6], dtype=np.int64) - num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), - num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # pylint: enable=g-long-lambda - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).interleave( - _interleave_fn, cycle_length=1) - - self.run_core_tests(_build_dataset, None, 20) - - -class ParallelInterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self.input_values = np.array([4, 5, 6], dtype=np.int64) - self.num_repeats = 2 - self.num_outputs = np.sum(self.input_values) * 2 - - def _build_ds(self, cycle_length, block_length, sloppy=False): - return (dataset_ops.Dataset.from_tensor_slices( - self.input_values).repeat(self.num_repeats).apply( - interleave_ops.parallel_interleave( - lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), - cycle_length, block_length, sloppy))) - - def testSerializationCore(self): - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - self.run_core_tests( - lambda: self._build_ds(cycle_length, block_length), - lambda: self._build_ds(cycle_length * 2, block_length * 1), - self.num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - - def testSerializationWithSloppy(self): - break_points = self.gen_break_points(self.num_outputs, 10) - expected_outputs = np.repeat( - np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), - self.num_repeats).tolist() - - def run_test(cycle_length, block_length): - actual = self.gen_outputs( - lambda: self._build_ds(cycle_length, block_length, True), - break_points, self.num_outputs) - self.assertSequenceEqual(sorted(actual), expected_outputs) - - # cycle_length > 1, block_length > 1 - run_test(2, 3) - # cycle_length = 1 - run_test(1, 3) - # block_length = 1 - run_test(2, 1) - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).apply( - interleave_ops.parallel_interleave(_interleave_fn, 1)) - - self.run_core_tests(_build_dataset, None, 20) - - class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -907,114 +778,5 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) -class DirectedInterleaveDatasetTest(test.TestCase): - - def testBasic(self): - selector_dataset = dataset_ops.Dataset.range(10).repeat(100) - input_datasets = [ - dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) - ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - for _ in range(100): - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def _normalize(self, vec): - return vec / vec.sum() - - def _chi2(self, expected, actual): - actual = np.asarray(actual) - expected = np.asarray(expected) - diff = actual - expected - chi2 = np.sum(diff * diff / expected, axis=0) - return chi2 - - def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): - # Create a dataset that samples each integer in `[0, num_datasets)` - # with probability given by `weights[i]`. - dataset = interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(num_datasets) - ], weights) - dataset = dataset.take(num_samples) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - freqs = np.zeros([num_datasets]) - for _ in range(num_samples): - freqs[sess.run(next_element)] += 1 - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - return freqs - - def testSampleFromDatasets(self): - random_seed.set_random_seed(1619) - num_samples = 10000 - rand_probs = self._normalize(np.random.random_sample((15,))) - - # Use chi-squared test to assert that the observed distribution matches the - # expected distribution. Based on the implementation in - # "tensorflow/python/kernel_tests/multinomial_op_test.py". - for probs in [[.85, .05, .1], rand_probs]: - probs = np.asarray(probs) - classes = len(probs) - freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - # Also check that `weights` as a dataset samples correctly. - probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() - freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - def testErrors(self): - with self.assertRaisesRegexp(ValueError, - r"vector of length `len\(datasets\)`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[0.25, 0.25, 0.25, 0.25]) - - with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[1, 1]) - - with self.assertRaisesRegexp(TypeError, "must have the same type"): - interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(0), - dataset_ops.Dataset.from_tensors(0.0) - ]) - - -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py similarity index 100% rename from tensorflow/contrib/data/python/ops/iterator_ops_test.py rename to tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8d4042927970cab2f5a518fc0da49b38444dbcdf..a075dfd8b56079c7b2509bb5795521b8b9eb3127 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -17,24 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools import os +import time import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -143,229 +140,82 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) -class MapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): +class MapDatasetBenchmark(test.Benchmark): - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 14 - self._num_outputs = self._tensor_slice_len * self._num_epochs + def benchmarkMapAndBatch(self): + small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100]) + large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10]) - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) + num_iters = 100 - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) + def benchmark(series): - return ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(self._num_epochs)) - - def testSaveRestoreCore(self): - self.run_core_tests( - self._build_ds, - lambda: self._build_ds(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map(_map_fn) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testSparseCore(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) - - def _build_ds(num_outputs): - return dataset_ops.Dataset.range(num_outputs).map(_sparse) - - num_outputs = 10 - self.run_core_tests(lambda: _build_ds(num_outputs), - lambda: _build_ds(int(num_outputs / 2)), num_outputs) - - -class ParallelMapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 1 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) - - def _build_ds_with_prefetch(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) - - def testSaveRestoreCore(self): - for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: - self.run_core_tests( - ds_fn, - lambda: ds_fn(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map( - _map_fn, num_parallel_calls=2).prefetch(2) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1), - num_parallel_calls=2).prefetch(2)) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - -class IgnoreErrorsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message")).apply( - error_ops.ignore_errors()) - - def testIgnoreErrorsCore(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) - num_outputs = 4 - self.run_core_tests(lambda: self._build_ds(components), - lambda: self._build_ds(diff_components), num_outputs) + for num_calls, inter_op, element_size, batch_size, num_steps in series: + dataset = dataset_ops.Dataset.from_tensors( + np.random.randint(100, size=element_size)).repeat().map( + lambda x: x, + num_parallel_calls=num_calls).batch(batch_size=batch_size) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + fused_dataset = dataset_ops.Dataset.from_tensors( + np.random.randint(100, size=element_size)).repeat(None).apply( + batching.map_and_batch( + lambda x: x, + num_parallel_calls=num_calls, + batch_size=batch_size)) + fused_iterator = fused_dataset.make_one_shot_iterator() + fused_get_next = fused_iterator.get_next() + + fused_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op)) as sess: + + for _ in range(5): + sess.run(fused_get_next) + for _ in range(num_iters): + start = time.time() + for _ in range(num_steps): + sess.run(fused_get_next) + end = time.time() + fused_deltas.append(end - start) + + chained_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op)) as sess: + for _ in range(5): + sess.run(get_next) + for _ in range(num_iters): + start = time.time() + for _ in range(num_steps): + sess.run(get_next) + end = time.time() + chained_deltas.append(end - start) + + chained_wall_time = np.median(chained_deltas) / num_iters + fused_wall_time = np.median(fused_deltas) / num_iters + print( + "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, " + "element size: %d, chained wall time: %f, fused wall time: %f" % + (batch_size, num_calls, inter_op, element_size, chained_wall_time, + fused_wall_time)) + + self.report_benchmark( + iters=num_iters, + wall_time=chained_wall_time, + name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d" + % (batch_size, num_calls, inter_op, element_size)) + + self.report_benchmark( + iters=num_iters, + wall_time=fused_wall_time, + name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d" + % (batch_size, num_calls, inter_op, element_size)) + + benchmark(small) + benchmark(large) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e35be8a23f3706bd170c09b967b4f419fc9a626e --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index b08132cd72254326d965907a1fdafb8a820926a1..40a8e4667678710251a25f906a917ca1eadd21c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -68,6 +68,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): with ops.device(device1): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, + output_types=[dtypes.float32], target_device=target, string_arg=ds_iterator_handle, buffer_size=3, @@ -201,6 +202,49 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) + def testStringsGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/gpu:0" + + ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]) + ds_iterator = ds.make_one_shot_iterator() + ds_iterator_handle = ds_iterator.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, ds.output_types, ds.output_shapes) + return remote_iterator.get_next() + + target = constant_op.constant(device0) + with ops.device(device1): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_remote_fn, + output_types=[dtypes.string], + target_device=target, + string_arg=ds_iterator_handle, + buffer_size=3, + shared_name="strings") + + with ops.device(device1): + prefetch_op = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=buffer_resource_handle, + output_types=[dtypes.string]) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + with self.test_session() as sess: + self.assertEqual([b"a"], sess.run(prefetch_op)) + self.assertEqual([b"b"], sess.run(prefetch_op)) + self.assertEqual([b"c"], sess.run(prefetch_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + class PrefetchToDeviceTest(test.TestCase): @@ -235,6 +279,36 @@ class PrefetchToDeviceTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device( + "/job:localhost/replica:0/task:0/device:CPU:0")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 80e1cb0041024b68bd5268b5de5d69c88c839896..592642da0cfd84e50cb20d9b2e534411faf927e8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -17,21 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -81,88 +73,5 @@ class RangeDatasetTest(test.TestCase): self.assertEqual(-2, sess.run(negative_get_next)) -class RangeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _iterator_checkpoint_prefix_local(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_prefix_local(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix_local()), - dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def testSaveRestore(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Saving and restoring in same session. - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def _build_range_dataset(self, start, stop): - return dataset_ops.Dataset.range(start, stop) - - def testRangeCore(self): - start = 2 - stop = 10 - stop_1 = 8 - self.run_core_tests(lambda: self._build_range_dataset(start, stop), - lambda: self._build_range_dataset(start, stop_1), - stop - start) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1fcb78a69b0b61668f5130931e5822811a439c0d..9df403ef50e459d94b8edf3f651c7c95baf3ec42 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -17,426 +17,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gzip import os -import zlib import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.lib.io import python_io -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat -class TextLineDatasetTestBase(test.TestCase): - - def _lineText(self, f, l): - return compat.as_bytes("%d: %d" % (f, l)) - - def _createFiles(self, - num_files, - num_lines, - crlf=False, - compression_type=None): - filenames = [] - for i in range(num_files): - fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) - filenames.append(fn) - contents = [] - for j in range(num_lines): - contents.append(self._lineText(i, j)) - # Always include a newline after the record unless it is - # at the end of the file, in which case we include it - if j + 1 != num_lines or i == 0: - contents.append(b"\r\n" if crlf else b"\n") - contents = b"".join(contents) - - if not compression_type: - with open(fn, "wb") as f: - f.write(contents) - elif compression_type == "GZIP": - with gzip.GzipFile(fn, "wb") as f: - f.write(contents) - elif compression_type == "ZLIB": - contents = zlib.compress(contents) - with open(fn, "wb") as f: - f.write(contents) - else: - raise ValueError("Unsupported compression_type", compression_type) - - return filenames - - -class TextLineDatasetSerializationTest( - TextLineDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, test_filenames, compression_type=None): - return core_readers.TextLineDataset( - test_filenames, compression_type=compression_type, buffer_size=10) - - def testTextLineCore(self): - compression_types = [None, "GZIP", "ZLIB"] - num_files = 5 - lines_per_file = 5 - num_outputs = num_files * lines_per_file - for compression_type in compression_types: - test_filenames = self._createFiles( - num_files, - lines_per_file, - crlf=True, - compression_type=compression_type) - # pylint: disable=cell-var-from-loop - self.run_core_tests( - lambda: self._build_iterator_graph(test_filenames, compression_type), - lambda: self._build_iterator_graph(test_filenames), num_outputs) - # pylint: enable=cell-var-from-loop - - -class FixedLengthRecordReaderTestBase(test.TestCase): - - def setUp(self): - super(FixedLengthRecordReaderTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - self._header_bytes = 5 - self._record_bytes = 3 - self._footer_bytes = 2 - - def _record(self, f, r): - return compat.as_bytes(str(f * 2 + r) * self._record_bytes) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - for j in range(self._num_records): - f.write(self._record(i, j)) - f.write(b"F" * self._footer_bytes) - return filenames - - -class FixedLengthRecordDatasetSerializationTest( - FixedLengthRecordReaderTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, num_epochs, compression_type=None): - filenames = self._createFiles() - return core_readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, - self._footer_bytes).repeat(num_epochs) - - def testFixedLengthRecordCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - -class TFRecordDatasetTestBase(test.TestCase): - - def setUp(self): - super(TFRecordDatasetTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - - self.test_filenames = self._createFiles() - - self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) - self.num_epochs = array_ops.placeholder_with_default( - constant_op.constant(1, dtypes.int64), shape=[]) - self.compression_type = array_ops.placeholder_with_default("", shape=[]) - self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = core_readers.TFRecordDataset( - self.filenames, self.compression_type).repeat(self.num_epochs) - batch_dataset = repeat_dataset.batch(self.batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - self.init_op = iterator.make_initializer(repeat_dataset) - self.init_batch_op = iterator.make_initializer(batch_dataset) - self.get_next = iterator.get_next() - - def _record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - -class TFRecordDatasetSerializationTest( - TFRecordDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, - num_epochs, - batch_size=1, - compression_type=None, - buffer_size=None): - filenames = self._createFiles() - if compression_type is "ZLIB": - zlib_files = [] - for i, fn in enumerate(filenames): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) - filenames = zlib_files - - elif compression_type is "GZIP": - gzip_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(gzfn, "wb") as gzf: - gzf.write(f.read()) - gzip_files.append(gzfn) - filenames = gzip_files - - return core_readers.TFRecordDataset( - filenames, compression_type, - buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) - - def testTFRecordWithoutBufferCore(self): - num_epochs = 5 - batch_size = num_epochs - num_outputs = num_epochs * self._num_files * self._num_records // batch_size - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, batch_size, - buffer_size=0), - lambda: self._build_iterator_graph(num_epochs * 2, batch_size), - num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, - num_outputs * batch_size) - # pylint: enable=g-long-lambda - - def testTFRecordWithBufferCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - def testTFRecordWithCompressionCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - - -def _interleave(iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 - - -class ReadBatchFeaturesTest(test.TestCase): - - def setUp(self): - super(ReadBatchFeaturesTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - self.test_filenames = self._createFiles() - - def _read_batch_features(self, - filenames, - num_epochs, - batch_size, - reader_num_threads=1, - parser_num_threads=1, - shuffle=False, - shuffle_seed=None, - drop_final_batch=False): - self.filenames = filenames - self.num_epochs = num_epochs - self.batch_size = batch_size - - return readers.make_batched_features_dataset( - file_pattern=self.filenames, - batch_size=self.batch_size, - features={ - "file": parsing_ops.FixedLenFeature([], dtypes.int64), - "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) - }, - reader=core_readers.TFRecordDataset, - num_epochs=self.num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads, - drop_final_batch=drop_final_batch).make_one_shot_iterator( - ).get_next() - - def _record(self, f, r): - example = example_pb2.Example( - features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[f])), - "record": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[r])), - "keywords": - feature_pb2.Feature( - bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) - return example.SerializeToString() - - def _get_keywords(self, f, r): - num_keywords = 1 + (f + r) % 2 - keywords = [] - for index in range(num_keywords): - keywords.append(compat.as_bytes("keyword%d" % index)) - return keywords - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - def _run_actual_batch(self, outputs, sess): - file_op = outputs["file"] - keywords_indices_op = outputs["keywords"].indices - keywords_values_op = outputs["keywords"].values - keywords_dense_shape_op = outputs["keywords"].dense_shape - record_op = outputs["record"] - return sess.run([ - file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op - ]) - - def _next_actual_batch(self, sess): - return self._run_actual_batch(self.outputs, sess) - - def _next_expected_batch(self, - file_indices, - batch_size, - num_epochs, - cycle_length=1): - - def _next_record(file_indices): - for j in file_indices: - for i in range(self._num_records): - yield j, i - - def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) - - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - for _ in range(num_epochs): - if cycle_length == 1: - next_records = _next_record(file_indices) - else: - next_records = _next_record_interleaved(file_indices, cycle_length) - for record in next_records: - f = record[0] - r = record[1] - file_batch.append(f) - record_batch.append(r) - keywords = self._get_keywords(f, r) - keywords_batch_values.extend(keywords) - keywords_batch_indices.extend( - [[batch_index, i] for i in range(len(keywords))]) - batch_index += 1 - keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) - if len(file_batch) == batch_size: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch - ] - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - if file_batch: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch - ] - - def _verify_records(self, - sess, - batch_size, - file_index=None, - num_epochs=1, - interleave_cycle_length=1): - if file_index is not None: - file_indices = [file_index] - else: - file_indices = range(self._num_files) - - for expected_batch in self._next_expected_batch( - file_indices, batch_size, num_epochs, interleave_cycle_length): - actual_batch = self._next_actual_batch(sess) - for i in range(len(expected_batch)): - self.assertAllEqual(expected_batch[i], actual_batch[i]) +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testRead(self): for batch_size in [1, 2]: @@ -444,33 +42,33 @@ class ReadBatchFeaturesTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 1. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from both files. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) @@ -504,18 +102,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with same seed produces the same result. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) + shuffle_seed=5).make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) batch2 = self._run_actual_batch(outputs2, sess) @@ -525,18 +123,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with different seeds produces a different order. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=15) + shuffle_seed=15).make_one_shot_iterator().get_next() all_equal = True for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) @@ -552,13 +150,14 @@ class ReadBatchFeaturesTest(test.TestCase): for parser_num_threads in [2, 4]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads) - self._verify_records( + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( sess, batch_size, num_epochs=num_epochs, @@ -571,11 +170,11 @@ class ReadBatchFeaturesTest(test.TestCase): for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, - drop_final_batch=True) + drop_final_batch=True).make_one_shot_iterator().get_next() for _, tensor in self.outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) @@ -622,14 +221,12 @@ class MakeCsvDatasetTest(test.TestCase): f.close() return fn - def _create_file(self, fileno, header=True, comment=True): + def _create_file(self, fileno, header=True): rows = [] if header: rows.append(self.COLUMNS) for recno in range(self._num_records): rows.append(self._csv_values(fileno, recno)) - if comment: - rows.append("# Some comment goes here. Ignore me.") return self._write_file("csv_file%d.csv" % fileno, rows) def _create_files(self): @@ -650,9 +247,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=False, shuffle_seed=None, header=True, - comment="#", na_value="", - default_float_type=dtypes.float32, ): return readers.make_csv_dataset( filenames, @@ -664,9 +259,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=shuffle, shuffle_seed=shuffle_seed, header=header, - comment=comment, na_value=na_value, - default_float_type=default_float_type, select_columns=select_cols, ) @@ -788,29 +381,6 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, label_name=None) - def testMakeCSVDataset_withNoComments(self): - """Tests that datasets can be created from CSV files with no header line. - """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), comment=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - comment=None, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) - def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ @@ -878,7 +448,7 @@ class MakeCsvDatasetTest(test.TestCase): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with comments, header, etc) + # Test that it works with standard test files (with header, etc) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -891,7 +461,9 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, defaults=[[], [], [], [], [""]]) - # Test on a deliberately tricky file + def testMakeCSVDataset_withTypeInferenceTricky(self): + # Test on a deliberately tricky file (type changes as we read more rows, and + # there are null values) fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, @@ -916,20 +488,29 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float32, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match for i in range(len(expected_dtypes)): + print(features["col%d" % i].dtype, expected_dtypes[i]) assert features["col%d" % i].dtype == expected_dtypes[i] for i in range(len(rows)): assert sess.run(features) == dict(zip(col_names, expected[i])) - # With float64 as default type for floats + def testMakeCSVDataset_withTypeInferenceAllTypes(self): + # Test that we make the correct inference for all types with fallthrough + fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string, dtypes.string ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] + expected = [[ + 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + ]] + self._write_file("file.csv", [col_names] + rows) + with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -938,7 +519,6 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float64, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match @@ -1088,7 +668,30 @@ class MakeCsvDatasetTest(test.TestCase): self.assertFalse(all_equal) -class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): +class MakeTFRecordDatasetTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase): + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 def _next_expected_batch(self, file_indices, @@ -1104,8 +707,8 @@ class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) record_batch = [] batch_index = 0 diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e63bc4c72049c61aa40314ffebe5c4366a818d46 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,331 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing reader datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class FixedLengthRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing FixedLengthRecordDataset.""" + + def setUp(self): + super(FixedLengthRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self._header_bytes = 5 + self._record_bytes = 3 + self._footer_bytes = 2 + + def _record(self, f, r): + return compat.as_bytes(str(f * 2 + r) * self._record_bytes) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + f.write(b"H" * self._header_bytes) + for j in range(self._num_records): + f.write(self._record(i, j)) + f.write(b"F" * self._footer_bytes) + return filenames + + +class ReadBatchFeaturesTestBase(test.TestCase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string) + }, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op + ]) + + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): + actual_batch = self._next_actual_batch(sess) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) + + +class TextLineDatasetTestBase(test.TestCase): + """Base class for setting up and testing TextLineDataset.""" + + def _lineText(self, f, l): + return compat.as_bytes("%d: %d" % (f, l)) + + def _createFiles(self, + num_files, + num_lines, + crlf=False, + compression_type=None): + filenames = [] + for i in range(num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(num_lines): + contents.append(self._lineText(i, j)) + # Always include a newline after the record unless it is + # at the end of the file, in which case we include it + if j + 1 != num_lines or i == 0: + contents.append(b"\r\n" if crlf else b"\n") + contents = b"".join(contents) + + if not compression_type: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + + return filenames + + +class TFRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing TFRecordDataset.""" + + def setUp(self): + super(TFRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + self.test_filenames = self._createFiles() + + self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) + self.num_epochs = array_ops.placeholder_with_default( + constant_op.constant(1, dtypes.int64), shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) + batch_dataset = repeat_dataset.batch(self.batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + self.init_op = iterator.make_initializer(repeat_dataset) + self.init_batch_op = iterator.make_initializer(batch_dataset) + self.get_next = iterator.get_next() + + def _record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index bdc003a8a5bd646e1d5c598befa2694da512d0a9..c5cfddb72b56a1bcffc80c0dd34994def3ee45cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin import time + from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index eb2ceff893543f710d4f0246adf4e6367a2deeb0..42cada0b97bcd9ab755297e8b1f0667766f7999e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,7 +21,6 @@ import itertools import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context @@ -64,7 +63,7 @@ class ScanDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) @@ -168,18 +167,5 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_elements): - return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( - scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) - - def testScanCore(self): - num_output = 5 - self.run_core_tests(lambda: self._build_dataset(num_output), - lambda: self._build_dataset(2), num_output) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..686788522acdf1c5e91132c38bdc81d10d2a0cc2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -0,0 +1,526 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "dataset_serialization_test_base", + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "batch_dataset_serialization_test", + size = "medium", + srcs = ["batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "cache_dataset_serialization_test", + size = "small", + srcs = ["cache_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "concatenate_dataset_serialization_test", + size = "small", + srcs = ["concatenate_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "dataset_constructor_serialization_test", + size = "medium", + srcs = ["dataset_constructor_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "filter_dataset_serialization_test", + size = "medium", + srcs = ["filter_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fixed_length_record_dataset_serialization_test", + size = "medium", + srcs = ["fixed_length_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "flat_map_dataset_serialization_test", + size = "medium", + srcs = ["flat_map_dataset_serialization_test.py"], + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "group_by_reducer_serialization_test", + size = "medium", + srcs = ["group_by_reducer_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "group_by_window_serialization_test", + size = "medium", + srcs = ["group_by_window_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "ignore_errors_serialization_test", + size = "small", + srcs = ["ignore_errors_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interleave_dataset_serialization_test", + size = "medium", + srcs = ["interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "map_and_batch_dataset_serialization_test", + size = "medium", + srcs = ["map_and_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "map_dataset_serialization_test", + size = "medium", + srcs = ["map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_serialization_test", + size = "small", + srcs = ["optimize_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "padded_batch_dataset_serialization_test", + size = "medium", + srcs = ["padded_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:string_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_interleave_dataset_serialization_test", + size = "medium", + srcs = ["parallel_interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_map_dataset_serialization_test", + size = "medium", + srcs = ["parallel_map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "prefetch_dataset_serialization_test", + size = "small", + srcs = ["prefetch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "range_dataset_serialization_test", + size = "small", + srcs = ["range_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sample_from_datasets_serialization_test", + size = "medium", + srcs = ["sample_from_datasets_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "scan_dataset_serialization_test", + size = "small", + srcs = ["scan_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:scan_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sequence_dataset_serialization_test", + size = "medium", + srcs = ["sequence_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "serialization_integration_test", + size = "small", + srcs = ["serialization_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_and_repeat_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_and_repeat_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sql_dataset_serialization_test", + size = "small", + srcs = ["sql_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + ], +) + +py_test( + name = "stats_dataset_serialization_test", + size = "medium", + srcs = ["stats_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "textline_dataset_serialization_test", + size = "medium", + srcs = ["textline_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "tf_record_dataset_serialization_test", + size = "medium", + srcs = ["tf_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "unbatch_dataset_serialization_test", + size = "medium", + srcs = ["unbatch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "unique_dataset_serialization_test", + size = "small", + srcs = ["unique_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:unique", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "zip_dataset_serialization_test", + size = "small", + srcs = ["zip_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af87d8b6083de268fafd4346d2871f14e0f4e7c9 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the BatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + def testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a1100893c7384b0e2bd9fcfdaa8d3698b95d28 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CacheDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class CacheDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.range_size = 10 + self.num_repeats = 3 + self.num_outputs = self.range_size * self.num_repeats + self.cache_file_prefix = 'test' + + def ds_fn(self): + return dataset_ops.Dataset.range(self.range_size).cache( + os.path.join(self.get_temp_dir(), + self.cache_file_prefix)).repeat(self.num_repeats) + + def expected_outputs(self): + return list(range(self.range_size)) * self.num_repeats + + def testCheckpointBeforeOneEpoch(self): + # Generate 5 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointBeforeOneEpochThenRunFewSteps(self): + # Generate 8 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 8, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, range(8)) + + # Restoring from checkpoint and running GetNext should return a + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + def testCheckpointAfterOneEpoch(self): + # Generate 15 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointAfterOneEpochThenRunFewSteps(self): + # Generate 18 entries from iterator but save checkpoint after producing + # 15. + outputs = self.gen_outputs( + self.ds_fn, [15], + 18, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) + + outputs = list(range(10)) + list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): + # Generate 13 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 13, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) + + # Since we ran for more than one epoch, the cache was completely written. + # The ckpt was saved when the iterator was in cache-write mode. Test that + # the iterator falls back to read mode after restoring if the cache has + # been completely written. + + outputs = list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedWriterIterator(self): + # Checkpoint before get_next is called even once. + outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + self.assertSequenceEqual(outputs, []) + + outputs = self.gen_outputs( + self.ds_fn, [], + self.num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedMidwayWriterIterator(self): + # Produce 5 elements and checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint, then produce no elements and checkpoint. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce rest of the elements. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testUnusedCheckpointError(self): + # Produce 5 elements and save ckpt. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + + def testIgnoreCheckpointIfCacheWritten(self): + # Produce 15 elements and save ckpt. This will write the complete cache. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Build the iterator again but do not restore from ckpt. Since the cache + # has already been written we should be able to use it. + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py index 17f2980157ddd0350dafd1d745cbb9b64e65f7c5..96f13d75a31b6762b35062e6cf8c0cdb4d61d2c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ConcatenateDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2139b5c33db69a7ffbdebee74e5824928004b407 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the dataset constructors serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class FromTensorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) + + return dataset_ops.Dataset.from_tensors(components) + + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) + + +class FromTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) + + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array([37.0, 38.0, 39.0, 40.0])) + + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + +class FromSparseTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py similarity index 97% rename from tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 78ecce8f7daaf84002ae78d8d77820755b967d89..393f08850b1865180a8b94e9209b2445b54c8b69 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase): ckpt_saved=False, init_before_restore=False, sparse_tensors=False, - verify_exhausted=True): + verify_exhausted=True, + save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. @@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase): if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True return outputs diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py index b572d6ed770fc0fe0f852359baf343c55966eddd..7c170078a11aadce9e5730437e4c25209bd58edb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FilterDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops @@ -35,7 +35,7 @@ class FilterDatasetSerializationTest( def testFilterCore(self): div = 3 - num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + num_outputs = np.sum([x % 3 != 2 for x in range(100)]) self.run_core_tests(lambda: self._build_filter_range_graph(div), lambda: self._build_filter_range_graph(div * 2), num_outputs) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34392d88d4505175c4562e23d5f0c4116e00b022 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FixedLengthRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class FixedLengthRecordDatasetSerializationTest( + reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): + filenames = self._createFiles() + return core_readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py index f3feecef32e587045be25056815315136a883ca7..16051ffd3fd1e1e7ff419f28109df7bc1f165257 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FlatMapDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..571e0899bbc1f856d66f85c4f6f3ac78aa0b1368 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByReducer serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f86af4084ef61c2f20dbe2fb388a20287676f39d --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByWindow serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..65ae9923b8f64dddcd54afc53e2fa67bc770fc2a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the IgnoreErrors input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3892fe81a1c0d325ddc5f501c2caed4b53f5d5 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the InterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class InterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cd211328fa595c0ce0efe3509e8ba9dc06af80 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab783e5cce95ed63fe64c273abb3846121c7a274 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c03495e34e73018bf9832bf77cdcf038449488 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the OptimizeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac42a461afcb6803a0e033892e74fb84d1e5e58 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the PaddedBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8a584df902180aa7ab020b47ecc749912a3a3a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelInterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb7605be1f230cef4cdae30aa672842a678edf7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelMapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class ParallelMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 1 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) + + def testSaveRestoreCore(self): + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py similarity index 90% rename from tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py index 3d120a3071ef730f21221e3291d8c84385b51aa3..c802402461216de33e7d3232ba38063c27f33557 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the PrefetchDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f5b6cf5db788ad2fd09b7e93d0ae5ebb530a11 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the RangeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix_local(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) + + def testRangeCore(self): + start = 2 + stop = 10 + stop_1 = 8 + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb35ea624c22ad0a9561d774c86247119c4c837 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SampleFromDatasets serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af9ef48c0f3b92f61c097410ef4dfd787292e76a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ScanDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ScanDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py index d0cb203a3afd2775756c8542a1e86faedc5cee53..2afebca0f5849c640044830fff05ebff131e0875 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the sequence datasets serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class SequenceDatasetSerializationTest( +class SkipDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_skip_dataset(self, count): @@ -52,6 +52,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + +class TakeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -79,6 +83,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + +class RepeatDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -117,5 +125,5 @@ class SequenceDatasetSerializationTest( None, 0) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py index 0a6b74dc3eb80a6168117beed06935737198cecb..992d996a485de94ad55305552e42c7fbc92ec64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Integration test for input pipeline serialization.""" +"""Integration test for dataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -26,7 +26,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib -class MultipleInputPipelinesTest(test.TestCase): +class SerializationIntegrationTest(test.TestCase): def _build_input_pipeline(self, name, num_outputs): with ops.name_scope(name): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f199ec835ef1c72e2c3f8b3b1cc4f5fe6ea0b6f4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleAndRepeatDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d46c762aaaadc4314a10acc5aeb7ace7df5002a8 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) + + def testShuffleCore(self): + + seed = 55 + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..93b26ed58a065de2074906528a0f49d696a813ff --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SqlDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetSerializationTest( + sql_dataset_op_test_base.SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..14cd3e9c4a72cc7832f9bb1cb49c72a8a7cb2dcd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the StatsDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + # pylint: enable=g-long-lambda + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + # pylint: enable=g-long-lambda + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2483787f44f913199e3f2aa46d181d609a4a9a8f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TextLineDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TextLineDatasetSerializationTest( + reader_dataset_ops_test_base.TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, test_filenames, compression_type=None): + return core_readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + + def testTextLineCore(self): + compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..55a6257a274cd7f78e3818943627cfa09a185fd7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TFRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TFRecordDatasetSerializationTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type == "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type == "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return core_readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a5a8a20dd7a9f891b07351570006636ca34bd0 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UnbatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22f15b88464a770207dc7c6f0387d73ea3d5c2e4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UniqueDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UniqueDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py index e39fa957f0bbb9d3671274d5f58b993e8399814b..340a6ff72e6813c3743d3d83a72ac12d4a392b66 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ZipDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index bcc644c0971854d948025009dc7add2fea214048..3c11d7a97fc9a4b2b8b19a8e82ad5e9037d6bbcd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -19,7 +19,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -27,60 +26,25 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import test -class ShuffleDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_shuffle_dataset( - self, - range_limit=10, - num_repeats=5, - buffer_size=5, - seed=None, - reshuffle_each_iteration=None, - ): - return dataset_ops.Dataset.range(range_limit).shuffle( - buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) - - def testShuffleCore(self): - - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False - # pylint: disable=cell-var-from-loop - # pylint: disable=g-long-lambda - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda - - -class ShuffleAndRepeatTest( - dataset_serialization_test_base.DatasetSerializationTestBase): +class ShuffleAndRepeatTest(test.TestCase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): + get_next = ds_fn().make_one_shot_iterator().get_next() + outputs = [] + with self.test_session() as sess: + for _ in range(num_outputs): + outputs.append(sess.run(get_next)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + return outputs + def testCorrectOutput(self): - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertSequenceEqual( sorted(output), sorted( np.array([range(20) for _ in range(5)]).flatten())) @@ -89,53 +53,53 @@ class ShuffleAndRepeatTest( def testReshuffling(self): # Check that the output orders of different epochs are indeed different. - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) for i in range(4): epoch1 = output[i * 20:(i + 1) * 20] epoch2 = output[(i + 1) * 20:(i + 2) * 20] self.assertNotEqual(epoch1, epoch2) def testSameOrderForSameSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertEqual(output1, output2) def testDifferentOrderForDifferentSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(20), 100) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountNone(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=None), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountMinusOne(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testInfiniteOutputs(self): # Asserting the iterator is exhausted after producing 100 items should fail. with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None), 100) with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) def testInfiniteEmpty(self): with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), - [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + 100) with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], - 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), + 100) def testLargeBufferSize(self): with ops.Graph().as_default() as g: @@ -146,17 +110,5 @@ class ShuffleAndRepeatTest( sess.run(get_next_op) -class ShuffleAndRepeatSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, seed): - return dataset_ops.Dataset.range(20).apply( - shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) - - def testCore(self): - self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), - 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 33c48e20bea53b88d69a59e715af38b22dd2cbd4..5590a4bf783d12b0d0710c0130b0b1df921c9baa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -58,6 +58,7 @@ class SlideDatasetTest(test.TestCase): [t.shape.as_list() for t in get_next]) with self.test_session() as sess: + # stride < window_size. # Slide over a finite input, where the window_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) @@ -71,11 +72,9 @@ class SlideDatasetTest(test.TestCase): result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Slide over a finite input, where the window_size does not # divide the total number of elements. sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) - num_batches = (20 * 7 - 17) // 9 + 1 for i in range(num_batches): result = sess.run(get_next) @@ -86,6 +85,41 @@ class SlideDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + # stride == window_size. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14}) + num_batches = 20 * 7 // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # stride > window_size. + sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14}) + num_batches = 20 * 7 // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(10): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + # Drop the last batch which is smaller than window_size. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19}) + num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*19 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + # Slide over a finite input, which is less than window_size, # should fail straight away. sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) @@ -108,10 +142,6 @@ class SlideDatasetTest(test.TestCase): # Invalid stride should be an initialization time error. with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 4148addf2878c99f47ebe1454edf69ad7f38dfbc..2c2cfbebff5d3eba00f120467102b4185d81ab24 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -18,83 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -import sqlite3 - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTestBase(test.TestCase): - - def _createSqlDataset(self, output_types, num_repeats=1): - dataset = readers.SqlDataset(self.driver_name, self.data_source_name, - self.query, output_types).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - return init_op, get_next - - def setUp(self): - self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - self.driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - self.query = array_ops.placeholder(dtypes.string, shape=[]) - - conn = sqlite3.connect(self.data_source_name) - c = conn.cursor() - c.execute("DROP TABLE IF EXISTS students") - c.execute("DROP TABLE IF EXISTS people") - c.execute("DROP TABLE IF EXISTS townspeople") - c.execute( - "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " - "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " - "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " - "favorite_big_number INTEGER, favorite_negative_number INTEGER, " - "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " - "account_balance INTEGER, registration_complete INTEGER)") - c.executemany( - "INSERT INTO students (first_name, last_name, motto, school_id, " - "favorite_nonsense_word, desk_number, income, favorite_number, " - "favorite_big_number, favorite_negative_number, " - "favorite_medium_sized_number, brownie_points, account_balance, " - "registration_complete) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, - 9223372036854775807, -2, 32767, 0, 0, 1), - ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, - -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) - c.execute( - "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") - c.executemany( - "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", - [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", - "California")]) - c.execute( - "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " - "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " - "FLOAT, accolades FLOAT, triumphs FLOAT)") - c.executemany( - "INSERT INTO townspeople (first_name, last_name, victories, " - "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", - [("George", "Washington", 20.00, - 1331241.321342132321324589798264627463827647382647382643874, - 9007199254740991.0), - ("John", "Adams", -19.95, - 1331241321342132321324589798264627463827647382647382643874.0, - 9007199254740992.0)]) - conn.commit() - conn.close() - - -class SqlDatasetTest(SqlDatasetTestBase): +class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # Test that SqlDataset can read from a database table. def testReadResultSet(self): @@ -656,27 +586,5 @@ class SqlDatasetTest(SqlDatasetTestBase): sess.run(get_next) -class SqlDatasetSerializationTest( - SqlDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_repeats): - data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - query = ("SELECT first_name, last_name, motto FROM students ORDER BY " - "first_name DESC") - output_types = (dtypes.string, dtypes.string, dtypes.string) - return readers.SqlDataset(driver_name, data_source_name, query, - output_types).repeat(num_repeats) - - def testSQLSaveable(self): - num_repeats = 4 - num_outputs = num_repeats * 2 - self.run_core_tests(lambda: self._build_dataset(num_repeats), - lambda: self._build_dataset(num_repeats // 2), - num_outputs) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5c725a9269e80311f3e73c51c28ab80e7c4815 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing SqlDataset.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sqlite3 + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetTestBase(test.TestCase): + """Base class for setting up and testing SqlDataset.""" + + def _createSqlDataset(self, output_types, num_repeats=1): + dataset = readers.SqlDataset(self.driver_name, self.data_source_name, + self.query, output_types).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return init_op, get_next + + def setUp(self): + self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + self.driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + self.query = array_ops.placeholder(dtypes.string, shape=[]) + + conn = sqlite3.connect(self.data_source_name) + c = conn.cursor() + c.execute("DROP TABLE IF EXISTS students") + c.execute("DROP TABLE IF EXISTS people") + c.execute("DROP TABLE IF EXISTS townspeople") + c.execute( + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " + "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " + "favorite_big_number INTEGER, favorite_negative_number INTEGER, " + "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " + "account_balance INTEGER, registration_complete INTEGER)") + c.executemany( + "INSERT INTO students (first_name, last_name, motto, school_id, " + "favorite_nonsense_word, desk_number, income, favorite_number, " + "favorite_big_number, favorite_negative_number, " + "favorite_medium_sized_number, brownie_points, account_balance, " + "registration_complete) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, + 9223372036854775807, -2, 32767, 0, 0, 1), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, + -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) + c.execute( + "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") + c.executemany( + "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", + [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", + "California")]) + c.execute( + "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " + "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " + "FLOAT, accolades FLOAT, triumphs FLOAT)") + c.executemany( + "INSERT INTO townspeople (first_name, last_name, victories, " + "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", + [("George", "Washington", 20.00, + 1331241.321342132321324589798264627463827647382647382643874, + 9007199254740991.0), + ("John", "Adams", -19.95, + 1331241321342132321324589798264627463827647382647382643874.0, + 9007199254740992.0)]) + conn.commit() + conn.close() + + diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 5c74ed6ae7210e8e22efb6e8fdb773397459ce1e..b4945685c1d1062bf416b73f1541f351adf45604 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTest(test.TestCase): +class StatsDatasetTestBase(test.TestCase): def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() @@ -49,6 +49,9 @@ class StatsDatasetTest(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + +class StatsDatasetTest(StatsDatasetTestBase): + def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -193,68 +196,44 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) -class StatsDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset_bytes_stats(self, num_elements): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) - - def test_bytes_produced_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.bytes_produced_stats(["bytes_produced"])), - None, 100) - - def testBytesStatsDatasetSaveableCore(self): - num_outputs = 100 - self.run_core_tests( - lambda: self._build_dataset_bytes_stats(num_outputs), - lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) +class FeatureStatsDatasetTest( + StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): - def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag)) - - def _build_dataset_multiple_tags(self, - num_elements, - tag1="record_latency", - tag2="record_latency_2"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) - - def test_latency_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats(["record_latency", "record_latency_2"])), - None, 100) - - def testLatencyStatsDatasetSaveableCore(self): - num_outputs = 100 - - self.run_core_tests( - lambda: self._build_dataset_latency_stats(num_outputs), - lambda: self._build_dataset_latency_stats(num_outputs // 10), - num_outputs) - - self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), - None, num_outputs) + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=True).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() - tag1 = "record_latency" - tag2 = "record_latency" - self.run_core_tests( - lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), - None, num_outputs) + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size): + sess.run(next_element) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:features", total_records * 3) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:feature-values", + self._sum_keywords(1) * num_epochs + 2 * total_records) -# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the -# transformation `stats_ops.set_stats_aggregator`, since we don't support -# serializing StatsAggregator yet. if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 9167cb3379bba5cb1ba76a96549395c45dca9e35..0486e2bce20e9dcf81dcb5ac49fe5b397e44bf0c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import threading +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import threadpool @@ -30,9 +31,11 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class OverrideThreadpoolDatasetTest(test.TestCase): +class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - def testNumThreads(self): + @parameterized.parameters((1, None), (2, None), (4, None), (8, None), + (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): # Python creates a dummy thread object to represent the current @@ -42,35 +45,35 @@ class OverrideThreadpoolDatasetTest(test.TestCase): # identifier that maps one-to-one with the underlying OS thread. return np.array(threading.current_thread().ident).astype(np.int64) - for num_threads in [1, 2, 4, 8, 16]: + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) - dataset = ( - dataset_ops.Dataset.range(1000).map( - lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), - num_parallel_calls=32).apply(unique.unique())) + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name="private_thread_pool_%d" % num_threads)) - dataset = threadpool.override_threadpool( - dataset, - threadpool.PrivateThreadPool( - num_threads, display_name="private_thread_pool_%d" % num_threads)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - thread_ids = [] - try: - while True: - thread_ids.append(sess.run(next_element)) - except errors.OutOfRangeError: - pass - self.assertEqual(len(thread_ids), len(set(thread_ids))) - self.assertGreater(len(thread_ids), 0) - # NOTE(mrry): We don't control the thread pool scheduling, and - # so cannot guarantee that all of the threads in the pool will - # perform work. - self.assertLessEqual(len(thread_ids), num_threads) + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index 3c436f7a0b45a13109960e87dd97ca56b10bb871..d79a842e7a5d816e2e6a52fc83acbd6b260cf64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import unique from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -79,18 +78,5 @@ class UniqueDatasetTest(test.TestCase): ]) -class UniqueSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testUnique(self): - - def build_dataset(num_elements, unique_elem_range): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: x % unique_elem_range).apply(unique.unique()) - - self.run_core_tests(lambda: build_dataset(200, 100), - lambda: build_dataset(40, 100), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33d95d67549e1c8d1d9af578fcebbb4f939c418a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -0,0 +1,523 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class WindowDatasetTest(test.TestCase, parameterized.TestCase): + + def _structuredDataset(self, structure, shape, dtype): + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def _structuredElement(self, structure, shape, dtype): + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self._structuredElement(substructure, shape, dtype) + for substructure in structure + ]) + + def _assertEqual(self, xs, ys): + self.assertEqual(type(xs), type(ys)) + if isinstance(xs, tuple) and isinstance(ys, tuple): + self.assertEqual(len(xs), len(ys)) + for x, y in zip(xs, ys): + self._assertEqual(x, y) + elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray): + self.assertAllEqual(xs, ys) + else: + self.assertEqual(xs, ys) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetFlatMap(self, structure, shape, dtype): + """Tests windowing by chaining it with flat map. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return args[0] + return dataset_ops.Dataset.zip( + tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) + + dataset = self._structuredDataset(structure, shape, dtype).apply( + grouping.window_dataset(5)).flat_map(fn) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run(self._structuredElement(structure, shape, dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetBatchDense(self, structure, shape, dtype): + """Tests batching of dense tensor windows. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.batch_window(args[0]) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) + for arg in args + ]) + + dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( + grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredElement(structure, np.concatenate( + ([5], shape), axis=0), dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([]),), + (np.int32([1]),), + (np.int32([1, 2, 3]),), + ) + def testWindowDatasetBatchDenseDynamicShape(self, shape): + """Tests batching of dynamically shaped dense tensor windows. + + Args: + shape: the input shape + """ + + shape_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape_t)).repeat(5).apply( + grouping.window_dataset(5)).apply( + grouping._map_x_dataset(batching.batch_window)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shape_t: shape}) + expected = sess.run( + self._structuredElement(None, np.concatenate(([5], shape), axis=0), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + def _make_dense_to_sparse_fn(self, is_scalar): + + def dense_to_sparse_scalar(tensor): + indices = [[]] + values = array_ops.expand_dims(tensor, 0) + shape = [] + return sparse_tensor.SparseTensorValue(indices, values, shape) + + def dense_to_sparse_non_scalar(tensor): + indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool)) + values = array_ops.gather_nd(tensor, indices) + shape = array_ops.shape(tensor, out_type=dtypes.int64) + return sparse_tensor.SparseTensorValue(indices, values, shape) + + if is_scalar: + return dense_to_sparse_scalar + return dense_to_sparse_non_scalar + + def _structuredSparseDataset(self, structure, shape, dtype): + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + if structure is None: + return dataset_ops.Dataset.from_tensors( + dense_to_sparse(array_ops.zeros(shape, dtype=dtype))) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredSparseDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def _structuredSparseElement(self, structure, shape, dtype): + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + if structure is None: + return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + else: + return tuple([ + self._structuredSparseElement(substructure, shape, dtype) + for substructure in structure + ]) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetBatchSparse(self, structure, shape, dtype): + """Tests batching of sparse tensor windows. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.batch_window(args[0]) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) + for arg in args + ]) + + dataset = self._structuredSparseDataset( + structure, shape, dtype).repeat(5).apply( + grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredSparseElement(structure, + np.concatenate(([5], shape), axis=0), + dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([]),), + (np.int32([1]),), + (np.int32([1, 2, 3]),), + ) + def testWindowDatasetBatchSparseDynamicShape(self, shape): + """Tests batching of dynamically shaped sparse tensor windows. + + Args: + shape: the input shape + """ + + shape_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map( + self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test + grouping.window_dataset(5)).apply( + grouping._map_x_dataset(batching.batch_window)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shape_t: shape}) + expected = sess.run( + self._structuredSparseElement(None, + np.concatenate(([5], shape), axis=0), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + def _structuredRaggedDataset(self, structure, shapes, dtype): + + if structure is None: + return dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredRaggedDataset(substructure, shapes, dtype) + for substructure in structure + ])) + + @parameterized.parameters( + (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), + (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), + ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), + ) + def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, + padded_shape): + """Tests padded batching of dense tensor windows. + + Args: + structure: the input structure + shapes: the input shapes + dtype: the input data type + padded_shape: the shape to pad the output to + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.padded_batch_window(args[0], padded_shape) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( + arg, padded_shape) for arg in args + ]) + + dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + expected = sess.run( + self._structuredElement( + structure, + np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([[1], [2], [3]]), [-1]), + (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + ) + def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): + """Tests padded batching of dynamically shaped dense tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + shapes_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shapes_t: shapes}) + expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + expected = sess.run( + self._structuredElement( + None, np.concatenate((np.int32([len(shapes)]), expected_shape)), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([[1]]), np.int32([0])), + (np.int32([[10], [20]]), np.int32([15])), + ) + def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): + """Tests invalid padded batching of dense tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def _structuredRaggedSparseDataset(self, structure, shapes, dtype): + + def map_fn(shape): + dense_to_sparse = self._make_dense_to_sparse_fn(False) + return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + + if structure is None: + return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredRaggedSparseDataset(substructure, shapes, dtype) + for substructure in structure + ])) + + def _structuredRaggedSparseElement(self, structure, shapes, dtype, + padded_shape): + if structure is None: + dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + values = [] + for shape in shapes: + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + padded_sparse = sparse_tensor.SparseTensor(sparse.indices, + sparse.values, dense_shape) + reshaped_sparse = sparse_ops.sparse_reshape( + padded_sparse, + array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0)) + values.append(reshaped_sparse) + return sparse_ops.sparse_concat(0, values) + else: + return tuple([ + self._structuredRaggedSparseElement(substructure, shapes, dtype, + padded_shape) + for substructure in structure + ]) + + @parameterized.parameters( + (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), + (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), + ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), + ) + def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, + padded_shape): + """Tests padded batching of sparse tensor windows. + + Args: + structure: the input structure + shapes: the input shapes + dtype: the input data type + padded_shape: the shape to pad the output to + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.padded_batch_window(args[0], padded_shape) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( + arg, padded_shape) for arg in args + ]) + + dataset = self._structuredRaggedSparseDataset( + structure, shapes, dtype).apply(grouping.window_dataset( + len(shapes))).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredRaggedSparseElement(structure, shapes, dtype, + padded_shape)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int64([[1], [2], [3]]), [-1]), + (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + ) + def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, + padded_shape): + """Tests padded batching of dynamically shaped sparse tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + shapes_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( + self._make_dense_to_sparse_fn(False) + ).apply(grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shapes_t: shapes}) + expected = sess.run( + self._structuredRaggedSparseElement(None, shapes, dtypes.int32, + padded_shape)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int64([[1]]), [0]), + (np.int64([[10], [20]]), [15]), + ) + def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): + """Tests invalid padded batching of sparse tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( + self._make_dense_to_sparse_fn(False) + ).apply(grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index eceecfd1744d0ae28953a4504450653efa473569..160d7fe22a9f127f7ee23d7a988c22cc4430ce11 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -49,26 +49,6 @@ py_library( ], ) -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - ], -) - py_library( name = "random_ops", srcs = [ @@ -96,8 +76,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", + ":gen_dataset_ops", ":interleave_ops", ":shuffle_ops", + ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", @@ -106,12 +88,12 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -133,6 +115,8 @@ py_library( srcs = ["batching.py"], srcs_version = "PY2AND3", deps = [ + ":get_single_element", + ":grouping", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", @@ -142,6 +126,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", ], @@ -208,6 +193,20 @@ py_library( ], ) +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -368,6 +367,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":optimization", ":prefetching_ops", ":readers", ":resampling", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index b9393de4e90ae2597045b29070934b94e18cfcbd..a4914f4cde71925af477636c91d98b54ce0cce0e 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,18 +17,133 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.util import deprecation + + +def batch_window(dataset): + """Batches a window of tensors. + + Args: + dataset: the input dataset. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + """ + if isinstance(dataset.output_classes, tuple): + raise TypeError("Input dataset expected to have a single component") + if dataset.output_classes is ops.Tensor: + return _batch_dense_window(dataset) + elif dataset.output_classes is sparse_tensor.SparseTensor: + return _batch_sparse_window(dataset) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _batch_dense_window(dataset): + """Batches a window of dense tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return array_ops.shape(first_element) + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, array_ops.shape(value)) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + batch_shape = array_ops.concat([[0], shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _batch_sparse_window(dataset): + """Batches a window of sparse tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return first_element.dense_shape + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, value.dense_shape) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), + math_ops.cast(shape, dtypes.int64)], 0)) + + def batch_reduce_fn(state, value): + return sparse_ops.sparse_concat(0, [state, value]) + + def reshape_fn(value): + return sparse_ops.sparse_reshape( + value, + array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(reshape_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) def dense_to_sparse_batch(batch_size, row_shape): @@ -75,17 +190,168 @@ def dense_to_sparse_batch(batch_size, row_shape): """ def _apply_fn(dataset): - return DenseToSparseBatchDataset(dataset, batch_size, row_shape) + return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) return _apply_fn -class UnbatchDataset(dataset_ops.Dataset): +def padded_batch_window(dataset, padded_shape, padding_value=None): + """Batches a window of tensors with padding. + + Args: + dataset: the input dataset. + padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like + object representing the shape to which the input elements should be padded + prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a + `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the + maximum size of that dimension in each batch. + padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the + padding value to use. Defaults are `0` for numeric types and the empty + string for string types. If `dataset` contains `tf.SparseTensor`, this + value is ignored. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + + Raises: + ValueError: if invalid arguments are provided. + """ + if not issubclass(dataset.output_classes, + (ops.Tensor, sparse_tensor.SparseTensor)): + raise TypeError("Input dataset expected to have a single tensor component") + if issubclass(dataset.output_classes, (ops.Tensor)): + return _padded_batch_dense_window(dataset, padded_shape, padding_value) + elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)): + if padding_value is not None: + raise ValueError("Padding value not allowed for sparse tensors") + return _padded_batch_sparse_window(dataset, padded_shape) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): + """Batches a window of dense tensors with padding.""" + + padded_shape = math_ops.cast( + convert.partial_shape_to_tensor(padded_shape), dtypes.int32) + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return padded_shape + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(array_ops.shape(value), padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", + array_ops.shape(value), padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, array_ops.shape(value)) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + if padding_value is None: + if dataset.output_types == dtypes.string: + padding_value = "" + elif dataset.output_types == dtypes.bool: + padding_value = False + elif dataset.output_types == dtypes.variant: + raise TypeError("Unable to create padding for field of type 'variant'") + else: + padding_value = 0 + + def batch_init_fn(_): + return array_ops.fill( + array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0), + constant_op.constant(padding_value, dtype=dataset.output_types)) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + def pad_fn(value): + shape = array_ops.shape(value) + left = array_ops.zeros_like(shape) + right = padded_shape - shape + return array_ops.pad( + value, array_ops.stack([left, right], 1), constant_values=padding_value) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(pad_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _padded_batch_sparse_window(dataset, padded_shape): + """Batches a window of sparse tensors with padding.""" + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return convert.partial_shape_to_tensor(padded_shape) + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(value.dense_shape, padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", value.dense_shape, + padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, value.dense_shape) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], + 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), padded_shape], 0)) + + def batch_reduce_fn(state, value): + padded_value = sparse_tensor.SparseTensor( + indices=value.indices, values=value.values, dense_shape=padded_shape) + reshaped_value = sparse_ops.sparse_reshape( + padded_value, + array_ops.concat( + [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) + return sparse_ops.sparse_concat(0, [state, reshaped_value]) + + reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, reducer))) + + +class _UnbatchDataset(dataset_ops.Dataset): """A dataset that splits the elements of its input into multiple elements.""" def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(UnbatchDataset, self).__init__() + super(_UnbatchDataset, self).__init__() flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -101,10 +367,7 @@ class UnbatchDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unbatch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -145,7 +408,7 @@ def unbatch(): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" if not sparse.any_sparse(dataset.output_classes): - return UnbatchDataset(dataset) + return _UnbatchDataset(dataset) # NOTE(mrry): We must ensure that any SparseTensors in `dataset` # are normalized to the rank-1 dense representation, so that the @@ -171,12 +434,12 @@ def unbatch(): dataset.output_shapes, dataset.output_classes, allow_unsafe_cast=True) - return UnbatchDataset(restructured_dataset) + return _UnbatchDataset(restructured_dataset) return _apply_fn -def filter_irregular_batches(batch_size): +def _filter_irregular_batches(batch_size): """Transformation that filters out batches that are not of size batch_size.""" def _apply_fn(dataset): @@ -218,6 +481,8 @@ def filter_irregular_batches(batch_size): return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -250,12 +515,16 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time + # after 6/30/2018. batched = dataset.batch(batch_size) - return filter_irregular_batches(batch_size)(batched) + return _filter_irregular_batches(batch_size)(batched) return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.") def padded_batch_and_drop_remainder(batch_size, padded_shapes, padding_values=None): @@ -284,19 +553,21 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` + # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return filter_irregular_batches(batch_size)(batched) + return _filter_irregular_batches(batch_size)(batched) return _apply_fn -class DenseToSparseBatchDataset(dataset_ops.Dataset): +class _DenseToSparseBatchDataset(dataset_ops.Dataset): """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(DenseToSparseBatchDataset, self).__init__() + super(_DenseToSparseBatchDataset, self).__init__() if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % @@ -309,11 +580,8 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + row_shape=convert.partial_shape_to_tensor(self._row_shape), + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -490,10 +758,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 6c21e489f7c35484ebacd465e3b46d6920df5933..d46d96c461ad4cc0ac25a8ddc285cec23d09c682 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse def ignore_errors(): @@ -48,26 +46,23 @@ def ignore_errors(): """ def _apply_fn(dataset): - return IgnoreErrorsDataset(dataset) + return _IgnoreErrorsDataset(dataset) return _apply_fn -class IgnoreErrorsDataset(dataset_ops.Dataset): +class _IgnoreErrorsDataset(dataset_ops.Dataset): """A `Dataset` that silently ignores errors when computing its input.""" def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(IgnoreErrorsDataset, self).__init__() + super(_IgnoreErrorsDataset, self).__init__() self._input_dataset = input_dataset def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 3a07df572748e464284f580d67e3a664e71acdfe..0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -64,10 +64,7 @@ def get_single_element(dataset): nested_ret = nest.pack_sequence_as( dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(sparse.as_dense_types( - dataset.output_types, dataset.output_classes)), - output_shapes=nest.flatten(sparse.as_dense_shapes( - dataset.output_shapes, dataset.output_classes)))) + **dataset_ops.flat_structure(dataset))) return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ea229b5b27b117984e508fa4edc6f1cf713008b4..bd8d398c58cc1825616c1ab5337cf6668c66697e 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -21,12 +21,9 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -58,7 +55,7 @@ def group_by_reducer(key_func, reducer): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByReducerDataset(dataset, key_func, reducer) + return _GroupByReducerDataset(dataset, key_func, reducer) return _apply_fn @@ -116,8 +113,8 @@ def group_by_window(key_func, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) + return _GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) return _apply_fn @@ -152,9 +149,9 @@ def bucket_by_sequence_length(element_length_func, @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown size to maximum length in batch. If `True`, will pad dimensions with - unknown size to bucket boundary, and caller must ensure that the source - `Dataset` does not contain any elements with length longer than - `max(bucket_boundaries)`. + unknown size to bucket boundary minus 1 (i.e., the maximum length in each + bucket), and caller must ensure that the source `Dataset` does not contain + any elements with length longer than `max(bucket_boundaries)`. Returns: A `Dataset` transformation function, which can be passed to @@ -206,7 +203,7 @@ def bucket_by_sequence_length(element_length_func, none_filler = None if pad_to_bucket_boundary: err_msg = ("When pad_to_bucket_boundary=True, elements must have " - "length <= max(bucket_boundaries).") + "length < max(bucket_boundaries).") check = check_ops.assert_less( bucket_id, constant_op.constant(len(bucket_batch_sizes) - 1, @@ -216,7 +213,7 @@ def bucket_by_sequence_length(element_length_func, boundaries = constant_op.constant(bucket_boundaries, dtype=dtypes.int64) bucket_boundary = boundaries[bucket_id] - none_filler = bucket_boundary + none_filler = bucket_boundary - 1 shapes = make_padded_shapes( padded_shapes or grouped_dataset.output_shapes, none_filler=none_filler) @@ -230,39 +227,56 @@ def bucket_by_sequence_length(element_length_func, return _apply_fn -class _VariantDataset(dataset_ops.Dataset): - """A Dataset wrapper for a tf.variant-typed function argument.""" +def _map_x_dataset(map_func): + """A transformation that maps `map_func` across its input. - def __init__(self, dataset_variant, output_types, output_shapes, - output_classes): - super(_VariantDataset, self).__init__() - self._dataset_variant = dataset_variant - self._output_types = output_types - self._output_shapes = output_shapes - self._output_classes = output_classes + This transformation is similar to `tf.data.Dataset.map`, but in addition to + supporting dense and sparse tensor inputs, it also supports dataset inputs. - def _as_variant_tensor(self): - return self._dataset_variant + Args: + map_func: A function mapping a nested structure of tensors and/or datasets + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to another nested structure of tensors and/or + datasets. - @property - def output_classes(self): - return self._output_classes + Returns: + Dataset: A `Dataset`. + """ - @property - def output_shapes(self): - return self._output_shapes + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _MapXDataset(dataset, map_func) - @property - def output_types(self): - return self._output_types + return _apply_fn + + +def window_dataset(window_size): + """A transformation that creates window datasets from the input dataset. + + The resulting datasets will contain `window_size` elements (or + `N % window_size` for the last dataset if `window_size` does not divide the + number of input elements `N` evenly). + + Args: + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of the input dataset to combine into a window. + + Returns: + Dataset: A `Dataset`. + """ + def _apply_fn(dataset): + return _WindowDataset(dataset, window_size) -class GroupByReducerDataset(dataset_ops.Dataset): + return _apply_fn + + +class _GroupByReducerDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a reduction.""" def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(GroupByReducerDataset, self).__init__() + super(_GroupByReducerDataset, self).__init__() self._input_dataset = input_dataset @@ -273,67 +287,27 @@ class GroupByReducerDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret) - if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): - raise ValueError( - "`key_func` must return a single tf.int64 tensor. " - "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func, "tf.contrib.data.group_by_reducer()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" + % (wrapped_func.output_types, wrapped_func.output_shapes)) + self._key_func = wrapped_func.function def _make_init_func(self, init_func): """Make wrapping Defun for init_func.""" - - @function.Defun(dtypes.int64) - def tf_init_func(key): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - ret = init_func(key) - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._state_classes = sparse.get_classes(ret) - self._state_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._state_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._init_func = tf_init_func - self._init_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + init_func, "tf.contrib.data.group_by_reducer()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + self._init_func = wrapped_func.function + self._state_classes = wrapped_func.output_classes + self._state_shapes = wrapped_func.output_shapes + self._state_types = wrapped_func.output_types def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" @@ -343,83 +317,47 @@ class GroupByReducerDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_reduce_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_reduce_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = reduce_func(nested_state_args, nested_input_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - # Extract shape information from the returned values. - flat_new_state = nest.flatten(ret) - flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) - - # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as(self._state_types, - [t.dtype for t in flat_new_state]))) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, - [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - # Use the private method that will execute `tf_reduce_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access - + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, "tf.contrib.data.group_by_reducer()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(wrapped_func.output_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, wrapped_func.output_classes)) + + # Extract and validate type information from the returned values. + for new_state_type, state_type in zip( + nest.flatten(wrapped_func.output_types), + nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, wrapped_func.output_types)) + + # Extract shape information from the returned values. flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ - old.most_specific_compatible_shape(new) - for old, new in zip(flat_state_shapes, flat_new_state_shapes) + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False - for old_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if old_shape.ndims is not None and ( + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( weakened_shape.ndims is None or - old_shape.as_list() != weakened_shape.as_list()): + original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break @@ -427,50 +365,19 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._reduce_func = tf_reduce_func + self._reduce_func = wrapped_func.function self._reduce_func.add_to_graph(ops.get_default_graph()) def _make_finalize_func(self, finalize_func): """Make wrapping Defun for finalize_func.""" - - @function.Defun(*(nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes)))) - def tf_finalize_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes))): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - - ret = finalize_func(nested_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._finalize_func = tf_finalize_func - self._finalize_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + finalize_func, "tf.contrib.data.group_by_reducer()", + input_classes=self._state_classes, input_shapes=self._state_shapes, + input_types=self._state_types) + self._finalize_func = wrapped_func.function + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types @property def output_classes(self): @@ -495,18 +402,15 @@ class GroupByReducerDataset(dataset_ops.Dataset): init_func=self._init_func, reduce_func=self._reduce_func, finalize_func=self._finalize_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) -class GroupByWindowDataset(dataset_ops.Dataset): +class _GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() + super(_GroupByWindowDataset, self).__init__() self._input_dataset = input_dataset @@ -516,74 +420,48 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_window_size_func(self, window_size_func): """Make wrapping Defun for window_size_func.""" - - @function.Defun(dtypes.int64) - def tf_window_size_func(key): - key.set_shape([]) - window_size = ops.convert_to_tensor( - window_size_func(key), dtype=dtypes.int64) - if window_size.dtype != dtypes.int64: - raise ValueError( - "`window_size_func` must return a single tf.int64 tensor.") - return window_size - - self._window_size_func = tf_window_size_func - self._window_size_func.add_to_graph(ops.get_default_graph()) + def window_size_func_wrapper(key): + return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + window_size_func_wrapper, "tf.contrib.data.group_by_window()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`window_size_func` must return a single tf.int64 scalar tensor.") + self._window_size_func = wrapped_func.function def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) - if ret.dtype != dtypes.int64: - raise ValueError("`key_func` must return a single tf.int64 tensor.") - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + def key_func_wrapper(*args): + return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 scalar tensor.") + self._key_func = wrapped_func.function def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" - - @function.Defun(dtypes.int64, dtypes.variant) - def tf_reduce_func(key, window_dataset_variant): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - window_dataset = _VariantDataset( - window_dataset_variant, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - if not isinstance(window_dataset, dataset_ops.Dataset): - raise TypeError("`window_dataset` must return a `Dataset` object.") - output_dataset = reduce_func(key, window_dataset) - if not isinstance(output_dataset, dataset_ops.Dataset): - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_classes = output_dataset.output_classes - self._output_types = output_dataset.output_types - self._output_shapes = output_dataset.output_shapes - return output_dataset._as_variant_tensor() # pylint: disable=protected-access - - self._reduce_func = tf_reduce_func - self._reduce_func.add_to_graph(ops.get_default_graph()) + nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, "tf.contrib.data.reduce_by_window()", + input_classes=(ops.Tensor, nested_dataset), + input_shapes=(tensor_shape.scalar(), nested_dataset), + input_types=(dtypes.int64, nested_dataset), + experimental_nested_dataset_support=True) + if not isinstance( + wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access + raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes + self._reduce_func = wrapped_func.function @property def output_classes(self): @@ -606,10 +484,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) class Reducer(object): @@ -637,3 +512,85 @@ class Reducer(object): @property def finalize_func(self): return self._finalize_func + + +class _MapXDataset(dataset_ops.Dataset): + """A `Dataset` that maps a function over elements in its input.""" + + def __init__(self, input_dataset, map_func): + """See `map_x_dataset()` for details.""" + super(_MapXDataset, self).__init__() + self._input_dataset = input_dataset + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + map_func, + "tf.contrib.data.map_x_dataset()", + input_dataset, + experimental_nested_dataset_support=True) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.map_dataset( + input_t, + self._map_func.captured_inputs, + f=self._map_func, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class _WindowDataset(dataset_ops.Dataset): + """A dataset that creates window datasets from the input elements.""" + + def __init__(self, input_dataset, window_size): + """See `window_dataset()` for more details.""" + super(_WindowDataset, self).__init__() + self._input_dataset = input_dataset + self._window_size = ops.convert_to_tensor( + window_size, dtype=dtypes.int64, name="window_size") + self._output_classes = nest.pack_sequence_as( + input_dataset.output_classes, + [ + dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access + output_classes=output_class, + output_shapes=output_shape, + output_types=output_type) + for output_class, output_shape, output_type in zip( + nest.flatten(input_dataset.output_classes), + nest.flatten(input_dataset.output_shapes), + nest.flatten(input_dataset.output_types)) + ]) + self._output_shapes = self._output_classes + self._output_types = self._output_classes + + def _as_variant_tensor(self): + return gen_dataset_ops.window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._window_size, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 812a50ecbf105393f7e422edbbdf5c87311d72c1..bcc959594a6b311a3c60bb4696ac97be5c448756 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -24,9 +24,9 @@ from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -153,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): return _apply_fn -class DirectedInterleaveDataset(dataset_ops.Dataset): +class _DirectedInterleaveDataset(dataset_ops.Dataset): """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" def __init__(self, selector_input, data_inputs): @@ -170,10 +170,7 @@ class DirectedInterleaveDataset(dataset_ops.Dataset): return gen_dataset_ops.directed_interleave_dataset( self._selector_input._as_variant_tensor(), [data_input._as_variant_tensor() for data_input in self._data_inputs], - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property @@ -239,4 +236,48 @@ def sample_from_datasets(datasets, weights=None, seed=None): selector_input = dataset_ops.Dataset.zip( (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) - return DirectedInterleaveDataset(selector_input, datasets) + return _DirectedInterleaveDataset(selector_input, datasets) + + +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return _DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index f1d0e5cddc2d757e98d5f6d0d73372ebc11eefd5..0d71be66018eeebe60de9deff24ceb6854d209d9 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -170,6 +170,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # `checkpoint_dir` is the same as the model checkpoint directory, there are # no conflicts during restore. self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True def begin(self): # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` @@ -184,7 +185,25 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # pylint: enable=protected-access self._checkpoint_saver_hook.begin() - def after_create_session(self, session, coord): + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access latest_checkpoint_path = saver_lib.latest_checkpoint( @@ -202,6 +221,9 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # pylint: enable=protected-access def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False return self._checkpoint_saver_hook.before_run(run_context) def after_run(self, run_context, run_values): diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..cf896572262929add5ac34d4fc8e4192c1049da3 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class _OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(_OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index e4c9f8b58a2a4390004b0ad318163526b443d44f..21fc17102e16a1f98f2c2e8aa0aeec89989edf67 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -32,15 +32,32 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops -# TODO(rohanj): Add a python class that constructs resource in the __init__ -# method and provides a get_next() that calls the prefetch op. def function_buffering_resource(string_arg, target_device, f, buffer_size, + output_types, container="", shared_name=None, name=None): + """Creates a FunctionBufferingResource. + + A FunctionBufferingResource fills up a buffer by calling a function `f` on + `target_device`. `f` should take in only a single string argument as input. + + Args: + string_arg: The single string argument to the function. + target_device: The device to run `f` on. + f: The function to be executed. + buffer_size: Size of the buffer to be populated. + output_types: The output types generated by the function. + container: (Optional) string. Defaults to "". + shared_name: (Optional) string. + name: (Optional) string to name the op. + + Returns: + Handle to a FunctionBufferingResource. + """ if shared_name is None: shared_name = "" return gen_dataset_ops.function_buffering_resource( @@ -50,7 +67,8 @@ def function_buffering_resource(string_arg, f=f, buffer_size=buffer_size, container=container, - name=name) + name=name, + output_types=output_types) def function_buffering_resource_get_next(function_buffer_resource, @@ -123,7 +141,10 @@ class _PrefetchToDeviceIterator(object): target_device=iterator_device, string_arg=input_iterator_handle, buffer_size=buffer_size, - shared_name=shared_name) + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes))) if not self._one_shot: reset_op = function_buffering_resource_reset(self._buffering_resource) @@ -212,6 +233,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): with ops.device(device): self._buffering_resource = function_buffering_resource( f=_prefetch_fn, + output_types=self._flat_output_types, target_device=gen_dataset_ops.iterator_get_device(self._resource), string_arg=input_iterator_handle, buffer_size=buffer_size, diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 28ef5e50f39dd7d1b6f124e58e068fc968ddd6dc..e670c4c8354f4067eb21c9b1fce708147c162967 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -39,10 +37,7 @@ class RandomDataset(dataset_ops.Dataset): return gen_dataset_ops.random_dataset( seed=self._seed, seed2=self._seed2, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 2c57d11cbbdcf40a0afd07f8114f83561e711b1c..83095c7ba1c6465d18490e5197f71bf7f1fe2497 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,8 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import csv -from math import ceil import numpy as np @@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import convert @@ -36,9 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -70,7 +69,7 @@ def _is_valid_float(str_val, float_dtype): return False -def _infer_type(str_val, na_value, prev_type, float_dtype): +def _infer_type(str_val, na_value, prev_type): """Given a string, infers its tensor type. Infers the type of a value by picking the least 'permissive' type possible, @@ -81,29 +80,33 @@ def _infer_type(str_val, na_value, prev_type, float_dtype): na_value: Additional string to recognize as a NA/NaN CSV value. prev_type: Type previously inferred based on values of this column that we've seen up till now. - float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type - to parse float strings as. Returns: Inferred dtype. """ if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type return prev_type - if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): - return dtypes.int32 + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most - if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, - dtypes.int64): - return dtypes.int64 + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions - if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: - return float_dtype + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] - return dtypes.string - -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment): +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): """Generator that yields rows of CSV file(s) in order.""" for fn in filenames: with file_io.FileIO(fn, "r") as f: @@ -115,9 +118,6 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, next(rdr) # Skip header lines for csv_row in rdr: - if comment is not None and csv_row[0].startswith(comment): - continue # Skip comment lines - if len(csv_row) != num_cols: raise ValueError( "Problem inferring types: CSV row has different number of fields " @@ -126,22 +126,21 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, comment, float_dtype, - num_rows_for_inference, select_columns): + na_value, header, num_rows_for_inference, + select_columns): """Infers column types from the first N valid CSV records of files.""" if select_columns is None: select_columns = range(num_cols) inferred_types = [None] * len(select_columns) for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment)): + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): if num_rows_for_inference is not None and i >= num_rows_for_inference: break for j, col_index in enumerate(select_columns): inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j], float_dtype) + inferred_types[j]) # Replace None's with a default type inferred_types = [t or dtypes.string for t in inferred_types] @@ -318,7 +317,6 @@ def make_csv_dataset( use_quote_delim=True, na_value="", header=True, - comment=None, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, @@ -327,7 +325,6 @@ def make_csv_dataset( num_parallel_reads=1, num_parallel_parser_calls=2, sloppy=False, - default_float_type=dtypes.float32, num_rows_for_inference=100, ): """Reads CSV files into a dataset. @@ -381,9 +378,6 @@ def make_csv_dataset( header: A bool that indicates whether the first rows of provided CSV files correspond to header lines with column names, and should not be included in the data. - comment: An optional character string that marks lines that should not be - parsed as csv records. If this is provided, all lines that start with - this character will not be parsed. num_epochs: An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever. shuffle: A bool that indicates whether the input should be shuffled. @@ -402,8 +396,6 @@ def make_csv_dataset( produced is deterministic prior to shuffling (elements are still randomized if `shuffle=True`. Note that if the seed is set, then order of elements after shuffling is deterministic). Defaults to `False`. - default_float_type: Either `tf.float32` or `tf.float64`. If defaults are - not provided, float-like strings are interpreted to be this type. num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. @@ -425,8 +417,6 @@ def make_csv_dataset( dataset = dataset.shuffle(len(filenames), shuffle_seed) # Clean arguments; figure out column names and defaults - if comment is not None and len(comment) != 1: - raise ValueError("`comment` arg must be a single-character string or None") if column_names is None: if not header: @@ -449,8 +439,7 @@ def make_csv_dataset( # construction time column_defaults = _infer_column_defaults( filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, comment, default_float_type, num_rows_for_inference, - select_columns) + header, num_rows_for_inference, select_columns) if select_columns is not None and len(column_defaults) != len(select_columns): raise ValueError( @@ -464,43 +453,33 @@ def make_csv_dataset( if label_name is not None and label_name not in column_names: raise ValueError("`label_name` provided must be one of the columns.") - # Define map and filter functions - def filter_fn(line): - return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) - def filename_to_dataset(filename): - ds = core_readers.TextLineDataset(filename) - if header: - ds = ds.skip(1) - if comment is not None: - ds = ds.filter(filter_fn) - return ds + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header) - def decode_csv(line): - """Decodes CSV line into features. + def map_fn(*columns): + """Organizes columns into a features dictionary. Args: - line: String tensor corresponding to one csv record. + *columns: list of `Tensor`s corresponding to one csv record. Returns: - A dictionary of feature names to values for that particular record. If + An OrderedDict of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - columns = parsing_ops.decode_csv( - line, - column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - ) - features = dict(zip(column_names, columns)) + features = collections.OrderedDict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label return features - # Read files sequentially or in parallel + # Read files sequentially (if num_parallel_reads=1) or in parallel dataset = dataset.apply( interleave_ops.parallel_interleave( filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) @@ -508,17 +487,12 @@ def make_csv_dataset( dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - # Use map_and_batch for perf - # TODO(b/76425672): use num_parallel_calls for better performance tuning when - # that is added - dataset = dataset.apply( - batching.map_and_batch( - map_func=decode_csv, - batch_size=batch_size, - num_parallel_batches=int( - ceil(num_parallel_parser_calls / batch_size)))) - + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map + dataset = dataset.batch(batch_size=batch_size) + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) + return dataset @@ -781,6 +755,8 @@ def make_batched_features_dataset(file_pattern, dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + dataset = dataset.apply(stats_ops.feature_stats("record_stats")) + if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) else: diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index bad6edd5147d832228c412919f1e6e782aafc40f..182a5c6ff36fcda8c9e2c522cce07bed0c2daec9 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -291,4 +291,4 @@ def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): # TODO(joelshor): Simplify fraction, if possible. a_i = (ratio_l - m) / (max_ratio - m) - return a_i, m \ No newline at end of file + return a_i, m diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index e911ad0fa0541f2d8b991d66182dd002c2ecaab0..ea9dcfe68fa2630d915323fa295031af7d48cdfb 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -22,7 +22,6 @@ import collections from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -67,102 +66,45 @@ class _ScanDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_scan_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_scan_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - print(self._state_classes) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - print(input_dataset.output_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = scan_func(nested_state_args, nested_input_args) - if not isinstance(ret, collections.Sequence) or len(ret) != 2: - raise TypeError("The scan function must return a pair comprising the " - "new state and the output value.") - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - new_state, output_value = ret - - # Extract and validate class information from the returned values. - for t, clazz in zip( - nest.flatten(new_state), nest.flatten(self._state_classes)): - if not isinstance(t, clazz): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, - nest.pack_sequence_as( - self._state_types, - [type(t) for t in nest.flatten(new_state)]))) - self._output_classes = sparse.get_classes(output_value) - - # Extract shape information from the returned values. - flat_new_state_shapes.extend( - [t.get_shape() for t in nest.flatten(new_state)]) - self._output_shapes = nest.pack_sequence_as( - output_value, [t.get_shape() for t in nest.flatten(output_value)]) - - # Extract and validate type information from the returned values. - for t, dtype in zip( - nest.flatten(new_state), nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as( - self._state_types, - [t.dtype for t in nest.flatten(new_state)]))) - self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in nest.flatten(output_value)]) - - # Serialize any sparse tensors. - new_state = nest.pack_sequence_as(new_state, [ - t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) - ]) - output_value = nest.pack_sequence_as(output_value, [ - t for t in nest.flatten( - sparse.serialize_sparse_tensors(output_value)) - ]) - return nest.flatten(new_state) + nest.flatten(output_value) - - # Use the private method that will execute `tf_scan_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + scan_func, "tf.contrib.data.scan()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + if not ( + isinstance(wrapped_func.output_types, collections.Sequence) and + len(wrapped_func.output_types) == 2): + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + + new_state_classes, self._output_classes = wrapped_func.output_classes + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(new_state_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, new_state_classes)) + + # Extract and validate type information from the returned values. + new_state_types, self._output_types = wrapped_func.output_types + for new_state_type, state_type in zip( + nest.flatten(new_state_types), nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, new_state_types)) + + # Extract shape information from the returned values. + new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -178,12 +120,10 @@ class _ScanDataset(dataset_ops.Dataset): break if need_to_rerun: - # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun - # `tf_scan_func`. self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._scan_func = tf_scan_func + self._scan_func = wrapped_func.function self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): @@ -193,10 +133,7 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index f35795abd38000b13cec0f08596e2ff66e86286c..d7f8a73fe3d67bb83e44e962832ce34c116aef66 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -56,10 +54,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): count=self._count, seed=self._seed, seed2=self._seed2, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 19cc3cb89fc5c494f79ce1d25ed57c92099c8bd2..3f3c5ca17cf6ae22a719ed1d593d98eec37413fb 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -19,7 +19,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -43,10 +42,7 @@ class _SlideDataset(dataset_ops.Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, stride=self._stride, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -90,7 +86,7 @@ def sliding_window_batch(window_size, stride=1): elements in the sliding window. stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the steps moving the sliding window forward for one iteration. The default - is `1`. It must be in `[1, window_size)`. + is `1`. It must be positive. Returns: A `Dataset` transformation function, which can be passed to diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3cbaab5affd7397213b0fbb6b0682db92b99d591..97931f75bd37d9e45864fe477c6e1620b5e4f193 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,13 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. @@ -97,10 +97,7 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): @@ -115,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return self._input_dataset.output_classes -# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def set_stats_aggregator(stats_aggregator): """Set the given stats_aggregator for aggregating the input dataset stats. @@ -133,6 +131,8 @@ def set_stats_aggregator(stats_aggregator): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. @@ -155,6 +155,8 @@ def bytes_produced_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def latency_stats(tag): """Records the latency of producing each element of the input dataset. @@ -176,6 +178,29 @@ def latency_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. +def feature_stats(tag): + """Records the features stats from `Example` records of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will be + associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag) + + return _apply_fn + + class _StatsDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and also records statistics.""" @@ -189,10 +214,7 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 56f67e1766bbaff680bdff6b939df0c3ba68c679..9af1e784ffb4f6d71da25f09d60343b649c5079b 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -22,8 +22,6 @@ import threading from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.ops import resource_variable_ops @@ -39,22 +37,28 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class PrivateThreadPool(object): """A stateful resource that represents a private thread pool.""" - def __init__(self, num_threads, display_name=None): + def __init__(self, num_threads, display_name=None, + max_intra_op_parallelism=1): """Creates a `PrivateThreadPool` with the given number of threads.""" if context.executing_eagerly(): shared_name = _generate_shared_name("privatethreadpool") self._resource = gen_dataset_ops.thread_pool_handle( num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, display_name=display_name, shared_name=shared_name) self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=context.context().device_name) else: self._resource = gen_dataset_ops.thread_pool_handle( - num_threads=num_threads, display_name=display_name) + num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name=display_name) class _ThreadPoolDataset(dataset_ops.Dataset): @@ -69,10 +73,7 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return gen_dataset_ops.thread_pool_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): @@ -87,6 +88,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return self._input_dataset.output_classes +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def override_threadpool(dataset, thread_pool): """Returns a new dataset that uses the given thread pool for its operations. diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 765ef3f9b6d42c9d7af3ce4916731d37d65c9260..e0ce0a4ef15f6b9181bce92fb4d73bf1fab2e66c 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -44,17 +42,17 @@ def unique(): """ def _apply_fn(dataset): - return UniqueDataset(dataset) + return _UniqueDataset(dataset) return _apply_fn -class UniqueDataset(dataset_ops.Dataset): +class _UniqueDataset(dataset_ops.Dataset): """A `Dataset` contains the unique elements from its input.""" def __init__(self, input_dataset): """See `unique()` for details.""" - super(UniqueDataset, self).__init__() + super(_UniqueDataset, self).__init__() self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): @@ -65,10 +63,7 @@ class UniqueDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unique_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 64a77bbed1d55c3d95329d9c7783c2b468bde745..eba0dd0ea330e29db0ea8e68ee14767fcb8ddad0 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -77,6 +77,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -148,6 +149,7 @@ py_library( ], deps = [ ":mirrored_strategy", + ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", @@ -445,10 +447,31 @@ py_library( srcs = ["cross_tower_utils.py"], srcs_version = "PY2AND3", deps = [ + ":values", + "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +cuda_py_test( + name = "cross_tower_utils_test", + srcs = ["cross_tower_utils_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_utils", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", ], ) @@ -476,6 +499,7 @@ cuda_py_test( additional_deps = [ ":combinations", ":cross_tower_ops", + ":multi_worker_test_base", ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", @@ -485,6 +509,7 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", @@ -547,3 +572,40 @@ cuda_py_test( "no_pip", ], ) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:keras", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/keras", + ], + tags = [ + "multi_and_single_gpu", + "notsan", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index d719234cf69fbc6189a8979762521e97c9bca408..9a8ea4aa48b8cf4c5906f18d8bddacc224e0b644 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -41,11 +41,15 @@ from __future__ import print_function from collections import OrderedDict import sys +import types +import unittest from absl.testing import parameterized +import six -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context @@ -67,29 +71,35 @@ def generate(combinations): combinations: a list of dictionaries created using combine() and times(). Restrictions: - -- there should always be a "mode" argument. Accepted values are "eager" - and "graph". + -- the "mode" argument can be either "eager" or "graph". It's "graph" by + default. -- arguments of the test method must match by name to get the corresponding - value of the combination. Tests must accept all arguments (except "mode", - which is optional). - -- distribution argument is special. It is meant for passing instances of - DistributionStrategy. Each instance is to be passed as `(, - )` tuple, where is the number of required - GPUs. If the required number of GPUs for the DistributionStrategy isn't - available then the test case is going to be skipped. + value of the combination. Tests must accept all arguments except the + "mode", "required_tpu" and "required_gpus". + -- "distribution" argument is special and optional. It is meant for passing + instances of DistributionStrategy. Each instance is to be passed as via + `NamedDistribution`. If using "distribution", "required_gpus" and + "required_tpu" should be specified via the NamedDistribution instance, + rather than as separate arguments. + -- "required_tpu" argument is special and optional. If not `None`, then the + test will be skipped if TPUs aren't available. + -- "required_gpus" argument is special and optional. If not `None`, then the + test will be skipped if the specified number of GPUs aren't available. Returns: - a decorator that will cause the test method to be run under the specified - conditions. + a decorator that will cause the test method or the test class to be run + under the specified conditions. Raises: - ValueError - if "mode" argument wasn't either "eager" or "graph. + ValueError - if "mode" argument wasn't either "eager" or "graph" or if other + arguments were not accepted by the test method. """ - def decorator(test_function): + def decorator(test_method_or_class): """The decorator to be returned.""" # Generate good test names that can be used with --test_filter. + named_combinations = [] for combination in combinations: # We use OrderedDicts in `combine()` and `times()` to ensure stable # order of keys in each dictionary. @@ -100,59 +110,96 @@ def generate(combinations): "".join(filter(str.isalnum, str(value)))) for key, value in combination.items() ]) - combination.update({"testcase_name": "_test{}".format(name)}) - - @parameterized.named_parameters(*combinations) - def decorated(self, **kwargs): - """A wrapped test method that sets up `test_function`.""" - assert "mode" in kwargs - mode = kwargs["mode"] - - if "distribution" in kwargs: - distribution = kwargs["distribution"] - kwargs["distribution"] = distribution.strategy - if distribution.required_tpu and not TPU_TEST: - self.skipTest("Test requires a TPU, but it's not available.") - if not distribution.required_tpu and TPU_TEST: - self.skipTest("Test that doesn't require a TPU.") - - if not distribution.required_gpus: - if GPU_TEST: - self.skipTest("Test that doesn't require GPUs.") - elif context.num_gpus() < distribution.required_gpus: - self.skipTest( - "{} GPUs are not available for this test. {} GPUs are available". - format(distribution.required_gpus, context.num_gpus())) - - requested_arguments = tf_inspect.getfullargspec(test_function).args - missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( - set(requested_arguments + ["mode"])) - if missing_arguments: - raise ValueError("The test is missing arguments {} .".format( - missing_arguments)) - - kwargs_to_pass = {} - for arg in requested_arguments: - if arg == "self": - kwargs_to_pass[arg] = self - else: - kwargs_to_pass[arg] = kwargs[arg] - - if mode == "eager": - with context.eager_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - elif mode == "graph": - with context.graph_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - else: - raise ValueError( - "'mode' has to be either 'eager' or 'graph' and not {}".format( - mode)) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) + + if isinstance(test_method_or_class, type): + class_object = test_method_or_class + class_object._test_method_ids = test_method_ids = {} + for name, test_method in six.iteritems(class_object.__dict__.copy()): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + isinstance(test_method, types.FunctionType)): + delattr(class_object, name) + methods = {} + parameterized._update_class_dict_for_param_test_case( + class_object.__name__, methods, test_method_ids, name, + parameterized._ParameterizedTestIter( + _augment_with_special_arguments(test_method), + named_combinations, parameterized._NAMED, name)) + for method_name, method in six.iteritems(methods): + setattr(class_object, method_name, method) + + return class_object + else: + test_method = _augment_with_special_arguments(test_method_or_class) + return parameterized.named_parameters(*named_combinations)(test_method) - return decorated return decorator +def _augment_with_special_arguments(test_method): + def decorated(self, **kwargs): + """A wrapped test method that treats some arguments in a special way.""" + mode = kwargs.pop("mode", "graph") + + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. + requested_arguments = tf_inspect.getfullargspec(test_method).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with ops.Graph().as_default(), context.eager_mode(): + test_method(**kwargs_to_pass) + elif mode == "graph": + with ops.Graph().as_default(), context.graph_mode(): + test_method(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + return decorated + + def combine(**kwargs): """Generate combinations based on its keyword arguments. @@ -160,7 +207,8 @@ def combine(**kwargs): can be computed using `times()`. Args: - **kwargs: keyword arguments of form `option=[possibilities, ...]`. + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. Returns: a list of dictionaries for each combination. Keys in the dictionaries are @@ -179,6 +227,8 @@ def combine(**kwargs): key = first[0] values = first[1] + if not isinstance(values, list): + values = [values] return [ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) @@ -240,9 +290,9 @@ class NamedObject(object): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus=None, + def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False): - self._distribution = distribution + self._distribution_fn = distribution_fn self._name = name self._required_gpus = required_gpus self._required_tpu = required_tpu @@ -252,7 +302,7 @@ class NamedDistribution(object): @property def strategy(self): - return self._distribution + return self._distribution_fn() @property def required_gpus(self): @@ -263,32 +313,56 @@ class NamedDistribution(object): return self._required_tpu +# pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( - "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy_single_iteration = NamedDistribution( - "TPUSingleIteration", - tpu_strategy.TPUStrategy(iterations_per_step=1), - required_tpu=True) -tpu_strategy = NamedDistribution( - "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/cpu:0"], prefetch_on_device=False), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/gpu:1"], prefetch_on_device=False), required_gpus=2) +multi_worker_strategy_with_cpu = NamedDistribution( + "MultiWorkerCPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=0), 0) +multi_worker_strategy_with_one_gpu = NamedDistribution( + "MultiWorker1GPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=1), 1) +multi_worker_strategy_with_two_gpus = NamedDistribution( + "MultiWorker2GPUs", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=2), 2) + adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 219b24160f3902fcfa5363cc39a8fc5b30d00308..86aa48cea889c6c2ce169b18bcabb6d08890fbed 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from collections import OrderedDict +from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.python.eager import test @@ -41,6 +42,15 @@ class TestingCombinationsTest(test.TestCase): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_combine_single_parameter(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 2, + "b": 2 + }], combinations.combine(a=[1, 2], b=2)) + def test_add(self): self.assertEqual( [{ @@ -111,5 +121,28 @@ class TestingCombinationsTest(test.TestCase): _ = combinations.times(c1, c2) +@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1])) +class CombineTheTestSuite(parameterized.TestCase): + + def test_add_things(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def test_add_things_one_more(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def not_a_test(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + def _test_but_private(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + # Check that nothing funny happens to a non-callable that starts with "_test". + test_member = 0 + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index c6a1bf6a9f65828c45617ae18a1b0989f9d46225..b0baf0dad1d55eafac5338d1eb43465927e428a1 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import six from tensorflow.contrib.distribute.python import cross_tower_utils @@ -27,11 +28,12 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util -def _validate_destinations(destinations): +def validate_destinations(destinations): if not isinstance(destinations, (value_lib.DistributedValues, six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," @@ -54,7 +56,7 @@ def _validate_value_destination_pairs(value_destination_pairs): # TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. -def _get_devices_from(destinations): +def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) elif isinstance(destinations, six.string_types): @@ -64,7 +66,7 @@ def _get_devices_from(destinations): def _devices_match(left, right): - return set(_get_devices_from(left)) == set(_get_devices_from(right)) + return set(get_devices_from(left)) == set(get_devices_from(right)) def _all_devices_match(value_destination_pairs): @@ -77,17 +79,17 @@ def _all_devices_match(value_destination_pairs): return True -def _simple_broadcast(tensor, destinations): +def _simple_broadcast(value, destinations): index = {} - devices = _get_devices_from(destinations) + devices = get_devices_from(destinations) for d in devices: - with ops.device(d): - index[d] = array_ops.identity(tensor) + index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + value, d) return value_lib.Mirrored(index) def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, - method_string): + aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 @@ -98,7 +100,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, continue count += len(v_list) # Sum within each device before aggregating across devices. - v = math_ops.add_n(v_list) + # TODO(yuefengz): Check whether it helps to use accumulation_fn here. + v = cross_tower_utils.aggregate_tensors_or_indexed_slices( + v_list, math_ops.add_n) else: count += 1 all_values.append(v) @@ -107,12 +111,14 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - if method_string == "sum": - reduced = accumulation_fn(all_values) - elif method_string == "mean": - reduced = accumulation_fn(all_values) / count - else: - raise ValueError("`method_string` must be 'sum' or 'mean'") + reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( + all_values, accumulation_fn) + if aggregation == vs.VariableAggregation.MEAN: + reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( + reduced, count) + elif aggregation != vs.VariableAggregation.SUM: + raise ValueError("`aggregation` must be VariableAggregation.SUM " + "or VariableAggregation.MEAN.") return reduced @@ -122,14 +128,15 @@ class CrossTowerOps(object): def __init__(self): pass - def reduce(self, method_string, per_device_value, destinations=None): + def reduce(self, aggregation, per_device_value, destinations=None): """Reduce `per_device_value` to `destinations`. - It runs the reduction operation defined by `method_string` and put the + It runs the reduction operation defined by `aggregation` and put the result on `destinations`. Args: - method_string: either 'sum' or 'mean' specifying the reduction method. + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. per_device_value: a PerDevice object. destinations: the reduction destinations. @@ -142,17 +149,18 @@ class CrossTowerOps(object): if not isinstance(per_device_value, value_lib.PerDevice): raise ValueError("`per_device_value` must be a `PerDevice` object.") if destinations is not None: - _validate_destinations(destinations) - return self._reduce(method_string, per_device_value, destinations) + validate_destinations(destinations) + return self._reduce(aggregation, per_device_value, destinations) - def batch_reduce(self, method_string, value_destination_pairs): + def batch_reduce(self, aggregation, value_destination_pairs): """Reduce PerDevice objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. Args: - method_string: either 'sum' or 'mean' specifying the reduction method. + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. value_destination_pairs: a list or a tuple of tuples of PerDevice objects and destinations. If a destination is None, then the destinations are set to match the devices of the input PerDevice object. @@ -169,9 +177,9 @@ class CrossTowerOps(object): "tuples of PerDevice objects and destinations") for _, d in value_destination_pairs: if d is not None: - _validate_destinations(d) + validate_destinations(d) - return self._batch_reduce(method_string, value_destination_pairs) + return self._batch_reduce(aggregation, value_destination_pairs) def broadcast(self, tensor, destinations): """Broadcast the `tensor` to destinations. @@ -183,14 +191,14 @@ class CrossTowerOps(object): Returns: a Mirrored object. """ - _validate_destinations(destinations) + validate_destinations(destinations) return self._broadcast(tensor, destinations) - def _reduce(self, method_string, per_device_value, destinations): + def _reduce(self, aggregation, per_device_value, destinations): raise NotImplementedError( "_reduce method must be implemented in descendants.") - def _batch_reduce(self, method_string, value_destination_pairs): + def _batch_reduce(self, aggregation, value_destination_pairs): raise NotImplementedError( "_batch_reduce method must be implemented in descendants.") @@ -216,22 +224,30 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): self.accumulation_fn = accumulation_fn super(ReductionToOneDeviceCrossTowerOps, self).__init__() - def _reduce(self, method_string, per_device_value, destinations): - devices = _get_devices_from(destinations or per_device_value) + def _reduce(self, aggregation, per_device_value, destinations): + devices = get_devices_from(destinations or per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, - self.accumulation_fn, method_string) + self.accumulation_fn, aggregation) return self.broadcast(reduced, devices) - def _batch_reduce(self, method_string, value_destination_pairs): - return [self._reduce(method_string, t, destinations=v) - for t, v in value_destination_pairs] + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] def _group_value_by_device(per_device_values): """Group values into sublists by their devices. - This grouping is needed to call the all-reduce library. + This grouping is needed to call the all-reduce library because it expects a + list of the following form: + [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... + (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... + (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + ... + ] Args: per_device_values: a list of PerDevice obejcts. @@ -250,18 +266,19 @@ def _group_value_by_device(per_device_values): return grouped -def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): +def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before - Mirrored objects are created if method_string is "mean". + Mirrored objects are created if aggregation is "mean". Args: grouped_reduced: a list of lists, each sublist has components for each device, paired with a None. It is the result from cross_tower_utils.aggregate_gradients_using*. destinations: a list of device strings for returned Mirrored objects. - method_string: "mean" or "sum". + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. Returns: a list of Mirrored objects. @@ -269,7 +286,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): index = [{} for _ in range(len(grouped_reduced[0]))] for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): - if method_string == "mean": + if aggregation == vs.VariableAggregation.MEAN: index[i][destinations[d]] = v / len(destinations) else: index[i][destinations[d]] = v @@ -319,7 +336,17 @@ class ConcatAndSplitPacker(object): # TODO(zhengxq): it is also possible to optimize away all the concat # as well. num_splits = self.num_packs - total_grad_size = array_ops.size(concat_grads) + + # The array_ops.size function will sometimes remove static shapes. So if + # all gradient shapes are defined, we use another method to get the + # total size. + # TODO(yuefengz): move this logic to array_ops.size. + if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]): + total_grad_size = sum( + [g.shape.num_elements() for g, _ in tower_grads_and_vars]) + else: + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits split_size_last = total_grad_size - split_size * (num_splits - 1) split_sizes = [split_size] * (num_splits - 1) + [split_size_last] @@ -409,6 +436,31 @@ class AggregateSmallTensorPacker(object): self.packing) +def _pack_tensors(device_grads, + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=0): + """Pack tensors if specified.""" + if num_packs > 0: + tensor_packer = ConcatAndSplitPacker(num_packs) + device_grad_packs = tensor_packer.pack(device_grads) + elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: + tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes, + agg_small_grads_max_group) + device_grad_packs = tensor_packer.pack(device_grads) + else: + tensor_packer = None + device_grad_packs = device_grads + return device_grad_packs, tensor_packer + + +def _unpack_tensors(reduced, tensor_packer=None): + """Unpack tensors if they are packed before all-reduce.""" + if tensor_packer: + return tensor_packer.unpack(reduced) + return reduced + + class AllReduceCrossTowerOps(CrossTowerOps): """Reduction using all reduce.""" @@ -437,70 +489,69 @@ class AllReduceCrossTowerOps(CrossTowerOps): agg_small_grads_max_group: see above. tensors. """ - self.all_reduce_alg = all_reduce_alg - self.num_packs = num_packs - self.agg_small_grads_max_bytes = agg_small_grads_max_bytes - self.agg_small_grads_max_group = agg_small_grads_max_group + self._all_reduce_alg = all_reduce_alg + self._num_packs = num_packs + self._agg_small_grads_max_bytes = agg_small_grads_max_bytes + self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossTowerOps, self).__init__() - def _reduce(self, method_string, per_device_value, destinations): + def _reduce(self, aggregation, per_device_value, destinations): + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) - and not context.executing_eagerly()): - return self._batch_all_reduce(method_string, [per_device_value])[0] + and not context.executing_eagerly() + and not contains_indexed_slices): + return self._batch_all_reduce(aggregation, [per_device_value])[0] else: - devices = _get_devices_from(destinations or per_device_value) + if contains_indexed_slices: + logging.log_first_n( + logging.WARN, + "Efficient allreduce is not supported for IndexedSlices.", 10) + + devices = get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, - math_ops.add_n, method_string) + math_ops.add_n, aggregation) return self.broadcast(reduced, devices) - def _batch_reduce(self, method_string, value_destination_pairs): - if (_all_devices_match(value_destination_pairs) and - not context.executing_eagerly()): - return self._batch_all_reduce(method_string, + def _batch_reduce(self, aggregation, value_destination_pairs): + all_devices_match = _all_devices_match(value_destination_pairs) + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + value_destination_pairs) + if (all_devices_match and not context.executing_eagerly() + and not contains_indexed_slices): + return self._batch_all_reduce(aggregation, [v[0] for v in value_destination_pairs]) else: - if not context.executing_eagerly(): + if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") + return [ - self._reduce(method_string, t, destinations=v) + self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ] - def _batch_all_reduce(self, method_string, per_device_values): + def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) - if self.num_packs > 0: - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s and num_packs = %d", len(per_device_values), - self.all_reduce_alg, self.num_packs) - tensor_packer = ConcatAndSplitPacker(self.num_packs) - device_grad_packs = tensor_packer.pack(grouped) - elif (self.agg_small_grads_max_bytes > 0 and - self.agg_small_grads_max_group > 0): - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self.all_reduce_alg, self.agg_small_grads_max_bytes, - self.agg_small_grads_max_group) - tensor_packer = AggregateSmallTensorPacker( - self.agg_small_grads_max_bytes, self.agg_small_grads_max_group) - device_grad_packs = tensor_packer.pack(grouped) - else: - logging.info( - "batch_all_reduce invoked for batches size = %d with algorithm = %s", - len(per_device_values), self.all_reduce_alg) - tensor_packer = None - device_grad_packs = grouped + + device_grad_packs, tensor_packer = _pack_tensors( + grouped, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) # The actual aggregation of the repacked gradients. Note that they are # sharded among different aggregation trees. So it is important to strike # the balance on num_splits. - if self.all_reduce_alg == "nccl": + if self._all_reduce_alg == "nccl": + # TODO(yuefengz): merge this into the all-reduce library. reduced = cross_tower_utils.aggregate_gradients_using_nccl( device_grad_packs) else: @@ -510,11 +561,135 @@ class AllReduceCrossTowerOps(CrossTowerOps): cross_tower_utils.aggregate_gradients_using_hierarchical_copy( destinations, device_grad_packs)) - if tensor_packer: - reduced = tensor_packer.unpack(reduced) - + reduced = _unpack_tensors(reduced, tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, - method_string) + aggregation) + + +AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", + "alg shards limit") + + +class MultiWorkerAllReduce(AllReduceCrossTowerOps): + """All-reduce algorithms for distributed TensorFlow.""" + + def __init__(self, + worker_devices, + num_gpus_per_worker, + all_reduce_spec=("pscpu/pscpu", 2, -1), + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """Initialize the all-reduce algorithm. + + Args: + worker_devices: a list of device strings for workers participating in + all-reduce. + num_gpus_per_worker: number of GPU devices per worker. + all_reduce_spec: a tuple or a named tuple or a list of tuples specifying + the all-reduce algorithm. + 1. The first element of a tuple is the name of the all-reduce algorithm. + Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", + "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with + a "/" are hierarchical, so two all-reduces are executed, the first one + aggregates tensors within a worker and the second aggregates across + workers. + 2. The second element of a tuple is the number of shards when doing + all-reduce. Let's say its values is M, each tensor after packing will be + split into M shards and then M parallel all-reduces would be performed + before finally they are concatenated backed into a complete tensor. + 3. The third element is the maximum size of tensors that will be + applicable for the algorithm specified by the first element. For + example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], + tensors with size not larger than 1024 bytes will be applied a 2-shard + "nccl" all-reduce and other tensors will be applied a 2-shard + "pscpu/pscpu" algorithm. The third elements should be in increasing + order across tuples and end with -1 which indicates infinity. + num_packs: see AllReduceCrossTowerOps. + agg_small_grads_max_bytes: see AllReduceCrossTowerOps. + agg_small_grads_max_group: see AllReduceCrossTowerOps. + """ + self._worker_devices = worker_devices + self._num_gpus_per_worker = num_gpus_per_worker + super(MultiWorkerAllReduce, self).__init__( + num_packs=num_packs, + agg_small_grads_max_bytes=agg_small_grads_max_bytes, + agg_small_grads_max_group=agg_small_grads_max_group) + + def validate_and_complete_spec(spec): + """Validate and complete the all-reduce spec.""" + # TODO(yuefengz): support namedtuple. + if not isinstance(spec, tuple): + raise ValueError( + "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) + if not spec or len(spec) > 3: + raise ValueError( + "Too many elements in the all-reduce spec tuple: %r" % spec) + if len(spec) == 1: + return AllReduceSpecTuple(spec[0], 1, -1) + elif len(spec) == 2: + return AllReduceSpecTuple(spec[0], spec[1], -1) + else: + return AllReduceSpecTuple(*spec) + + self._all_reduce_spec = [] + if isinstance(all_reduce_spec, six.string_types): + self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) + elif isinstance(all_reduce_spec, tuple): + self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) + elif isinstance(all_reduce_spec, list): + self._all_reduce_spec = [ + validate_and_complete_spec(spec) for spec in all_reduce_spec + ] + + def _batch_all_reduce(self, aggregation, per_device_values): + """All reduce algorithm in a batch.""" + logging.info( + "distributed batch_all_reduce invoked for batches size = %d with " + "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " + "and agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + + destinations = sorted(per_device_values[0].devices) + device_grads = _group_value_by_device(per_device_values) + + # The all reduce library requires fully defined shapes. + # TODO(yuefengz): when tensor sharding is not needed, static shapes are not + # required as well. + for device_grad in device_grads: + for grad, _ in device_grad: + if not grad.shape.is_fully_defined(): + raise ValueError("Shape is unknown for node %r" % grad) + + remaining_grads = device_grads + aggregated_grads = [] + for spec_tuple in self._all_reduce_spec: + if spec_tuple.limit < 0: + this_grads = remaining_grads + remaining_grads = [] + else: + (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( + spec_tuple.limit, remaining_grads) + if this_grads: + device_grad_packs, tensor_packer = _pack_tensors( + this_grads, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( + self._worker_devices, device_grad_packs, len(self._worker_devices), + spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) + range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) + + if not aggregated_grads: + aggregated_grads = range_agg_grads + else: + assert len(aggregated_grads) == len(range_agg_grads) + for i in range(len(aggregated_grads)): + aggregated_grads[i] += range_agg_grads[i] + assert not remaining_grads + + return _ungroup_and_make_mirrored(aggregated_grads, destinations, + aggregation) _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 7c7b0870887465ec2fe40007695d099277db38bf..6a780ff60ffcd59d416278bfde6d005d7ad37a68 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -31,10 +32,12 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.training import device_util def _make_per_device(values, devices): - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) index = {} for d, v in zip(devices, values): @@ -51,24 +54,51 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + _cpu_device = "/device:CPU:0" -class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): +class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): - def _assert_value_equal(self, left, right): + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): if isinstance(left, list): for l, r in zip(left, right): - self._assert_value_equal(l, r) + self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) - self.assertEqual(left.devices, right.devices) - if context.executing_eagerly(): + self.assertEqual(set(left.devices), set(right.devices)) + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.items(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): self.assertEqual([v.numpy() for v in left._index.values()], list(right._index.values())) else: @@ -76,6 +106,81 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) + def _testReductionAndBroadcast(self, cross_tower_ops, distribution): + devices = distribution.worker_devices + + values = [constant_op.constant(float(d)) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = (len(devices) - 1.) / 2. + + values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = mean + 1. + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + cross_tower_ops.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), + _fake_mirrored(mean, destinations or per_device)) + self._assert_values_equal( + cross_tower_ops.reduce( + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device)) + self._assert_values_equal( + cross_tower_ops.reduce( + vs.VariableAggregation.SUM, per_device, + destinations=destinations), + _fake_mirrored(mean * len(devices), destinations or per_device)) + self._assert_values_equal( + cross_tower_ops.reduce( + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices), destinations or per_device)) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ]) + self._assert_values_equal( + cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2) + ]) + + # test broadcast() + for destinations in all_destinations: + if destinations is None: + continue + else: + self._assert_values_equal( + cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + _fake_mirrored(1., destinations)) + + +class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): # TODO(yuefengz): decouple the num_gpus check from distribution in # combinations module so that we can pass in devices instead of a distribution # strategy. @@ -121,100 +226,154 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): @combinations.generate(reduction_to_one_combinations + allreduce_combinations) def testReductionAndBroadcast(self, cross_tower_ops, distribution): - devices = distribution.worker_devices - - values = [constant_op.constant(float(d)) for d in range(len(devices))] - per_device = _make_per_device(values, devices) - mean = (len(devices) - 1.) / 2. - - values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) - mean_2 = mean + 1. - - destination_mirrored = _fake_mirrored(1., devices) - destination_different = _fake_mirrored(1., _cpu_device) - destination_str = _cpu_device - destination_list = devices - - all_destinations = [ - None, destination_mirrored, destination_different, destination_str, - destination_list - ] - - # test reduce() - for destinations in all_destinations: - self._assert_value_equal( - cross_tower_ops.reduce("mean", per_device, destinations=destinations), - _fake_mirrored(mean, destinations or per_device)) - self._assert_value_equal( - cross_tower_ops.reduce( - "mean", per_device_2, destinations=destinations), - _fake_mirrored(mean_2, destinations or per_device)) - self._assert_value_equal( - cross_tower_ops.reduce("sum", per_device, destinations=destinations), - _fake_mirrored(mean * len(devices), destinations or per_device)) - self._assert_value_equal( - cross_tower_ops.reduce( - "sum", per_device_2, destinations=destinations), - _fake_mirrored(mean_2 * len(devices), destinations or per_device)) - - # test batch_reduce() - for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_value_equal( - cross_tower_ops.batch_reduce( - "mean", [(per_device, d1), (per_device_2, d2)]), - [_fake_mirrored(mean, d1 or per_device), - _fake_mirrored(mean_2, d2 or per_device_2)]) - self._assert_value_equal( - cross_tower_ops.batch_reduce( - "sum", [(per_device, d1), (per_device_2, d2)]), - [_fake_mirrored(mean * len(devices), d1 or per_device), - _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) - - # test broadcast() - for destinations in all_destinations: - if destinations is None: - continue - else: - self._assert_value_equal( - cross_tower_ops.broadcast(constant_op.constant(1.), destinations), - _fake_mirrored(1., destinations)) + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) # if devices links contain each device itself device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + result = cross_tower_ops_lib._simple_reduce( + per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate( + combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_tower_ops_instance.batch_reduce(aggregation, + [(per_device, devices)]) + else: + result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if aggregation == vs.VariableAggregation.SUM: + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert aggregation == vs.VariableAggregation.MEAN + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + + +class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, + CrossTowerOpsTestBase): + + worker_devices = [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + multi_worker_allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "MultiWorkerAllReduce", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReducePack", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReduceAggregation", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), + combinations.NamedObject( + "MultiWorkerAllReduceMultipleSpecs", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, [("pscpu/pscpu", 2, 100), + ("xring", 2, -1)], 0, 0, 0)), + ], + distribution=[ + combinations.multi_worker_strategy_with_cpu, + combinations.multi_worker_strategy_with_one_gpu, + combinations.multi_worker_strategy_with_two_gpus + ], + mode=["graph"]) + + @combinations.generate(multi_worker_allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index fc04e2195f6d305e0f7c642f24c355286f1a8cfa..2bb088e704c584598b863b1b836166af2a5bb12c 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -21,9 +21,12 @@ from __future__ import print_function import collections as pycoll from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce +from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -156,6 +159,148 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, return (grad, v), None +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: a list of canonical device strings. + group_size: integer which is equal to or greater than 1. + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size == 0 then each device will appear exactly once. + + Raises: + ValueError: if group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError( + 'only %d devices, but group_size=%d' % (num_devices, group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= threshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with ops.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + math_ops.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, math_ops.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, math_ops.add, math_ops.add_n) + elif alg == 'pscpu/pscpu': + second_gather_devices = aux_devices[:num_shards] + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, math_ops.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, + num_shards, gpu_indices): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + + Returns: + list of reduced tensors + """ + alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']]) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + # Auxiliary devices for hierarchical all-reduces. + aux_device_groups = group_device_names( + aux_devices, num_shards if alg_contains_shuffle else 1) + group_index = 0 + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads + + def extract_ranges(index_list, range_size_limit=32): """Extract consecutive ranges and singles from index_list. @@ -328,7 +473,7 @@ def unpack_small_tensors(tower_grads, packing): for dev_idx, gv_list in enumerate(tower_grads): gv_list = list(gv_list) new_gv_list = gv_list[num_packed:] - for i in xrange(0, num_packed): + for i in range(num_packed): k = '%d:%d' % (dev_idx, i) gpt = packing[k] gv = unpack_grad_tuple(gv_list[i], gpt) @@ -337,3 +482,46 @@ def unpack_small_tensors(tower_grads, packing): new_gv_list.insert(idx, gv[gi]) new_tower_grads.append(new_gv_list) return new_tower_grads + + +def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): + """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" + if any(isinstance(v, ops.IndexedSlices) for v in values): + return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access + else: + return accumulation_fn(values) + + +def divide_by_n_tensors_or_indexed_slices(value, n): + if isinstance(value, ops.IndexedSlices): + value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access + return ops.IndexedSlices( + value.values / n, value.indices, value.dense_shape) + else: + return value / n + + +def copy_tensor_or_indexed_slices_to_device(value, device): + with ops.device(device): + if isinstance(value, ops.IndexedSlices): + copied_values = array_ops.identity(value.values) + copied_indices = array_ops.identity(value.indices) + copied_shape = array_ops.identity(value.dense_shape) + result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) + else: + result = array_ops.identity(value) + return result + + +def contains_indexed_slices(value): + """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" + if isinstance(value, ops.IndexedSlices): + return True + elif isinstance(value, (list, tuple)) and value: + return any(contains_indexed_slices(v) for v in value) + elif isinstance(value, value_lib.DistributedValues): + return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access + elif isinstance(value, value_lib.MapOutput): + return contains_indexed_slices(value.get()) + else: + return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d25964fa41adc7b1c9164a4ffe49c4c5532f76ac --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_tower_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_PerDevice(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_PerDeviceMapOutput(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({ + "/gpu:0": value_lib.MapOutput([t0]), + "/cpu:0": value_lib.MapOutput([t1])}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75ecd90dcffa7a786b78238ef453c4c8e4346afa --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -0,0 +1,148 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras Sequential and Functional models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), + 'keras_mirrored_strategy_test') + gfile.MakeDirs(self._base_dir) + self._config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + + def tearDown(self): + writer_cache.FileWriterCache.clear() + if os.path.isdir(self._base_dir): + gfile.DeleteRecursively(self._base_dir) + + def test_train_functional_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_functional_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + def test_train_sequential_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6bf143098c1bba64d47efce1bfface7682683d --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df..aeeb9553e6044a0a928936597400e582e0329b95 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -56,6 +56,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): is_tpu=[True])) def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) @@ -84,8 +88,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) if is_tpu: with self.test_session() as sess: @@ -111,6 +115,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): is_tpu=[True])) def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + created_variables = [] trainable_variables = [] @@ -186,7 +194,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # towers will re-execute UPDATE_OPS of previous towers. update_ops_in_cross_tower_mode=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], + distribution=[combinations.tpu_strategy], optimizer_fn=[ combinations.gradient_descent_optimizer_v1_fn, combinations.gradient_descent_optimizer_v2_fn @@ -198,6 +206,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm, is_tpu, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): num_towers = len(distribution.worker_devices) model_fn, dataset_fn, batchnorm = batchnorm_example( @@ -242,7 +254,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all towers. @@ -279,12 +291,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], + distribution=[combinations.tpu_strategy], is_tpu=[True], mode=["graph"], use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, use_callable_loss, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): all_vars = [] @@ -329,7 +345,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) - weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w # loss = (predict - y)^2 diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 89f2c431fece63269928fec6aa6d23b5a79ba0b9..dcbc6b0878b89cbb5b9779de315429e6f9478d15 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import threading import six @@ -39,6 +40,16 @@ from tensorflow.python.training import distribute as distribute_lib # TODO(josh11b): Replace asserts in this file with if ...: raise ... +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) @@ -73,9 +84,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? - self._devices = devices - self._canonical_device_set = set( - [device_util.canonicalize(d) for d in devices]) + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops @@ -94,9 +104,39 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - tower_local = kwargs.pop("tower_local_reduce_method", None) - if tower_local is not None: + # Get synchronization value + synchronization = kwargs.get( + "synchronization", variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [ + variable_scope.VariableAggregation.NONE, + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually @@ -108,7 +148,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - kwargs["name"] = "%s/replica_%d" % (var0name, i) + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( @@ -123,11 +166,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert not isinstance(v, values.DistributedVariable) index[d] = v - if tower_local is None: - result = values.MirroredVariable(index, index[devices[0]]) + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], + aggregation) else: - result = values.TowerLocalVariable( - index, index[devices[0]], tower_local) + result = values.MirroredVariable(index, index[devices[0]], aggregation) if not context.executing_eagerly(): g = ops.get_default_graph() @@ -248,8 +291,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: @@ -262,8 +312,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. index = {} - i = 0 - for m in map_over: + for i, m in enumerate(map_over): d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) @@ -286,27 +335,46 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) return self._cross_tower_ops - def _reduce(self, method_string, value, destinations): - if len(self._devices) == 1 and not isinstance(value, values.PerDevice): - value = values.PerDevice({self._devices[0]: value}) - assert isinstance(value, values.PerDevice) + def _reduce(self, aggregation, value, destinations): + assert not isinstance(value, values.Mirrored) + if not isinstance(value, values.PerDevice): + if value == 0: + return 0 + if aggregation == variable_scope.VariableAggregation.MEAN: + return self._broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + if len(self._devices) == 1: + if destinations: + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + raise ValueError("A non PerDevice value cannot be reduced with the given " + "aggregation.") return self._get_cross_tower_ops().reduce( - method_string, value, destinations=destinations) + aggregation, value, destinations=destinations) - def _batch_reduce(self, method_string, value_destination_pairs): - return self._get_cross_tower_ops().batch_reduce(method_string, + def _batch_reduce(self, aggregation, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Also support TowerLocalVariables here? If so, args and - # kwargs don't need to be mirrored. - assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. + assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) @@ -323,32 +391,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - if isinstance(val, values.TowerLocalVariable): - val = self.reduce(val.reduce_method, val, destinations=destination) - with ops.device(destination): - return fn(self.unwrap(val)[0]) - - assert isinstance(val, values.Mirrored), ( - "val = %s (type %s)" % (val, val.__class__.__name__)) - if val.on_device(destination): - with ops.device(destination): - # Use an identity here to make sure we are returning a tensor - # instead of e.g. a variable object. - return array_ops.identity(fn(val.get(destination))) - device = None - for d in self._devices: - if val.on_device(d): - device = d - break - assert device is not None, ( - "Could not find destination %s in list of devices %s." % - (destination, val.devices)) - with ops.device(device): - v = fn(val.get(device)) - with ops.device(destination): - return array_ops.identity(v) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + if isinstance(tower_local_var, values.TowerLocalVariable): + return tower_local_var._get_cross_tower() # pylint: disable=protected-access + assert isinstance(tower_local_var, values.Mirrored) + return array_ops.identity(tower_local_var.get()) def _unwrap(self, val): if isinstance(val, values.DistributedValues): @@ -389,7 +437,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # pylint: disable=protected-access return list(colocate_with._index.keys()) elif isinstance(colocate_with, six.string_types): - return [colocate_with] + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] else: return colocate_with @@ -416,6 +466,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self.merge_args = None self.merge_kwargs = None self.merge_result = None + self.captured_name_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -439,13 +490,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. - self._captured_name_scope = self.graph.get_name_scope() - if self._captured_name_scope: - self._captured_name_scope += "/" + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" if self.tower_id > 0: - if not self._captured_name_scope: - self._captured_name_scope = "" - self._captured_name_scope += "tower_%d/" % self.tower_id + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id def run(self): # pylint: disable=protected-access @@ -458,10 +509,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with self.coord.stop_on_exception(), \ context.context()._mode(self.context_mode), \ context.context().device_policy(self.context_device_policy), \ - self.graph.as_default(), \ + _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ - ops.name_scope(self._captured_name_scope), \ + ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._captured_var_scope, reuse=self.tower_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): @@ -487,6 +538,10 @@ class MirroredTowerContext(distribute_lib.TowerContext): t.merge_fn = fn t.merge_args = args t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" t.has_paused.set() t.should_run.wait() t.should_run.clear() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 3f9a02b249dde9a66056ed8952b664bbc3f74ead..b597bce035493891c3f492bca04abda60c6e8e22 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib + GPU_TEST = "test_gpu" in sys.argv[0] @@ -83,13 +85,13 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self.skipTest("Not GPU test") self.assertEqual(2, self._get_distribution_strategy().num_towers) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): if not GPU_TEST: self.skipTest("Not GPU test") self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRunRegroupError(self): def run_fn(device_id): @@ -101,7 +103,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): with dist.scope(), self.assertRaises(AssertionError): dist.call_for_each_tower(run_fn, dist.worker_device_index) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceToCpu(self): if not GPU_TEST: self.skipTest("Not GPU test") @@ -112,12 +114,35 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): result = dist.call_for_each_tower(run_fn, dist.worker_device_index) - reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + result, + destinations="/device:CPU:0") unwrapped = dist.unwrap(reduced) self.assertEqual(1, len(unwrapped)) expected = sum(range(len(dist.worker_devices))) self.assertEqual(expected, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes() + def testReduceToMultipleDestinations(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + devices = ["/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + 1.0, + destinations=["/device:CPU:0", "/device:GPU:0"]) + unwrapped = dist.unwrap(reduced) + self.assertEqual(2, len(unwrapped)) + self.assertEqual(1.0, self.evaluate(unwrapped[0])) + class MirroredStrategyVariableCreationTest(test.TestCase): @@ -263,19 +288,69 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertIsInstance(bias, values.MirroredVariable) self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.variable(1.0, name="var0", aggregation=None) + with variable_scope.variable_scope("common"): + v1 = variable_scope.variable(1.0, name="var1") + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.variable( + 1.0, + name="var2", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.variable( + 1.0, + name="var3", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + v = variable_scope.variable(1.0, name="var-main0") + self.assertEquals("var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("var0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) + @test_util.run_in_graph_and_eager_modes(config=config) def testWithGetVariableAndVariableScope(self): self._skip_eager_if_gpus_less_than(1) def model_fn(): - v0 = variable_scope.get_variable("var-thread0", [1]) + v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): - v1 = variable_scope.get_variable("var-thread1", [1]) + v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) - v2 = variable_scope.get_variable("var-thread2", [1]) + v2 = variable_scope.get_variable( + "var2", [1], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.get_variable( + "var3", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) - return v0, v1, v2 + return v0, v1, v2, v3 devices = ["/device:CPU:0", "/device:GPU:0"] dist = mirrored_strategy.MirroredStrategy(devices) @@ -285,14 +360,89 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("main/var-main0:0", v.name) result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(3, len(result)) - v0, v1, v2 = result + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var-thread0:0", v0.name) + self.assertEquals("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var-thread1:0", v1.name) - self.assertIsInstance(v2, values.MirroredVariable) - self.assertEquals("main/common/var-thread2:0", v2.name) + self.assertEquals("main/common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("main/common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, + v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("main/common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNoneSynchronizationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNoneSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable synchronization mode: Invalid for " + "variable: v"): + variable_scope.variable(1.0, name="v", synchronization="Invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") @test_util.run_in_graph_and_eager_modes(config=config) def testThreeDevices(self): @@ -337,12 +487,16 @@ class MirroredStrategyVariableCreationTest(test.TestCase): all_v_sum = {} all_v_mean = {} + components_sum = {} + components_mean = {} def model_fn(device_id): tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope("sum"): + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.SUM): v_sum = variable_scope.variable(1.0) - with tower_context.tower_local_var_scope("mean"): + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.MEAN): v_mean = variable_scope.variable(4.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) @@ -350,21 +504,33 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v_mean.assign(6.0 * device_id)] all_v_sum[device_id] = v_sum all_v_mean[device_id] = v_mean - return updates, v_sum, v_mean + c_sum = v_sum.get() + c_mean = v_mean.get() + components_sum[device_id] = c_sum + components_mean[device_id] = c_mean + self.assertIsNot(v_sum, c_sum) + self.assertIsNot(v_mean, c_mean) + return updates, v_sum, v_mean, c_sum, c_mean dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): # Create "sum" and "mean" versions of TowerLocalVariables. - ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( - model_fn, dist.worker_device_index, run_concurrently=False) + ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( + dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False)) # Should see the same wrapping instance in all towers. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) - for i in range(1, dist.num_towers): - self.assertIs(all_v_sum[0], all_v_sum[1]) - self.assertIs(all_v_mean[0], all_v_mean[1]) + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Regroup should recover the same wrapper. + self.assertIs(ret_v_sum, regrouped_sum) + self.assertIs(ret_v_mean, regrouped_mean) + self.assertIsNot(components_sum[0], components_sum[1]) + self.assertIsNot(components_mean[0], components_mean[1]) # Apply updates self.evaluate(variables.global_variables_initializer()) @@ -385,14 +551,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Without get(device), should return the value you get by # applying the reduction across all towers (whether you use - # fetch(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + # read_var(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) - if not context.executing_eagerly(): - self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) - self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -438,6 +603,74 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + def testDynamicRnnVariables(self): def model_fn(): inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) @@ -462,6 +695,232 @@ class MirroredStrategyVariableCreationTest(test.TestCase): _, v1 = dist.unwrap(v) self.assertStartsWith(v1.name, "tower_1/") + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariableUpdate(self): + with context.graph_mode(): + + def model_fn(): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.SUM): + v_sum = variable_scope.variable(1.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + + def update(var, value): + return var.assign(value) + + with dist.scope(): + ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) + update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values before running the update ops. + self.assertEquals(1.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + self.evaluate(update_ops) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values after running the update ops. + self.assertEquals(5.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(10.0, self.evaluate(ret_v_sum)) + + +class MirroredVariableUpdateTest(test.TestCase): + # The following tests check assign, assign_add and assign_sub on Mirrored + # variables in tower and cross tower context. + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithoutAggregationType(self): + # Test that we always have an aggregation type set on the mirrored variable + # if we assign to it in tower mode. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "You must specify an aggregation method to update a " + "MirroredVariable in Tower Context."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSum(self): + # Test that we don't reduce a non-per-device value with the "sum" + # aggregation type. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "A non PerDevice value cannot be reduced with the given " + "aggregation."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) + self.assertEquals(6.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(0.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + self.assertEquals(7.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_add(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(1.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(5.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) + self.assertEquals(3.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_sub(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.5, self.evaluate(mirrored_var)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 61cbe6df813bb28bf8baa83d9e28ffafc4f0cbb8..a066adf1246ecd9ab8bd6a85be1f1e9be2c35b17 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -47,7 +47,7 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def testTowerId(self): self._test_tower_id(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 8277e1e7919e86ef616b31d0986589dcc9c49bbd..2892ce439494320a115b8eae0025a132841c4a8f 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -51,11 +52,11 @@ class MonitorTest(test.TestCase, parameterized.TestCase): self.assertEqual(1, len(layer.trainable_variables)) mirrored_weight_variable = layer.trainable_variables[0] - start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = self.evaluate(mirrored_weight_variable) start_error = abs(numpy.array(start_error) - 1) monitor.run_steps(9) - end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = self.evaluate(mirrored_weight_variable) end_error = abs(numpy.array(end_error) - 1) self.assertGreaterEqual(start_error, end_error) @@ -65,7 +66,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase): step_function, _ = single_loss_example( lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) - with self.test_session() as sess: + with session.Session() as sess, context.eager_mode(): with self.assertRaisesRegexp(ValueError, "Should not provide"): _ = monitor_lib.Monitor(step_function, sess) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py index a552b370ebf359464afcaf3211119e73434e0dfb..0f21a427320510635279f80c11711e81715ec37c 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -121,7 +121,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy): worker: [device_util.canonicalize(worker, '/device:CPU:0')] for worker in self._workers } - self._devices = nest.flatten(self._worker_device_map.values()) + self._devices = nest.flatten(self._worker_device_map) super(MultiWorkerMirroredStrategy, self).__init__( devices=self._devices, prefetch_on_device=prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 09b6d4a515ab46879520f304cd5ef60469512380..dbd3514aec7d40d9a04dba4bcbc5c14be639aa33 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import values from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib @@ -43,11 +44,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): - # No need to distinguish tower-local variables when not mirroring, - # we just enforce that they are not trainable. - if kwargs.pop("tower_local_reduce_method", None) is not None: - kwargs["trainable"] = False - colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: with ops.device(self._device): @@ -80,15 +76,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device): return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): if not isinstance(value, values.MapOutput): return value l = value.get() assert l with ops.device(self._device): - if method_string == "sum": + if aggregation == vs.VariableAggregation.SUM: return math_ops.add_n(l) - elif method_string == "mean": + elif aggregation == vs.VariableAggregation.MEAN: return math_ops.add_n(l) / len(l) else: assert False @@ -102,12 +98,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device), distribute_lib.UpdateContext(self._device): return fn(*args, **kwargs) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - with ops.device(self._device): - v = fn(val) - with ops.device(destination): - return array_ops.identity(v) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + return array_ops.identity(tower_local_var) def _unwrap(self, value): return [value] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7aad8a953cbedd30b48739416e74b3dc164dc4cd..4fdc0f72e6745b7ef25c591157955f214e0b2c79 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -44,7 +44,7 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testTowerId(self): self._test_tower_id(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index abd3a65ac4e19ece6b69b9834f4218fde55b60c2..a2d736e42271ab1627240949b99088ed3f0746f6 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -59,8 +59,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index 7b3670b45aba801cf8c18e04bfea03e23eb67184..24cdc627a35f4455cb92484566dc13fa1bbaf2cc 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -89,6 +89,9 @@ class _PrefetchToDeviceIterator(object): with ops.device(device): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_prefetch_fn, + output_types=data_nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)), target_device=target_device, string_arg=input_iterator_handle, buffer_size=buffer_size, diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index a0b452fc2d445d1cf7dbf5e8fe0e29edef516207..2a9ab51fcfd29a8ae5b37b5c513415af29b277dc 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -46,7 +46,7 @@ class CanonicalizeVariableNameTest(test.TestCase): class SharedVariableCreatorTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSharedVariable(self): shared_variable_store = {} diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 75c5ec9659d193e77d219ba79977615d58841d64..2ee94d8f70868c07ca217dd4d433585458efa8d8 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,8 +50,8 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 2b4ad9f146bc1d6a987fbeecbb05122946137154..baed0ebaae8a3f41c55f309d28203b363336dd16 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer @@ -106,13 +107,14 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list for i in range(10): @@ -159,12 +161,13 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase): with d.scope(): map_in = [constant_op.constant(i) for i in range(10)] map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.fetch(d.reduce("sum", map_out)) + observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out) expected = 90 # 2 * (0 + 1 + ... + 9) self.assertEqual(expected, observed.numpy()) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 75441786a615fc0d87b4c4b0b45b9384d678c1d3..bc53898539d76320e331784f9a717be9491365e1 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,104 +21,126 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - from tensorflow.contrib import tpu from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import nest class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, - num_cores_per_host=2, - iterations_per_step=2): + def __init__(self, num_cores_per_host=2): # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. super(TPUStrategy, self).__init__('/cpu:0') # TODO(isaprykin): Auto-detect number of cores and hosts. self._num_cores_per_host = num_cores_per_host - # TODO(isaprykin): This might have to be per-call. - self._iterations_per_step = iterations_per_step + # TODO(priyag): This should not be hardcoded here. + self._host = '/task:0/device:CPU:0' def distribute_dataset(self, dataset_fn): - return values.PerIterationDataset( - self._call_dataset_fn(dataset_fn), self._iterations_per_step, - self._num_cores_per_host) - - def _call_for_each_tower(self, fn, *args, **kwargs): - kwargs.pop('run_concurrently', None) - - inputs = {'args': args, 'kwargs': kwargs} - flat_inputs = nest.flatten(inputs) - - feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] - - feeds = lambda: itertools.compress(flat_inputs, feed_mask) - shapes = [f.get_shape() for f in feeds()] + # TODO(priyag): Perhaps distribute across cores here. + return self._call_dataset_fn(dataset_fn) + + # TODO(priyag): Deal with OutOfRange errors. + # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have + # a mechanism to infer the outputs of `fn`. Pending b/110550782. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + # Enqueue ops + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') - types = [f.get_dtype() for f in feeds()] - - def infeed_input(i): - """Get input, split it and then enqueue.""" - iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] - for inputs_per_core in iteration_inputs] - for core_id in range(self._num_cores_per_host)] - - infeed_ops = [] - for core_id, infeed_input in enumerate(infeed_inputs): - infeed_ops.append( + types = nest.flatten(iterator.output_types) + + def enqueue_ops_fn(): + """Enqueue ops for one iteration.""" + control_deps = [] + sharded_inputs = [] + with ops.device(self._host): + for _ in range(self._num_cores_per_host): + # Use control dependencies to ensure a deterministic ordering. + with ops.control_dependencies(control_deps): + inputs = nest.flatten(iterator.get_next()) + control_deps.extend(inputs) + sharded_inputs.append(inputs) + + enqueue_ops = [] + for core_id, shard_input in enumerate(sharded_inputs): + enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( - inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) + inputs=shard_input, shapes=shapes, device_ordinal=core_id)) + return enqueue_ops - with ops.control_dependencies(infeed_ops): + def enqueue_ops_loop_body(i): + with ops.control_dependencies(enqueue_ops_fn()): return i + 1 - with ops.device('/task:0/device:CPU:0'): + with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( - lambda i: i < self._iterations_per_step, - infeed_input, [constant_op.constant(0)], + lambda i: i < iterations, + enqueue_ops_loop_body, + [constant_op.constant(0)], parallel_iterations=1) - def dequeueing_fn(*args, **kwargs): - """Dequeue input arguments and supply them to `fn`.""" - del args, kwargs + # Dequeue ops + def dequeue_fn(): dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - dequeued = iter(dequeued) + return nest.pack_sequence_as(iterator.output_shapes, dequeued) - fn_inputs = [] - for inp, is_feed in zip(flat_inputs, feed_mask): - if is_feed: - fn_inputs.append(next(dequeued)) - else: - fn_inputs.append(inp) + # Wrap `fn` for repeat. + if initial_loop_values is None: + initial_loop_values = [] + ctx = values.MultiStepContext(initial_loop_values) + def run_fn(*args, **kwargs): + del args, kwargs + fn_result = fn(ctx, dequeue_fn()) + if ctx.last_step_outputs is None: + ctx.last_step_outputs = [] + with ops.control_dependencies([fn_result]): + return array_ops.identity(ctx.last_step_outputs) + + # Repeat + # TODO(sourabhbajaj): The input to while loop should be based on the output + # type of the step_fn + def iterate_on_tpu(): + return tpu.repeat(iterations, run_fn, [initial_loop_values]) - fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) - return fn(*fn_inputs['args'], **fn_inputs['kwargs']) + # Re-write and distribute computation. + # TODO(sourabhbajaj): Convert the output to PerDevice variable and + # implement support for that in reduce. + last_step_tensor_outputs = tpu.batch_parallel( + iterate_on_tpu, [], num_shards=self._num_cores_per_host) - def iterate_on_tpu(): - return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) + # Take index [0] of last_step_tensor_outputs as we wrapped + # initial_loop_values in a list in the `repeat` call. + return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), + last_step_tensor_outputs[0], ctx) + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access - tpu_result = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) + return fn(*args, **kwargs) + + def get_initialization_ops(self): + return [tpu.initialize_system()] - return control_flow_ops.group(tpu_result, enqueue_ops) + def get_finalize_ops(self): + return [tpu.shutdown_system()] - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. - if method_string == 'mean': + if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_cores_per_host) return tpu_ops.cross_replica_sum(value) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 49b4e24daa4ffe417712bc854aa29995d5afc408..b36ac563d29fc9157873796a845fefba3651edda 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -23,10 +23,8 @@ from __future__ import print_function import collections import weakref - import six -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context @@ -35,6 +33,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver @@ -43,7 +43,7 @@ from tensorflow.python.util import nest # pylint: disable=line-too-long -# TODO(josh11b): Should device values be strings or DeviceSpec objects +# TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): """Holds a map from device to values. Either PerDevice or Mirrored.""" @@ -65,9 +65,10 @@ class DistributedValues(object): device = device_util.canonicalize(device) try: return self._index[device] - except KeyError: - raise ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())) + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) def on_device(self, device): device = device_util.canonicalize(device) @@ -162,9 +163,16 @@ class PerDevice(DistributedValues): pass -class Mirrored(DistributedValues): +# Note that unlike PerDevice, Mirrored values inherit from +# DistributedDelegate and so can be used directly in cross-tower mode. +class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" - pass + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return list(self._index.values())[0] def _assign_on_device(device, variable, tensor): @@ -185,6 +193,10 @@ class DistributedVariable(DistributedDelegate): # Child class must set self._primary_var before calling # super(...).__init__(index). self._common_name = self._primary_var.name.split(":")[0] + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._distributed_container = weakref.ref(self) # pylint: disable=protected-access super(DistributedVariable, self).__init__(index) @property @@ -237,35 +249,9 @@ class DistributedVariable(DistributedDelegate): pass -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion(var, dtype=None, name=None, as_ref=False): - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.update() call. - assert not as_ref - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) ops.register_dense_tensor_like_type(DistributedVariable) -class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): - """Class for defining how to restore a MirroredVariable.""" - - def __init__(self, mirrored_variable, primary_variable, name): - self._mirrored_variable = mirrored_variable - super(_MirroredSaveable, self).__init__(primary_variable, "", name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - tensor, = restored_tensors - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access - - def _get_update_device(): """Validate we are in update/update_non_slot() and return current device. @@ -286,34 +272,85 @@ def _get_update_device(): return device +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + class MirroredVariable(DistributedVariable, Mirrored, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are kept in sync.""" - def __init__(self, index, primary_var): + def __init__(self, index, primary_var, aggregation): # Use a weakref to make it easy to map from the contained values # to the container without introducing a reference cycle. for v in six.itervalues(index): v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var + self._aggregation = aggregation super(MirroredVariable, self).__init__(index) - # We use _get_update_device() for the assign* methods to enforce - # that we are in an update() function. The arguments to update() are - # automatically unwrapped so the update() function would normally - # see regular variables, not MirroredVariables. However, the update - # function can still operate on wrapped MirroredVariables through - # object members, captured arguments, etc. This is more likely in an + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + f = kwargs.pop("f") + if distribute_lib.get_cross_tower_context(): + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self.get(device=update_device) + return f(v, *args, **kwargs) + + return distribute_lib.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "MirroredVariable in Tower Context.") + + def merge_fn(strategy, value): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self)) + + return distribute_lib.get_tower_context().merge_call(merge_fn, *args, + **kwargs) + def assign_sub(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + return self._assign_func(f=state_ops.assign_sub, *args, **kwargs) def assign_add(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + return self._assign_func(f=state_ops.assign_add, *args, **kwargs) def assign(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign(*args, **kwargs) + return self._assign_func(f=state_ops.assign, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation def _get_cross_tower(self): device = device_util.canonicalize(device_util.current()) @@ -341,6 +378,20 @@ class MirroredVariable(DistributedVariable, Mirrored, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(MirroredVariable, + _tensor_conversion_mirrored) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" @@ -349,7 +400,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().fetch( + return distribute_lib.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -364,7 +415,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. - if self._tower_local_variable.reduce_method == "sum": + if self._tower_local_variable.aggregation == vs.VariableAggregation.SUM: tensor *= 1. / len(self._tower_local_variable.devices) return control_flow_ops.group([ _assign_on_device(d, v, tensor) @@ -381,9 +432,9 @@ class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" - def __init__(self, index, primary_var, reduce_method): + def __init__(self, index, primary_var, aggregation): self._primary_var = primary_var - self._reduce_method = reduce_method + self._aggregation = aggregation super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): @@ -399,14 +450,14 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign(*args, **kwargs) @property - def reduce_method(self): - return self._reduce_method + def aggregation(self): + return self._aggregation def _get_cross_tower(self): all_components = tuple(self._index.values()) # TODO(josh11b): Use a strategy-specific method. total = math_ops.add_n(all_components) - if self._reduce_method == "mean": + if self._aggregation == vs.VariableAggregation.MEAN: return total * (1./ len(all_components)) return total @@ -430,6 +481,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function for TowerLocalVariable which allows as_ref to +# be true. +def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(TowerLocalVariable, + _tensor_conversion_tower_local) + + def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) @@ -477,40 +539,40 @@ def regroup(per_device, wrap_class=PerDevice): same_id = False break # Consider three cases where same_id is true: - # * If v0 is a MirroredVariable (and same_id means it is the same - # across all devices), we want to return it. We check - # MirroredVariable specifically since it can look like it - # has a _mirrored_container member since its members do. - # * If v0 is a member of a mirrored variable, in which case - # hasattr(v0, "_mirrored_container") is true, we want to - # return the MirroredVariable that contains it using the - # _mirrored_container logic below. This case can trigger + # * If v0 is a DistributedVariable (a MirroredVariable or + # TowerLocalVariable, and same_id means it is the same across all + # devices), we want to return it. We check DistributedVariable + # specifically since it can look like it has a + # _distributed_container member since its members do. + # * If v0 is a member of a distributed variable, in which case + # hasattr(v0, "_distributed_container") is true, we want to + # return the DistributedVariable that contains it using the + # _distributed_container logic below. This case can trigger # same_id when there is only one device. # * In any other situation, same_id means we return v0. - if same_id and (isinstance(v0, MirroredVariable) or - not hasattr(v0, "_mirrored_container")): + if same_id and (isinstance(v0, DistributedVariable) or + not hasattr(v0, "_distributed_container")): return v0 # Detect the case where each device has a parallel component of the - # same MirroredVariable. In this case we want to return the - # containing MirroredVariable, after a bunch of sanity checking. - # In particular, each component should have the same container, - # and the devices of the variables should match the keys of the - # per-device dictionary. - # TODO(josh11b): Do we need similar logic for TowerLocalVariables? - if hasattr(v0, "_mirrored_container"): + # same MirroredVariable (or TowerLocalVariable). In this case we + # want to return the containing MirroredVariable, after a bunch of + # sanity checking. In particular, each component should have the + # same container, and the devices of the variables should match the + # keys of the per-device dictionary. + if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) assert _devices_match(v0.device, items[0][0]), ( "v0.device = %s, items = %s" % (v0.device, items)) - mirrored_container = v0._mirrored_container() - assert mirrored_container is not None + distributed_container = v0._distributed_container() + assert distributed_container is not None for d, v in items[1:]: assert _devices_match(v.device, d), ( "v.device = %s, d = %s, items = %s" % (v.device, d, items)) - assert mirrored_container is v._mirrored_container() - return mirrored_container + assert distributed_container is v._distributed_container() + return distributed_container # pylint: enable=protected-access return wrap_class(per_device) @@ -592,8 +654,7 @@ class PerDeviceDataset(object): # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. - self._dataset = dataset.apply( - batching.batch_and_drop_remainder(len(devices))) + self._dataset = dataset.batch(len(devices), drop_remainder=True) def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" @@ -804,3 +865,71 @@ class MapOutput(object): def get(self): return self._l + + +class MultiStepContext(object): + """A context object that can be used to capture things when running steps. + + This context object is useful when running multiple steps at a time using the + `run_steps_on_dataset` API. For e.g. it allows the user's step function to + specify which outputs to emit at what frequency. Currently it only supports + capturing output from the last step, but will soon be augmented to support + other use cases such as output each N steps. + """ + + def __init__(self, initial_loop_values=None): + """Initializes an output context. + + Args: + initial_loop_values: Initial values passed to the run steps + while loop. The only purpose is to verify the shapes and types + when the actual output is set. This will be removed once we + automatically infer the output shapes and types (and do not need to + check for user error in specifying them manually). + Returns: + A context object. + """ + self._last_step_outputs = None + self._non_tensor_outputs = None + self._initial_loop_values = initial_loop_values + + @property + def last_step_outputs(self): + """Return the last step's outputs.""" + return self._last_step_outputs + + @last_step_outputs.setter + def last_step_outputs(self, outputs): + """Set the last step's outputs.""" + self._verify_structure_shapes_types(outputs, self._initial_loop_values) + self._last_step_outputs = outputs + + @property + def non_tensor_outputs(self): + """Return the non tensor outputs.""" + return self._non_tensor_outputs + + @non_tensor_outputs.setter + def non_tensor_outputs(self, outputs): + """Set any non tensor outputs.""" + self._non_tensor_outputs = outputs + + def _verify_structure_shapes_types(self, left, right): + """Verify that the structure, shapes and types of left are same as right.""" + nest.assert_same_structure(left, right) + flat_left = nest.flatten(left) + flat_right = nest.flatten(right) + assert len(flat_left) == len(flat_right), ( + "Length of left {} and right {} should be same.". + format(len(flat_left), len(flat_right))) + + for o, i in zip(flat_left, flat_right): + # TODO(priyag): Add checks for other types like IndexedSlices. + if isinstance(o, ops.Tensor): + assert isinstance(i, ops.Tensor) + assert o.shape == i.shape, ( + "Shape {} of left {} doesn't match shape {} of right {}.". + format(o.shape, o, i.shape, i)) + assert o.dtype == i.dtype, ( + "Dtype {} of left {} doesn't match dtype {} of right {}.". + format(o.dtype, o, i.dtype, i)) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 1c95758d96aba47e9581dde6411763e98b99a968..8e44f2fea16ac851c124b573948ee14ec0640556 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -82,7 +82,7 @@ class DistributedValuesTest(test.TestCase): class DistributedDelegateTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetAttr(self): with ops.device("/device:CPU:0"): @@ -97,7 +97,7 @@ class DistributedDelegateTest(test.TestCase): with self.assertRaises(AttributeError): _ = v.y - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): with ops.device("/device:CPU:0"): v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) @@ -158,7 +158,8 @@ def _make_mirrored(): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) index[d] = v[-1] - mirrored = values.MirroredVariable(index, v[0]) + mirrored = values.MirroredVariable(index, v[0], + variable_scope.VariableAggregation.SUM) return v, devices, mirrored @@ -277,7 +278,8 @@ class RegroupAndSelectDeviceTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) index = {d: v} - mirrored = values.MirroredVariable(index, v) + mirrored = values.MirroredVariable(index, v, + variable_scope.VariableAggregation.SUM) result = values.regroup(index) self.assertIs(mirrored, result) @@ -363,7 +365,7 @@ class PerDeviceDatasetTest(test.TestCase): self._test_iterator_no_prefetch(devices, dataset, expected_values) self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] dataset = dataset_ops.Dataset.range(10) @@ -581,7 +583,8 @@ class MirroredVariableTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - mirrored = values.MirroredVariable(index, v) + mirrored = values.MirroredVariable(index, v, + variable_scope.VariableAggregation.MEAN) self.assertEquals(v.name, mirrored.name) self.assertEquals(v.dtype, mirrored.dtype) @@ -716,7 +719,9 @@ class MirroredVariableTest(test.TestCase): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + mirrored = values.MirroredVariable({ + "/device:GPU:0": v + }, v, variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) @@ -746,24 +751,27 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) self.assertEquals(v[0].name, tower_local.name) self.assertEquals(v[0].dtype, tower_local.dtype) self.assertEquals(v[0].shape, tower_local.shape) - self.assertEquals("sum", tower_local.reduce_method) + self.assertEquals(variable_scope.VariableAggregation.SUM, + tower_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - tower_local = values.TowerLocalVariable(index, v, "mean") + tower_local = values.TowerLocalVariable( + index, v, variable_scope.VariableAggregation.MEAN) self.assertEquals(v.name, tower_local.name) self.assertEquals(v.dtype, tower_local.dtype) self.assertEquals(v.shape, tower_local.shape) - self.assertEquals("mean", tower_local.reduce_method) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + tower_local.aggregation) def _assign_tower_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -789,7 +797,7 @@ class TowerLocalVariableTest(test.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.test_session() as sess: - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -812,7 +820,8 @@ class TowerLocalVariableTest(test.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.test_session() as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -831,7 +840,8 @@ class TowerLocalVariableTest(test.TestCase): def _save_tower_local_mean(self): """Save variables with mirroring, returns save_path.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -893,7 +903,8 @@ class TowerLocalVariableTest(test.TestCase): def _restore_tower_local_mean(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [7., 8.]) @@ -907,7 +918,7 @@ class TowerLocalVariableTest(test.TestCase): def _restore_tower_local_sum(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_tower_local(_devices, v, [7., 8.]) @@ -966,6 +977,18 @@ class TowerLocalVariableTest(test.TestCase): save_path = self._save_normal() self._restore_tower_local_sum(save_path) + def testTensorConversion(self): + with context.graph_mode(): + _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) + converted = ops.internal_convert_to_tensor(tower_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + + converted = ops.internal_convert_to_tensor(tower_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 6192f04c8b695d124b498850ad430823b44fd472..ad00d1734dd14ed846522a33d888a5387cb25cc6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -16,6 +16,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "bijectors_py", srcs = glob(["python/ops/bijectors/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/linalg:linalg_py", @@ -42,6 +49,13 @@ py_library( py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ ":bijectors_py", @@ -940,6 +954,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "gumbel_test", size = "small", @@ -1032,6 +1065,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "real_nvp_test", size = "small", @@ -1099,6 +1151,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "scale_tril_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", @@ -1216,6 +1287,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "transform_diagonal_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "weibull_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index ddf59891e626a85e6c917ac74b3cfaabf16eb15d..5cec93c4df2e970f203253be6342bb292f296eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. - -See the @{$python/contrib.distributions} guide. """ from __future__ import absolute_import from __future__ import division @@ -32,6 +30,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse @@ -156,6 +155,7 @@ _allowed_symbols = [ 'kl_divergence', 'RegisterKL', 'fill_triangular', + 'fill_triangular_inverse', 'matrix_diag_transform', 'reduce_weighted_logsumexp', 'softplus_inverse', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index e281e81bdf0698c1f7b2f60fb27783dd1351773f..d1ce273499c8a646c0757844c91a785fa8d56ce4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -61,6 +61,28 @@ class CholeskyOuterProductBijectorTest(test.TestCase): atol=0., rtol=1e-7) + def testNoBatchStaticJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2))) + + def testNoBatchDynamicJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + x_pl = array_ops.placeholder(dtypes.float32) + + with self.test_session(): + log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2) + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + log_det_jacobian.eval({x_pl: x})) + def testNoBatchStatic(self): x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 8b279ebcd908b6f375b35594ac5f3db9228a1e31..f8a52615b0f3f5ad0c7e01e0f76c7d7a6b455ef7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase): for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1., event_ndims=0., arg1="b1", arg2="b2") + method(1., event_ndims=0, arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3530e142e4d1545e80a3b1bf1e8ddbf7819ba58a --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..85d604e34ac25cf94b601470b7f166d9d414a8e3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py index a5f5219588fb3be67beb797ba68ed8148e9e9fd2..cb42331a21a6acdd5244c311a7def5359bb6c574 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -36,7 +36,7 @@ class OrderedBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorVector(self): with self.test_session(): ordered = Ordered() @@ -82,7 +82,7 @@ class OrderedBijectorTest(test.TestCase): atol=0., rtol=1e-7) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapeGetters(self): with self.test_session(): x = tensor_shape.TensorShape([4]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b3367f9a31a9c602e0b138e617db68834b8229 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ScaleTriLBijectorTest(test.TestCase): + """Tests the correctness of the ScaleTriL bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testComputesCorrectValues(self): + shift = 1.61803398875 + x = np.float32(np.array([-1, .5, 2])) + y = np.float32(np.array([[np.exp(2) + shift, 0.], + [.5, np.exp(-1) + shift]])) + + b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(), + diag_shift=shift) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + @test_util.run_in_graph_and_eager_modes + def testInvertible(self): + + # Generate random inputs from an unconstrained space, with + # event size 6 to specify 3x3 triangular matrices. + batch_shape = [2, 1] + x = np.float32(np.random.randn(*(batch_shape + [6]))) + b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), + diag_shift=3.14159) + y = self.evaluate(b.forward(x)) + self.assertAllEqual(y.shape, batch_shape + [3, 3]) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(fldj, -ildj) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 45760a29ee42835da69ef63803ccec7ce82a5a8f..795f1993ba5c31bf5a26333f31f1bc73125bff07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.) self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.) - # Do the numpy calculation in float128 to avoid inf/nan. - y_float128 = np.float128(y) - self.assertAllClose( - np.log(np.cosh( - np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( - y_float128**2 + 1)) - - np.log(tailweight), - bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), - rtol=1e-4, - atol=0.) + # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision. + # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and + # below test fails due to overflow error giving inf. So this check avoids that error by skipping square + # calculation and corresponding assert. + + if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \ + np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)): + + # Do the numpy calculation in float128 to avoid inf/nan. + y_float128 = np.float128(y) + self.assertAllClose( + np.log(np.cosh( + np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( + y_float128**2 + 1)) - + np.log(tailweight), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + rtol=1e-4, + atol=0.) self.assertAllClose( -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index 2ac06fce55b448a5f3da7ccb7f8766b5b1404ad7..d0098c3c105626da1da5855710169069ebeffbd9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -40,7 +40,7 @@ class SoftsignBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorBounds(self): bijector = Softsign(validate_args=True) with self.test_session(): @@ -54,7 +54,7 @@ class SoftsignBijectorTest(test.TestCase): with self.assertRaisesOpError("less than 1"): bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverse(self): bijector = Softsign(validate_args=True) self.assertEqual("softsign", bijector.name) @@ -64,7 +64,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorLogDetJacobianEventDimsZero(self): bijector = Softsign(validate_args=True) y = self._rng.rand(2, 10) @@ -74,7 +74,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(ildj, self.evaluate( bijector.inverse_log_det_jacobian(y, event_ndims=0))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverseEventDimsOne(self): bijector = Softsign(validate_args=True) self.assertEqual("softsign", bijector.name) @@ -83,7 +83,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorLogDetJacobianEventDimsOne(self): bijector = Softsign(validate_args=True) y = self._rng.rand(2, 10) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..efc9f266d1fb6bcc53ae318e218b0697825c0155 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TransformDiagonalBijectorTest(test.TestCase): + """Tests correctness of the TransformDiagonal bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes + def testBijector(self): + x = np.float32(np.random.randn(3, 4, 4)) + + y = x.copy() + for i in range(x.shape[0]): + np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :]))) + + exp = bijectors.Exp() + b = bijectors.TransformDiagonal(diag_bijector=exp) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllEqual( + fldj, + self.evaluate(exp.forward_log_det_jacobian( + np.array([np.diag(x_mat) for x_mat in x]), + event_ndims=1))) + self.assertAllEqual( + ildj, + self.evaluate(exp.inverse_log_det_jacobian( + np.array([np.diag(y_mat) for y_mat in y]), + event_ndims=1))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 31d24aa9ea09007b8db40e4869371b1f62639ac7..181c46d2e52552e641bc59c0fe94743f1af42845 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag @@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +class TestMoveDimension(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_move_dimension_static_shape(self): + + x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) + + @test_util.run_in_graph_and_eager_modes + def test_move_dimension_dynamic_shape(self): + + x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + x = array_ops.placeholder_with_default(input=x_, shape=None) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + x_perm = distribution_util.move_dimension(x, -1, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 968057331787059240110b90545f70c0ab128aa8..b91a610acf1a9094d612504d63030b3bffb873ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py index ce6cf702d522792f1ad26066a3d9be42003a0e3c..9c4dfed83631e9f0815fb674d650cac2e570b923 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -98,23 +98,21 @@ class StatisticalTestingTest(test.TestCase): num_samples = 5000 # 5000 samples is chosen to be enough to find discrepancies of # size 0.1 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.1) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.1) # Test that the confidence interval computed for the mean includes # 0.5 and excludes 0.4 and 0.6. - with self.test_session() as sess: - samples = rng.uniform(size=num_samples).astype(np.float32) - (low, high) = st.true_mean_confidence_interval_by_dkwm( - samples, 0., 1., error_rate=1e-6) - low, high = sess.run([low, high]) - self.assertGreater(low, 0.4) - self.assertLess(low, 0.5) - self.assertGreater(high, 0.5) - self.assertLess(high, 0.6) + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = self.evaluate([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) def test_dkwm_mean_one_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -123,21 +121,45 @@ class StatisticalTestingTest(test.TestCase): # Test that the test assertion agrees that the mean of the standard # uniform distribution is 0.5. samples = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.5, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.4. - with self.assertRaisesOpError("Mean confidence interval too high"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.4, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.6. - with self.assertRaisesOpError("Mean confidence interval too low"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.6, false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_in_interval_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is between 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.2 and 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.6 and 0.8. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -145,20 +167,18 @@ class StatisticalTestingTest(test.TestCase): # 4000 samples is chosen to be enough to find discrepancies of # size 0.2 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( - num_samples, 0., 1., num_samples, 0., 1., - false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.2) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.2) # Test that the test assertion agrees that the standard # uniform distribution has the same mean as itself. samples1 = rng.uniform(size=num_samples).astype(np.float32) samples2 = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): rng = np.random.RandomState(seed=0) @@ -168,15 +188,14 @@ class StatisticalTestingTest(test.TestCase): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(2, 1). - beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples1 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_high_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): rng = np.random.RandomState(seed=0) @@ -186,15 +205,14 @@ class StatisticalTestingTest(test.TestCase): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(1, 2). - beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples2 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_low_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_argument_validity_checking(self): rng = np.random.RandomState(seed=0) @@ -203,18 +221,17 @@ class StatisticalTestingTest(test.TestCase): # Test that the test library complains if the given samples fall # outside the purported bounds. - with self.test_session() as sess: - with self.assertRaisesOpError("maximum value exceeds expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) - with self.assertRaisesOpError("minimum value falls below expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) - - # But doesn't complain if they don't. - op = st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) - _ = sess.run(op) + with self.assertRaisesOpError("maximum value exceeds expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) + with self.assertRaisesOpError("minimum value falls below expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) + _ = self.evaluate(op) def test_do_maximum_mean(self): n = 117 @@ -223,10 +240,9 @@ class StatisticalTestingTest(test.TestCase): samples = rng.uniform(size=n).astype(np.float32) # Compute the answer in TF using the code under test - with self.test_session() as sess: - envelope_t = ops.convert_to_tensor(envelope) - max_mean = st._do_maximum_mean(samples, envelope_t, 1) - max_mean = sess.run(max_mean) + envelope_t = ops.convert_to_tensor(envelope) + max_mean = st._do_maximum_mean(samples, envelope_t, 1) + max_mean = self.evaluate(max_mean) # Compute the correct answer for this case in numpy. In this # example, `n` and `envelope` are such that `samples[2]` is the diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..42ecea034d77430924bd6f597bf42ec3f64fec92 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -0,0 +1,51 @@ +# Description: +# Internal testing utilities, e.g., computing the correct answer to +# put in a unit test. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "correlation_matrix_volumes_py", + srcs = [ + "correlation_matrix_volumes_lib.py", + ], + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "correlation_matrix_volumes", + srcs = [ + "correlation_matrix_volumes.py", + ], + deps = [ + ":correlation_matrix_volumes_py", + ], +) + +py_test( + name = "correlation_matrix_volumes_test", + size = "medium", + srcs = ["correlation_matrix_volumes_test.py"], + tags = [ + "no_pip", + "optonly", + ], + deps = [ + ":correlation_matrix_volumes_py", + # For statistical testing + "//tensorflow/contrib/distributions:distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py new file mode 100644 index 0000000000000000000000000000000000000000..2eab51cd3053ea55f2e03619fd002fbf48251fb1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Executable to estimate the volume of various sets of correlation matrices. + +See correlation_matrix_volumes_lib.py for purpose and methodology. + +Invocation example: +``` +python correlation_matrix_volumes.py --num_samples 1e7 +``` + +This will compute 10,000,000-sample confidence intervals for the +volumes of several sets of correlation matrices. Which sets, and the +desired statistical significance, are hard-coded in this source file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint + +from absl import app +from absl import flags + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr + +FLAGS = flags.FLAGS + +# Float to support giving the number of samples in scientific notation. +# The production run used for the LKJ test used 1e7 samples. +flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.') + + +def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + # This wrapper undoes the batching in compute_true_volumes, because + # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM. + bounds = {} + for db in det_bounds: + bounds[db] = corr.compute_true_volumes( + [db], dim, num_samples, error_rate=error_rate, seed=seed)[db] + return bounds + + +# The particular bounds in all three of these functions were chosen by +# a somewhat arbitrary walk through an empirical tradeoff, for the +# purpose of testing the LKJ distribution. Setting the determinant +# bound lower +# - Covers more of the testee's sample space, and +# - Increases the probability that the rejection sampler will hit, thus +# - Decreases the relative error (at a fixed sample count) in the +# rejection-based volume estimate; +# but also +# - Increases the variance of the estimator used in the LKJ test. +# This latter variance is also affected by the dimension and the +# tested concentration parameter, and can be compensated for with more +# compute (expensive) or a looser discrepancy limit (unsatisfying). +# The values here are the projection of the points in that test design +# space that ended up getting chosen. +def compute_3x3_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 3, num_samples, error_rate=5e-7, seed=46) + + +def compute_4x4_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 4, num_samples, error_rate=5e-7, seed=47) + + +def compute_5x5_volumes(num_samples): + det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4] + return ctv_debatched( + det_bounds, 5, num_samples, error_rate=5e-7, seed=48) + + +def main(_): + full_bounds = {} + full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples)) + full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples)) + full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples)) + pprint.pprint(full_bounds) + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..455e71f00c96e799c4aaae25050c77a9ae36df06 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimating the volume of the correlation matrices with bounded determinant. + +Why? Because lkj_test.py tests the sampler for the LKJ distribution +by estimating the same volume another way. + +How? Rejection sampling. Or, more precisely, importance sampling, +proposing from the uniform distribution on symmetric matrices with +diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation +matrix if and only if it is also positive semi-definite. + +The samples can then be converted into a confidence interval on the +volume in question by the [Clopper-Pearson +method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval), +also implemented here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.ops.distributions import util +from tensorflow.python.platform import tf_logging + +__all__ = [ + "correlation_matrix_volume_rejection_samples", + "compute_true_volumes", +] + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +optimize = try_import("scipy.optimize") +stats = try_import("scipy.stats") + + +def _psd_mask(x): + """Computes whether each square matrix in the input is positive semi-definite. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + + Returns: + mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each + scalar is 1 if the corresponding matrix was PSD, otherwise 0. + """ + # Allegedly + # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite + # it is more efficient to test for positive semi-definiteness by + # trying to compute the Cholesky decomposition -- the matrix is PSD + # if you succeed and not PSD if you fail. However, TensorFlow's + # Cholesky raises an exception if _any_ of the input matrices are + # not PSD, from which I don't know how to extract _which ones_, so I + # proceed by explicitly computing all the eigenvalues and checking + # whether they are all positive or not. + # + # Also, as was discussed in the answer, it is somewhat dangerous to + # treat SPD-ness as binary in floating-point arithmetic. Cholesky + # factorization can complete and 'look' like everything is fine + # (e.g., O(1) entries and a diagonal of all ones) but the matrix can + # have an exponential condition number. + eigenvalues, _ = linalg_ops.self_adjoint_eig(x) + return math_ops.cast( + math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype) + + +def _det_large_enough_mask(x, det_bounds): + """Returns whether the input matches the given determinant limit. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + det_bounds: A floating-point `Tensor` that must broadcast to shape + `[B1, ..., Bn]`, giving the desired lower bound on the + determinants in `x`. + + Returns: + mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each + scalar is 1 if the corresponding matrix had determinant above + the corresponding bound, otherwise 0. + """ + # For the curious: I wonder whether it is possible and desirable to + # use a Cholesky decomposition-based algorithm for this, since the + # only matrices whose determinant this code cares about will be PSD. + # Didn't figure out how to code that in TensorFlow. + # + # Expert opinion is that it would be about twice as fast since + # Cholesky is roughly half the cost of Gaussian Elimination with + # Partial Pivoting. But this is less of an impact than the switch in + # _psd_mask. + return math_ops.cast( + linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) + + +def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): + """Returns a uniformly random `Tensor` of "correlation-like" matrices. + + A "correlation-like" matrix is a symmetric square matrix with all entries + between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, + the ones that are positive semi-definite are exactly the correlation + matrices. + + Args: + num_rows: Python `int` dimension of the correlation-like matrices. + batch_shape: `Tensor` or Python `tuple` of `int` shape of the + batch to return. + dtype: `dtype` of the `Tensor` to return. + seed: Random seed. + + Returns: + matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` + and dtype `dtype`. Each entry is in [-1, 1], and each matrix + along the bottom two dimensions is symmetric and has 1s on the + main diagonal. + """ + num_entries = num_rows * (num_rows + 1) / 2 + ones = array_ops.ones(shape=[num_entries], dtype=dtype) + # It seems wasteful to generate random values for the diagonal since + # I am going to throw them away, but `fill_triangular` fills the + # diagonal, so I probably need them. + # It's not impossible that it would be more efficient to just fill + # the whole matrix with random values instead of messing with + # `fill_triangular`. Then would need to filter almost half out with + # `matrix_band_part`. + unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) + tril = util.fill_triangular(unifs) + symmetric = tril + array_ops.matrix_transpose(tril) + diagonal_ones = array_ops.ones( + shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), + dtype=dtype) + return array_ops.matrix_set_diag(symmetric, diagonal_ones) + + +def correlation_matrix_volume_rejection_samples( + det_bounds, dim, sample_shape, dtype, seed): + """Returns rejection samples from trying to get good correlation matrices. + + The proposal being rejected from is the uniform distribution on + "correlation-like" matrices. We say a matrix is "correlation-like" + if it is a symmetric square matrix with all entries between -1 and 1 + (inclusive) and 1s on the main diagonal. Of these, the ones that + are positive semi-definite are exactly the correlation matrices. + + The rejection algorithm, then, is to sample a `Tensor` of + `sample_shape` correlation-like matrices of dimensions `dim` by + `dim`, and check each one for (i) being a correlation matrix (i.e., + PSD), and (ii) having determinant at least the corresponding entry + of `det_bounds`. + + Args: + det_bounds: A `Tensor` of lower bounds on the determinants of + acceptable matrices. The shape must broadcast with `sample_shape`. + dim: A Python `int` dimension of correlation matrices to sample. + sample_shape: Python `tuple` of `int` shape of the samples to + compute, excluding the two matrix dimensions. + dtype: The `dtype` in which to do the computation. + seed: Random seed. + + Returns: + weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the + corresponding matrix was not a correlation matrix, or had too + small of a determinant. Otherwise, the entry is the + multiplicative inverse of the density of proposing that matrix + uniformly, i.e., the volume of the set of `dim` by `dim` + correlation-like matrices. + volume: The volume of the set of `dim` by `dim` correlation-like + matrices. + """ + with ops.name_scope("rejection_sampler"): + rej_proposals = _uniform_correlation_like_matrix( + dim, sample_shape, dtype, seed=seed) + rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) + # The density of proposing any given point is 1 / rej_proposal_volume; + # The weight of that point should be scaled by + # 1 / density = rej_proposal_volume. + rej_weights = rej_proposal_volume * _psd_mask( + rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) + return rej_weights, rej_proposal_volume + + +def _clopper_pearson_confidence_interval(samples, error_rate): + """Computes a confidence interval for the mean of the given 1-D distribution. + + Assumes (and checks) that the given distribution is Bernoulli, i.e., + takes only two values. This licenses using the CDF of the binomial + distribution for the confidence, which is tighter (for extreme + probabilities) than the DKWM inequality. The method is known as the + [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Assumes: + + - The given samples were drawn iid from the distribution of interest. + + - The given distribution is a Bernoulli, i.e., supported only on + low and high. + + Guarantees: + + - The probability (over the randomness of drawing the given sample) + that the true mean is outside the returned interval is no more + than the given error_rate. + + Args: + samples: `np.ndarray` of samples drawn iid from the distribution + of interest. + error_rate: Python `float` admissible rate of mistakes. + + Returns: + low: Lower bound of confidence interval. + high: Upper bound of confidence interval. + + Raises: + ValueError: If `samples` has rank other than 1 (batch semantics + are not implemented), or if `samples` contains values other than + `low` or `high` (as that makes the distribution not Bernoulli). + """ + # TODO(b/78025336) Migrate this confidence interval function + # to statistical_testing.py. In order to do that + # - Get the binomial CDF from the Binomial distribution + # - Implement scalar root finding in TF. Batch bisection search + # shouldn't be too hard, and is definitely good enough for this + # problem. Batching the Brent algorithm (from scipy) that is used + # here may be more involved, but may also not be necessary---it's + # only used here because scipy made it convenient. In particular, + # robustness is more important than speed here, which may make + # bisection search actively better. + # - The rest is just a matter of rewriting in the appropriate style. + if optimize is None or stats is None: + raise ValueError( + "Scipy is required for computing Clopper-Pearson confidence intervals") + if len(samples.shape) != 1: + raise ValueError("Batch semantics not implemented") + n = len(samples) + low = np.amin(samples) + high = np.amax(samples) + successes = np.count_nonzero(samples - low) + failures = np.count_nonzero(samples - high) + if successes + failures != n: + uniques = np.unique(samples) + msg = ("Purportedly Bernoulli distribution had distinct samples" + " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) + raise ValueError(msg) + def p_small_enough(p): + prob = stats.binom.logcdf(successes, n, p) + return prob - np.log(error_rate / 2.) + def p_big_enough(p): + prob = stats.binom.logsf(successes, n, p) + return prob - np.log(error_rate / 2.) + high_p = optimize.brentq( + p_small_enough, float(successes) / n, 1., rtol=1e-9) + low_p = optimize.brentq( + p_big_enough, 0., float(successes) / n, rtol=1e-9) + low_interval = low + (high - low) * low_p + high_interval = low + (high - low) * high_p + return (low_interval, high_interval) + + +def compute_true_volumes( + det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + """Returns confidence intervals for the desired correlation matrix volumes. + + The confidence intervals are computed by the [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Args: + det_bounds: A rank-1 numpy array of lower bounds on the + determinants of acceptable matrices. Entries must be unique. + dim: A Python `int` dimension of correlation matrices to sample. + num_samples: The number of samples to draw. + error_rate: The statistical significance of the returned + confidence intervals. The significance is broadcast: Each + returned interval separately may be incorrect with probability + (under the sample of correlation-like matrices drawn internally) + at most `error_rate`. + seed: Random seed. + + Returns: + bounds: A Python `dict` mapping each determinant bound to the low, high + tuple giving the confidence interval. + """ + bounds = {} + with session.Session() as sess: + rej_weights, _ = correlation_matrix_volume_rejection_samples( + det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) + rej_weights = sess.run(rej_weights) + for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): + template = ("Estimating volume of {}x{} correlation " + "matrices with determinant >= {}.") + print(template.format(dim, dim, det)) + sys.stdout.flush() + bounds[det] = _clopper_pearson_confidence_interval( + rw, error_rate=error_rate) + return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f99300e63871119800a42f122c8321e5986541a --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for correlation_matrix_volumes_lib.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +# NxN correlation matrices are determined by the N*(N-1)/2 +# lower-triangular entries. In addition to being between -1 and 1, +# they must also obey the constraint that the determinant of the +# resulting symmetric matrix is non-negative. In 2x2, we can even +# analytically compute the volume when the determinant is bounded to > +# epsilon, as that boils down to the one lower-triangular entry being +# less than 1 - epsilon in absolute value. +def two_by_two_volume(det_bound): + return 2 * np.sqrt(1.0 - det_bound) + + +# The post +# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ +# derives (with elementary calculus) that the volume (with respect to +# Lebesgue^3 measure) of the set of 3x3 correlation matrices is +# pi^2/2. The same result is also obtained by [1]. +def three_by_three_volume(): + return np.pi**2 / 2. + + +# The volume of the unconstrained set of correlation matrices is also +# the normalization constant of the LKJ distribution from [2]. As +# part of defining the distribution, that reference a derives general +# formula for this volume for all dimensions. A TensorFlow +# computation thereof gave the below result for 4x4: +def four_by_four_volume(): + # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) + return 11.6973076 + +# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of +# correlation matrices." The American Statistician, 48(4), 276-279. + +# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating +# random correlation matrices based on vines and extended onion +# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. + + +class CorrelationMatrixVolumesTest(test.TestCase): + + def testRejection2D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + exact_volumes = two_by_two_volume(det_bounds) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) + # shape of rej_weights: [num_samples, 9, 2, 2] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + # Correct the false fail rate due to different broadcasting + false_fail_rate=1.1e-7, false_pass_rate=1e-6), + 0.036) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection3D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) + # shape of rej_weights: [num_samples, 1, 3, 3] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 3% relative error + 0.15) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection4D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = [four_by_four_volume()] + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) + # shape of rej_weights: [num_samples, 1, 4, 4] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 10% relative error + 1.1) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testVolumeEstimation2D(self): + # Test that the confidence intervals produced by + # corr.compte_true_volumes are sound, in the sense of containing + # the exact volume. + num_samples = int(1e5) # Chosen by symmetry with testRejection2D + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + volume_bounds = corr.compute_true_volumes( + det_bounds, 2, num_samples, error_rate=1e-6, seed=47) + exact_volumes = two_by_two_volume(det_bounds) + for det, volume in zip(det_bounds, exact_volumes): + computed_low, computed_high = volume_bounds[det] + self.assertLess(computed_low, volume) + self.assertGreater(computed_high, volume) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index d813831bef803a22c095d9c98e7163aa4861a15d..bb9b8043b2233b2109f51b5dde188d088fdb0d39 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Autoregressive(distribution_lib.Distribution): @@ -107,6 +108,14 @@ class Autoregressive(distribution_lib.Distribution): https://arxiv.org/abs/1606.05328 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution_fn, sample0=None, @@ -144,7 +153,7 @@ class Autoregressive(distribution_lib.Distribution): `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index c709318f76552e1188f735f5bafff4be0537baed..519077bc9ab1063a1135486cfae34656f3f68157 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -72,6 +72,14 @@ class BatchReshape(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, batch_shape, @@ -103,7 +111,7 @@ class BatchReshape(distribution_lib.Distribution): ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) name = name or "BatchReshape" + distribution.name with ops.name_scope(name, values=[batch_shape]) as name: # The unexpanded batch shape may contain up to one dimension of -1. @@ -353,6 +361,14 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensor_util.constant_value_as_shape(new_shape) @@ -385,6 +401,14 @@ def calculate_reshape(original_shape, new_shape, validate=False, name=None): return expanded_new_shape, batch_shape_static, validations +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" if batch_shape.shape.ndims is not None: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 51478dbeffaabc58ce3662f25f06bc579e8a407e..e141f8b5c6423bd6cce4d09da6f49d55b3e25a24 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -24,23 +24,27 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL @@Ordered @@Permute @@PowerTransform @@RealNVP @@Reshape +@@ScaleTriL @@Sigmoid @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Softsign @@Square +@@TransformDiagonal @@Weibull @@masked_autoregressive_default_template @@ -63,22 +67,26 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * +from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * from tensorflow.contrib.distributions.python.ops.bijectors.square import * +from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index c9e31d7712f09f6c4b4cc6ae51a34c42a19c291d..4d6a46e7358933fdf512f49eae2673f35953c90a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "AbsoluteValue", @@ -70,6 +71,14 @@ class AbsoluteValue(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index b4c2939eb914d50475ba6b1c1e979a804090f641..25f29452c3949600b8a4153a8585dd7269bd3b2b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,6 +37,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _as_tensor(x, name): """Convenience to convert to `Tensor` or leave as `None`.""" return None if x is None else ops.convert_to_tensor(x, name=name) @@ -97,6 +106,14 @@ class Affine(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale_identity_multiplier=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 59f9742d576a7804f401d3a47ba31ae61d6c6e54..91301f15ad87e133777371b346864ecf7b964f27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util import deprecation __all__ = [ @@ -88,6 +89,14 @@ class AffineLinearOperator(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py index cd792e2c8cf48602daf9fb5eb56b8c34bac050c7..460d906231bd30f8cec4fe21d42afe7b2a05805e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -52,6 +53,14 @@ class AffineScalar(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index 224cec8a63dba53a528490117efac890312fe8d5..f19f147dd645b4f805f1905899b44293284d4225 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -34,6 +35,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _undo_batch_normalization(x, mean, variance, @@ -128,6 +137,14 @@ class BatchNormalization(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batchnorm_layer=None, training=True, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index b158a51bb022b5e2ea3afda74e97b9dc131665a6..910774ea5bb4106a948567144c46c6db23a2c6e0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,10 +32,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -142,6 +159,14 @@ class Chain(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijectors=None, validate_args=False, name=None): """Instantiates `Chain` bijector. @@ -234,7 +259,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return ildj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -248,12 +273,15 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -270,7 +298,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return fldj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -283,21 +311,14 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj - - def _maybe_get_event_ndims_statically(self, event_ndims): - event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( - event_ndims) - if event_ndims_ is None: - return event_ndims - return event_ndims_ - - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 268c8d03426d435dc38412ac1bd05c674bd05d2b..3e1e4fc82971b71792d193ea8518dd402e4a4d9d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -69,6 +70,14 @@ class CholeskyOuterProduct(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. @@ -173,7 +182,20 @@ class CholeskyOuterProduct(bijector.Bijector): axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag - return fldj + # We finally need to undo adding an extra column in non-scalar cases + # where there is a single matrix as input. + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 2: + fldj = array_ops.squeeze(fldj, axis=-1) + return fldj + + shape = array_ops.shape(fldj) + maybe_squeeze_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 2), + np.array([], dtype=np.int32), shape[-1:])], 0) + return array_ops.reshape(fldj, maybe_squeeze_shape) def _make_columnar(self, x): """Ensures non-scalar input has at least one column. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 9fc1bbf052b419d07a9db149b990c2b80190d72b..07627e1e45eae6b63d830b2adf036bdc3b1d2895 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors import power_transform +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Exp(power_transform.PowerTransform): over the event space. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="exp"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000000000000000000000000000000000..31a9ca27e519bc312813668bf621a875838f12a0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,165 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util +from tensorflow.python.util import deprecation + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index e656a258e56e71898ecb719dd2af876f158cf799..71e562a927a30a17d695b81c566f981db7553ad9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Gumbel", @@ -45,6 +46,14 @@ class Gumbel(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=0., scale=1., diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index 2bde956d1345129285acae4684256c5ac828b9a1..1504bd27204f728c0cb519159230e945128c4740 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -43,6 +44,14 @@ class Inline(bijector.Bijector): The above example is equivalent to the `Bijector` `Exp()`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, forward_fn=None, inverse_fn=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 84a3289ba2160ed22a2bc7030dd612ba9ca6f6df..a648676d4b1956e5c27f67a71e6bd93d0d7fc97d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Invert", @@ -40,6 +41,14 @@ class Invert(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijector, validate_args=False, name=None): """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py index 97000c17262d3efdef10274711364c2bc2083bd4..33b75a04d34fdd01bc0d854d4e5b9c45a737b122 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -44,6 +45,14 @@ class Kumaraswamy(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 83667b0e80cfcc1c4f0617cdc739221f24439665..b8f2a4b2c731bdaee78692c036fb9f2fba4e3760 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops import variable_scope as variable_scope_lib from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -186,6 +187,14 @@ class MaskedAutoregressiveFlow(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift_and_log_scale_fn, is_constant_jacobian=False, @@ -296,6 +305,14 @@ MASK_INCLUSIVE = "inclusive" MASK_EXCLUSIVE = "exclusive" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): """Generate the slices for building an autoregressive mask.""" # TODO(b/67594795): Better support of dynamic shape. @@ -313,6 +330,14 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): return slices +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_mask(num_blocks, n_in, n_out, @@ -327,6 +352,14 @@ def _gen_mask(num_blocks, return mask +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_dense(inputs, units, num_blocks=None, @@ -399,6 +432,14 @@ def masked_dense(inputs, return layer.apply(inputs) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_autoregressive_default_template( hidden_layers, shift_only=False, @@ -515,6 +556,14 @@ def masked_autoregressive_default_template( "masked_autoregressive_default_template", _fn) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): """Clips input while leaving gradient unaltered.""" with ops.name_scope(name, "clip_by_value_preserve_grad", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..49e6192f067edec4890dcfa107876a5104c14dd4 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,154 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py index 3f03592f314cc13e8a9ea7e2ae18c5bb1f14e74f..fb393218b6b47764f45b5055bbf15cc17aba219e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -57,6 +58,14 @@ class Ordered(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="ordered"): super(Ordered, self).__init__( forward_min_event_ndims=1, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 12a16a3f2ba3da53077307fd97d3f10d99b2c81f..f182a1adcbb6b11af2376cd271f903d50e50f1a0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -74,6 +75,14 @@ class Permute(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, permutation, validate_args=False, name=None): """Creates the `Permute` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index 71f123f2a998458edaa9c8da07ea2932f62625ca..16264fe728a334db347304500767ce5876f9db7e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -41,6 +42,14 @@ class PowerTransform(bijector.Bijector): This bijector is equivalent to the `Exp` bijector when `c=0`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, power=0., validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 66e8a5b9b356867424d1d47efaf848fc6903c371..773ae2446118051a61636bc21de6b81dfacda746 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -126,6 +127,14 @@ class RealNVP(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, num_masked, shift_and_log_scale_fn, @@ -228,6 +237,14 @@ class RealNVP(bijector.Bijector): return math_ops.reduce_sum(log_scale, axis=-1) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def real_nvp_default_template( hidden_layers, shift_only=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 5497c422e4d51e259435692dac722f801e8844ac..c8282229a30fabff0c4c267d0bdfcdbce4f5f3d9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,10 +37,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _static_ndims_from_shape(shape): return shape.shape.with_rank_at_least(1)[0].value +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _ndims_from_shape(shape): return array_ops.shape(shape)[0] @@ -86,6 +103,14 @@ class Reshape(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, event_shape_out, event_shape_in=(-1,), validate_args=False, name=None): """Creates a `Reshape` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbe8665781211ca803feb8bf5a8c04fb0b969e8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar +from tensorflow.contrib.distributions.python.ops.bijectors import chain +from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular +from tensorflow.contrib.distributions.python.ops.bijectors import softplus +from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal +from tensorflow.python.util import deprecation + +__all__ = [ + "ScaleTriL", +] + + +class ScaleTriL(chain.Chain): + """Transforms unconstrained vectors to TriL matrices with positive diagonal. + + This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` + followed by `tfb.TransformDiagonal`, and provided mostly as a + convenience. The default setup is somewhat opinionated, using a + Softplus transformation followed by a small shift (`1e-5`) which + attempts to avoid numerical issues from zeros on the diagonal. + + #### Examples + + ```python + tfb = tf.contrib.distributions.bijectors + b = tfb.ScaleTriL( + diag_bijector=tfb.Exp(), + diag_shift=None) + b.forward(x=[0., 0., 0.]) + # Result: [[1., 0.], + # [0., 1.]] + b.inverse(y=[[1., 0], + [.5, 2]]) + # Result: [log(2), .5, log(1)] + + # Define a distribution over PSD matrices of shape `[3, 3]`, + # with `1 + 2 + 3 = 6` degrees of freedom. + dist = tfd.TransformedDistribution( + tfd.Normal(tf.zeros(6), tf.ones(6)), + tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()])) + + # Using an identity transformation, ScaleTriL is equivalent to + # tfb.FillTriangular. + b = tfb.ScaleTriL( + diag_bijector=tfb.Identity(), + diag_shift=None) + + # For greater control over initialization, one can manually encode + # pre- and post- shifts inside of `diag_bijector`. + b = tfb.ScaleTriL( + diag_bijector=tfb.Chain([ + tfb.AffineScalar(shift=1e-3), + tfb.Softplus(), + tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) + # = log(expm1(1.)) = 0.5413 + diag_shift=None) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector=None, + diag_shift=1e-5, + validate_args=False, + name="scale_tril"): + """Instantiates the `ScaleTriL` bijector. + + Args: + diag_bijector: `Bijector` instance, used to transform the output diagonal + to be positive. + Default value: `None` (i.e., `tfb.Softplus()`). + diag_shift: Float value broadcastable and added to all diagonal entries + after applying the `diag_bijector`. Setting a positive + value forces the output diagonal entries to be positive, but + prevents inverting the transformation for matrices with + diagonal entries less than this value. + Default value: `1e-5` (i.e., no shift is applied). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + Default value: `False` (i.e., arguments are not validated). + name: Python `str` name given to ops managed by this object. + Default value: `scale_tril`. + """ + + if diag_bijector is None: + diag_bijector = softplus.Softplus(validate_args=validate_args) + + if diag_shift is not None: + diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift), + diag_bijector]) + + super(ScaleTriL, self).__init__( + [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), + fill_triangular.FillTriangular()], + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index 5df8c886315ff75cdc884e3b9b4665fb64bb109d..194b318fce31a13f84e7b664b58cebb24fc9a264 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,6 +32,14 @@ __all__ = [ class Sigmoid(bijector.Bijector): """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 2a32e8abcde940b0056b0faf2955ec1b3bd71803..241fba2cb7ec33b7b02c1ca79051f1b826d7d2aa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -26,12 +26,21 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _sqrtx2p1(x): """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" return array_ops.where( @@ -88,6 +97,14 @@ class SinhArcsinh(bijector.Bijector): `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, skewness=None, tailweight=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index f52b91550edff7390d8094a4508d862674e85d59..20ee0d340833d5c5275e2ab52a89dcdf7198add1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -60,6 +61,14 @@ class SoftmaxCentered(bijector.Bijector): makes the (forward) image non-open and the theorem does not directly apply. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softmax_centered"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 96a938c803418ff818f9c531754b47ba1eb8667a..3df84ef8b04c2c8f6be91ecd1c972ad1484b4285 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -80,6 +81,14 @@ class Softplus(bijector.Bijector): "hinge_softness": ( "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, hinge_softness=None, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py index b4a658c171b8313358754228aabbfa4bf93fd84d..f96a4bb01de59a21107b9e7c14f929e13e358ac9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py @@ -22,6 +22,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -51,6 +52,14 @@ class Softsign(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softsign"): super(Softsign, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py index 2ccfdc95970e387e708603e2614ad29fb6a18db3..294460a80f6209797831ea361e64efe677f71e59 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ class Square(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="square"): """Instantiates the `Square` bijector. @@ -81,4 +90,3 @@ class Square(bijector.Bijector): is_valid = check_ops.assert_non_negative( t, message="All elements must be non-negative.") return control_flow_ops.with_dependencies([is_valid], t) - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7a3b026b8dcc31bed49c489d77b9c184f463cb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + +__all__ = [ + "TransformDiagonal", +] + + +class TransformDiagonal(bijector.Bijector): + """Applies a Bijector to the diagonal of a matrix. + + #### Example + + ```python + b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) + + b.forward([[1., 0.], + [0., 1.]]) + # ==> [[2.718, 0.], + [0., 2.718]] + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector, + validate_args=False, + name="transform_diagonal"): + """Instantiates the `TransformDiagonal` bijector. + + Args: + diag_bijector: `Bijector` instance used to transform the diagonal. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._diag_bijector = diag_bijector + super(TransformDiagonal, self).__init__( + forward_min_event_ndims=2, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x)) + return array_ops.matrix_set_diag(x, diag) + + def _inverse(self, y): + diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y)) + return array_ops.matrix_set_diag(y, diag) + + def _forward_log_det_jacobian(self, x): + # We formulate the Jacobian with respect to the flattened matrices + # `vec(x)` and `vec(y)`. Suppose for notational convenience that + # the first `n` entries of `vec(x)` are the diagonal of `x`, and + # the remaining `n**2-n` entries are the off-diagonals in + # arbitrary order. Then the Jacobian is a block-diagonal matrix, + # with the Jacobian of the diagonal bijector in the first block, + # and the identity Jacobian for the remaining entries (since this + # bijector acts as the identity on non-diagonal entries): + # + # J_vec(x) (vec(y)) = + # ------------------------------- + # | J_diag(x) (diag(y)) 0 | n entries + # | | + # | 0 I | n**2-n entries + # ------------------------------- + # n n**2-n + # + # Since the log-det of the second (identity) block is zero, the + # overall log-det-jacobian is just the log-det of first block, + # from the diagonal bijector. + # + # Note that for elementwise operations (exp, softplus, etc) the + # first block of the Jacobian will itself be a diagonal matrix, + # but our implementation does not require this to be true. + return self._diag_bijector.forward_log_det_jacobian( + array_ops.matrix_diag_part(x), event_ndims=1) + + def _inverse_log_det_jacobian(self, y): + return self._diag_bijector.inverse_log_det_jacobian( + array_ops.matrix_diag_part(y), event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index a22560fe80298b762795e7b0e7aea2db55823065..8903a70d98ae144731b12047e5074d0450b59378 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Weibull(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale=1., concentration=1., diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index 24b26bf124c78c8320b9a6bc3b900e6c7a93f5e4..b349e5966dd750fdf96c0b211dce02658c9400b7 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation _binomial_sample_note = """ @@ -42,6 +43,14 @@ to integer values. """ +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _bdtr(k, n, p): """The binomial cumulative distribution function. @@ -130,6 +139,14 @@ class Binomial(distribution.Distribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, @@ -163,7 +180,7 @@ class Binomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index f5ffdd873124d6626dca26f603592bd0b030d7b3..cb5223b0557080e10bf24c3e1cb432f15fd5e7e3 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Cauchy", @@ -93,6 +93,14 @@ class Cauchy(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -121,7 +129,7 @@ class Cauchy(distribution.Distribution): Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index 08cdc1582892cc7d308bd60f082dde082704f57f..e9a7b39070f3d76693ad54852ed0847a0980d2a6 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -25,7 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -64,6 +64,14 @@ class Chi2(gamma.Gamma): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, @@ -84,7 +92,7 @@ class Chi2(gamma.Gamma): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True @@ -115,12 +123,20 @@ class Chi2(gamma.Gamma): class Chi2WithAbsDf(Chi2): """Chi2 with parameter transform `df = floor(abs(df))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 10b45361358b40a3c8fd725f27ad84ef9b8a37f5..3598c8d23ea9007fb359ae4931738fb61ede4ccc 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -106,7 +105,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -131,7 +130,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -220,14 +219,14 @@ class ConditionalTransformedDistribution( inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) - static_event_ndims = tensor_util.constant_value(event_ndims) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 6d7d6d307bd0f815344c8a0e347f45ae11ba6462..ad853ee293f86565c1af601214522f53d936b70a 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -32,7 +32,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Deterministic", @@ -44,6 +44,14 @@ __all__ = [ class _BaseDeterministic(distribution.Distribution): """Base class for Deterministic distributions.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -87,7 +95,7 @@ class _BaseDeterministic(distribution.Distribution): Raises: ValueError: If `loc` is a scalar. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: @@ -204,6 +212,14 @@ class Deterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -309,6 +325,14 @@ class VectorDeterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 289e1d50e1146a641c0cc433ece3465aed73b1c2..6959b3e8775d2dd488b4ee3252d143ef376d58f9 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,12 +21,19 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib + +# The following two lines are redundant, in a sense. The first enables +# good coding practice *within* this file (`util.prefer_static_value` +# rather than `prefer_static_value`). The second ensures that users +# also get the core utils when they import this file. +from tensorflow.python.ops.distributions import util from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def move_dimension(x, source_idx, dest_idx): + """Move a single tensor dimension within its shape. + + This is a special case of `tf.transpose()`, which applies + arbitrary permutations to tensor dimensions. + + Args: + x: Tensor of rank `ndims`. + source_idx: Integer index into `x.shape` (negative indexing is + supported). + dest_idx: Integer index into `x.shape` (negative indexing is + supported). + + Returns: + x_perm: Tensor of rank `ndims`, in which the dimension at original + index `source_idx` has been moved to new index `dest_idx`, with + all other dimensions retained in their original order. + + Example: + + ```python + x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x_perm = _move_dimension(x, 1, 1) # no-op + x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] + x_perm = _move_dimension(x, 0, -2) # equivalent to previous + x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] + ``` + """ + ndims = util.prefer_static_rank(x) + if isinstance(source_idx, int): + dtype = dtypes.int32 + else: + dtype = dtypes.as_dtype(source_idx.dtype) + + # Handle negative indexing. Since ndims might be dynamic, this makes + # source_idx and dest_idx also possibly dynamic. + if source_idx < 0: + source_idx = ndims + source_idx + if dest_idx < 0: + dest_idx = ndims + dest_idx + + # Construct the appropriate permutation of dimensions, depending + # whether the source is before or after the destination. + def move_left_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, dest_idx, dtype=dtype), + [source_idx], + math_ops.range(dest_idx, source_idx, dtype=dtype), + math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) + + def move_right_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, source_idx, dtype=dtype), + math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), + [source_idx], + math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) + + def x_permuted(): + return array_ops.transpose( + x, perm=smart_cond.smart_cond(source_idx < dest_idx, + move_right_permutation, + move_left_permutation)) + + # One final conditional to handle the special case where source + # and destination indices are equal. + return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), + lambda: x, + x_permuted) diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 98edd337fe02ffbf53c6ecd9ebda9424231ea2fe..bdec6527d5378d6e86aa8e6279cc6ee672083e56 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHea from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -30,6 +31,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def estimator_head_distribution_regression(make_distribution_fn, label_dimension=1, logits_dimension=None, @@ -77,6 +86,14 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, make_distribution_fn, label_dimension, diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index 446cff6ec242f25178fed0c6a424791fa9f176ad..d62f024aa2a081f0ec231015af1f26a8851518e9 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Geometric(distribution.Distribution): @@ -55,6 +56,14 @@ class Geometric(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, logits=None, probs=None, @@ -85,7 +94,7 @@ class Geometric(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ed9ea6f4f3ffe18fb6bf1e0a7d57728d010e0f01..acdea4d61d3ada7e9f4f0aa7bc58c5643db2802b 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _Gumbel(distribution.Distribution): @@ -97,6 +97,14 @@ class _Gumbel(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -125,7 +133,7 @@ class _Gumbel(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index 7e12767f6d8f6c61565ecf266d3b222de68c0e40..b02c4031069191592b8acc1a90313450f98af6d7 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -86,6 +86,14 @@ class HalfNormal(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale, validate_args=False, @@ -106,7 +114,7 @@ class HalfNormal(distribution.Distribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index fa89fff3b7b2f8266a44c446a0c9807790b3aed8..0672702b96c1eb81c176774554df3f5922a0319e 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Independent(distribution_lib.Distribution): @@ -95,6 +95,14 @@ class Independent(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): @@ -117,7 +125,7 @@ class Independent(distribution_lib.Distribution): ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: @@ -259,6 +267,14 @@ class Independent(distribution_lib.Distribution): @kullback_leibler.RegisterKL(Independent, Independent) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 85e8e10466038e5e55ef4b754f82c0c2c2543b6d..70d050d7a647b38928ddb1c788db0e6957ac0f03 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -95,6 +96,14 @@ class InverseGamma(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, @@ -125,7 +134,7 @@ class InverseGamma(distribution.Distribution): Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), @@ -274,13 +283,21 @@ class InverseGamma(distribution.Distribution): class InverseGammaWithSoftplusConcentrationRate(InverseGamma): """`InverseGamma` with softplus of `concentration` and `rate`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 66682b2ff5493f8565410138e770b45ffc6b5d77..e3712dd84e36609d6bba4a5a39866046c0c8d1d8 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -41,6 +41,14 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. @@ -59,7 +67,6 @@ def _harmonic_number(x): return math_ops.digamma(x + one) - math_ops.digamma(one) -@tf_export("distributions.Kumaraswamy") class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. @@ -125,6 +132,14 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 0103283259b0526b5a108ea1836f95709eedc067..02e3bad51ee48188acf83cb09359861c9e6932c7 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Logistic(distribution.Distribution): @@ -92,6 +92,14 @@ class Logistic(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -120,7 +128,7 @@ class Logistic(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index d54f30dc634ab5c8aa82066056266747b63eec21..3b7114ef067c0aaede23fff04c40d1dc6e830f1c 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Mixture(distribution.Distribution): @@ -66,6 +67,14 @@ class Mixture(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, cat, components, @@ -116,7 +125,7 @@ class Mixture(distribution.Distribution): matching static batch shapes, or all components do not have matching static event shapes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index c7c90cf875484a1753577227bf22de878d00a502..8ffee940d03c9a5204f2ac6f7acd9ea482adae1a 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class MixtureSameFamily(distribution.Distribution): @@ -95,6 +96,14 @@ class MixtureSameFamily(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mixture_distribution, components_distribution, @@ -130,7 +139,7 @@ class MixtureSameFamily(distribution.Distribution): ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution @@ -321,6 +330,14 @@ class MixtureSameFamily(distribution.Distribution): return x +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" z = x - y diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index cad398582b9c939e8e96cf498638869ccd3701bd..cd0c282ba6cebf784261a4e821f36ce4eed98fe0 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -193,7 +202,7 @@ class MultivariateNormalDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): @@ -218,13 +227,21 @@ class MultivariateNormalDiag( class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale_diag, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagWithSoftplusScale"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale_diag]) as name: super(MultivariateNormalDiagWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 1c11594df3ad2612dd8746bb8785d86390b69937..d8401801f21afbe8fd042053c6a38a31a2539438 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -141,6 +142,14 @@ class MultivariateNormalDiagPlusLowRank( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -215,7 +224,7 @@ class MultivariateNormalDiagPlusLowRank( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 47d7d13cf357f1ac657641420602c92eefdad197..dbc4c1b3dc956641f3e38ffafe3a3410bd3e2097 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -24,7 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -113,6 +113,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, covariance_matrix=None, @@ -156,7 +164,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 79916fef8d7b752649dcc673a84ea45ccf460905..efe5a6d0d99ca8fa9e0274049423bb3c4eef2d6f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -27,6 +27,7 @@ from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -133,6 +134,14 @@ class MultivariateNormalLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -170,7 +179,7 @@ class MultivariateNormalLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: @@ -266,6 +275,14 @@ class MultivariateNormalLinearOperator( @kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, MultivariateNormalLinearOperator) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d6b0ed994ec0a62e9b7684e7478130052a1fd300..d9110947ecdbba1a63669573f46db17b02e512ab 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalTriL( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_tril=None, @@ -179,7 +188,7 @@ class MultivariateNormalTriL( Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index 1085c56dc86c8d45bdab2e7cecedf44663e5c408..6acfc5746a0cc20e916de81b71f90e08d8d91ad5 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class NegativeBinomial(distribution.Distribution): @@ -51,6 +52,14 @@ class NegativeBinomial(distribution.Distribution): * `n!` is the factorial of `n`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, @@ -90,7 +99,7 @@ class NegativeBinomial(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index a4b9f3b78d4fdcc328bac84623114b921b9ded49..214c6dca4a7f2b4cd6242e1b7ca78be9eeffb851 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class OneHotCategorical(distribution.Distribution): @@ -83,6 +84,14 @@ class OneHotCategorical(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, logits=None, @@ -115,7 +124,7 @@ class OneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, @@ -226,13 +235,21 @@ class OneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) @kullback_leibler.RegisterKL(OneHotCategorical, OneHotCategorical) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_categorical_categorical(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical. diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index b34539402102b8f289d4eb289fcb82f4030f4e8c..3d055085cc7386e57a71aa310458b7666bb9a396 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Poisson", @@ -65,6 +66,14 @@ class Poisson(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, rate=None, log_rate=None, @@ -93,7 +102,7 @@ class Poisson(distribution.Distribution): TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index fe72091d7d759e54c51eb666f2ceacc8371e55fd..7a7ad1be35b80ff0f000181ea0778ab282a8220f 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -33,6 +33,7 @@ from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -42,6 +43,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_gauss_hermite( loc, scale, quadrature_size, validate_args=False, name=None): # pylint: disable=unused-argument @@ -85,6 +94,14 @@ def quadrature_scheme_lognormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_quantiles( loc, scale, quadrature_size, validate_args=False, name=None): @@ -214,6 +231,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): validate_args=True) """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -255,7 +280,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: if loc is not None: loc = ops.convert_to_tensor(loc, name="loc") @@ -417,6 +442,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 584d2c385fced95ec496bb8dae9556e5c376b66d..ef3bdfa75fcaa8df17db1238ceadadf788601356 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -27,10 +27,19 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distributions from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = ["QuantizedDistribution"] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _logsum_expbig_minus_expsmall(big, small): """Stable evaluation of `Log[exp{big} - exp{small}]`. @@ -228,6 +237,14 @@ class QuantizedDistribution(distributions.Distribution): https://arxiv.org/abs/1711.10433 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, low=None, @@ -263,7 +280,7 @@ class QuantizedDistribution(distributions.Distribution): `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) values = ( list(distribution.parameters.values()) + [low, high]) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 0362996e684fb34b15cd98a2fc40df58087fbe95..7e1f64dc425e6a576bfbe1bb456901fddfac26e1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic +from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class RelaxedBernoulli(transformed_distribution.TransformedDistribution): @@ -131,6 +132,14 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Gumbel-Softmax. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, temperature, logits=None, @@ -165,7 +174,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 910c430ae7f026a3ac9ce50d1d5936d4454cba41..25aaac379a7c54c832bdcf962e16f339522d61fc 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class ExpRelaxedOneHotCategorical(distribution.Distribution): @@ -125,6 +126,14 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, @@ -162,7 +171,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( @@ -290,7 +299,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) @@ -368,6 +377,14 @@ class RelaxedOneHotCategorical( A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index 056d349688511e19a4fa3d58a5b3c1c8355671a3..cf505ac627b62ae0a3d1ec1ce2a237c3c2ff1b74 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -169,7 +169,7 @@ class SeedStream(object): and TensorFlow Probability code base. See class docstring for rationale. """ - self._seed = seed + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed self._salt = salt self._counter = 0 diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 6a7f28713acefd2285b07a212e2e47a6db1ae5e1..4f348be2806aa3ade7c1ea2a7bc68ca26db6447f 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _DistributionShape(object): @@ -166,6 +167,14 @@ class _DistributionShape(object): "free," i.e., during graph construction. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batch_ndims=None, event_ndims=None, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index f04dc8da39140240edbe4efb75de30e321436d55..a9d0fb4ccfb1803873f7fe17089f3e7c7f10f4b7 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", @@ -94,6 +95,14 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -132,7 +141,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale, skewness, tailweight]) as name: diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index 9c69435fac109914ff29b307dfad105f62849339..c25e8c51d7705b641699fb05623c7b0fb4950e1b 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -140,6 +140,7 @@ __all__ = [ "assert_true_mean_equal_by_dkwm", "min_discrepancy_of_true_means_detectable_by_dkwm", "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_in_interval_by_dkwm", "assert_true_mean_equal_by_dkwm_two_sample", "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", "min_num_samples_for_dkwm_mean_two_sample_test", @@ -209,17 +210,17 @@ def _maximum_mean(samples, envelope, high, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `high`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - high: Floating-point tensor of upper bounds on the distributions' - supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. `samples <= high`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of upper bounds on the true means. + bound: Floating-point `Tensor` of upper bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be larger than @@ -254,17 +255,17 @@ def _minimum_mean(samples, envelope, low, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `low`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - low: Floating-point tensor of lower bounds on the distributions' - supports. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. `samples >= low`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of lower bounds on the true means. + bound: Floating-point `Tensor` of lower bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be smaller than @@ -300,12 +301,12 @@ def _dkwm_cdf_envelope(n, error_rate, name=None): probability above. Args: - n: Tensor of numbers of samples drawn. - error_rate: Floating-point tensor of admissible rates of mistakes. + n: `Tensor` of numbers of samples drawn. + error_rate: Floating-point `Tensor` of admissible rates of mistakes. name: A name for this operation (optional). Returns: - eps: Tensor of maximum distances the true CDF can be from the + eps: `Tensor` of maximum distances the true CDF can be from the empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and `error_rate`. @@ -324,8 +325,8 @@ def _check_shape_dominates(samples, parameters): sample counts end up inflated. Args: - samples: A Tensor whose shape is to be protected against broadcasting. - parameters: A list of Tensors who are parameters for the statistical test. + samples: A `Tensor` whose shape is to be protected against broadcasting. + parameters: A list of `Tensor`s who are parameters for the statistical test. Returns: samples: Return original `samples` with control dependencies attached @@ -369,19 +370,23 @@ def true_mean_confidence_interval_by_dkwm( members. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - error_rate: *Scalar* admissible total rate of mistakes. + error_rate: *Scalar* floating-point `Tensor` admissible total rate + of mistakes. name: A name for this operation (optional). Returns: - low: A floating-point tensor of stochastic lower bounds on the true means. - high: A floating-point tensor of stochastic upper bounds on the true means. + low: A floating-point `Tensor` of stochastic lower bounds on the + true means. + high: A floating-point `Tensor` of stochastic upper bounds on the + true means. """ with ops.name_scope( name, "true_mean_confidence_interval_by_dkwm", @@ -436,15 +441,17 @@ def assert_true_mean_equal_by_dkwm( the assertion will insist on stronger evidence to fail any one member. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - expected: Floating-point tensor of expected true means. - false_fail_rate: *Scalar* admissible total rate of mistakes. + expected: Floating-point `Tensor` of expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -454,20 +461,8 @@ def assert_true_mean_equal_by_dkwm( with ops.name_scope( name, "assert_true_mean_equal_by_dkwm", [samples, low, high, expected, false_fail_rate]): - samples = ops.convert_to_tensor(samples, name="samples") - low = ops.convert_to_tensor(low, name="low") - high = ops.convert_to_tensor(high, name="high") - expected = ops.convert_to_tensor(expected, name="expected") - false_fail_rate = ops.convert_to_tensor( - false_fail_rate, name="false_fail_rate") - samples = _check_shape_dominates(samples, [low, high, expected]) - min_mean, max_mean = true_mean_confidence_interval_by_dkwm( - samples, low, high, error_rate=false_fail_rate) - less_op = check_ops.assert_less( - min_mean, expected, message="Mean confidence interval too high") - with ops.control_dependencies([less_op]): - return check_ops.assert_greater( - max_mean, expected, message="Mean confidence interval too low") + return assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected, expected, false_fail_rate) def min_discrepancy_of_true_means_detectable_by_dkwm( @@ -487,30 +482,35 @@ def min_discrepancy_of_true_means_detectable_by_dkwm( with the same `false_pass_rate`. Args: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. - low: Floating-point tensor of lower bounds on the distributions' + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true + discr: `Tensor` of lower bounds on the distances between true means detectable by a DKWM-based test. For each batch member `i`, of `K` total, drawing `n[i]` samples from some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discr[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discr[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The detectable discrepancy scales as @@ -558,17 +558,19 @@ def min_num_samples_for_dkwm_mean_test( on a scalar distribution supported on `[low, high]`. Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low: Tensor of lower bounds on the distributions' support. - high: Tensor of upper bounds on the distributions' support. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + low: `Tensor` of lower bounds on the distributions' support. + high: `Tensor` of upper bounds on the distributions' support. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. The `discrepancy`, `low`, and `high` tensors must have @@ -578,12 +580,15 @@ def min_num_samples_for_dkwm_mean_test( some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discrepancy[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discrepancy[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The required number of samples scales as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, @@ -610,6 +615,76 @@ def min_num_samples_for_dkwm_mean_test( return math_ops.maximum(n1, n2) +def assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected_low, expected_high, + false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is in the given interval. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the mean of the distribution from which the given samples are + drawn is _outside_ the given interval with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want + to check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected_low: Floating-point `Tensor` of lower bounds on the + expected true means. + expected_high: Floating-point `Tensor` of upper bounds on the + expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean + interval does not overlap with the corresponding confidence + interval. + """ + with ops.name_scope( + name, "assert_true_mean_in_interval_by_dkwm", + [samples, low, high, expected_low, expected_high, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected_low = ops.convert_to_tensor(expected_low, name="expected_low") + expected_high = ops.convert_to_tensor(expected_high, name="expected_high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates( + samples, [low, high, expected_low, expected_high]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, false_fail_rate) + # Assert that the interval [min_mean, max_mean] intersects the + # interval [expected_low, expected_high]. This is true if + # max_mean >= expected_low and min_mean <= expected_high. + # By DeMorgan's law, that's also equivalent to + # not (max_mean < expected_low or min_mean > expected_high), + # which is a way of saying the two intervals are not disjoint. + check_confidence_interval_can_intersect = check_ops.assert_greater_equal( + max_mean, expected_low, message="Confidence interval does not " + "intersect: true mean smaller than expected") + with ops.control_dependencies([check_confidence_interval_can_intersect]): + return check_ops.assert_less_equal( + min_mean, expected_high, message="Confidence interval does not " + "intersect: true mean greater than expected") + + def assert_true_mean_equal_by_dkwm_two_sample( samples1, low1, high1, samples2, low2, high2, false_fail_rate=1e-6, name=None): @@ -630,23 +705,26 @@ def assert_true_mean_equal_by_dkwm_two_sample( the assertion will insist on stronger evidence to fail any one member. Args: - samples1: Floating-point tensor of samples from the + samples1: Floating-point `Tensor` of samples from the distribution(s) A. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low1: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low1 <= samples1 <= high1`. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - samples2: Floating-point tensor of samples from the + samples2: Floating-point `Tensor` of samples from the distribution(s) B. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low2: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low2 <= samples2 <= high2`. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of mistakes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -676,20 +754,10 @@ def assert_true_mean_equal_by_dkwm_two_sample( # and sample counts should be valid; however, because the intervals # scale as O(-log(false_fail_rate)), there doesn't seem to be much # room to win. - min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( - samples1, low1, high1, false_fail_rate / 2.) min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( samples2, low2, high2, false_fail_rate / 2.) - # I want to assert - # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), - # but I think I only have and-combination of asserts, so use DeMorgan. - check_confidence_intervals_can_intersect = check_ops.assert_greater_equal( - max_mean_1, min_mean_2, message="Confidence intervals do not " - "intersect: samples1 has a smaller mean than samples2") - with ops.control_dependencies([check_confidence_intervals_can_intersect]): - return check_ops.assert_less_equal( - min_mean_1, max_mean_2, message="Confidence intervals do not " - "intersect: samples2 has a smaller mean than samples1") + return assert_true_mean_in_interval_by_dkwm( + samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.) def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( @@ -710,22 +778,24 @@ def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( with the same `false_pass_rate`. Args: - n1: Tensor of numbers of samples to be drawn from the distributions A. - low1: Floating-point tensor of lower bounds on the supports of the + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. - low2: Floating-point tensor of lower bounds on the supports of the + n2: `Tensor` of numbers of samples to be drawn from the distributions B. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true means + discr: `Tensor` of lower bounds on the distances between true means detectable by a two-sample DKWM-based test. For each batch member `i`, of `K` total, drawing `n1[i]` samples @@ -776,24 +846,26 @@ def min_num_samples_for_dkwm_mean_two_sample_test( (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low1: Floating-point tensor of lower bounds on the supports of the + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - low2: Floating-point tensor of lower bounds on the supports of the + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n1: Tensor of numbers of samples to be drawn from the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + n2: `Tensor` of numbers of samples to be drawn from the distributions B. For each batch member `i`, of `K` total, drawing `n1[i]` samples from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index cd6d7499595d88d18de339371d4a07fe780662d9..ece03fe4aab3cc3046e0958d883ca9388517b94b 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_gauss_hermite( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -111,6 +120,14 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_quantiles( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -318,6 +335,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): https://arxiv.org/abs/1801.03080 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mix_loc, temperature, @@ -395,7 +420,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[mix_loc, temperature]) as name: if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " @@ -779,6 +804,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return array_ops.reshape(p, shape=expand_shape) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): @@ -812,6 +845,14 @@ def maybe_check_quadrature_param(param, name, validate_args): return param +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): @@ -850,6 +891,14 @@ def determine_batch_event_shapes(grid, endpoint_affine): return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: @@ -876,6 +925,14 @@ def interpolate_loc(grid, loc): return [x[..., k] for k in range(deg)] # list(shape:[B, e]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: @@ -892,6 +949,14 @@ def interpolate_scale(grid, scale): ])[0] for q in range(deg)] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def linop_scale(w, op): # We assume w > 0. (This assumption only relates to the is_* attributes.) with ops.name_scope("linop_scale", values=[w]): @@ -927,6 +992,14 @@ def linop_scale(w, op): "Unsupported Linop type ({})".format(type(op).__name__)) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] @@ -935,6 +1008,14 @@ def concat_vectors(*args): return [val for vec in args_ for val in vec] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -944,11 +1025,27 @@ def add(x, y): return x + y +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def softmax(x, axis, name=None): """Equivalent to tf.nn.softmax but works around b/70297725.""" with ops.name_scope(name, "softmax", [x, axis]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 3465d66b30501e7aebd9904d2ae2206d628c10b7..73356a3625c9a1aa15af5b6c1cf2ccb0c514b39a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_exponential_linear_operator as vector_exponential_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -116,6 +117,14 @@ class VectorExponentialDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -175,7 +184,7 @@ class VectorExponentialDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 2c31b019845d7e4558eb3047af84732a2ae03986..9a47b4855763a25b484ad04a3415d191f19256f7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = ["VectorExponentialLinearOperator"] @@ -138,6 +139,14 @@ class VectorExponentialLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -175,7 +184,7 @@ class VectorExponentialLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 6a36018d6f1b83955ef9080ec11c74c08a670075..e68ddc569c95ff63760b4b2f6d7a92f17240a558 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_laplace_linear_operator as vector_laplace_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -151,6 +152,14 @@ class VectorLaplaceDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -210,7 +219,7 @@ class VectorLaplaceDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name): with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 97e5c76d800acd800e34a9e66a3c5fdd7ce4f660..3923161a332a77e4eaab8d65d96fd8c278c872ec 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -154,6 +155,14 @@ class VectorLaplaceLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -191,7 +200,7 @@ class VectorLaplaceLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index ff5ca4525700aedc88d75e391bf0c2415c2afa13..49ffff24caec8d6c525f65f06796d10548d5ec40 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "VectorSinhArcsinhDiag", @@ -95,6 +96,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -163,7 +172,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope( name, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 4742f7521816d4643354017495f3380c78ac7bc2..f289b39e51aff36780541a0545ed9e6cfe21dd4e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation class _VectorStudentT(transformed_distribution.TransformedDistribution): @@ -121,6 +122,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, loc=None, @@ -175,7 +184,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index f555867e7f3c2a6bc797e9b3d56da2fa434aba6f..f1accaaa4c920344608015c792a2c3606de1337f 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "WishartCholesky", @@ -73,6 +74,14 @@ class _WishartLinearOperator(distribution.Distribution): this class. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale_operator, @@ -107,7 +116,7 @@ class _WishartLinearOperator(distribution.Distribution): ValueError: if df < k, where scale operator event shape is `(k, k)` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) self._cholesky_input_output_matrices = cholesky_input_output_matrices with ops.name_scope(name) as name: with ops.name_scope("init", values=[df, scale_operator]): @@ -501,6 +510,14 @@ class WishartCholesky(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -530,7 +547,7 @@ class WishartCholesky(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) @@ -617,6 +634,14 @@ class WishartFull(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -646,7 +671,7 @@ class WishartFull(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 4384431e7b9c3e6ef259391fa9efa5a35d23c86a..86d203452e24d6d73f3ebb17b989867905a61382 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -44,7 +44,7 @@ Installation instructions at https://www.tensorflow.org/install/ For an introduction to eager execution in TensorFlow, see: -- [User Guide](https://www.tensorflow.org/programmers_guide/eager) ([source](../../docs_src/programmers_guide/eager.md)) +- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md)) - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index d7909dd5a2691a015a6afed2caa475b39ca7ebc3..58c548d798178a2848006cbf301f7d5cb2143f24 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -102,11 +102,13 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): with ops.device(self._device): self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long string_arg=iter_string_handle, + output_types=self._flat_output_types, f=remote_fn, target_device=target, buffer_size=10, container="", - shared_name=_generate_shared_name("function_buffer_resource")) + shared_name=_generate_shared_name( + "contrib_eager_iterator_function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long handle=self._buffer_resource_handle, handle_device=self._device) diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index c1fd9e0ed020beeb722204edf1adfe1dfcf8ff03..12155a459c29c353c57679c407e7dda25047a35c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -7,10 +7,16 @@ py_library( name = "examples_pip", deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/l2hmc", + "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", + "//tensorflow/contrib/eager/python/examples/revnet", + "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/sagan", + "//tensorflow/contrib/eager/python/examples/sagan:config", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b80c90902353709b7f739585291ec3b5890c27c7..cc9cf53410f641cc3303b4450e9eaa1301904a64 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -227,7 +227,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tfe.GradientTape(persistent=True) as g: + with tf.GradientTape(persistent=True) as g: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -306,7 +306,7 @@ def main(_): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index bd35e50c1f434d167c5a8c5aa7d224912523ce28..81ac05e26d23c2fc53f63d64bb28bdea6072e396 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -111,5 +111,5 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7bdf9053de749af9d09b12ba7b848e21c1fdb8f0 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -0,0 +1,39 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "neural_nets", + srcs = ["neural_nets.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +py_library( + name = "l2hmc", + srcs = ["l2hmc.py"], + srcs_version = "PY2AND3", + deps = [ + ":neural_nets", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "l2hmc_test", + size = "large", + srcs = ["l2hmc_test.py"], + additional_deps = [ + ":l2hmc", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..729d8525fab31ee214178ca1bcb18dbd069f767a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets + + +class Dynamics(tf.keras.Model): + """Dynamics engine of naive L2HMC sampler. + + Args: + x_dim: dimensionality of observed data + loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + """ + + def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + super(Dynamics, self).__init__() + + self.x_dim = x_dim + self.potential = loglikelihood_fn + self.n_steps = n_steps + + self._construct_time() + self._construct_masks() + + self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) + self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) + + self.eps = tfe.Variable( + initial_value=eps, name="eps", dtype=tf.float32, trainable=True) + + def apply_transition(self, position): + """Propose a new state and perform the accept or reject step.""" + + # Simulate dynamics both forward and backward; + # Use sampled Bernoulli masks to compute the actual solutions + position_f, momentum_f, accept_prob_f = self.transition_kernel( + position, forward=True) + position_b, momentum_b, accept_prob_b = self.transition_kernel( + position, forward=False) + + # Decide direction uniformly + forward_mask = tf.cast( + tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + backward_mask = 1. - forward_mask + + # Obtain proposed states + position_post = ( + forward_mask[:, None] * position_f + + backward_mask[:, None] * position_b) + momentum_post = ( + forward_mask[:, None] * momentum_f + + backward_mask[:, None] * momentum_b) + + # Probability of accepting the proposed states + accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b + + # Accept or reject step + accept_mask = tf.cast( + accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32) + reject_mask = 1. - accept_mask + + # Samples after accept/reject step + position_out = ( + accept_mask[:, None] * position_post + reject_mask[:, None] * position) + + return position_post, momentum_post, accept_prob, position_out + + def transition_kernel(self, position, forward=True): + """Transition kernel of augmented leapfrog integrator.""" + + lf_fn = self._forward_lf if forward else self._backward_lf + + # Resample momentum + momentum = tf.random_normal(tf.shape(position)) + position_post, momentum_post = position, momentum + sumlogdet = 0. + # Apply augmented leapfrog steps + for i in range(self.n_steps): + position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, + i) + sumlogdet += logdet + + accept_prob = self._compute_accept_prob(position, momentum, position_post, + momentum_post, sumlogdet) + + return position_post, momentum_post, accept_prob + + def _forward_lf(self, position, momentum, i): + """One forward augmented leapfrog step. See eq (5-6) in paper.""" + + t = self._get_time(i) + mask, mask_inv = self._get_mask(i) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _backward_lf(self, position, momentum, i): + """One backward augmented leapfrog step. See Appendix A in paper.""" + + # Reversed index/sinusoidal time + t = self._get_time(self.n_steps - i - 1) + mask, mask_inv = self._get_mask(self.n_steps - i - 1) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _update_momentum_forward(self, position, momentum, t): + """Update v in the forward leapfrog step.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= .5 * self.eps + transformed *= self.eps + momentum = ( + momentum * tf.exp(scale) - + .5 * self.eps * (tf.exp(transformed) * grad - translation)) + + return momentum, scale + + def _update_position_forward(self, position, momentum, t, mask): + """Update x in the forward leapfrog step.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask * position, t]) + scale *= self.eps + transformed *= self.eps + position = ( + mask * position + + mask_inv * (position * tf.exp(scale) + self.eps * + (tf.exp(transformed) * momentum + translation))) + + return position, mask_inv * scale + + def _update_momentum_backward(self, position, momentum, t): + """Update v in the backward leapfrog step. Inverting the forward update.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= -.5 * self.eps + transformed *= self.eps + momentum = ( + tf.exp(scale) * (momentum + .5 * self.eps * + (tf.exp(transformed) * grad - translation))) + + return momentum, scale + + def _update_position_backward(self, position, momentum, t, mask): + """Update x in the backward leapfrog step. Inverting the forward update.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask_inv * position, t]) + scale *= -self.eps + transformed *= self.eps + position = ( + mask_inv * position + mask * tf.exp(scale) * + (position - self.eps * tf.exp(transformed) * momentum + translation)) + + return position, mask * scale + + def _compute_accept_prob(self, position, momentum, position_post, + momentum_post, sumlogdet): + """Compute the prob of accepting the proposed state given old state.""" + + old_hamil = self.hamiltonian(position, momentum) + new_hamil = self.hamiltonian(position_post, momentum_post) + + return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + + def _construct_time(self): + """Convert leapfrog step index into sinusoidal time.""" + + self.ts = [] + for i in range(self.n_steps): + t = tf.constant( + [ + np.cos(2 * np.pi * i / self.n_steps), + np.sin(2 * np.pi * i / self.n_steps) + ], + dtype=tf.float32) + self.ts.append(t[None, :]) + + def _get_time(self, i): + """Get sinusoidal time for i-th augmented leapfrog step.""" + + return self.ts[i] + + def _construct_masks(self): + """Construct different binary masks for different time steps.""" + + self.masks = [] + for _ in range(self.n_steps): + idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] + mask = np.zeros((self.x_dim,)) + mask[idx] = 1. + mask = tf.constant(mask, dtype=tf.float32) + self.masks.append(mask[None, :]) + + def _get_mask(self, i): + """Get binary masks for i-th augmented leapfrog step.""" + + m = self.masks[i] + return m, 1. - m + + def kinetic(self, v): + """Compute the kinetic energy.""" + + return .5 * tf.reduce_sum(v**2, axis=1) + + def hamiltonian(self, position, momentum): + """Compute the overall Hamiltonian.""" + + return self.potential(position) + self.kinetic(momentum) + + def grad_potential(self, position, check_numerics=True): + """Get gradient of potential function at current location.""" + + if not tf.executing_eagerly(): + # TODO(lxuechen): Change this to tfe.gradients_function when it works + grad = tf.gradients(self.potential(position), position)[0] + else: + grad = tfe.gradients_function(self.potential)(position)[0] + + if check_numerics: + return tf.check_numerics(grad, message="gradient of potential") + + return grad + + +# Examples of unnormalized log density/probabilities +def get_scg_energy_fn(): + """Get energy function for 2d strongly correlated Gaussian.""" + + # Avoid recreating tf constants on each invocation of gradients + mu = tf.constant([0., 0.]) + sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy + + +def get_multivariate_gaussian_energy_fn(x_dim=2): + """Get energy function for 2d strongly correlated Gaussian.""" + + mu = tf.random_normal(shape=[x_dim]) + # Lower triangularize and positive diagonal + l = tf.sigmoid( + tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) + # Exploit Cholesky decomposition + sigma = tf.matmul(l, tf.transpose(l)) + sigma *= 100. # Small covariance causes extreme numerical instability + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e33b4cae4c73388dfd78542c9907953f137ad710 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc + + +def get_default_hparams(): + return tf.contrib.training.HParams( + x_dim=2, + n_samples=200, + n_steps=10, + eps=.1, + n_iters=10, + learning_rate=.0003, + n_warmup_iters=3) + + +# Relevant functions for benchmarking +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + + with tf.GradientTape() as tape: + loss_val, x_out = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.variables) + + return loss_val, grads, x_out + + +def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): + """Warmup optimization to reduce overhead.""" + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + for _ in range(n_iters): + _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + +def fit(dynamics, + samples, + optimizer, + loss_fn=compute_loss, + n_iters=5000, + verbose=True, + logdir=None, + decay_lr=True): + """Fit L2HMC sampler with given log-likelihood function.""" + + if logdir: + summary_writer = tf.contrib.summary.create_file_writer(logdir) + + for i in range(n_iters): + loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + # TODO(lxuechen): Proper learning rate decay + if decay_lr: + grads = [grad * .96**(i // 1000) for grad in grads] + optimizer.apply_gradients(zip(grads, dynamics.variables)) + if verbose: + print("Iteration %d: loss %.4f" % (i, loss)) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) + + +class L2hmcTest(tf.test.TestCase): + """Unit tests for l2hmc in both eager and graph mode.""" + + def test_apply_transition(self): + """Testing function `Dynamics.apply_transition` in graph and eager mode.""" + + # Eager mode testing + hparams = get_default_hparams() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples) + + self.assertEqual(x_.shape, v_.shape) + self.assertEqual(x_out.shape, samples.shape) + self.assertEqual(x_.shape, x_out.shape) + self.assertEqual(x_accept_prob.shape, (hparams.n_samples,)) + + # Graph mode testing + with tf.Graph().as_default(): + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x) + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run( + [x_, v_, x_accept_prob, x_out], feed_dict={x: samples}) + + self.assertEqual(np_x_.shape, np_v_.shape) + self.assertEqual(samples.shape, np_x_out.shape) + self.assertEqual(np_x_.shape, np_x_out.shape) + self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,)) + + +class L2hmcBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for l2hmc.""" + + def _get_energy_fn(self): + """Get specific energy function according to FLAGS.""" + + if FLAGS.energy_fn == "scg": + energy_fn = l2hmc.get_scg_energy_fn() + elif FLAGS.energy_fn == "multivariate_gaussian": + energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) + else: + raise ValueError("No such energy function %s" % FLAGS.energy_fn) + + return energy_fn + + def benchmark_graph(self): + """Benchmark Graph performance.""" + + hparams = get_default_hparams() + tf.reset_default_graph() + with tf.Graph().as_default(): + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + loss, x_out = compute_loss(dynamics, x) + + global_step = tf.Variable(0., name="global_step", trainable=False) + learning_rate = tf.train.exponential_decay( + hparams.learning_rate, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + # Warmup to reduce initialization effect when timing + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + for _ in range(hparams.n_warmup_iters): + _, _, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + + # Training + start_time = time.time() + for i in range(hparams.n_iters): + samples, loss_np, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + print("Iteration %d: loss %.4f" % (i, loss_np)) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="graph_train_%s" % ("gpu" + if tf.test.is_gpu_available() else "cpu"), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + def benchmark_eager(self): + self._benchmark_eager() + + def benchmark_eager_defun(self): + self._benchmark_eager(defun=True) + + def _benchmark_eager(self, defun=False): + """Benchmark Eager performance.""" + + hparams = get_default_hparams() + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + loss_fn = tfe.defun(compute_loss) if defun else compute_loss + + # Warmup to reduce initialization effect when timing + warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) + + # Training + samples = tf.random_normal( + shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) + start_time = time.time() + fit(dynamics, + samples, + optimizer, + loss_fn=loss_fn, + n_iters=hparams.n_iters, + decay_lr=True) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else + "cpu", "_defun" if defun else ""), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + del dynamics + del loss_fn + + +if __name__ == "__main__": + tf.flags.DEFINE_string("energy_fn", "scg", + ("The energy function/unnormalized log-probability. " + "Either be `scg` or `multivariate_gaussian`")) + tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") + FLAGS = tf.flags.FLAGS + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..e230ad5e259df5b450897bd815e901e3934cd293 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.eager as tfe + + +class GenericNet(tf.keras.Model): + """Generic neural net with different initialization scale based on input. + + Args: + x_dim: dimensionality of observed data + factor: factor of variance scaling initializer + n_hidden: number of hidden units + """ + + def __init__(self, x_dim, factor, n_hidden=10): + super(GenericNet, self).__init__() + + self.v_layer = _custom_dense(n_hidden, 1. / 3.) + self.x_layer = _custom_dense(n_hidden, factor / 3.) + self.t_layer = _custom_dense(n_hidden, 1. / 3.) + self.h_layer = _custom_dense(n_hidden) + + # Scale + self.scale_layer = _custom_dense(x_dim, .001) + self.coeff_scale = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) + # Translation + self.translation_layer = _custom_dense(x_dim, factor=.001) + # Transformation + self.transformation_layer = _custom_dense(x_dim, .001) + self.coeff_transformation = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), + name='coeff_transformation', + trainable=True) + + def call(self, inputs): + v, x, t = inputs + h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t) + h = tf.nn.relu(h) + h = self.h_layer(h) + h = tf.nn.relu(h) + scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale) + translation = self.translation_layer(h) + transformation = ( + tf.nn.tanh(self.transformation_layer(h)) * tf.exp( + self.coeff_transformation)) + + return scale, translation, transformation + + +def _custom_dense(units, factor=1.): + """Custom dense layer with specified weight initialization.""" + + return tf.keras.layers.Dense( + units=units, + use_bias=True, + kernel_initializer=tf.contrib.layers.variance_scaling_initializer( + factor=factor * 2., mode='FAN_IN', uniform=False), + bias_initializer=tf.constant_initializer(0., dtype=tf.float32)) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 4e1380afb2e6e722de65c691d4fbf44621072e87..099b712fc06d1d3eb9ab4095f8db7283690bda76 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): mse = lambda xs, ys: mean_square_loss(model, xs, ys) loss_and_grads = tfe.implicit_value_and_gradients(mse) - tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= @@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if verbose: print("Iteration %d: loss = %s" % (i, loss.numpy())) - optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + optimizer.apply_gradients(grads) if logdir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + tf.contrib.summary.scalar("loss", loss, step=i) + tf.contrib.summary.scalar("step", i, step=i) def synthetic_dataset(w, b, noise_level, batch_size, num_batches): @@ -119,7 +119,7 @@ def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size, def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() # Ground-truth constants. true_w = [[-2.0], [4.0], [1.0]] true_b = [0.5] diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index e53234b51a7dccc11e548ac81a7ef070c628aa52..2bc2fc2aa9150a3181db612439d0c37c8e76d1e3 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -117,5 +117,5 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..34ce5e0cc349bfe71f2e6faad497e6c149754d14 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -0,0 +1,909 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nmt_with_attention.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", + "timestamp": 1527858391290 + }, + { + "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", + "timestamp": 1527776041613 + } + ], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "metadata": { + "id": "AOpGoE2T-YXS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Neural Machine Translation with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "CiwtNgENbx2g", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", + "\n", + "After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n", + "\n", + "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", + "\n", + "\"spanish-english\n", + "\n", + "Note: This example takes approximately 10 mintues to run on a single P100 GPU." + ] + }, + { + "metadata": { + "id": "tnxXKDjq3jEL", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow >= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import unicodedata\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "\n", + "print(tf.__version__)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wfodePkj3jEa", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Download and prepare the dataset\n", + "\n", + "We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n", + "\n", + "```\n", + "May I borrow this book?\t¿Puedo tomar prestado este libro?\n", + "```\n", + "\n", + "There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n", + "\n", + "1. Add a *start* and *end* token to each sentence.\n", + "2. Clean the sentences by removing special characters.\n", + "3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n", + "4. Pad each sentence to a maximum length." + ] + }, + { + "metadata": { + "id": "kRVATYOgJs1b", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Download the file\n", + "path_to_zip = tf.keras.utils.get_file(\n", + " 'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n", + " extract=True)\n", + "\n", + "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "rd0jw-eC3jEh", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Converts the unicode file to ascii\n", + "def unicode_to_ascii(s):\n", + " return ''.join(c for c in unicodedata.normalize('NFD', s)\n", + " if unicodedata.category(c) != 'Mn')\n", + "\n", + "\n", + "def preprocess_sentence(w):\n", + " w = unicode_to_ascii(w.lower().strip())\n", + " \n", + " # creating a space between a word and the punctuation following it\n", + " # eg: \"he is a boy.\" => \"he is a boy .\" \n", + " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", + " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", + " w = re.sub(r'[\" \"]+', \" \", w)\n", + " \n", + " # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n", + " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n", + " \n", + " w = w.rstrip().strip()\n", + " \n", + " # adding a start and an end token to the sentence\n", + " # so that the model know when to start and stop predicting.\n", + " w = ' ' + w + ' '\n", + " return w" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "OHn4Dct23jEm", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# 1. Remove the accents\n", + "# 2. Clean the sentences\n", + "# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n", + "def create_dataset(path, num_examples):\n", + " lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n", + " \n", + " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", + " \n", + " return word_pairs" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "9xbqO7Iie9bb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", + "# (e.g., 5 -> \"dad\") for each language,\n", + "class LanguageIndex():\n", + " def __init__(self, lang):\n", + " self.lang = lang\n", + " self.word2idx = {}\n", + " self.idx2word = {}\n", + " self.vocab = set()\n", + " \n", + " self.create_index()\n", + " \n", + " def create_index(self):\n", + " for phrase in self.lang:\n", + " self.vocab.update(phrase.split(' '))\n", + " \n", + " self.vocab = sorted(self.vocab)\n", + " \n", + " self.word2idx[''] = 0\n", + " for index, word in enumerate(self.vocab):\n", + " self.word2idx[word] = index + 1\n", + " \n", + " for word, index in self.word2idx.items():\n", + " self.idx2word[index] = word" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "eAY9k49G3jE_", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def max_length(tensor):\n", + " return max(len(t) for t in tensor)\n", + "\n", + "\n", + "def load_dataset(path, num_examples):\n", + " # creating cleaned input, output pairs\n", + " pairs = create_dataset(path, num_examples)\n", + "\n", + " # index language using the class defined above \n", + " inp_lang = LanguageIndex(sp for en, sp in pairs)\n", + " targ_lang = LanguageIndex(en for en, sp in pairs)\n", + " \n", + " # Vectorize the input and target languages\n", + " \n", + " # Spanish sentences\n", + " input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n", + " \n", + " # English sentences\n", + " target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n", + " \n", + " # Calculate max_length of input and output tensor\n", + " # Here, we'll set those to the longest sentence in the dataset\n", + " max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n", + " \n", + " # Padding the input and output tensor to the maximum length\n", + " input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n", + " maxlen=max_length_inp,\n", + " padding='post')\n", + " \n", + " target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n", + " maxlen=max_length_tar, \n", + " padding='post')\n", + " \n", + " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "GOi42V79Ydlr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Limit the size of the dataset to experiment faster (optional)\n", + "\n", + "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + ] + }, + { + "metadata": { + "id": "cnxC7q-j3jFD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Try experimenting with the size of that dataset\n", + "num_examples = 30000\n", + "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "4QILQkOs3jFG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Creating training and validation sets using an 80-20 split\n", + "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", + "\n", + "# Show length\n", + "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "rgCLkfv5uO3d", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Create a tf.data dataset" + ] + }, + { + "metadata": { + "id": "TqHsArVZ3jFS", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "BUFFER_SIZE = len(input_tensor_train)\n", + "BATCH_SIZE = 64\n", + "embedding_dim = 256\n", + "units = 1024\n", + "vocab_inp_size = len(inp_lang.word2idx)\n", + "vocab_tar_size = len(targ_lang.word2idx)\n", + "\n", + "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", + "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "TNfHIF71ulLu", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Write the encoder and decoder model\n", + "\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "\n", + "\"attention\n", + "\n", + "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", + "\n", + "Here are the equations that are implemented:\n", + "\n", + "\"attention\n", + "\"attention\n", + "\n", + "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", + "\n", + "* FC = Fully connected (dense) layer\n", + "* EO = Encoder output\n", + "* H = hidden state\n", + "* X = input to the decoder\n", + "\n", + "And the pseudo-code:\n", + "\n", + "* `score = FC(tanh(FC(EO) + FC(H)))`\n", + "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", + "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n", + "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n", + "* `merged vector = concat(embedding output, context vector)`\n", + "* This merged vector is then given to the GRU\n", + " \n", + "The shapes of all the vectors at each step have been specified in the comments in the code:" + ] + }, + { + "metadata": { + "id": "avyJ_4VIUoHb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", + " # the code automatically does that.\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "nZ2rI24i3jFg", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Encoder(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", + " super(Encoder, self).__init__()\n", + " self.batch_sz = batch_sz\n", + " self.enc_units = enc_units\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.enc_units)\n", + " \n", + " def call(self, x, hidden):\n", + " x = self.embedding(x)\n", + " output, state = self.gru(x, initial_state = hidden) \n", + " return output, state\n", + " \n", + " def initialize_hidden_state(self):\n", + " return tf.zeros((self.batch_sz, self.enc_units))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "yJ_B3mhW3jFk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Decoder(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", + " super(Decoder, self).__init__()\n", + " self.batch_sz = batch_sz\n", + " self.dec_units = dec_units\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.dec_units)\n", + " self.fc = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " # used for attention\n", + " self.W1 = tf.keras.layers.Dense(self.dec_units)\n", + " self.W2 = tf.keras.layers.Dense(self.dec_units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, x, hidden, enc_output):\n", + " # enc_output shape == (batch_size, max_length, hidden_size)\n", + " \n", + " # hidden shape == (batch_size, hidden size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n", + " # we are doing this to perform addition to calculate the score\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, max_length, hidden_size)\n", + " score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, max_length, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * enc_output\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # output shape == (batch_size * max_length, hidden_size)\n", + " output = tf.reshape(output, (-1, output.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc(output)\n", + " \n", + " return x, state, attention_weights\n", + " \n", + " def initialize_hidden_state(self):\n", + " return tf.zeros((self.batch_sz, self.dec_units))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "P5UY8wko3jFp", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", + "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "_ch_71VbIRfK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Define the optimizer and the loss function" + ] + }, + { + "metadata": { + "id": "WmTHr5iV3jFr", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "hpObfY22IddU", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Training\n", + "\n", + "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n", + "2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n", + "3. The decoder returns the *predictions* and the *decoder hidden state*.\n", + "4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "5. Use *teacher forcing* to decide the next input to the decoder.\n", + "6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n", + "7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate." + ] + }, + { + "metadata": { + "id": "ddefjBMa3jF0", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "EPOCHS = 10\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " \n", + " hidden = encoder.initialize_hidden_state()\n", + " total_loss = 0\n", + " \n", + " for (batch, (inp, targ)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " with tf.GradientTape() as tape:\n", + " enc_output, enc_hidden = encoder(inp, hidden)\n", + " \n", + " dec_hidden = enc_hidden\n", + " \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", + " \n", + " # Teacher forcing - feeding the target as the next input\n", + " for t in range(1, targ.shape[1]):\n", + " # passing enc_output to the decoder\n", + " predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n", + " \n", + " loss += loss_function(targ[:, t], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(targ[:, t], 1)\n", + " \n", + " total_loss += (loss / int(targ.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables)\n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + "\n", + " if batch % 100 == 0:\n", + " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", + " batch,\n", + " loss.numpy() / int(targ.shape[1])))\n", + " \n", + " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", + " total_loss/len(input_tensor)))\n", + " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "mU3Ce8M6I3rz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Translate\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the *end token*.\n", + "* And store the *attention weights for every time step*.\n", + "\n", + "Note: The encoder output is calculated only once for one input." + ] + }, + { + "metadata": { + "id": "EbQpyYs13jF_", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", + " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", + " \n", + " sentence = preprocess_sentence(sentence)\n", + "\n", + " inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n", + " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n", + " inputs = tf.convert_to_tensor(inputs)\n", + " \n", + " result = ''\n", + "\n", + " hidden = [tf.zeros((1, units))]\n", + " enc_out, enc_hidden = encoder(inputs, hidden)\n", + "\n", + " dec_hidden = enc_hidden\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", + "\n", + " for t in range(max_length_targ):\n", + " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", + " \n", + " # storing the attention weigths to plot later on\n", + " attention_weights = tf.reshape(attention_weights, (-1, ))\n", + " attention_plot[t] = attention_weights.numpy()\n", + "\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + "\n", + " result += targ_lang.idx2word[predicted_id] + ' '\n", + "\n", + " if targ_lang.idx2word[predicted_id] == '':\n", + " return result, sentence, attention_plot\n", + " \n", + " # the predicted ID is fed back into the model\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " return result, sentence, attention_plot" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "s5hQWlbN3jGF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# function for plotting the attention weights\n", + "def plot_attention(attention, sentence, predicted_sentence):\n", + " fig = plt.figure(figsize=(10,10))\n", + " ax = fig.add_subplot(1, 1, 1)\n", + " ax.matshow(attention, cmap='viridis')\n", + " \n", + " fontdict = {'fontsize': 14}\n", + " \n", + " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", + " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", + "\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "sl9zUHzg3jGI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", + " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", + " \n", + " print('Input: {}'.format(sentence))\n", + " print('Predicted translation: {}'.format(result))\n", + " \n", + " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", + " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "WrAM0FDomq3E", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "zSx2iM36EZQZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "A3LLCx3ZE0Ls", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "DUQVLVqUE1YW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# wrong translation\n", + "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "RTe5P5ioMJwN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Next steps\n", + "\n", + "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n", + "* Experiment with training on a larger dataset, or using more epochs\n" + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb deleted file mode 100644 index 9fd2d8d1254e32ae75ab5b085986c6e1c05e76f4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Eager Execution Tutorial: Basics", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg", - "timestamp": 1504118841551 - } - ] - } - }, - "cells": [ - { - "metadata": { - "id": "U9i2Dsh-ziXr", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Eager Execution Tutorial: Basics\n", - "\n", - "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n", - "\n", - "* Importing required packages\n", - "* Enabling eager execution\n", - "* Creating and using TensorFlow Tensors and Variables\n", - "* Using TensorFlow interactively\n", - "* Using GPUs with eager execution enabled\n", - "\n", - "This notebook does *not* cover modeling topics, such as gradients." - ] - }, - { - "metadata": { - "id": "z1JcS5iBXMRO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 1: Import Eager\n", - "\n", - "The key imports for eager execution are the following:" - ] - }, - { - "metadata": { - "id": "RlIWhyeLoYnG", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "# Import TensorFlow.\n", - "import tensorflow as tf\n", - "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "tfe = tf.contrib.eager" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "H9UySOPLXdaw", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 2: Enable eager execution\n", - "\n", - "All future TensorFlow calls will execute the\n", - "underlying TensorFlow ops immediately:" - ] - }, - { - "metadata": { - "id": "WPTUfGq6kJ5w", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "tf.enable_eager_execution()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "twBfWd5xyu_d", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 3: Interactively Use TensorFlow!\n", - "\n", - "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n", - "\n", - "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors." - ] - }, - { - "metadata": { - "id": "ngUe237Wt48W", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "print(tf.add(1, 2))\n", - "print(tf.add([1, 2], [3, 4]))\n", - "print(tf.square(5))\n", - "print(tf.reduce_sum([1, 2, 3]))\n", - "print(tf.encode_base64(\"hello world\"))\n", - "print(\"\")\n", - "\n", - "x = tf.constant(2)\n", - "y = tf.constant(3)\n", - "print(x * y + 1)\n", - "\n", - "# Most TensorFlow ops are directly usable with eager execution, giving\n", - "# results immediately.\n", - "print(tf.contrib.signal.hamming_window(x * y + 1))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "IDY4WsYRhP81", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Numpy arrays are supported, too:" - ] - }, - { - "metadata": { - "id": "lCUWzso6mbqR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "import numpy as np\n", - "\n", - "ones = np.ones([3, 3])\n", - "\n", - "print(\"numpy 3x3 matrix of 1s:\")\n", - "print(ones)\n", - "print(\"\")\n", - "\n", - "print(\"Multiplied by 42:\")\n", - "print(tf.multiply(ones, 42))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "PBNP8yTRfu_X", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 4: Define and Print TensorFlow Variables\n", - "\n", - "To define TensorFlow variables, use the `get_variable()` function as follows:" - ] - }, - { - "metadata": { - "id": "3Twf_Rw-gQFM", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "x = tfe.Variable(0.)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "45G7094TxsMb", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Printing TensorFlow Variables" - ] - }, - { - "metadata": { - "id": "UJBJeZ5XxuwA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "# This does NOT print the Variable's actual value:\n", - "print(\"Printing a TensorFlow Variable:\")\n", - "print(x)\n", - "print(\"\")\n", - "\n", - "\n", - "print(\"Printing a TensorFlow Variable's value as a numpy array:\")\n", - "print(x.numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "2njjWHcTpBEn", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Changing a TensorFlow Variable's value\n", - "\n", - "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:" - ] - }, - { - "metadata": { - "id": "v3wr6Erbo_hB", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "x.assign(42)\n", - "print(x)\n", - "\n", - "x.assign_add(3)\n", - "print(x)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "uhtynjHVpTB5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Use a Variable just like any other Tensor" - ] - }, - { - "metadata": { - "id": "7PbktdnHoehR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "print(x + 3)\n", - "\n", - "# This code will broadcast the value across the list of numbers:\n", - "print(x * [1, 2, 4])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "GVChqwlwy1SI", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 5: Debug Errors with Instant Feedback\n", - "\n", - "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n", - "\n", - "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n", - "one being legal and the other being illegal, leading to a runtime error that is\n", - "raised immediately." - ] - }, - { - "metadata": { - "id": "23ap04N0v4k0", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "vector = tf.constant([10.0, 20.0, 30.0, 40.0])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "FCUMsIYxxRRa", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n", - "# arguments) are within the bound of `vector`.\n", - "print(tf.slice(vector, [1], [3]))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "T8me2oCNxpFp", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "# The following does NOT work, because the value of `size` (the 3rd\n", - "# argument) causes the indices to go out of the bounds of `vector`. The\n", - "# error is raised immediately.\n", - "try:\n", - " print(tf.slice(vector, [1], [4]))\n", - "except tf.OpError as e:\n", - " print(\"Caught error: %s\" % e)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "irxJhAgar84v", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 6: Using the GPU\n", - "\n", - "You can explicitly place Tensors on the GPU by calling a Tensor's `.gpu()` method. The `.device` property tells you whether the Tensor is backed by CPU or GPU memory.\n", - "\n", - "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster." - ] - }, - { - "metadata": { - "id": "7J4N9baqaKCL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Create some Tensors\n", - "SIZE = 1000\n", - "tensor = tf.random_normal([SIZE, SIZE])\n", - "print(tensor.device)\n", - "\n", - "\n", - "if tf.test.is_gpu_available():\n", - " gpu_tensor = tensor.gpu()\n", - " cpu_tensor = tensor.cpu()\n", - "else:\n", - " print(\"GPU not available.\")\n", - " cpu_tensor = tensor" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "4E-2n7VbzY1n", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Time a CPU-based matrix multiplication\n", - "\n", - "print(\"Time to conduct matmul on CPU:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "vbSFW-T5zhZF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Time GPU-based matrix multiplications.\n", - "\n", - "if tf.test.is_gpu_available():\n", - " # First use of the GPU will be slow:\n", - " print(\"Time to conduct first matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)\n", - " print()\n", - "\n", - " # Subsequent uses are much faster:\n", - " print(\"Time to conduct second matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" - ], - "execution_count": 0, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb deleted file mode 100644 index 1e65b27bc8be8b05fefa38dffae7799b1e503bd3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ /dev/null @@ -1,583 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vDJ4XzMqodTy" - }, - "source": [ - "# Eager Execution: Working with Gradients\n", - "\n", - "This notebook demonstrates:\n", - "\n", - "* How to get gradients using TensorFlow's eager execution capabilities\n", - "* How to apply the gradients so you can update your variables" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "GQJysDM__Qb0" - }, - "source": [ - "# Setup: Import eager and enable eager execution.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "OiMPZStlibBv" - }, - "outputs": [], - "source": [ - "# Import TensorFlow.\n", - "import tensorflow as tf\n", - "\n", - "\n", - "# Enable eager execution.\n", - "tf.enable_eager_execution()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1CLWJl0QliB0" - }, - "source": [ - "# Fitting a Simple Linear Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Step 1: Synthesize some data\n", - "\n", - "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n", - "\n", - "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "rQsdCg9PfIL-" - }, - "outputs": [], - "source": [ - "# The constants we'll try to fit our variables to:\n", - "true_w = 3\n", - "true_b = 2\n", - "\n", - "NUM_EXAMPLES = 1000\n", - "\n", - "# Our inputs:\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", - "\n", - "# Our labels, with noise:\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", - "labels = inputs * true_w + true_b + noise" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 347 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 374, - "status": "ok", - "timestamp": 1525154227149, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "O4lsC4ckAcar", - "outputId": "f8becb3f-498b-4cb7-9ef3-608a68cb65d0" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAFKCAYAAAAnj5dkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xt8VPWdP/7X3M5MkpkkM8mEAAER\nQoICgUBALkUEQ7FucekDEeWL3VZXu121dler39pu1Vbb77b+2m1/3277qNXa2kUptGttt/tDEWqp\nyDWBiC6ES8slXDJJJpfJ3C+/P8JM5nLOmTOTmWQm83r+RebMnJyTAO/z+Xzen/dbFQqFQiAiIqKc\noR7rCyAiIqJYDM5EREQ5hsGZiIgoxzA4ExER5RgGZyIiohzD4ExERJRjtGN9AWE220DWzm02F8Nu\nd2bt/LmukO+/kO8d4P0X8v0X8r0D+XH/VqtJ8lhBjJy1Ws1YX8KYKuT7L+R7B3j/hXz/hXzvQP7f\nf0EEZyIionzC4ExERJRjGJyJiIhyDIMzERFRjmFwJiIiyjEMzkRERDmGwZmIiCjHMDgTERHlGAZn\nIiKiJDy+ADrtTnh8gVH5fjlTvpOIiCjXBIJBbNt9Gq3tNvT0e2Ap1aOxzopNq2uhUWdvfMvgTERE\nJGHb7tPYdfhi5Ovufk/k683NdVn7vpzWJiIiEuHxBdDabhM91treldUpbgZnIiIiEX0OD3r6PaLH\n7ANu9DnEj2UCgzMREZGIMqMellK96DGzyYAyo/ixTGBwJiIiEqHXadBYZxU91lhXCb0ue20pmRBG\nREQkYdPqWgBDa8z2ATfMJgMa6yojr2cLgzMREZEEjVqNzc112LByBvocHpQZ9VkdMYcxOBMRESWh\n12lQZS4ete/HNWciIsqa0a6sNV5w5ExERBk3VpW1xgsGZyIiyrixqqw1XvDxhYiIMmosK2uNFwzO\nRESUUWNZWWu8YHAmIqKMGsvKWuMFgzMREWXUWFbWGi+YEEZERBk3VpW1xgsGZyIiyrixqqw1XjA4\nExFR1ox2Za3xgmvORESUMawIlhmKRs7t7e34x3/8R3zmM5/Bli1bcPnyZTzxxBMIBAKwWq34zne+\nA0EQYj7zzW9+E8eOHYNKpcJTTz2FhoaGrNwAERGNPVYEy6ykPzGn04lvfOMbWLp0aeS1H/zgB9i8\neTO2bt2K6667Djt27Ij5zMGDB3Hu3Dls27YNzz//PJ5//vnMXzkREeWMcEWw7n4PQhiuCLZt9+mx\nvrS8lDQ4C4KAF198EVVVVZHXDhw4gFtvvRUAsGrVKrz//vsxn3n//ffR3NwMAJgxYwb6+vrgcDgy\ned1ERJQjlFQE43R3apJOa2u1Wmi1sW9zuVyRaeyKigrYbLG/lK6uLsyePTvytcVigc1mg9FozMQ1\nExFRCjy+QFYzppNVBHt150mcPG/ndHcKRpytHQqFMvIes7kYWm320uytVlPWzp0PCvn+C/neAd5/\nId+/xVKCl3/3IfYfvwxbrwvW8iIsmTMR962bDY0mc4HRVFYEq7kInXZXwjG9oMW+41ciX4enu4uL\nBDywfm7GrkFMPv/u0wrOxcXFcLvdMBgMuHr1asyUNwBUVVWhq6sr8nVnZyesVvFqMWF2uzOdS1HE\najXBZhvI2vlzXSHffyHfO8D7L+T7t1pN+L+/ao3pDNVpd+HNvWfhdHkz3hmqYUZFzPcKC4WCou9/\n79glfGLxlKztfc6H373cw0Naj07Lli3Dzp07AQBvvfUWVqxYEXN8+fLlkeMffvghqqqqOKVNRDSK\n3F7/qHaG2rS6Fs1NNagoNUCtAipKDVg+pxpur3hwZgMMeUlHzsePH8e//uu/oqOjA1qtFjt37sQL\nL7yA//2//ze2bduGSZMmYf369QCAf/qnf8K3vvUtLFiwALNnz8bdd98NlUqFp59+Ous3QkREw+z9\nyTtDZaI4SPR6dnxFMAA4cd6ObpHrYAMMeUmD85w5c/Dqq68mvP6zn/0s4bXvfe97kT8//vjjI7w0\nIiJKl7l0qDNUssCYLFlM6rjcvubooN9YZxWd7mYDDHks30lENA54fAHYel1AKASruRhWQSsbGLUa\nFbbuapcsGpKsqEh4X3NYONELQMx6tlgDjIYZFqxqnAyPL8AALYHBmYgojwWCQbz+zim898EVuL1D\n68gGQY3mxdfhzlumAxDvDJUsuMod37Byhsx6tg0bVs6IBN3oBhg9/W7sOnIRbae78MfWS9xWJYPB\nmYgoj23bfRrvHOmIec3tDeL3f/4L3G6faGeoZEVD1i2bJnv85nmTJNezu/s9eHXnSXz29lkxAVev\n02BPawf2tHTEvFdstE1sfEFElLfkgiwAtJy0RaaOq8zFkdFssqIhFzsdsscRCsFSKp3Mte/4lYSy\nnUqqiNEwBmciojE0krKWckEWAOwDHtHtSmVGvWRwNZsMqKkySh4XdBpYyorQWCdfuyI+4CZ7IOC2\nqlic1iYiGgPpdHGKz5wOB1mxjGwAMJv0otuV9DqNbLKYqViQPO72BvDG3rPYtLoWLrcf70VV/4oW\nv11L7lq5rSoRgzMR0RhQmu0MyAdyqSAKAAvqrZLZ0GJZ1OFkMQBYv+J6/LntciTJLFprexc2rJyB\nLWvr8T/netAz4E14T3zATfZAwKztWAzORESjLNn6a3S2MyAfyDetrkUoFIrL1tagefFU/O2y6ySv\nITqLWmwfs8Ppg0ckMAOxo+IF9VWKA26yBwIaxuBMRDTKlKy/hqeDlQTy/7WmHnfeUhuzz7lmUrmi\n2tLhZLF4SqehUwm4yR4IaBiDMxHRKEtl/VVpINfrNKixKu9hkKwymNJp6HQCrtQDAQ1jcCYiGmWp\nrL9mOpHK6fFh69uncOJcD+wDXlhK9WiorUTzwhpYSg0x3zuVUTEDbmYxOBMRjQGlgS9TiVThpLL4\nJK/ufg/2tAwVB6mIyxgXGxUDQHefO/JnTk9nB4MzEdEYSGU6eP2K6+F0+3HinB29Do9oIE82TR2f\nVCZGKmNcr9OgoswQyRjv7vfAIKgBqODxBliGMwsYnImIxlB4v7LSzk9LZ1fjnjV1KNZrJd/TWGfF\nw3c1Rs6TrJJYPCUZ49F9mlmGM/MYnImIRlH0CFerUaXc+em941fg8vrxd7fNgqlYkNxmVVwkYP3y\naQCSVxKLl0rGeDSxoE7pYXAmIsqAZNPKYiPcYoMOFzodkfco7fzU0t6F1vY/Y7K1BE63T/Q9+49f\nxicWT1FUSSye2aSH1xeI1OVWGtzjgzqlj8GZiGgElJbhFBvhSgXLZJ2fACAE4KJtUPJ4V68LZzv6\nMH1ymWxSmZhBtw9Pv3woci/rV0xXFNxZhjNzGJyJiEZASRnOVNd8ozs/KR3tJlAB33n9aCQDO7G3\nsx71U83QaVU4ftYO+4Abgk4DtzcQWU+OvhclwZ1lODOHwZmIKE1Ky3CmuuZbbtQDKhUaaitj+h+n\nIngtXyv+YUEsO9zjC8DW68K//eqoZC3tZ+9fhEAgiNZTXeh1eGEQhj7r9QVYhjMLGJyJiNKktHpX\nqmu+To8fT790EGaTgImWYlzucY74WsMPC2L0Og0ErRp2kQYWwNC9bH37FE6et6PP4YXZqMf8ukps\nWDkdDqeP+5yzgMGZiChNSqt3KV3zFbQqeP2hyOh1qNuTF3qdGh5fUPazydgH3Hh150mcPG8XXRuX\nuxdBp8G+qNaQdsdQ4RKNWsWtU1nC3eJERCny+ALotA+NZhvrrKLviV9/3bS6FsvmVMue1+sPib6u\nUqV5oVEEnRr7jl9Bd78HIQxPd7/+zikAww8Q4sSvq7W9Cx6feOcqGhmOnImIFBLLzJ4/sxKrF07G\nsVPdsmU4NWo17l1bj5Pn7Sknebm9QSyfU40j7TbRNWElfH7xkfd7H1zBnbfUQq/TYNPq2si6cp/D\nC0upAbOmluO9qFFzNG6dyh4GZyIihcQys9850oHmpho898BNSctwprqlKUytGirh+T/netIKztWW\nIlzpcYkec3uHksEmVhRj2+7TaDvTjT6HF+VGPRpqK7Bh5QyckHig4Nap7OG0NhGNC+Gp5mxNsybL\nzAYQad0oJRAMIhQKRTKdlQqGgE67SzJhKxmXxy//hlAo8uARnvYOryu/sfes4ql7yhyOnIkor8kV\nAckkpZnZcrbtPo13jqS+Ncpi0qOmypj2vue+QR/0WjU8IlPbBmGogpjcg8ez9y+K/DlZ60jKDAZn\nIsprckVAHr1nYca+z0j7KqdaiCRaSZEOGo0K9VPNMVnTSpUbBcybWYl3Wy8lHFs2txouj1/2wcPh\n9CnuoEWZweBMRHkr2VSz25tkOjcFep0GDTMqsEckwCmZ3k21EEm0C50OPP7DffB4AzAIGoRCoZS2\nVjXOrMTmNXXQadRoOWlDz4AHZSU6LKiz4p5bZ8IfkK5GFr8ljMlfoyPt4Lx9+3a8+eabka+PHz+O\n1tbWyNezZ8/GggULIl+/8sor0Gj4pEVEmZNsqtne78nICCQ8dd52phvAUIJWMDQ03bygXnoKPboZ\nRqqFSOKFE8FSTQibaCnG5jV10KjVQ9nYwRCOtneh1+FB25luaDSnsWl1rWSiGteVx0baf283btyI\njRs3AgAOHjyI//7v/445bjQa8eqrr47s6oiIZCSbajaX6jHQJ56lrEQ4uO48dCGmjGbw2rbfeTMr\nRYtwSK2Dz59Zmdaas5jwA4Icg6DGV/6uKdKAY9vu0zH3Eb0EEH7A4LpybsjItPYPf/hDvPDCC5k4\nFRGRYnJbkxrrKmEQtBgQ+Vwq7R27+z1QSxQBaTvdDc+qQMI5tr7dHjP9HQ6CtyyYhBpriWw3KRWk\nSn7EShaYAeBjDZNQrB/6b37A6cXh/+kUfV+4tCfXlXPHiINzW1sbJk6cCKs1NtXe6/XiscceQ0dH\nB9auXYvPfvazsucxm4uh1WbvL4LVasraufNBId9/Id87MP7v/+G7GlFcJGD/8cuw2V0wl+qxZM5E\nPLh+LoDY+w8Egnj5dx8OvbfXBWt5EZbMmYj71s2GRjO8s/TFNz6ICfhSgdA+4IZG0MFaWRI5/0/e\n+ADvHktclwaAAx9ehcsjPy2tJDADgLXcgAX1VXj70PlIk4toRXotHlg/F3pBi5++eRxvHzwHj1d8\nnTr+PmoUXkOuy+e/+yMOzjt27MCnPvWphNefeOIJ3HHHHVCpVNiyZQuampowd+5cyfPY7SMv7C7F\najXBZhN7fi4MhXz/hXzvQOHc/7qlUzEw6MFRXxfs/R4cOH4ZXq8fD9/ViJ6e4VHq1l3tMUG30+7C\nm3vPwunyxrR3fO+Ysqlns8mAgNcX+RnHnz9essAMDK1jz5tZibbT3TFtHOPNq63E8jnVeOvAedHz\neLx+/OWCHbuOXExa9CT+PsaDfPi7L/fwMOLgfODAAXz1q19NeP2ee+6J/HnJkiVob2+XDc5EROmS\nWkstLhKwfvk0ANlp7xidLDWSrVLRFtRbsbm5Dp5VQ1PvxmIBb+w9G7MWPG9mBUKhEP7tV0clR9qV\n5UUo0mvRclJ8KlvqPig3jKhC2NWrV1FSUgJBEGJeP3v2LB577DGEQiH4/X60tLRg5syZI7pQIiIx\nckFx//HLkYphSoqIAMNJZmLUqqEmFBWlBjQ31WDT6tpIZTKb3Zk0qMtVBrOY9JFzAsPblor1Wmxu\nrsNzD9yEbz64BF/7TBM8ngDeOdJxrWuVuCVzJg7tX05SVWzZnGomfeWgEY2cbTYbLBZL5Ouf/OQn\nWLRoERobG1FdXY0777wTarUaq1evRkNDw4gvlojGv2TJWvHkgq7N7sLZjj5Mn1yWkfaOK+dPwtrF\nU1Fm1EOrUSVkZAs6FTw+8bGsXqfGktlV+GPr5YRjy+dUY8vaetn71WpU2HXkIlpOdsoG3IprmeH3\nrZuNy1f7YTEJku+3lOpx79r6SDY35Y4RBec5c+bgpz/9aeTrBx98MPLnL33pSyM5NREVGLkynHLB\nQy7oqtTAC68fjZxr3sxK7BbZyiTW3hEQ31YUvpb49eVk+5c9viBuXVgDrUYje14p8ZXQxKgAPHpn\nA2qqTNBo1NDrNFhQXyX5uQV1Vk5n5yhWCCOinCBXhlNsL3GY3Eg3nMUcPtetCyejualGci9v9Kg9\nflsRAHT3uSN/Tmd9efeRDty7dlbS7UrxswceX0DR2rGl1ABrXAWvTatrEQyFsO+DK5HEMoOgwfK5\nnM7OZQzORDTmlCZrSYke6fb0u6GSKNBx9FQ3nnvgpoTgGAgGsXVXu+iovaLMkDCir59qTqsUZ9uZ\nHnh8AckymFKzB6saJyddOwbEE7s0ajW2rKnHxltqYbM7AZUK1vIijphzHIMzEY25kXZ80qjVkZHu\n2Y4+vPD6UdH39Qy4I2vQ0eeTG7UHAsGEgiL7jl+BRF0SWcnuRap4idcfkK0IJuhUWNEwSXYkrNdp\nUFOVv/t+Cw2DMxGNuZF2fArT6zSYPrlMeg0awHdePxpJmtq0uhb+QEhy1L637RK8EoU7lBYLiSZ1\nL0Mj91N496h48ZK2092yFcG8vhBUKhUTu8YR/iaJaMyF143FpLoHV+5c4QAXHpFu231adtTu8QbT\nCsJSpO4lvE9bKgD3ObwoNwriB69pbe+KbBsDALfXj067M+Y1yh8cORNRTshk44XwZ9rOdMPW64IK\n4lPCre1dWLds2oi6RSlRbhTQNKtK9F6UFC+xlBrQUFsRU2glXnjKPLxG3namGza7S3HWO+UWBmci\nygnR68YjbbwQPtfnNhTh4LEOfEdiDdo+4IbL45fM9s6EcqOAZ+9bDFOx+MhXSUWy6IeUd1vFR9jh\nKfN0s94pt/AxiohySjiTOZXAHK7SJTaFayrWoUKi4lc4oG1aXYtbF05GNgaWbq8fb+w9i8vdg6LX\nl6wi2c3zqrGqcTL8gRDuWlWLxTdMEH1vY10lAOktXvHT3pTbOHImorwltfXozlumY8cfz0amdvWC\neNSNXgMOBkOi3Z1SMVTeU4VA1NDW7R3K9t7TeikmES08xSy3T3tSZQk+/Isde49dgV7QAAjB7Q3C\nIKgBqOD1BWKm/7v73CPKeqfcweBMRHlLagr35PleXOh0RF53X8u41qiBwLUAbBA0CIVCCASDQxnb\np7pGdC2L6q24Z00dnvv5Ick9yVJTzGLr7cUGbdw9RCd7Dd3EsjnVuDeq7Gemst5p7DE4E1Fekkuk\n6rA5RF8PRI2M3d6h5hHBENBUZ0WvI3mRDznGEgFeXwB2BcVC4gurxK+3F+m1+Porh5Ke5+T53piv\n5Ubh7DyVXxiciSjjUm1ekY6efrdkhrXcnuB477Z2yGZBp3Ieh9MLs0yjiTD7gBu2XhcErTrmZxRe\nb+9U0OEKGPoZxE9VR2eqd/W6RpT1TmOHwZmIMibd5hXp2HVEOrtarppWvFQCebLzHDphg0bBbQo6\nDb63rRV2hw8Wk4AF9VUxPyO56eloekGTMFUdnal+5q/dWX1AouxhtjYRZUx4Dbi734MQYot9ZJLH\nF0Dbaek14urKsUt6CihIKnN7A7A7fACAngEvdh2+iNfeORXJOgcgWUhFKYOgTTnrnXIHR85ElBEj\nbV6RCrkpbQCwlhlwyebMyPfKJLUKQAgQi99/bO3A0XYb7ANeWEr1mDezErcunIyWk12wO8Tv1Xtt\n+YAZ2OMPR85ElBFKmldkytuHz0seU6uAY6d7Mva9MikoEZiBofaWPQPeyIzD7iMdUKlUeOa+RZKl\nO5mBPX4xOBNRRsgV05ALIh5fABc7B3DR5lBUJMPjC2D/h9K9jTO1hpwLWtu7IOg0aJpVJXqcGdjj\nF6e1iQjAyDOsU93GEwgG8do7p7Dvg8uRfbsGQYPlc6tx960zJRPIbL2umD2/41l4xiGTdccpPzA4\nExW4TGZYpxJEtu0+jd1HYrcwhfceq1QqbG6uE39gCOXe0DiV7PBUhGccMll3nPIDgzNRgctkowSl\nQcTjC6DlpPTUdGu7DYFAEG1nuhMeGKzmYhgEdWS0nQuyNZUeP+MQ3gdN4x/XnIkKWLIMa6WNEuIb\nTyRrXtHn8MgW6uju92BP6yXRLVl6nQY33Sje/GG8ELRqNDfVcNq6gHHkTFTAlGRYy43U0p0SL9Jr\nIWhU8AbEh5xS08QtJ20IBEM4fjY3s7GTUauGZuXNJj2cHr/o2rleq8a3/mEpypmFXdAYnIkKWCqN\nEsTWf1OdEo8O5lKBGZCeJu4Z8GSk1OZYWXzDBKxfcT3KjHr8+t0zoslzK+ZPYmAmBmeiQqYkw1pq\ndLx+xfUpFx2JD+ZSykp06Bv0JbyercSr0WAQNNiyth7F+qH/dpmBTXIYnIkKXLIgITU6drr9KU2J\ny61vR6soNWDuDDP+2Ho54Vi+BmYA+FjDxEhgBpQnz1FhYnAmKnByQUIuoJ44Z1c0JR6eDvf6Aoo6\nLTXWVSKYg9ul0lVRKj8iZgY2iWFwJiIA4kFCLmGs1+HB4hsmoPujqwnH6qeWIxAMYeuu9pjpcL3M\nFihTkRY3Xl+B25dch+d/cXjkN5QDyo0CvvaZJpiKxctvEklJOzgfOHAAjz76KGbOnAkAqKurw7/8\ny79Eju/btw/f/e53odFocPPNN+Ohhx4a+dUS0aiSSxgTdBq0X+wFMLwWrNepgJAK7x+/gpb2zphA\nnKz94YDLjwMfXR1KFvPnzh7mkegf9MLl8TM4U8pGNHJevHgxfvCDH4gee+655/DSSy9hwoQJ2LJl\nC9auXYvaWiY6EOWD6MxsqYQxtzcQ2QoUXgv2+EIAQteOiwdYg6BGKAR4fOLHMxGY9Tq15PnV6qEm\nE+kyCBrF5UPZmILSlZVp7QsXLqCsrAwTJ04EAKxcuRLvv/8+gzNRjhPLzJ4/sxKrF07GsVPdsA+4\nUW6U3qOrRCYre5mNAubXWdF2ujsmmW353Go8+zPxqfFQEBC0Knj9qa1rq1XAyvmToFar8M4RZdu5\nGmZYmORFaRlRcD59+jT+4R/+AX19fXj44YexfPlyAIDNZoPFYom8z2Kx4MKFCyO7UiLKOrHM7HeO\ndKC5qQbPPXDTUGKXP4inXzo4hlc5zHXtAeHhDXOgUathLS+CVqPC1rfbJbddhQD4ZfZYSwmGgOam\nKagyF0GlUqG1vQs9A27ZUt/NTVNS/j5EwAiC87Rp0/Dwww/jE5/4BC5cuIBPf/rTeOuttyAI6a2t\nmM3F0Gqz94RptZqydu58UMj3X8j3Dii/f6fLiz+3JW5fAoDWU134zLo5qJlUDrfXD6u5CJ12VyYv\nMy1ubwB7Wjqwp6UDVnMR5s6ohKBTY0/rJdnPpbsl670Pr+LzG+bh0XsWwu3140r3IL7+0/2w9boT\n3qtWD73/wfVzodGMTaVk/t3P3/tPOzhPmDABt99+OwBg6tSpqKysxNWrVzFlyhRUVVWhq6sr8t6r\nV6+iqkq8H2mY3e5M91KSslpNsNkGsnb+XFfI91/I9w4ov/9AMIiv/fSg5FR1d58bD39nN5pmVWHT\n6lo0zKhQVExkNNnsLuw+nN0Zuv0fXMa6pddFpqpLtGrMq60U/VkEg8Af9v0VXq8/5QYimcC/+7l/\n/3IPD2k/zr355pt46aWXAAxNY3d3d2PChKFi9DU1NXA4HLh48SL8fj/27NkTmfImotwQ3axi665T\nuNwj/4Dc6/BGmk9sWl2L5qYaVJQaoFYNJUllmirjZxy5ngEP+hyxWeebVtdiVeMkqCUuOJUGIkRh\naY+cV69ejccffxzvvPMOfD4fnnnmGfz+97+HyWTCmjVr8Mwzz+Cxxx4DANx+++24/vrrM3bRRJS+\n+KQvs0nAoMuv+PPh0pzRhUuMxTq8sfcvOHyiE70O6W5TqcjFMiRq1VDTjmgatRprF0/FHyWm0pU0\nECGKl3ZwNhqN+PGPfyx5fNGiRdi2bVu6pyeiLIlP+pJr3SgmOthEFy7Z3FyHdcum4emXD2YsQOea\nYAii+5ZTaSBCpAT7ORMVEKX1reWIdasKT4+bigUUG8Zv4cGKUr1ooA03EBETbiBClIrx+6+IiBLI\nleNUSq5bVZFei0td2UvuHGuNdVbJQMsuU5RJDM5EBaTMqIfZJIhOZet1ahiLdOgZ8KC8RI95MysQ\nCoVw7HQ3+hxeWEqTd6sCRhb4c9nK+ZNkAy27TFEmMTgTFRC9ToOSIvHgXGUuxlP3LoxJ8Gptt6HP\n4UW5UY+G2gpsWl0LjVqdkenxfPOJm6ZCo06+EsguU5QJDM5EBcTjC8Dp9okec7p9sNmdsJqL8et3\nz8SMiu0OD/a0dECjVmFzc11GpsfzidmoY1IXjSoGZ6ICIhdUu/s9+NrLh2AxCXB6xPfltrZ3Yf2K\n6fjDgXNQqSBbunI8MRbLT1FHNwrhVDZlAoMzUQGR2/ITJre1yj7gxvM/P5y0YMl4M+jyweMLJARe\nsaS4xjprZPqfKF3820NUQOS2/Cih1agKLjADQK8jsTIYMJwU193vQQhDsw/hKmpEI8HgTDTORO87\nFhNdelOVYo1Mf7odI/KcWCERuaQ4luykkeK0NlEOS2UtU2yKdfm8yVi3NDbLOLzlZ/2K6/HLnSdx\n4KNOxaUyg5lrxZxXxAqJyK3fs2QnjRSDM1EOSmctU2zf8Zt7z8Lp8op2RXpj71+w/6POrN1DPlOr\nhmp7W2QKibBkJ2UTgzNRDhILtOGvxQJtsinWDStnxIz8CnGfcipWzp+EtYunxsxYDDi9uNjpQE2V\nEaZiIbJ+L9YukiU7aaQYnIlyTKqBFkg+xWrrdUHQqiPBps/hkc3YLiRTqoxwuv0JJTfDMxRevx/P\n/6IFHTYHgqGhUfVkqxFf+fQCluykrGFwJsox6axlyk2xCjoN/u1XR2Ef8Eamx29fMhVq1VCXpZEQ\ntGp4/bm/EF1jLcGDd8zGntYOtJ3uTgik/kBIcm3/+V+04EKnI/J1MARc6HTg+V+04Nn7FrNkJ2UF\ngzNRjklnLVNuitXtDcDtHcocDk+PO93+EQdmAPjnTQ048FEn3j16KSPny4abGybi3tvqoVGrce/H\n6+FZlZhkp1FDNHlrwOlFh82R8DoAdNgcGHB6I1PcTP6iTOJWKqIck077wUAwiFAoBIMwfMwgqGEQ\nxP+Jnzhnh8UkiB5LxSt/OIlFCLMMAAAgAElEQVSb508as0phgjb5XjCNJvY94UCqZIR7sdMh+dAR\nDA0dJ8oGBmeiHBS9F1mtAipKDWhuqpFcy9y2+zTeOdIRGSEDgNsbhNsrPuXc6/DghussI77OK3YX\n/s8vWyDoUtwwnQHlJTo01FZCp5X/b2xP66W0i4LUVBmhlrg1tWroOFE2cFqbKAeF9yKvWzYtJkM4\nnscXgM3uTDnzWtBpcM+aOpzvdMSsp6bD4xubNefeQR8On1B231KJdMmYigVMthpFf0aTreK/E6JM\nYHAmyiHhoiPRLRvF9jlH74NOJ+s6FArhavcgBl3SdbTHk5EUBfnKpxdIZmsTZQuDM1EOiC86ohc0\nMVPU8fuc4/dBp8rjC+Ibvzgy4uvOFyMpCiJotXj2vsUJ+5yJsolrzkQ5YOvb7TENFKIDc7TW9i4M\nOL0sIJKiTBQFMRULuGGahYGZRgVHzkRjKBAMYuuuU3j36CVF7+/pd+Nip0NyH3Suq7YUodPuyvq2\nK/W1XtNWcxEaZlSwKAjlHQZnojHi8QXwy50n8d7xK4o/oxc0qKkyJu3JnItunj8RaxdNRZFei1f/\nvxM4cd4OlzeYkWIo8UIAHr97PhbPm4yBPldmT040ChiciUZZeH35yImrsDt8KX46hDf2nsWgO9XP\njb332i7jT0cvJ7yejVF0eYke0yeXwSBoMZD50xNlHdeciUZZOJkr9cA8tHd5T+ulhP3LGjWwsnEi\nKkpztxNSIMmOK71O+X9HS26sgtmokzw+n40nKM8xOBMl4fEF0Gl3wuMLiH6d6rmykcwVCAJqlVqy\nslg+ULJf2iBo0NxUg/s/eSMWzpog+p4pVUZsbp6Z6csjGlWc1iaSEL+9yWwSUFIkwOn2Ke6xHC+b\n3aBaT9owt9aSN80o4pmKdBB0atmfT4lBiw0rZ0CjVsd0hOrpd6PMKKBxZiU2r6lT/PsgylUMzkQS\n4vcS9wx40TMwXLQjWY/leIFgEDsPXchKAhQA9A56sfeY8uSyXNNYXwlBq5Hdv20f8ESKiYSrqLEj\nFI1HIwrO3/72t3HkyBH4/X587nOfw8c//vHIsdWrV6O6uhoazdA/lhdeeAETJohPQxHlmlSmn5WW\nhty2+zT2tHSM6LoErQpef462f7rm5saJOHuxHxdtg4o/o9WocO/H6wEAgWAI77Z2iD7AiBUTYUco\nGo/SDs779+/HqVOnsG3bNtjtdnzqU5+KCc4A8OKLL6KkpGTEF0k02uR6KsdTUhoyU2vNC+qtOH6m\nBw63f8TnyhatSoWnP7sIW99uR+upLvQ5vBB0atk15dJiHfyBELQaFTRqFXRa8fdnopgIUT5IOzgv\nWrQIDQ0NAIDS0lK4XC4EAoHISJkon8n1VI6npDSkXLBXqQBTkYB+p3yda4OggVarzunADADvHb+C\njatm4t61s3DX6qFa4V5fAF97+ZDkZ+wOL/ocHuw6clF0WnsoG30yi4lQwUg7OGs0GhQXD40UduzY\ngZtvvjkhMD/99NPo6OjAwoUL8dhjj0Glkm4rZzYXQ6vNXmC3Wk1ZO3c+KOT7T/fel8+bjDf3nlXw\nvkmomVQu+x5TWRGs5qHqWAnXV16EuuvK8WeRPcDRqiuKse+D3F9T9niD8KtUqCwrwmC3EyUmA2pM\nBljLDbD1ukU/Yy0vQs2kcrT96pjo8UAQMOh1qJ5QlvL18O9+4crn+x9xQtiuXbuwY8cOvPzyyzGv\nf+ELX8CKFStQVlaGhx56CDt37sRtt90meR673TnSS5FktZpgsxVuKYJCvv+R3Pu6pVPhdHnR2t4F\n+4Ab5UY9Sop0cLp9sA94YDYZ0FhXiXVLp8JmG4h0lJJKTGqYUSE6Kuwf9CQNzBMtxfjr5fz5Hf7i\nDx/hg9Ndkf3YBkGDynKD5PsbZlTg4qVe0YeXsPfbLmPd0utSmtbm3/3CvHcgP+5f7uFhRMF57969\n+PGPf4yf/vSnMJliv8n69esjf7755pvR3t4uG5yJco1GrcaGlTNwc8NEQKWCtbwIep0mIQgP1cdu\nl2zvGHbnLdNx8nxvpPVgWHxBkXhmowCvP/U91WPp0EedMV+7vQFc7BxEtaUI9gFPZD3ZIGiwfG41\nNq2uhT8QQrlRQK9DfHq/d9CTdttHonyTdnAeGBjAt7/9bbzyyisoLy9POPbFL34RP/rRjyAIAg4d\nOoS1a9eO+GKJRkv8HufogBufHRy/5Sp+i1U4mO88eB4XOh0pX8sN0yx4P4X627nsSo8LZpMejTPL\nsPam61BtKY6MhDVqoHFmJfa0ijcBsYyg7SNRvkk7OP/hD3+A3W7HF7/4xchrN910E+rr67FmzRrc\nfPPN2LRpE/R6PW688UaOmimvSAXcQDAU2fIDAANOLw6f6BQ7BVrbbQgEgmg7042efg9kUi4kLZ9T\njXvWzMTJ8/a8a3QhxT7gwf6POqFRq7FlbX3Msc1r6nC6o1/0IYaZ2lRIVKFQKCc2TWZzbSAf1h6y\nqZDvP5179/gC+OqL+0WDoVoFLLphAjavmYk39v4FR050ot+ZnSYUpmItHr2zAZOtppS7V+ULi0nA\ngvqqmCWAcBvNo+1d6B30wHJtbT+VSmxh/LtfmPcO5Mf9Z23NmSjfhaeci/RauDx+lBn1stuegiHg\nwEdXceCjq0nPrbrWUzhdA04/nvtFCwyCGotvqIJBUCddn843PQPehCprGrUad62qxarGyUAoBKu5\nmCNmKjgMzlSQwmvKLSc70TPgjZTUrCjVY870CpSVCOgdlN93nEym5qTc3iD+dOwKplQZ01qzzgfh\nKmtajUpyrZ/1sqmQMDhTQYpfUw5nT3f3e/DuUfGEpLE2MDg+1pzFhKusxRchSbV+OdF4wUdRKjjZ\natuYbb2D2VnbzgVmkwFFeq3k76W1vSutFp1E+YrBmcalcM9ltzex1GUqdbMzqdwoYGXjJOi16f2z\ny9VlV4Mw8v9GGusq4fL4JX8v4ZE1UaHgtDaNK/H7k63mIjTMqIhZs0ylbnaYkh7J1RVF6OxxiXZT\nUqmAL909H5ayIrSf68XlntQr4uXqwLGi1ICOLvn7MQgaeH0BmE16FBt0GHT50OsYrrIWLkIi9XtR\nUr+caDxhcKZxJX4tudPuSliz1Os0aKyzyvYNjmYx6THnegsOfHQVnmsBWqMeanPo8YWgAhAC4HL7\nJfs0h0LAsz8/BBVUst2Z8pHD5cOqBZPRdrob3f3uoZkBNeDzBSPBd/2K6XA4vZGqamKlTjVqSP5e\nuMeZCg2DM40bcmvJ8T2Xw92NWk7a0DPgiWRriykp0uFPbbG1rwNBYIKlCJe6nAh/rC/JmrDXFwKQ\nE2UFMqp/0Ie1i6bgrlW1kYALICH4FuuH/7uR6sEc/r2E65lHj6yJCgmDM40bcmvJ8T2XNWo1NjfX\nYcPKGZF9zg6XD7sOX0DbmZ5IYGiYYUHbmW7Rc15KMpVbKCylhkgQjg646dTAjv+9SDURIRrvGJxp\n3JBbS5Zas4wOKKZiAfeunRUz5WqzOyVrPcspLdZlrXJYrsnGlLPUyJqoUDBbm8aN8FqymPgAEs7m\njt+eEw7MxmIBv373DL63vS3l67CY9Nh068yUP5ePjEVaTjkTZQFHzjSuxK9ZVpYPZ2sD0t2m7rxl\nOnb88Sxa223o7vcoys6W4nB58dPffZSxe8plTrcfTrcfpmJhrC+FaFxhcKZxJX7Ncsa0Cgz0uSLH\npbpNnTzfG1MaM93APPTZ8Zf0JSUYAi52OnDDNMtYXwrRuMJpbRqXwmuWBmH4+VMum3u81qzONrUK\nqKkyjvVlEI07DM5UMMaqMthYaaqvhEFQlqhVbhSgVg01/qixlkCvG/6vwSBoYCwSn2SbbDVySpso\nCzitTQWjzKiH2SSgZ2Bk3abyhVarwavP3ob/OW2D1+fHd147KloAxSBo8Ox9iyMtM8NFQmx2J6BS\nwVpeBJUqhOd/0YIOmwPB0NCIebLViK98esEY3BnR+MfgTAVDr9Ng1nUW7Dt+ZawvZVS0n+8FANRY\njRhwehGSqrICQNBpYkbAep0GNVWxjeCfvW8xBpxeXOx0oKaKI2aibGJwpoKyec1MtLTb4PbmaKHq\nDOp1eNDV68L2XSfx52OX4Q2IB2ePNxBToEWOqVhg8hfRKOCaM41LUl2p9DoNKssMY3RVo8tsMuB3\ne89i95EO2ezzcIUvIsodHDmTLLEGBbkm+hq1GpVsV6ptu0/jom0w49dQXiKgpEiLrj53zjS2mD3d\njP0fJK9uJlbhKx9+70TjGYMziZIq1hHdenGsiV1jsUEXsy0quivVhpUzJLdSjVTvoBcurz9nArOx\nSIu2093odcgnvy2bUx1T4Ssffu9EhYDBmURJFesAhlsvjjWxa5Tq0dza3oWlN07I6laqXAnMAOBw\n+ZO/CYBOp4r5Oh9+70SFgI/ClCBZ68X4etRjQe4axXT3u/GDX7eNw4aNI/Nu62Vs230aQH783okK\nBYMzJVDSenGspVNQJFm/5ULVctIWWWPO9d87UaFgcKYE4daLYqRaL442uWuk1NgHPJHkr1z/vRMV\nCgZnSpBK68WxIneNwFDVK7UKqCiQbVMjYTbpI1nZuf57JyoUDM4katPqWjQ31aCi1HCt5rIBzU01\nOdW7d/2K6yVrR5cYtHjms4vw/X++BRUcYctaUG+NBN58+L0TFYK0s7W/+c1v4tixY1CpVHjqqafQ\n0NAQObZv3z5897vfhUajwc0334yHHnooIxdLoye+9WIu7nd1OH3wSFT6Cq+dlhn1mDXVjPcKpGQn\nANRPLcPJ831J32cQNFg2N3YrVT783okKQVrB+eDBgzh37hy2bduGM2fO4KmnnsK2bdsix5977jm8\n9NJLmDBhArZs2YK1a9eitpZP3vko3HpxrMgVwwivkYptnwoB+P6ONiw/1Y31N08vmOCsVgGf/cQN\n+MqL+xEQ2dmlF9T40j2NEDRqWM3FkoF3rH/vRIUureD8/vvvo7m5GQAwY8YM9PX1weFwwGg04sKF\nCygrK8PEiRMBACtXrsT777/P4EwpUVIMI7xGGr0vN1p3vwdv7j2LQx8WRmAGhjpFVZmLcUvjZLxz\npCPh+MfmTsT0iWVjcGVElIq0gnNXVxdmz54d+dpiscBms8FoNMJms8FiscQcu3DhQtJzms3F0Gqz\nN31mtZqSv2kcy7f7f/GND0SLYRQXCXhg/dzI6w/f1YjiIgHvf3AJtl636Lku9zizfr1jTa0GplWX\n4juPrIAgaPHIpgUoKdZj//HLsPW6YC0vwpI5E3HfutnQaAor1STf/u5nUiHfO5Df95+RCmGh0MhL\nO9jt2fsP1Go1wWYbyNr5c12u33/81LXT48NbB86Jvve9Y5fwicVTYqZjP7F4CqZXG/Fv29tG65Jz\nyv23z0JDbSVMxQL6+lyR19cvn4Z7b78BZ/7aHfnZ9vRkvq54Lsv1v/vZVMj3DuTH/cs9PKQVnKuq\nqtDV1RX5urOzE1arVfTY1atXUVVVlc63oXFOaura4fZJtnQMF8OoMhcnfL4QVZQa0HTDBMm1Y4Og\n5doxUR5Ka35r+fLl2LlzJwDgww8/RFVVFYxGIwCgpqYGDocDFy9ehN/vx549e7B8+fLMXTGNG+E6\nzt39HoQwPHXderJT8jPlJj28vgA8vkDC5wsR9x8TjU9pjZwXLFiA2bNn4+6774ZKpcLTTz+N3/zm\nNzCZTFizZg2eeeYZPPbYYwCA22+/Hddff31GL5ryn1wdZ49POtQ6nD48/fIhmE0CnJ7Cq/WsUSOS\nhW0Q1AiGQggEg+wYRTTOpL3m/Pjjj8d8PWvWrMifFy1aFLO1igqT3DaodGpjA4DXPxSZegbkWyGO\nFwZBA68vALPJgCK9JqYXtdsbxO4jHVCrVOwYRTTOsGUkZZySbVBye5QNgkZyzXm8U6uG9mhbTAY0\n1lVi/Yrp6HN4sPPQefz52GXRz7S2d2HDyhmc3iYaRxicKeOU9ASW26NcWW5AV687EqD1WjU8/uz3\nSv7nuxrQM+DBG386g95BZf2QM23l/ElYu3hqzGzDG3vP4k9HxQMzEJskR0TjA4MzZZTcWvKREzas\nWzYNpmIBACJlI1vbu2AfcMNsMqDYoMWFTkfsOf1BGAQ13N7EAG0QNCjSa2Af4TS3SgW88t8n0DPg\nRaa3AWvUgFajhscn/4BhEDTYcEstivXD/yyV9K1mxyii8YfBmTJKtieww4OnXz6IpllVkSnu6DrO\nRXotvv7KIYkzq0RftZYXYVJlMQ58JJ3hrUQoNLyOLVb2MhUq1dD5wl2xnti8AIJWjadfPoheh/RD\nhNcXgMPpjQnOStbmmbFNNP4wxZMyKlmf5V6HF7sOX8R/vH0y8lq4jrPL45cMRF5fANWWooTXL3Q6\ncPRUl8gnxk64Jk8wBNh63fiXn+7H7/b9FQvrpVtcAuIjYLmfp1oFrGqcxI5RROMQgzNlVLI+y2F/\nbL2MV986iUBweJhqLNZBL9ECUqdV42qPS/RYsuniseb2BrHr8EWEADQ31Ui2uWyYYUGfwwOPbzgZ\nTu7nubJxMu5dO4vbqIjGIU5rU0Z5fAGsapyMQDCElnYb+mSmcfe0dECjHt4G9Js/nZXM0s71AKzE\nsVPdeO6Bm7B+xfXY+vYpnDhnR6/Dg3KjHiVFOrSd6cYfWy8lZLeLrc031lVyxEw0jjE4U0ZEb5/q\n7vdA0Krg9Sev2xXeBgQA+z6QzkiWky9br6Kzqv/+kzdG9oHvPHQBe1qGO0jFZ7ezxzJR4eF8GGVE\ndClNAIoCMzAcsGx2p2g2thJL50xAjbUkrc+Opvg1Zb1OgzKjHm2nxdfMW9u7Eqa4q2R6MBPR+MHg\nTCOmZLuPlEjAUolnYyvh8wfh8ozNvuRUiGVVy2a3X3twIaLCw+BMafH4Aui0OyNTs+l2hQoHLGt5\nETRq8QCdLGy/13ZFtNKYEjoNUG4U0vqsHL1ODUE7fOUGQYPQtTrY0eSysbl/mahwcc2ZUiJWmrOh\nthJmk6Co3nV8ecropCadVoWAN3E6XC+oMb/Wiv0fXRU950g6UvkCkN17nK4qc3FMMRW3N4B3jnRA\nFVcHW6tRodigE3244P5losLF4EwpESvNuaelA1OqjIqC88caqnHTDdWoqTJGKoUBQ9O7UmvObm8Q\nK+dPwoGPruZ0a0iVauiho6G2AsdOiU/zRyfA9Tk82HnwfEJFNACYUmVkNjZRAWNwJsXk1padbh9W\nNU7C+x9eFc2cVquBSRUl+PAvduw9diVhu5CxWCebdf3j3x5XFJjD1blGm8WkxxfvmoeyEgEXOx0x\n2dfRevrd+OXOkzhx3o6efo/kUrvT7Yc/EMp4KVEiyg8MzqSYfPKSB2sXT8WGW2rx2tvtQ8FnwIOy\nEgGzppZDr9fi3dZLkffHbxd6Y+9fZLdD9Q36FF3jWARmAJhfV4k/HbsUme5Xq4YqhMXT6dR47/iV\nyNdS18tmFkSFjcGZFJNr8xhOXtLrNLj/2h5eW68LCIVQZtRL1sxube/CumXT0HJyZLWxR5v62gjd\nUjq0dh4KhWKm+6WCrldhMRUmgxEVNgZnUkyuzWN08lIgGMSv3z0TGUWWGQXJpCv7gBsXOx2K1qtz\nycrGyVg1fxKgUqGsRJB8+JAaQSfDZDCiwsbgTEmFt0uVGfWKSknGJ43JZUObTQZUmYvSDmIjoVYD\nQZmBrFoFlJYMPViEr89i0mN+XSVUAL6/oy3pw0cIQGmxgH6n/MNH/EicyWBEhY3BmSSJbZsKJ3FJ\nlZJMtSBJY10lXB7/qAdmYCgw11SVoMM2KDoNfcuCydh4S22knaXL40eZUY9fv3tG8cOHqViH/sHk\nswIrGydj7aIpLM1JRAAYnEmG2Lap6CQusWSlPodHtiCI2ahH36AHZpMBc6ab4XT78b3tbZm/eIVc\n7gD+n4eW4Ve7T+PDv3RjwBWA2ShgYVTP6fB9moqFlB8++gd9srMCFpMeC+qHs9aJiAAGZ5IgF4TC\ne3XjR3iBYBA7D12QDEYVpQZ87TNNcLh82HXkIt4/fjntetqZYh9ww+UJwFgsQNBpoXIFoJaoVAbI\nZ6wDQw8f9riSm1KBefmcamxZW8+RMhEl4KM6iVJa8zm6jOe23aexp6VDMhg11lXCVCxgT2sH9rR0\njHlgBoCyEj12HjofadoRwvAMwbbdp2PeGwgGsfPgecm9yRWlBjx17wLJcqBq1VAp0opSA5qbavCZ\n22cxMBORKI6cCUBs0le4W5LctiljsYCtu9qHM7JLBLi80s0nJltLsGl17YiaZGSD3eHBn4+Jt6qM\nnyHYtvs09kTt1Y7XWFeJQDAk2cM6BODxu+dj+uQyBmUiksXgXODkkr6ktk0V6TX4j7dO4v0Ph2td\n9yZJehp0+eDxBbD17VNpN6nIFqmRfnQhELmHCrVqKKFr0+pa+AMhyYcai8nAwExEijA457H40W46\n5JK+Nq2uxcnzvQm1ny/aBnHRNpjS9+lzeLH17VPYF1UdK9eVG/WRQiBy0/yhELB20RRo1Gpo1FC0\nF5yISA6Dcx6SG+2mkvGbLOlr3bJpcLqVlc1Mxlyqx4lzPSM6h0EYCmweXwAqZH9fdEmRLhJM5ab5\nLaWx1byU7AUnIpLD4JyHkm1xUipZ0tfFTkfafZrjOZw+eP3KE8AMggZeXwDma12emhfWwFJqiFz3\nHw6cw5+Oiq8VZ4rTPTQVr9dpFFdHAwCNWo3NzXWSe8GJiJJJKzj7/X585Stfwfnz5xEIBPDEE0+g\nqakp5j2zZ8/GggULIl+/8sor0Gj4H9RIpbPFSUqypK+aKqPkcTla9VBWcnQZaaWBWaUCbmmcjA0r\nZ8Dh9IoGtipzMdYumpr14Gwf8MQ0n0h1RKzXadi4gojSklZw/u1vf4uioiK89tprOHXqFL785S9j\nx44dMe8xGo149dVXM3KRNEzJFielASHZaFDQaVA/1ZzyOnEKA+QEEyxFuPfj9QCAYr30X09LqQEG\nQZ32dixBq4I/EJKdGo9vPpFsRJyJHAAiIiDN4HzHHXfgk5/8JADAYrGgt7c3oxdF0pR0hkqF2Ghw\n/swKBEMhfPXF/ejp90TWeuVaOmaKvd+DAac3UipTPshJFwvR69TwyHSA8vmTL1hLJXDFj4gzlQNA\nRBSWVnDW6XSRP//85z+PBOpoXq8Xjz32GDo6OrB27Vp89rOfTf8qKSKVtU8lxEaDv373DN6JOn84\nKC+bUw29To22Mz2wD7gh6DIftD2+IJ5+6SD6Br2yQa7P4YFH4vuqVEBTfVVM3+R4ZpMeKhVEH3LU\nKmDl/EmKE7gylQNARBSWNDhv374d27dvj3ntkUcewYoVK/Af//Ef+PDDD/HjH/844XNPPPEE7rjj\nDqhUKmzZsgVNTU2YO3eu5Pcxm4uh1WZvKtBqNWXt3KPt4bsaUVwkYP/xy+jqdaGyvAhL5kzEfetm\nQ6MRH6kpuf8aAG6vH21nukWPn+7oww+fWA0AuNLtRCAYwH/vO4cDx6+g15G5vcvhPdPhIFdcJOCB\n9bF/d0xlRbCai9BpdyV83lpehEfubkTFzpN4++A5uDyJQfxj8ycDAN7cezbh2G1Lp+HzG+Ypula5\nn1fbmW58bkMRDMLY5l2Op7/76Sjk+y/kewfy+/6T/q+xceNGbNy4MeH17du3Y/fu3fj3f//3mJF0\n2D333BP585IlS9De3i4bnO12p9JrTpnVaoLNNpC18482jy+AZTdW4dbGSTHTvz094nuPU7n/TrsT\nNpGABwBdvS60n+3CntYOtLbb0i4mkmp7yPeOXcInFk9JmBWYO92Cd450JLx/7nQLnA4P1i+fhs1r\n6/H/vt6KE+ftsA94Iklc65ZOBQA4Xd6EBK9PfWxaxn5eZ/7aPaZJYePt736qCvn+C/negfy4f7mH\nh7Qe6S9cuIDXX38dv/zlL6HXJ65xnj17Fj/84Q/xwgsvIBAIoKWlBbfddls634qiiK1tNsyoQHPT\nFFhKDRlJQkq2pr3r8AXZEpZKTLYaEwqbyOnpF090k4rv/mAQnXYnyox6WIsE3P/JGyWTtUa65SnT\nOQBERECawXn79u3o7e3Fgw8+GHntpZdewiuvvIJFixahsbER1dXVuPPOO6FWq7F69Wo0NDRk7KIL\nldja5p7WS9jTegkVGUpC0us0aKitxJ4WkRHpDEtMyc50LJ9TjXtvq8P2PWfw3gdXIuvVekENny8o\nOqLWC5qEIOfxBXBUYkvZ3qOX8afWy7CU6rF83mSsWzpVdlvTSLY8ZToHgIgIAFShkFib+dGXzemH\nfJjeAOS34nh8AXz1xf1Jp5Kbm2oSkpCU3n94ZN5yshM9A97I9HNFqR6zpprhDwZx4KPOpOdRqYZK\nWsazmPR4/sElkXvz+AKw9boQCASx52iH5L5lg6DB9x75WORzgWAQL/3+I+xXcC2A+M8kk4ZnNBL3\nP491tna+/N3PlkK+/0K+dyA/7j/j09qUWUq24iTrIxyWaiGSaPEj8/Ao1uHy4r3jV2Q2LsWqNhfj\nck9iDsGCemvMdel1GtRYjdi6q122oIj32kNLlbkYgWAQX3/lcErT4iP5mSjBimBElGnchJkDwkFR\nrp9weG0zmehey6mQqzzm8Q1F6WRTLAZBA4OgwZUeZ+TP4f7FqxZMxqrGyfD4YjOnlbSQjF673fp2\ne0qBGUj/Z5Kq8PQ4AzMRjRSD8xhLVo4zHMzCa5vJpJuEpHRkLkavU+OmG6vg9gbg9gYQAiJ/Xjqn\nGg0zLGg73YWvvngAX31xP7buakcgGFT8fcNrtx5fAK2nulK+PiZmEVG+4bT2GEulHGd0Na/ufrfo\nZ9JNQpLLOk5m2Zxqyb2+Le22mCIl8QU65L5vdJ9kYOhn1euQ7xstholZRJRvOHIeY3LT1VK1nZ97\n4CY8/8BNWLVgMipKDVCrhqaOm5tq0m5LqHRkDgwFTVXU92xumiL5gCFVPSw8KyD3fVfOn4R7P14f\nWXcv0mtRbhQUXWNYkXhq+nIAABFXSURBVF6D9Sump/QZIqKxxpHzGEtnK45ep8HEihLc+/F6eFZl\nrtlCfJ1tQacRDa4r50/C2sVTI9/T4wukPOruiZoVSNbtKTphLtWRs8cbgMPplW2iQUSUa/g/Vg5I\ntRVhtEy2JYzOOu7pd+Otw+dx4MOrkc5PBkGDJbOr0Nw0JeFhYNZUs2wt63gqADsPnsfmNXVJs53j\ns8hTUVlexPVmIso7DM45YKRbcTLdqlCv02BPawfebY3d3uT2BrD/w068e63Ax/yZlQgBOHaqC939\nHhgENQAVvL6A5Kg7LBgC9rRegkajjuxBFnvQkEuY0+vUKDFo0evwSn6/JXMmcr2ZiPIOg3MOSXUU\nnK1WhXIBMRwAu/s9CXWtwyPsJbMnoP28XVG3qmR7kOUS5nz+IL5413wIWjWMxQLe2Hs2YfbhvnWz\nJWuOExHlKgbnPJatVoUj2VYFACfP9cKucG04PiM9nnztaj2s5UWRwC42+yDVpYuIKJfxf648pXR/\ntNznO+1O0fcpLXgixZ5CwY9ke5D1Og2KDYldzwCg2KBLGHGzEAgRjQccOecpudFtd78bPf1uTKwo\nSTgWPxVebtRjfl0lNjfPjEyFy2WQK5FKS8hke5A9vgAGXeKj8EGXL7Idi4hoPOHIOU8lG93uOnxB\n9PX4UqF2hwd7Wjrw9VcOR6p2AUMZ5M1NNZF91AZBeQBUEpjVKmDVgslJM9L7HB7YB8SDc6/DMypl\nOYmIRhtHzjkqWQa2XGtHAGg705MwqpSbCr/Q6cDWt9tx79pZAGIzyMOdo/7Udhltp7sjCVfzZ1Zc\ny9Yefq2htgLHTtnQIxFQw0IhYO2iKUkT19gvmYgKEYNzjkklA7t5YY1kcBZLtEqW6PXe8SvYcEtt\npGBHIBjEr989E3MtDTMq0Nw0BZZSQyTwb7wl9kFCo1YlnRK3lIoH1vBDSZFeC5fHjzKjnv2Siajg\nMDjnmFQysC2lBlSkMKosM+pRbtRLJmx5fUG89nY77v/kjZLXEr83GUjcApZODXC5XtLzZ1Zi9cLJ\nMSN0pUVaiIjyEYNzDpGbdj58ohPrlk2DqXi4tnSqpT/1Og3m10lPhQPAifP2SAa3XDa43N7k+Epj\nuw5fQNuZHtnAKtVLOryfurmpBs89cBP7JRNRQWBwziFy0869Di+eefkQFs6KneJOtfTn5uaZOPFX\nOy73OEWP2weGk6yUdsuSEqkBvnaW7Bq6kp7O4QeCTJUqJSLKZQzOOSRZ20a7I3GKO9XSnxq1Gl/5\nu4V47P++B48vmHA8PB0eCIagF9SRql9i70mlbKhc9TMlRU+UPhAQEY0HDM45ROn+4tb2LqxbNi2S\nMKXXaVIq/Vms12HFvEmS0+EAsPXtdtHADADzZlYkJIqNpGyokl7SzMwmokLC4JxjwtPRh090SrZH\n7O534+mXD6LP4U07MIpNh8+bWYFQKISvvrhfMlAaBA2CwRB2tw6vW4+0bKiShxJmZhNRIWFwzjHh\naep1y6bhmZcPSWZWhwN3KoExfho6fjr81++eSTpq93gDOHaqW/RYskQxOeGHhZaTNvQMeGKytcMP\nH0REhYLBOUeZigUsnKW8hGZ8YIwOxIFAEFt3tYtOQ4enw5UkZQFAmVFAr8QDw0jWhePXzqP3OXPE\nTESFhsE5h8VPPZeVSO9RDgfGijJDQhGTMqMeZy/1R94rNtpW2omqcWYl2s50Z61iV/TaefS2MSKi\nQsLgnMPERpNff+WQbGAUKxwitX4cPdpOlpRlMemxoP7a2rbmNCt2ERFlEYNzHogeTcoVHQGkC4eI\niZ6GlkvKWj6nGlvW1kcCb6p7q4mIKDUMznkkEAzCHwhCr1XD4x/a5mQQNFg2txqbVteiu8+taGo6\nLH4aWi7oRmeCp7q3moiIUsPgnMOik7q0GhW+/sphXOh0xLzH7Q3A6fLjctegov3C0eKnoVMNuqns\nrSYiIuXSCs6/+c1v8P3vfx9Tp04FACxbtgyf//znY97z5ptv4uc//znUajXuuusubNy4ceRXWyDE\nOlMZ9Fp02AZF37//o6vY/9FVGAQ1KsuKACQG5ylVRjjdfkXT0Ay6RERjK+2R8+23344nn3xS9JjT\n6cQPf/hD7NixAzqdDnfeeSfWrFmD8vLytC+0kIgldYkF3HhubxAXbYMJgXj5vElYt3Qq/IEQp6GJ\niPJAVqa1jx07hrlz58JkMgEAFixYgJaWFqxevTob3y4vKK1D7fEF0HKyc0Tfa9Dlw9OfXRTZJ1wz\nqRw22wA0anBETESUB9IOzgcPHsT9998Pv9+PJ598EjfeeGPkWFdXFywWS+Rri8UCm00+i9hsLoZW\nm73RnNVqysp53V4/7P0emEv1MAiJP85AIIiXf/ch9h+/DFuvC9byIiyZMxH3rZsNjUad8N4f/Ooo\negbEy3YqZR/woKjEgOnXlURey9b954NCvneA91/I91/I9w7k9/0nDc7bt2/H9u3bY177m7/5Gzzy\nyCO45ZZb0NraiieffBK/+93vJM8RCoWSXojdLt7CMBOsVhNstoGMnlNsXVisxvXWXe0xU9Sddhfe\n3HsWTpc3odzm1l3t2K2wIpgcs0mPgNcXueds3H++KOR7B3j/hXz/hXzvQH7cv9zDQ9LgvHHjRtlk\nrsbGRvT09CAQCECjGRr5VlVVoaurK/Kezs5OzJ8/P5Vrznli68LxVbfkSmKKldtMtkc5vJbc0++G\nTqeGV6TlIwAsqLdyTZmIKI+l3t8PwIsvvojf//73AID29nZYLJZIYAaAefPm4YMPPkB/fz8GBwfR\n0tKCpqamzFxxDkgWdD2+AAD5kpjhAiBhycpnLptTja99pgnPPXATvvW5Jfjuwx/DrQsnwyAM/9wN\nggarF05mMRAiojyX1przunXr8KUvfQmvv/46/H4/nn/+eQDAT37yEyxatAiNjY147LHHcP/990Ol\nUuGhhx6KJIeNB0qCbpW5WHbfcXwBELn3VpTqce/aemjU6pikrv+1ph533lILW68LCIVgvVbpi4iI\n8ltawbm6uhqvvvpqwusPPvhg5M+33XYbbrvttvSvLIcpDbpyJTHjC4DIv1d6mlqv06DGakz3VoiI\nKAexQlgaUgm6qdShZs1qIiICGJzTpjSQplISkzWriYgIYHBOWzbrULN8JhFRYUsrW5uGhQNpvoxw\nPb4AOu3OSEY5ERHlHo6cC4TSoilERDT2GJwLhJKiKURElBvG7ZCJ07fD3F6/oqIpRESUG8bdyFls\n+nb5vMlYt3Rqzk/fKu1clSp7v7KiKURElBvGXXAWm76VajSRK7K9HmwuVV6pjIiIxl5uDyVTpLTm\nda4JP1B093sQwvB68LbdpzNyfoOgRWOdVfRYfNEUIiIae+MqOKfSaCJXjNYDxabVtWhuqkFFqQFq\nFVBRakBzUw2rjxER5aBxNa2dSqOJXKG0icZIsfoYEVH+GFcj53DNazG5On0bfqAQk40HinwrmkJE\nVIjGVXAGxKdv71gxPWenb/PxgYKIiLJrXE1rA+LTtzWTymGzDYz1pUliNyoiIoo27oJzWD41j8jm\nerDHF8DlrkEEfAGOwomI8sS4Dc75KJMPFDF7pwc8sJhYS5uIKF8wOI9TrKVNRJS/OIQah/K1GAsR\nEQ1hcB6H8rEYCxERDWNwHodGe+80ERFlFoPzOMS900RE+Y0JYRmWrbaPqeLeaSKi/MXgnCHZbvuY\nqui90xpBh4DXxxEzEVGe4LR2hmS77WO69DoNJlaWMDATEeURBucM4NYlIiLKJAbnDODWJSIiyqS0\n1px/9KMfYd++fQCAYDCIrq4u7Ny5M3L84sWLWLduHebMmQMAMJvN+MEPfpCBy81N+dhHmoiIclda\nwfnzn/88Pv/5zwMA/vM//xPd3d0J77n++uvx6quvjuzq8kR461J0ucwwbl0iIqJUjShb2+/347XX\nXsMvfvGLTF1P3uLWJSIiypQRBee33noLH/vYx2AwGBKOdXV14Qtf+AI6OzuxefNm3HHHHSP5Vjkv\nm20fiYiosKhCoVBI7g3bt2/H9u3bY1575JFHsGLFCtx///149tlnUVNTE3Pc4XBg586duOOOOzAw\nMICNGzfitddeQ1VVleT38fsD0GoZzIiIiJIGZylOpxMbN27Ef/3XfyV976OPPop77rkHS5YskXyP\nzTaQzmUoYrWasnr+XFfI91/I9w7w/gv5/gv53oH8uH+r1SR5LO2tVCdOnMD06dNFj+3fvx/f+ta3\nAAwF8RMnTuD6669P91sREREVlLSDs81mg8ViiXnt+eefx4ULF9DU1IS+vj5s2rQJn/70p/Hggw9i\nwoQJI75YIiKiQpD2tHamcVo7ewr5/gv53gHefyHffyHfO5Af95+VaW0iIiLKDgZnIiKiHMPgTERE\nlGMYnImIiHJMziSEERER0RCOnImIiHIMgzMREVGOYXAmIiLKMQzOREREOYbBmYiIKMcwOBMREeWY\nggjO3d3d+Pu//3vce++9uPvuu3Hs2LGxvqRR4/f78eSTT+Kee+7BXXfdhcOHD4/1JY26gwcPYunS\npdizZ89YX8qo+uY3v4lNmzbh7rvvRltb21hfzqhrb29Hc3MzfvnLX471pYy6b3/729i0aRM2bNiA\nt956a6wvZ1S5XC48+uij2LJlCzZu3Ji3/+61Y30Bo+HNN9/E3/7t32LdunU4ePAgvv/97+Pll18e\n68saFb/97W9RVFSE1157DadOncKXv/xl7NixY6wva9ScP38eP/vZz7BgwYKxvpRRdfDgQZw7dw7b\ntm3DmTNn8NRTT2Hbtm1jfVmjxul04hvf+AaWLl061pcy6vbv349Tp05h27ZtsNvt+NSnPoWPf/zj\nY31Zo2bPnj3/f3v3D5JaFIAB/BNvRtHfK9ewLVqKIlqaoqJoimgTWguChhqL4g7NRrQooZiDQ2Bo\nBEFDEVE0BOGoREtLiFEXScqSQHhDcHnCe5EP3j3q+X7TuWf6DlzOxz2IB/39/VhYWEA6ncb8/DzG\nx8dFxyqbFOU8NzdnjjOZjFTXV87MzGB6ehoAoKoqXl5eBCeylqZp8Pv90HVddBRLXV9fY3JyEgDQ\n3d2NXC6Ht7c3NDU1CU5mDYfDgVAohFAoJDqK5YaGhjAwMAAAaGlpwcfHB4rFIux2u+Bk1piamjLH\n1bzfS1HOwNf904uLi8jn84hEIqLjWKaurs4cRyIRs6hl0dDQIDqCEIZhoK+vz3xWVRXPz8/SlLOi\nKFAUaba3Ena7HY2NjQCAeDyO0dFRaYr5d7Ozs3h8fEQgEBAd5Z/U3Nsbi8UQi8VK5paXlzEyMoKD\ngwNcXl5ifX29Jo+1v1v73t4eUqlU1b6oP/Hd+mXHf+mVz9nZGeLxeE3udT8RjUZxe3uLlZUVHB0d\nwWaziY5UlporZ4/HA4/HUzJ3c3ODXC6H1tZWjI2NYXV1VVC6/+tPawe+Suv8/Bw7OzslX9K15m/r\nl5HL5YJhGObz09MTNE0TmIisdHV1hUAggN3dXTQ3N4uOY6lkMgmn0wm3243e3l4Ui0Vks1k4nU7R\n0coixa+1T09PcXh4CAC4u7uD2+0WnMg6Dw8PiEaj8Pv9qK+vFx2HLDI8PIyTkxMAQCqVgsvlkuZI\nW3avr6/Y3NxEMBhEW1ub6DiWSyQS5mmBYRh4f39He3u74FTlk+JWqmw2i7W1NeTzeXx+fkLXdQwO\nDoqOZYnt7W0cHx+js7PTnAuHw3A4HAJTWefi4gLhcBj39/dQVRWapklzzLe1tYVEIgGbzYaNjQ30\n9PSIjmSZZDIJr9eLdDoNRVHQ0dEBn88nRVnt7+/D5/Ohq6vLnPN6vSV7QC0rFArQdR2ZTAaFQgFL\nS0uYmJgQHatsUpQzERFRNZHiWJuIiKiasJyJiIgqDMuZiIiowrCciYiIKgzLmYiIqMKwnImIiCoM\ny5mIiKjCsJyJiIgqzC8iivHPF8qqogAAAABJRU5ErkJggg==\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f7a18dfb8d0\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Plot the Data (Optional)\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.scatter(inputs, labels)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "JaFHyAG9nDET" - }, - "source": [ - "## Step 2: Define our TensorFlow variables\n", - "\n", - "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 332, - "status": "ok", - "timestamp": 1525154229931, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "z9r-ZeyrXu3A", - "outputId": "e19a698e-5892-4fcd-80d3-1394605ee72c" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 48, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Create TensorFlow Variables using Keras's Dense layer.\n", - "\n", - "wb = tf.keras.layers.Dense(units=1, use_bias=True)\n", - "\n", - "# We can access the underlying TensorFlow variables using wb.variables.\n", - "# However, the variables won't exist until the dimensions of the input\n", - "# tensors are known. Once the dimensions of the input tensors are known,\n", - "# Keras can create and initialize the variables. Until then, Keras will\n", - "# report the variables as an empty list: [].\n", - "\n", - "wb.variables" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "docKLUaonYG_" - }, - "source": [ - "## Step 3: *Define the loss function*\n", - "\n", - "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "0_w8ZJSCtuY7" - }, - "outputs": [], - "source": [ - "def loss_fn(predictions, labels):\n", - " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n", - " return tf.reduce_mean(tf.square(predictions - labels))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 348, - "status": "ok", - "timestamp": 1525154234538, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "RkNbXoXkpjVH", - "outputId": "e4688f3c-e29f-416d-f541-6d81953b5660" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=1252, shape=(), dtype=float32, numpy=16.979801\u003e" - ] - }, - "execution_count": 50, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Test loss function (optional).\n", - "\n", - "loss_fn(wb(inputs), labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 418, - "status": "ok", - "timestamp": 1525154260083, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "K_7beXoHOU7t", - "outputId": "8f55c028-fe2b-4edb-ad68-a849afc60623" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w: -0.311619\n", - "b: 0.000000\n" - ] - } - ], - "source": [ - "# At this point, the variables exist, and can now be queried:\n", - "\n", - "w, b = wb.variables\n", - "print(\"w: %f\" % w.numpy())\n", - "print(\"b: %f\" % b.numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "JVDWpL9VYWdP" - }, - "source": [ - "## Step 4: Create an optimizer\n", - "\n", - "We'll use a `GradientDescentOptimizer` to fit our model." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "DudNEebMKDWN" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "YBeJYxY8YaiO" - }, - "source": [ - "### Step 5: Define a training step\n", - "\n", - "To fit model variables to the data we'll need to:\n", - "\n", - "1. Calculate the gradients of the loss with respect to the model variables.\n", - "2. Use `optimizer` to compute updates to the variable values based on those gradients.\n", - "\n", - "To calculate the gradients, we use the [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context manager\n", - "and its `gradient` function to compute gradients through computation conducted within its context:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "diDZfrMJM3OC" - }, - "outputs": [], - "source": [ - "def run_step(inputs, labels):\n", - " with tf.GradientTape() as g:\n", - " loss = loss_fn(wb(inputs), labels)\n", - " # Compute the partial derivatives of loss with respect to the variables\n", - " grads = g.gradient(loss, wb.variables)\n", - " optimizer.apply_gradients(zip(grads, wb.variables))\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1WWepgmJQOzc" - }, - "source": [ - "Repeatedly running the training step will nudge the variables towards the values that best fit the data (i.e., \"w\" will move closer to 3.0, while \"b\" will tend to 2.0):\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 380, - "status": "ok", - "timestamp": 1525154412590, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "ya5Qxz5XQlhU", - "outputId": "8dd47155-a6c1-44c5-c279-617c803f1723" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Values of w, b BEFORE applying gradients: 2.725763, 1.894334\n", - "Values of w, b AFTER applying gradients: 2.774932, 1.922555\n" - ] - } - ], - "source": [ - "w, b = wb.variables\n", - "print(\"Values of w, b BEFORE applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n", - "run_step(inputs, labels)\n", - "print(\"Values of w, b AFTER applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "61TgeLVlKEQp" - }, - "source": [ - "## Step 6: Create a training loop\n", - "\n", - "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 364 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 580, - "status": "ok", - "timestamp": 1525154278709, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "VukGe-huNaJ4", - "outputId": "c79c8e63-c781-451e-f74f-20815d8da49f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.9409681558609009, 1.3733772039413452, 1.7128530740737915, 1.9793939590454102, 2.188689708709717, 2.3530514240264893, 2.4821391105651855, 2.583533763885498, 2.6631851196289062, 2.7257626056671143]\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFKCAYAAABcq1WoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xd4U2X/BvD7ZLRpumlLS6EDgbKh\niIggU7aAgPhDRKsIUoYgiK++ioAguBARXmZBEARFUBGhiChIEQcqe+/RMlpGd9KRcX5/nDZtaFra\nkuY07f25rlw5zXmSfPMk5OY5Oec8giiKIoiIiMhhFHIXQEREVN0wfImIiByM4UtERORgDF8iIiIH\nY/gSERE5GMOXiIjIwVSOeJJbtzLs/pi+vlqkpOjt/rhkjf3sGOxnx2A/Owb7WRIQ4FnsOqcd+apU\nSrlLqBbYz47BfnYM9rNjsJ/vzWnDl4iIyFkxfImIiByM4UtERORgDF8iIiIHY/gSERE5GMOXiIjI\nwRi+REREDsbwJSIih/vxx61YtGi+3GXIhuFLRETkYA45vSQREZEtGzeux65dPwMAOnbsjOeeG45/\n/tmHFSuWwNVVA1/fGnjnndk4eHB/kdtUKueNMKesPDZ2C7p37wSNxkfuUoiInN6MGVOxdetmuz2e\nQiGgb98BmDFjdontbty4hgMH/sGKFV8AAKKjX0DXrt3x3XcbMH78q2jZshX27PkVaWmpNm/z8/O3\nW82O5nSbnTMy0jFixHN48cUX5S6FiIjuw9mzZ9G0aXOoVCqoVCo0b94S58+fRdeu3fHxxx/giy9W\noUGDhvDz87d5mzNzupGvp6cXOnTohF27duHkyRNo0qSp3CURETm1GTNm33OUWhYBAZ6lms1OEABR\nFC1/GwwGCIICvXv3Rdu27fDbb3H4739fxezZc2zeFhYWbreaHc3pRr4AEB09DgCwYsVSmSshIqLy\niohoiOPHj8FoNMJoNOLkyROIiGiI1as/g1KpwoABT6Jbt564fPmizducmdONfAGgR49eqFevHr79\ndgPefnsG/P2de/MDEVF1FBQUjFatHsKECdEwm0X07z8AQUG1EBgYhEmTxsHT0wuenp4YOvQ56PX6\nIrc5M0EsPOavIKXZ/FBW69d/jokTJ+LNN6di8uQ37P74JCnt5iO6P+xnx2A/Owb7WRIQ4FnsOqfc\n7AwAL774Ijw9vbBq1Qrk5ubKXQ4REVGpOW34enp6YtiwKNy8mYQfftgkdzlERESl5rThCwAvvTQa\nCoUCMTFL4ICt50RERHbh1OEbFhaO3r374ujRw/j7731yl0NERFQqTh2+ADB6tHTY0fLlS2SuhIiI\nqHScPnwfeaQ9mjdviR9/3Ir4+Ctyl0NERHRPTh++giAgOnoszGYzVq5cLnc5REQkk/Pnz1kGYe+8\n8xZycrLL/ViHDx9ESkqyvUorwunDFwAGDhyMgICa+PLLL5CZmSl3OUREJIM9e35FQkI8AGDmzA/g\n6qop92Nt27alQsPXKc9wdTdXV1e8+OJLmDPnfWzY8BVGjoyWuyQiIrqHYcMGY+3ajRBFEX36PIaF\nC5ehUaMmmDx5PN54420EBdWCyWTCnDnv4fr1azAajXjppTFo3boNtm+PxaZNG6FSqVG/fgQGDhyM\nH37YhD17foWvry+mT38LX3yxAZ9+Oge+vr44c+Y0UlNT8OyzL2Dbtq1IS0vFokXLIQjAzJlTkZWV\nhezsbLz66uvQ6TKxd28cLl26iNmz5+DMmZP4+ut1UCpVaNiwMSZMePW+X3uVCF8AeOGFkZg/fy5W\nrFiKF198CQpFlRjUExFVOPcZU+FqxykFoRDg3ncAdPeYrKFhw8a4ePECjEYDGjVqjOPHjyIiohGS\nk5MRFFQLAPDLLz/Bz88fb701HampqZg4cQzWrPkaX3+9DnPmzEdgYBC2bduCOnXqoG3bdujSpRua\nNGlm9TxKpQoLFizFzJlTcezYUSxYsASzZk3DwYP7ER5eF/36DUSnTl1w4MC/+PLLNXjvvY9Rv34E\nJk9+A15eXlizZiWWLfscLi4umDbtTRw9ehgtWkTeVxdVmfANCAjA4MFDsH79Ouza9TN69Ogtd0lE\nRFSCyMgHceLEMeTm5uCpp57Gnj270bLleURENLS0OX78KI4cOYSjRw8DAHJycmAwGNC9ey9MmfI6\nevXqg+7de5W4iblxY2n2Oz8/f8tMSL6+ftDpMlGjhh/WrPkM69evhcFggEZj/TiXLl1EUlIiJk8e\nDwDQ6TKRmJiIFi3u77VXmfAFgFGjxmL9+nWIiVnK8CUiKiXdjNn3HKWWRUCAJ3SlOLdzq1atsW7d\nauTkZKNfvwHYtm0rjh07ggcffMjSRqVS4/nnRxT5To+KehE9evRBXNxOvPLKWCxeXPwOt0ql0uay\nKIrYuPEr+PvXxLRps3D69EksWjTf6r5qtbSped68Rfd8PWVRpbbNNmvWHB06dMJvv+3GqVMn5S6H\niIhKEBoahqSkJGRm6qDVusPPzw9798ZZhW+TJs3w++97AAApKcmIiVkMs9mMmJjF8Pf3x9Chz6FZ\ns+ZITEyEIAgwmUxlqiEtLRW1a9cBAOzZsxtGoxEAoFAoYDKZEBoajsuXL1l2vlq5Mga3bt2879de\nqvA9e/YsunfvjnXr1gEAbty4gaioKAwbNgwTJ06sVBMbcK5fIiLn4evri6CgIABS0N64cQM1awZa\n1j/2WHe4uWkxZswIvPHGq2jRIhIKhQJarTtGj34REyeOhSAIaNAgAi1btsL8+R9j//5/Sv38vXv3\nxYYNX+LVV19G06bNcOfOHWzbtgWRkQ9i6tT/4vr1a5g48TX85z8TMXbsCKSlpcLfP+C+X/c9pxTU\n6/UYPXo0wsPD0bBhQzz33HN466230KlTJ/Tp0wfz5s1DUFAQhg0bVuxjVMTUUsVNWWUymdCu3YO4\nceM6Dh06xbl+7xOnBnMM9rNjsJ8dg/0sua8pBV1cXLBixQrUrFnTctvff/+Nbt26AQC6du2Kv/76\nyw5l2odSqcSoUWOQk5ODtWs/l7scIiKiIu4ZviqVqsjeX1lZWXBxcQEA+Pn54datWxVTXTk988xz\nnOuXiIgqrfve27k0U/n5+mqhUinv2a6sihvSBwR44qWXRuLTTz9FXNxPePbZZ+3+3NVJSZtOyH7Y\nz47BfnYM9nPJyhW+Wq0W2dnZ0Gg0SEpKstokbUtKir5cxZXkXr8pDBv2IhYsWIC5cz9Bjx79IQiC\n3WuoDvjbjWOwnx2D/ewY7GfJff3ma0v79u2xY8cOAMDPP/+Mjh07lq+yChQWFo5evR7H4cOH8M8/\nf8tdDhERkcU9w/f48eOIiorC999/jy+++AJRUVEYP348Nm/ejGHDhiE1NRUDBw50RK1lxrl+iYio\nMrrnoUb24MhDjQoTRRHdunXEyZPH8e+/RxESEmr3Oqo6bj5yDPazY7CfHcPe/RwXtwtdunSz2+M5\nit03OzsLzvVLROTcbty4jp07d8hdht1V6fAFgEGDnoK/fwDWrVvDuX6JiCqRYcMGw2QywWg0okeP\nTjh9Wjot8OTJ45GYeAMAMG/eRzh8+CA+/3wFVq6MwaxZ0zFu3EvYv/8fTJ36huWx+vaVRsaXLl3E\nK6+MwcSJY/HWW68hI6Nybumo8uGbP9dvenoaNmz4Su5yiIgqpRqtm9m8aAptNfQcN8pmG8/o4ZY2\nmrWrgfDwUj1n/pSC586dsUwpaDabraYUfOaZKERGPogXXxwFADAaDViy5LNip42dP/9jvP76FCxY\nsBRt2jyCTZs2lqs/KlqVD19AmutXOlPXUpjNZrnLISIiFEwpeOzYETz11NM4efIELlywnlLwbvnT\nAxbn5MkT+Oij2Rg/Pho7dvxomRChsqlSUwoWp2bNmnjyyf/D119/ybl+iYhsSD5w/J5tMpasuGeb\n7Kjh8Jw8AbDTlIJ3U6vVAFDk3A35sxFpNBosXBhT6c/tUC1GvoA01y8AxMRwtiMiosqgNFMK5k/t\ndzd3d3fcuXMbAHD+/Dno9dLJnOrXb4B9+/4EAOzcuaNMMxw5UrUJ3+bNW+DRRztyrl8iokrkXlMK\nhoXVxZkzp/G//31idb/69SOg0bhhzJgR2LHjRwQFBQMAJk78D9au/Rzjx0fjxx9jS9yELacqfZzv\n3bZv34YXXngGzz33AubNW2j3mqoiHhfpGOxnx2A/Owb7WVJtj/O9W8+evREWFo5vvvkat2/flrsc\nIiKqpqpV+HKuXyIiqgyqVfgC0ly/Hh6enOuXiIhkU+3C19PTC88+G4WkpERs2fK93OUQEVE1VO3C\nFwBGjhwNQRCwfPkSOGB/MyIiIivVMnzDw+uid+++OHz4EP79t3IeA0ZERFVXtQxfgHP9EhHJ6ccf\nt2LRovl2eSydLhP//LMPALB27WocP3603I+VmJiIkyfvfbav+1Vtw7ddu0fRrFkLxMb+gISEeLnL\nISKicjpz5rQlfKOihqNZsxblfqyDB//FqVMn7FVasarFuZ1tyZ/r95VXxmLVqhV4551ZcpdERFSt\n3LhxDf/5zyu4eTMJQ4YMQ79+A6zWf/fdRuzc+RMEQYGOHbvgmWeew9mzp/HJJx9BrVbDxcUFM2d+\ngHnz5kCv1yEkJBTHjx9Fly7dkJaWisOHDyI1NRWXLl1EdPRY7Ny5A5cvX8L06bPRtGkzLFw4DydP\nnkBubi4GDhyMDh06Y9Wq5VCpVAgMDELt2iH49NM5EAQBWq0WU6bMgKdn8SfOKItqG76ANNfvu+9O\nx7p1a/Daa/+Fh4eH3CURETncjBmu2LrVfnGgUAB9+7pixoycEtslJMRj1aovodNlYvjwYejb9wnL\nhAjXr19DXNwuLFmyEgAwduxIdO3aHT/+uBWDBj2F3r374sCBf5GcfAfDhkXh4sULGDDgSatNzgkJ\n8Viy5DNs3boZ69atxqpVX2L79q3YuXMH6tdvgKCgYEyYMBk5OdkYMmQg+vcfiD59+sHHxwcdOnTG\nxIlj8frrUxASEopNm77Bpk0b8cILI+3SR9U6fPPn+v344w+wceN6jBgxSu6SiIiqjRYtIqFSqeDt\n7QN3d3ekpaXBx8cHAHDq1AlcvZqACRNGAwD0eh0SE6+jQ4fOmDv3QyQkxKNbtx4ICwvHiRPHbD5+\no0ZNIAgC/Pz8Ua9eAyiVSvj6+kGnOwJXV1ekp6dhzJgRUKlUSE1NKXL//OkJAcBgMKBx4yZ2e+3V\nOnwBaa7fBQs+wYoVSzF8+MhiJ2gmIqqqZszIuecotSykczuX5vGsp/0rPAugSqVGu3aP4o033i5y\nr88++wJ//rkXs2fPwPjxk4p9dKVSaXNZFEUcOnQABw/ux6JF0mbmHj06Frl/RU5PWO2TJn+u3wsX\nzuPXX3+RuxwiomrjxImjMJlMSElJQVZWFry8vC3rGjZsjIMHDyA7OxuiKGL+/LnIycnGd99tQHp6\nGnr27IOnnx6Gs2dPQxAEm9MOliQtLRU1awZCpVLh99/3wGQyw2AwWE1hWJHTE1b7kS8gzfX79ddf\nIiZmCbp37yV3OURE1UJoaDimTXsT164lIDp6nNUIMygoCEOGPIOXXx4FhUKBTp26wNVVg9q1QzBt\n2pvw8PCAWq3GlCnvIDU1BcuWLURAQM1SP/dDD7XFl1+uwfjx0ejYsTPat++AuXM/QPfuPTF79gz4\n+Phi4sT/YM6c9/Dll2vg4uKKGTNm2+21V6spBUsyaFBf/PHHXvz2299o1Kix3R7X2XFqMMdgPzsG\n+9kx2M8STilYCtHR0kk3VqxYKnMlRERU1TF88xSe6/fOnTtyl0NERFUYwzdP/ly/2dnZnOuXiIgq\nFMO3EM71S0REjsDwLSR/rt/ExBvYunWz3OUQEVEVxfC9S/5cvzExiznXLxERVQiG71041y8RUcUr\nzZSCu3fvdFA1jsfwtYFz/RIRyW/dujVyl1BhGL42cK5fIqKKlz+l4PPPP43Y2B+s1n311Rc4f/4s\npkx5HQcP7scbb0zC+PHROH36FPr27WZpN3XqGzh4cD/0eh2mTn0DEyeOxfjx0Th//pyjX06ZMHxt\nyJ/r12w2Y9WqFXKXQ0RU4Vq3drd5WblSbWkzbpzGZpvoaI2lzdq1aoSHl+45ExLi8eGH87BwYQxW\nroyx2s9m2LDn4eHhgfff/xgAcOHCecybt6jYMxBu3Lgebdu2x4IFS/Haa29i0aJPy94JDsTwLcbA\ngYPh7x+AdevWIDMzU+5yiIiqHFtTChanfv0GcHFxKXb9sWNHsXnzdxg/PhqffPIhdLrK/b3NiRWK\nodFoMHz4SMyd+yHn+iWiKu/AAd092yxZkn3PNlFRBkyerMGtW6V51uKnFLybWq22ebvRaMxbr8Kr\nr76OZs1alOaJZceRbwleeGEkXFxcsGLFUpjNZrnLISKqUkqaUhAAzGbbh3sKgoDs7GxkZ2fj7Nkz\nAIAmTZrht9/iAACXLl3E11+vq9Da7xfDtwSBgYEYNOgpzvVLRFQB8qcUnDRpbJEpBQEgIqIhRo16\nvsj9Bg58CtHRL+D992eiYUPpN+Cnnnoa164lYNy4l/DRR7MRGfmgQ15DeXFKwXs4duwIunXriM6d\nu+Kbb3649x2qGE4N5hjsZ8dgPzsG+1nCKQXvQ/PmLdG+fQfs2bMbp0+fkrscIiKqAhi+pcC5fomI\nyJ4YvqXQq1cfhIZyrl8iIrIPhm8pSHP9jkZ2djbWrVstdzlEROTkGL6lNGxYFDw8PLFy5XIYDAa5\nyyEiIifG8C0lT08vDBv2HOf6JSKi+8bwLQPO9UtERPbA8C2DunUfQK9ej+PQoYPYv59z/RIRUfmU\nK3x1Oh3Gjx+PqKgoDB06FHv37rV3XZVWwVy/POyIiIjKp1zh+/3336Nu3bpYu3YtFixYgPfee8/e\ndVVa7dt3QNOmzREb+wOuXk2QuxwiInJC5QpfX19fpKamAgDS09Ph6+tr16IqM0EQMHr0OJhMJs71\nS0RE5VLuczuPHDkS8fHxSE9PR0xMDCIjI4ttazSaoFIpy11kZZOdnY2wsDDk5ubi6tWrcHd3l7sk\nIiJyIuWaz/eHH35AcHAwVq5cidOnT2PKlCnYtGlTse1TUvTlLrA4cp+4+/nnR2Du3A+xePFyvPji\nS7LVUdHk7ufqgv3sGOxnx2A/S+w+scLBgwfRoUMHAECjRo1w8+ZNmEym8lXnpDjXLxERlVe5wjcs\nLAxHjhwBAFy7dg3u7u5QKqvOZuXSyJ/r9/z5c9i9e6fc5RARkRMpV/g+/fTTuHbtGp577jm89tpr\nmDFjhp3Lcg7R0WMBADExS2SuhIiInEm5fvN1d3fHggUL7F2L08mf6zcu7lecPn0KjRo1lrskIiJy\nAjzD1X0qmOt3mcyVEBGRs2D43qeCuX7XIzmZc/0SEdG9MXzvU+G5fteuXS13OURE5AQYvnbAuX6J\niKgsGL52wLl+iYioLBi+dsK5fomIqLQYvnbCuX6JiKi0GL52xLl+iYioNBi+dsS5fomIqDQYvnbE\nuX6JiKg0GL52NnDgYPj7B2Dt2tXQ6XRyl0NERJUQw9fONBoNhg8fibS0VGzcuF7ucoiIqBJi+FYA\nzvVLREQlYfhWgMDAQAwcOBjnz59DXNwuucshIqJKhuFbQTjXLxERFYfhW0FatIhEu3aPYvfuXThz\n5rTc5RARUSXC8K1AnOuXiIhsYfhWoN69H0doaBjn+iUiIisM3wqkVCrx0kujkZWVhXXr1shdDhER\nVRIM3wo2bFgU3N09ONcvERFZMHwrmJeXN4YNew43blxHbOwPcpdDRESVAMPXAV56aQwEQcAHH8xC\nZmam3OUQEZHMGL4OULfuA3j55Ym4fPkSpk79r9zlEBGRzBi+DvLmm1PRokUkvvpqLbZu3Sx3OURE\nJCOGr4O4uLhg2bKVcHNzw+TJr+Datatyl0RERDJh+DpQ/foNMGvWh0hLS8X48aNhMpnkLomIiGTA\n8HWwqKjh6NOnH/74Yy8WL/6f3OUQEZEMGL4OJggC5s1biMDAIHz44SwcPnxQ7pKIiMjBGL4y8PPz\nw6JFMTAajRgzZiR0Op3cJRERkQMxfGXSuXNXjBv3Ci5evIBp096UuxwiInIghq+M3nprGpo1a4F1\n69Zg61ae/YqIqLpg+MrI1dXVcvjRa69NwPXr1+QuiYiIHIDhK7OIiIaYOfN9pKamYsKEMTCbzXKX\nREREFYzhWwm88MII9O79OPbu3YMlSxbKXQ4REVUwhm8lIB1+tAg1awbigw/exdGjh+UuiYiIKhDD\nt5Lw9/fHwoXLYDAYePgREVEVx/CtRLp27YbRo1/G+fPnMH36FLnLISKiCsLwrWSmTp2Bpk2bY+3a\nz/Hjj7Fyl0NERBWA4VvJ5B9+pNFoMHnyeCQm3pC7JCIisjOGbyXUsGEjzJjxHpKTk/Hyy6N5+BER\nURXD8K2kXnzxJfTs2Rt798Zh2bLFcpdDRER2xPCtpARBwKefLkZAQE28994MHDt2RO6SiIjIThi+\nlVhAQAAWLlxqOfxIr9fLXRIREdkBw7eSe+yxHoiOHotz587inXfelrscIiKyA4avE5g6dSYaN26K\nNWtW4qeffpS7HCIiuk/lDt8tW7bgiSeewJNPPom4uDg7lkR302g0WLZsJVxdXfHqqy8jKSlR7pKI\niOg+lCt8U1JSsHjxYnz11VdYtmwZdu3aZe+66C6NGzfBjBmzcefOHc5+RETk5MoVvn/99RfatWsH\nDw8P1KxZE7NmzbJ3XWTDiBHR6N69J+LifsXy5UvkLoeIiMqpXOF79epVZGdnY8yYMRg2bBj++usv\ne9dFNgiCgAULlsLfPwCzZ8/AsWNH5S6JiIjKQRBFUSzrnZYvX46DBw9i0aJFuH79Op5//nns3r0b\ngiDYbG80mqBSKe+7WJJs374djz/+OBo3boz9+/dDq9XKXRIREZWBqjx38vPzQ6tWraBSqRAaGgp3\nd3ckJyfDz8/PZvuUFPsfnxoQ4IlbtzLs/rjO4KGHOuCll0bjs89iMH78RHz00bwKe67q3M+OxH52\nDPazY7CfJQEBnsWuK9dm5w4dOmDfvn0wm81ISUmBXq+Hr69vuQuksps+fRYaN26Czz//DD//vF3u\ncoiIqAzKFb6BgYHo1asXhgwZglGjRmHq1KlQKHjIsCNpNBosXSodfjRx4jgkJSXJXRIREZVSuX7z\nLauK2PzAzRqSFSuW4u23/4uuXbth/frv7P6fIPazY7CfHYP97BjsZ4ndNztT5fHSS2Pw2GPdsXv3\nLnz22TK5yyEiolJg+Dq5gsOP/PHuu9Nx4sRxuUsiIqJ7YPhWAYGBgZg/fzFyc3MxduxIZGVlyV0S\nERGVgOFbRfTs2QcjRozC6dOn8O670+Quh4iISsDwrULeeWc2GjZshJUrl2Pnzh1yl0NERMVg+FYh\nbm5uWLZsFVxcXPDKK+Nw8+ZNuUsiIiIbGL5VTNOmzTBt2kzcvn0LEyeOhQOOJCMiojJi+FZBo0aN\nRZcuj2HXrl+wcmWM3OUQEdFdGL5VkEKhwMKFy+Dn54eZM6fh1KmTcpdERESFMHyrqMDAIHz66WLk\n5ORgzJgRyM7OlrskIiLKw/Ctwnr3fhzDh4/EqVMnMWvWdLnLISKiPAzfKm7GjPcQEdEQK1Ysw65d\nP8tdDhERgeFb5Wm1WixdutJy+NGtW7fkLomIqNpj+FYDzZu3wNtvz8CtWzcxadI4Hn5ERCQzhm81\nMXr0OHTu3BW//LIDq1atkLscIqJqjeFbTeQfflSjRg3MnDkVp0+fkrskIqJqi+FbjQQF1cKnny5G\ndnY2xowZycOPiIhkwvCtZvr06Yvnnx+BkyeP4733ZspdDhFRtcTwrYZmznwP9es3QEzMYuzevUvu\ncoiIqh2GbzXk7u6OZctWQq1WY8KEMbh9+7bcJRERVSsM32qqRYtIvPXWdNy8mYRXX32Zhx8RETkQ\nw7caGzduAjp27IIdO7Zj9eqVcpdDRFRtMHyrMYVCgUWLlsHX1xfvvDMFZ8+ekbskIqJqgeFbzdWq\nFYx58xYhOzsbo0ePQE5OjtwlERFVeQxfQt++/REVNRwnThzD+++/K3c5RERVHsOXAADvvvsB6tWr\nj6VLFyIu7le5yyEiqtIYvgSg4PAjlUqFCRPG4M6dO3KXRERUZTF8yaJly1Z4881pSEpKxKuvjufh\nR0REFYThS1bGj5+IDh064aeftuGLLz6XuxwioiqJ4UtWpMOPYuDj44Pp09/C6dOn5S6JiKjKYfhS\nEcHBtfHJJwuRlZWF/v374+LFC3KXRERUpTB8yab+/Qdg8uTXcf78eTz+eDf8/fc+uUsiIqoyGL5U\nrDffnIbly5cjLS0Ngwf3w/fffyt3SUREVQLDl0o0atQorF//HVxdNRg9egQ+/fRj7gVNRHSfGL50\nT126PIbY2J9Rp04IPvhgFiZNehm5ublyl0VE5LQYvlQqjRs3wfbtuxAZ2Qrr16/DM88MRlpaqtxl\nERE5JYYvlVpgYBC+//5H9O7dF3v37kHfvj1w5cplucsiInI6DF8qE3d3d3z++TqMGTMeZ8+eQZ8+\nj+HAgX/lLouIyKkwfKnMlEol3n33fXz44SdITk7GoEF9sXXrD3KXRUTkNBi+VG4jRozCunUboFSq\nMHJkFBYtWsA9oYmISoHhS/ele/de2LLlJ9SqFYx3352G//xnEgwGg9xlERFVagxfum/Nm7fATz/9\nimbNWmDt2s/x7LP/h/T0NLnLIiKqtBi+ZBe1agVjy5af0KNHL8TF/Yr+/Xvh6tUEucsiIqqUGL5k\nNx4eHlizZj1GjozGqVMn0bv3Yzhy5JDcZRERVToMX7IrlUqFDz6Yi9mzP8StWzcxYEAfbN++Te6y\niIgqFYYvVYjo6HFYvforAMDw4cMQE7OYe0ITEeW5r/DNzs5G9+7dsWnTJnvVQ1VInz598cMP2xEQ\nUBPTpr2FKVNeh9FolLssIiLZ3Vf4Ll26FN7e3vaqhaqgli1b4aeffkXjxk2xcuVyvPDCM8jMzJS7\nLCIiWZU7fC9cuIDz58+jS5cudiyHqqI6dUIQG7sDXbo8hl9+2YEnnuiNGzeuy10WEZFsBLGcP8RF\nR0dj2rRp2Lx5M2rXro0nn3xblRT0AAAgAElEQVSy2LZGowkqlbLcRVLVYDAYMH78eCxfvhy1a9dG\nbGwsIiMj5S6LiMjhVOW50+bNmxEZGYmQkJBStU9J0ZfnaUoUEOCJW7cy7P64ZM3e/Txr1seoVSsU\nM2dOxaOPdsBnn61G9+697Pb4zoqfZ8dgPzsG+1kSEOBZ7LpyhW9cXBwSEhIQFxeHxMREuLi4ICgo\nCO3bty93kVQ9CIKAl19+BaGhYXj55VF47rmn8f77H2PEiFFyl0ZE5DDlCt/58+dblhcuXIjatWsz\neKlM+vcfgODgYERFDcWbb76GS5cuYsaM2VAq+fMEEVV9PM6XZNO6dRts374LERENEROzGC+++Bx0\nOp3cZRERVbj7Dt8JEyaUuLMVUUnCwsKxbdsv6NixM376aRsGDnwcSUmJcpdFRFShOPIl2Xl7+2D9\n+u/wzDPP4ciRQ+jTpxtOnTopd1lERBWG4UuVgouLC+bPX4wpU6bj6tUE9OvXE7t375K7LCKiCsHw\npUpDEARMmvQfxMSsQm5uDoYNewpr166WuywiIrtj+FKlM2jQU/j2263w9vbGa6+9glmz3oHZbJa7\nLCIiu2H4UqXUtu0j+PHHXahXrz4WLvwUo0YNR1ZWltxlERHZBcOXKq0HHqiHH3/ciXbtHsXWrZvx\n5JN9cevWLbnLIiK6bwxfqtR8fWtg48bNeOqpp3HgwH706dMNZ8+ekbssIqL7wvClSs/V1RWLFy/H\n66+/hfj4y+jbtwd+//03ucsiIio3hi85BUEQ8Prrb2HRohjo9ToMGTIQX3/9pdxlERGVC8OXnMqQ\nIc/gm29+gIeHB155ZSw+/HAWyjkrJhGRbBi+5HTat++AH3/chbCwcMyb9zHGjh2J7OxsucsiIio1\nhi85pfr1G2D79l/Rpk1bbNr0Lf7v/wbgzp07cpdFRFQqDF9yWv7+/vjuu60YOPBJ/P33X3j88W64\nePG83GUREd0Tw5ecmkajwbJlqzBp0n9w6dJF9OnTjXtCE1Glx/Alp6dQKDBlynTMn78YGRkZePLJ\nfnjhhWE4ffqU3KUREdnE8KUqY9iwKGzZ8hPatGmL7dtj0aVLO0yYMAbx8VfkLo2IyArDl6qUhx56\nGLGxP2Pdug1o1KgJNmz4Cu3aPYgpU17HzZs35S6PiAgAw5eqIEEQ0LNnH/z66+9YuvQzBAfXxmef\nxeDhh1vigw/eRVpaqtwlElE1x/ClKkuhUGDw4CH4888DmDPnU3h6euLTT+eiTZsWWLhwPvR6vdwl\nElE1xfClKk+tVmP48JH4++/DmDp1JgBg1qzpaNs2EqtXr4TBYJC5QiKqbgTRAefmu3Urw+6PGdCm\nOUzmoqXrx72C7JHRAADPcaOg/vuvIm0MrR9CxvLVAADN2tXQzp9r8zmS/zoIuLhAee4svIc+abNN\nxryFMHTuCgDw6dUFitu3i7TJHvIM9P99GwDg/s7bcI39oUgbU2gY0r7fBgBw2b4NHlP/a/P5Urfu\ngDm4NoTUFPh262izjW7KdOQMHgIA8Hr2/6CysddvbtfuyJw7HwDgtnA+3FZ/VqSNqNVCdfoUbt3K\ngGr/P/AaPcLm86WvWgtjy1YAAN+2kRCMxiJtsqLHImv0ywAAj0kvw2XvniJtjM1bIn21dL5m16+/\nhPvHH9h8vuQ9+wAPDyguX4LP4P4222TOmYfcbj0BAD79ekJx47plndlsRkZ6OlZl6fG60Yjw8Lr4\nrmEjtDxxHBAEq8cx1wpGauzPAACXXT/D443JNp8v9butMIfXBTIzUaPzIzbb6F5/CzlDnwUAeA1/\nFqpjRyzrlAoBJrOI3I6dkTl/MQDALWYx3JYvLfI4okqFlL8PAwBURw7Ba0SUzedLj1kF40MPAwB8\nOz4MwcZIP2v4S8iaMAkA4PGfSXDZvbNIG2Ojxkj/8hsAgOt3G+H+/rs2ny9l116IPr5QXL8Gn/69\nbLbJnP0Rcvv0BQB4D+oLpY2d4XL6DYBu5nsAAO1H70GzcX2RNmZ/f6TuiAMAqPfshufkCTafL+3r\nTTA1iAByc1Gj3YOWfi5MP+k/yI4aDgDwjB4O9YH9RR7H0LYdMpasAABoVi6Hdsn/bD5f8oHjAADl\nyRPwjnraZpuMRTEwtHsUAODb9VEI6WlF2mQ/+zz0k98AALhPeR2uO7YXaWOqVx9pGzcDAFy2bobH\njKk2ny9l+68Qa9aEcPMmfPs8ZrNN5ozZyO0/EADgPWQglBeKHi+f06sPdO9/DADQzpsDzZdfFGkj\nenkjZfcfCAjwROqWn+A5frTN50tbuwGmJk0BADVaN7PZRs7vcnsJCPAsdp3Krs9E5AQUCgW8fXzw\nwpBncBoivvjic+y4fAmBajW8vX3g5uYmd4lEVMU578g3wLNCHpesVYd+vnLlMj7++AN8883XEEUR\nDz/8CKZOnYFHHmnvsBqqQz9XBuxnx2A/S0oa+fI3X6r2wsLCsWhRDPbs2Yc+ffrhn3/24YknemPo\n0CdxrNCmYSIie2H4EuVp1Kgx1qz5Ctu370KHDp3w66870a1bR0RHD+c5o4nIrhi+RHdp3boNvvtu\nKzZu3IzIyFbYvHkTHn20DV577RVcv35N7vKIqApg+BLZIAgCunR5DDt2xGHlyrV44IF6WLt2Ndq2\njcQ777yN5GROX0hE5cfwJSqBIAjo338A9uzZhwULliAgoCaWLl2Ihx5qgblzP0RmJncqIXJqZjOg\n0wGZmQ59Wu7tTCViP1vLycnBmjUrMX/+XNy+fRv+/v6YOPE1vPDCSGg0mnI/LvvZMdjPjnHf/SyK\ngMEAITsLQlYWoNdDyM6GkKWHkJUFITsL0GdJfxe6HdlZEPRZBW2yCrXRF2pT+PbsbOkpFQqkffUt\nDI91t1MvlLy3M8OXSsR+ti0zMwMxMUuwZMlCZGSko3btOnj99bcwZMgzUKnKfvg8+9kx2M92IoqA\nTgdBp4Ogy4Sg00Ghy4SgywR0OngrTMhISraEoJCVBdwVggXhWNBG0OuB/DA1mexbsloN0U0L0c0N\n0GggarUQNRrLbaKPL3RvvwNznRC7PSfDl8qN/Vyy5OQ7+N//PsWqVcuRnZ2NBg0i8Oab09Cv3xMQ\n7jpbVknYz45RLftZFKWQKxSUQmZmwbLuruXM/GUdhMyMguXC6/Q6CHaIDlEQADc3KfzcCsIQbm4Q\nNW4QtXnrNG557Qq3KRScGqkd7g5UjRugzbsux3+K7xfDl8qN/Vw6169fwyeffISvvloLk8mEyMhW\nmDLlHXTu3LVUIcx+dgyn6Oe8sFSkp0FIS4OQnlZ8GBYXmlZ/Z0Iwm8tfjiBAdPeA6O4uXTw8Cy17\nWK9zl9Z5BvkhzaQoCNG84LQEZn6AuroWOaVrVcLwpXJjP5fNhQvn8NFH72Hz5k0AgA4dOuHtt99B\n69ZtSrwf+9kxHNLPZjOEjPS84Ey3CtGC5XTp70LLVuttnB+9tEStuyUMzR6eQKHl/Nvh7gGzR35o\neuSFqDtErUeRUIVWW+aA5OdZwvClcmM/l8+xY0fw/vvvYteuXwAAvXv3xVtvTUPjxk1stmc/O0ap\n+jk3F0JaGhQZxQVlqnVopqVBUWhZyMwo8yZZUauF2csbopcXRC9vmL3zl32k68JBefeIs/Cy1h1Q\nyH8QCz/PEoYvlRv7+f7s2/cnZs+egX/+2QdBEPDUU0/jjTemICws3Kod+9nOjEYIKSlQJN+BIvkO\nhDvStWeuDvrrN4sP0Yx0aeegMhAFAaKnF0Rvbyk4ve5a9vKG6C2FqNnLO29d/rJ0O9TqCuoIefDz\nLGH4Urmxn++fKIrYuXMH3nvvXZw8eRxqtRpRUcPx6qtvIDAwEAD7uURmM4TUFCiSky0hWjhQFcl3\nICTfgeJO3nVyMhRpqaV+eFGthujtDbMlQPOCMi9ArZel8LSMUr29IXp4VorRZmXCz7OE4Uvlxn62\nH7PZjM2bv8OHH87G5cuXoNVqER09Di+//Arq1w+pHv0sitIo0yo4k6G4Yx2iluWUZAgpKaXaYUhU\nqWCu4QfRz0+6riFdm/1qWJa9wmsjRVRLI868wIWbW5Xe6UcO/N6QMHyp3NjP9mcwGPDVV2sxd+6H\nSEpKhI+PDyZNmoQePfqhQYMIucsrPVGU9qS1MRq1CtSUZGn9HWm5NDsTiQoFxBo1pPDMD1K/QoFa\no4YUsr41LIErenrdM0T5eXYM9rOE4Uvlxn6uOHq9HitXLsfChfOQmiptJm3YsBH69n0C/foNQNOm\nzcp0rLDdmM0Qbt+GMukGFEmJUCQlQZFYaDnphnR96yaE3NzSPaSvr1WIWo9M85Z9a0D0ywtTb58K\n2ZTLz7NjsJ8lDF8qN/ZzxcvISMcff/yKr77agLi4XcjOO91deHhd9Os3AP36PYFWrVrffxAbjVDc\nvpUXoolQJCbeFah5t926WeLZhUQXF5gDg2AOCIDZz79oiNbwsx61+vjIcoIDW/h5dgz2s4ThS+XG\nfnaM/H7OzMzErl0/IzZ2C375ZQf0eh0AIDi4Nvr27Y9+/Qbg4YcfgVKpLLizwQDFrZt5o9OkvBC9\nAcXNJOuQvX2rxN9ORY0G5ppBMAcGwhxUC6bAQClk8y9BtWAODIToW8NpfyPl59kx2M8Shi+VG/vZ\nMWz1c1ZqKvbH/oCD27bg8l9/wFuvRy0AdTUaNK/hh1C1Gp46HRR3bpd4XKmo1cJcMxCmoFp5IRpk\nFbJSuAZKm3qdNFRLi59nx2A/S0oK38qxLYiousnKgjIhHsqEK1DExwOpt+B58Yr1ZuDkZIQCePLu\n+2ZnA9evIQPARYUCBv8AuNVvAP9mzSEE15HC1TJaDZIOhanioUrkbBi+RBUhOxvKawlQXLmSF7Lx\nUMRflpbj46G4dbPIXfInJDR7ecMcGAhj0+Yw1wwsGK3mBaohoCb+jr+CH3b9jG3btuLGjevArZvw\nOHYUPXr0RL/QAXisVWu4u7s79jUTUalxszOViP1cjNxcKK4m5IXpFSjyrqWQvQJlUqLNu4lqNcy1\n68AUGg5TaCjMIaEwhYbBq2kE7rh6wRwYJJ1Lt5TMZjMOHtyP2NgtiI3dgvj4ywAANzc3dO3aHX37\n9kevXn3g5eVtj1ft9Ph5dgz2s6RCfvOdM2cODhw4AKPRiNGjR6Nnz57FtmX4Oq9q288GAxTXrlqP\nWuPzlhPiobhx3ebvrKJSCXPtEJhCpVA1h4TCFBIKU2g4zKGhUrgW3lkqjz36WRRFHD9+DNu2/YDY\n2C04e/YMAECtVqNTpy7o128AevfuCz8/v/t6HmdWbT/PDsZ+ltg9fPft24eVK1dixYoVSElJwaBB\ngxAXF1dse4av86qy/Ww0QnHjuvWoNX85IR6K69ds7hksKhTSyDUktFCwhsEcGibdViu4XIfVVEQ/\nnz17BrGxUhAfP34UAKBUKtG+fQf07fsEHn+8H4KCatn1OSu7Kvt5rmTYzxK7h6/JZEJOTg60Wi1M\nJhPat2+PP//80/rwh0IYvs7LafvZZIIi8YYUpFcuW0asls3E167aPJZVFASYawVbNgebQkKlYM1f\nDq5dISfBr+h+vnz5ErZt24rY2B9w4MC/AABBEPDQQw+jX78B6Nu3P0JDwyrs+SsLp/08Oxn2s6RC\nDzXasGED9u/fj48//rjYNhXxJrRp4wmzjZHJuHG5GDnSkLeswd9/F/0PQevWJixfLp3IYO1aNebP\nd7H5HH/9pYOLC3DunAJDh7rZbDNvXjY6d5a+xHv10uL27aJ7lQ4ZYsB//yudCeidd1wRG1t0ZBQa\nasb330uzqWzfrsLUqa42n2/rVj2Cg0WkpgLdutneoWbKlBwMHiydwu/ZZ91w+nTRMwV17WrE3Lk5\nAICFC12wenXRQNFqRZw+rcStWxnYv1+B0aNt98GqVVlo2VJ6L9q2dYetswdGR+di9GjpfZk0yRV7\n9xbtg+bNTVi9Wnpfvv5ahY8/tt0He/bo4OEBXL4sYPBADWA0QDAYAEPetdGIJRiLx02xAIAO2Iur\nqFPwAAoloFLh/8L3YUbfP2EOCcP033rhu31hgEpptWdwrVpmxMZK78uuXUq88YYGtnz3nR7h4SIy\nM4HOnW2/L6+/noOhQ6XOGT5cg2PHCj6bCoUCZrMZHTsaMX++9L7ExKixfHnRz6ZKBfz9t3T875Ej\nCowYYft9iYnJwkMPSe9Lx45a6PXS6zIaTcjK0iMrS4/c3PkQxTkAAD+/b2A0doebmxvUhf6D0aiR\nGV9+mZX3OlV4/33b78uuXTr4+ADXrwvo39/279azZ+egTx+pDwYNckN8fNHPZr9+RsycKfXBRx+5\nYOPGop9Nf38RO3boAQB79igxebLt9+Xrr7PQoIEZublAu3buln4ubNKkXERFSZ/N6GgNDhwo+p3R\ntq0JS5ZIn82VK9VYssT2d8aBA9L7cvKkAlFRtt+XRYuy0a6d9J3RtasW6elFvzOefdaAyZOl74wp\nU1yxY0fRfy/16pmxcaP0vmzdqsKMGbbfl+3b9ahZU8TNmwL69LH9vsyYkYP+/aX3ZcgQN1y4UPR9\n6dXLiPffl96XefNc8OWXRd8XLy8Ru3frERDgiS1b9Bg/3vb7snZtFpo0kd6H1q1t/3uR87vcXirs\nUKOdO3fi22+/xapVq0ps5+urhUple1R8PxQ2Tj/n6alBQID0hms0ts9Q5+qqQECAOq998WexCwjw\nhIsLcOdO8W18fLQICJCWVSrb7dzdXREQIP3D0Gptt1GrFZY3ytu7+Ofz8/NAQEDxzwUAXl5ulppc\nXGy3c3NzQUCA9EH18LDdJn9DRkCAJ3x9i38+X193y/MplYCt8zh4eNzP+yICJhOQKwVswKyp8Dh/\nGBnH9VCkfFv0gRQKCA3qA62GAuHhwDf1gMy8sywpVZZwVT05CO4fDJJqugkoDhV9qLK+L25uxbfx\n9Cx4X1xdi7ZTKBTQaEr3vuTXVJb3Jb+di4sCLi7e8Pb2xvPPT0OdOvWwadMm/PxzMkQxFWlpqVCr\n1dBqtXB3d4eLi9ryfF5exT+fv7/0OcnJKb6Nt3dBH6jVtttptQV94F7M9LQqVUEf+JRwJsoaNaQ+\nyM0taHP390bh7wxb7wsAaDSl/86Qntd+3xnFfaZcXBSlfF+kz6bZbL/vjNK9L9p7vi9ASf9e5Psu\nd4Ryj3z37t2LBQsW4LPPPoOPj0+JbbnZ2Xk5tJ+NRigvX4Ly7Bkoz52BKu9aee4cFLpMq6aiQgFz\naBiMEQ1himgkXTeIgKlBBEQn3LO3MnyeU1NTsGPHdmzbtgW7d+9CTo40yqlb9wHLaS4jIx+U53zT\ndlIZ+rk6qMh+FkXAaATyNnbBYBBgMEj/wTIagdxcoci6/PWF/zYYhEL3kf7TMXSoAV5e9qvV7pud\nMzIyMGzYMKxevbpUe04yfJ1XhfRzVhaU589Bde4MlGfPQHXurBSyFy8UOVG/6OICU70GBeEa0RDG\nBg1hqldf+u9wFVHZPs+ZmRnYuVM6zeXOnT9bTnNZu3Ydy2ku27RpW+x+HpVVZevnyshkks7jkpsr\nBVlODpCTI11b3ybdLt0GZGdLyzk5AtRqV6Sm5tgIwMLhd3cAFl1XEKrW7SrKypVZls3v9mD38N2w\nYQMWLlyIunXrWm776KOPEBwcbLM9w9d53U8/C2mpBeFaaDSriL9S5DAds4cnTBERMDXIC9eIhtLf\noeGV5qT8Fakyf56zsrIQF/crYmN/wI4d25GengYA8PX1RYsWkWjZshVatoxEixaRCA0Nq9Qj48rc\nz2YzkJUlhVh2trSclSXkBaFgFXZSEBaEoK1l27cVF6QFj2syyfP+qdUi1GppBKpSiXnXgIuLdHvB\npaBd/nL+xcVFzLtP0XVqtfVjFrST1nl5AW3amGwdCVhuPLczlds9+1kUobiZJIXr2TPSaDY/bG2c\naMLsH5A3ipXC1digIUwNG8EcVKtanwLRWT7Pubm5+P33PYiN3YK9e/fgypXLVut9fHzQvHkkWraM\ntARyeHjdShPIZelnUcwf0UmBqNcXBGN2tmAVjkWvC9rq9db3KQhX679zciq+jwRBhKurFDouLiI0\nGmnZ1VXMuw3QaETL+vy2rq62lotfX7OmFnq9/q4ALByu1gGoUlXNf/4MXyo3Sz+bzVDEX8nbVHy2\n0G+yZ6FISy1yP1NIKEwNCsLVmBe2om8NGV5F5eesn+fU1BQcPXoER44cxtGj0uXSpYtWbby9fdCi\nRUs0b97SEsrh4Q/Y3GGyLERRGqnpdAJ0uruvbd8mii5ISTFYQq9wgNoKUbPZ/ong5iaFnkYjXbu5\niXBzk/7Ovy58e36wFQ486/C71/qCx3BUyDnr59neGL5UJkJKMlRHDkN15BA8Lp6F4dgJqM6fhZA3\nz2w+UaWCqe4DeTs8ReSNZhvCWK+BtDsklVpV+jynpaXi2LGjeYF8CEePHsGFCxcAuAPwAOABrTYI\nDzzQEqGhTREcHIGaNetBqw1EVpaiSGjq9cWHqz02kapURcPPVgjmL+cHp5tb6f8uuL1qjvDuVpU+\nz/eDsxpRsYSMdKiOHoHq8CGojhyE+tBBKO/alKjSamGMaGS9w1NEQ5jqPlAhJ5ygysFgANLTBaSn\nAxkZAtLSBMvf6ekCMjKKjjCloNRCpwuGTtfHchtgnTh6PXD8uHQpLa1WhLu7CHd3oEYNs2XZ+rrk\n22rXdkdWViY0GunxNJpqsUsBVUL82FUnej1Ux49BfeQgVIcOQnXkEJTnz1nt/GSuUQO5XbvB0OpB\nGFs+CO9Oj+C2WwkHk1KlZDYDmZn54WkdmmlpUnCmp8OynB+sGRkFt+WflKOslEoRHh5S2Pn6iqhT\n5+5QlJbV6hxkZNzAnTuXkZh4Htevn8aNG+cgiukAMgFkws1NRNOmddGqVSO0bNkSLVu2Qv36Dcq9\nl3VAAHDrVoVv7CO6J4ZvVZWbC9XJ49KI9vBBqA8fgvLMKatTKpo9vWB4tCOMkQ/CENkKxsgHYQ4J\ntd4uFuAJcPORQ4mi9Ptj4dC0DkncFZgC0tJQaFkKUVEsW3hKe3yK8PQEgoLM8PIS8y4otFxwm6en\nCA8PEVqtdai6uJRl02qtvEs7AIBOp8Px48dw9Oghy+/IBw/GYf/+Xy330Gq1aNq0uWWHrpYtW6FB\ngwioOIQlJ8LffKsCoxHKM6ehPnIob0R7EKqTJ6yOmRXd3GBs3jJvRCsFremBevcc0bKf7092NpCc\nLFhd7tyRrlNSCv7OzFQhOdlsGZ0aDGULTkGQQlMKTxHe3sWHZuG/vb0L7uPmVjl/j9Tr9Thx4hiO\nHj2MI0eky9mzp2Eq9B9JNzc3NGnSLG+HrlZo0SISDRs2KhLI/Dw7BvtZwh2uqhKzGcoL56E6fNAy\nolUdPwohK8vSRHRxgbFps7wRrRS2poiGlWa2HWeVk1NykOYvF/67tJtu3dwALy9zMSPNoiHq7Y1C\nISsWeyrKqiorKwsnTx63jI6PHDmMM2dOwVjoxOIajQZNmzbL28taCuSOHR9Gamp2CY9M9sDvDQnD\n11mJIhRXLhca0R6C6shhKDILXreoVMLUqIlls7ExshWMjZtK2/7soKr2c04OLCPPwkFaeDR6d5Dq\ndKULUq1WRI0a0sXXV4SfX/F/+/lJt4WEVM1+dqTs7GycOnXCKpBPnz4Jg8FgaaNUKlG7dh3UqRNi\nuYSEhFqua9euA1dX2xMUUOlV1e+NsuLezs5AFKG4cd0SsurD0rUiJaWgiSDA1CACuS1bFWw+btZC\nGjZVczodkJQkIClJgdu3bQdp4dFpZmbpglSjkQLygQfMRYLz7kt+kPLtkIdGo0GrVq3RqlVry205\nOTk4ffqkZXP1hQtncOnSZfz11x8obtwRGBiUF8YhCAkJsyzXqSOFtIeHh6NeElVhHPnKRLh1C+rD\nB6x2iFLcumnVxhReN29E21oa0bZoCdGj+P9JVQS5+zkzsyBUExMFJCUJSExU5N0mWNZlZNw7TF1d\nC8KzpCAt3EZrewY2u5O7n6uL/H7Ozc3FtWtXcfVqAq5eTUBCQrxlOT4+HtevX7XahF2Yr6/vXaEs\nBbMU1qHw8fGtNGf0kgs/zxKOfOWWlQX1/n+gOrgf6vxDfK5dtWpiql0HOY/3LxjRtoyssmeDEkUp\nVPNDND9Uk5IKQjV/3b029fr7mxESYkZgoIigIBGBgWYEBBQfpNX8O5HyuLi4oG7dB1C37gM215tM\nJiQlJSIhIQFXr8YjISHesnz1agLOnTuDo0cP27yvu7tHoVCWRs/5f4eEhCIgoOZ9n92LnB/DtyIY\njVAdPgiXvXug3rsH6n//hpA3PRsgnd84p0cvy2+0hpYPQqxZU8aC7UMUgfR0WI1MExMVuHlTsBq1\n3rxZ8o5IgiCFZt26+aEqXdesWRCwQUEiAgJEe/20TWRFqVQiOLg2goNro23bR4qsF0URd+7cQULC\nlbyRc0Ewx8dL16dPn7L52K6urggOrm0VyvnBHBISilq1gnnYVDXAd9geRBHKUyfhsjdOCts//7Da\nKcrYtDlyO3aG4eFHYGz1IMzBtZ1qCCaKQGoqrDb9Wo9SC/7Ozi7+dSkUIvz9RdSrZ7aEaGCgaBWw\ngYFSqPLEWVSZCYIAf39/+Pv7W/3GXFh6elpeKCcgIeGKZVkaSSfgt99227yfUqlErVrBVjuF1ahR\nA76+NQpd+6FGjRrw8vLmKNpJ8TffclJcvpQ3so2Dy++/QXH7tmWdse4DMHTsgtxOnWFo3xGiv79s\ndZaG0Qhcvy4gPl6BhAQBV64oLMtJSSrcuCGWOOOKQiGNSvM3/dasab0ZWLqWgpf/obdN7s9zdVGZ\n+lmv1+Patat3/d58xbKcmHgDZrO5xMdQKBTw9fWFr68Uyn5+fpbl/KDOX65RI3+dL1wqeJNRZepn\nOfE3XzsQkpLg8ru0Gdnl99+gjL9iWWcKDEL2U08jt1MXGDp0grlOiIyVFmU2SzstXbkiBWp+sMbH\nSyF77Zpg8wT1CoWIWolNSOEAAAkTSURBVLWAJk3Md41SrUet/v6iXefAJKoOtFotGjSIQIMGETbX\nGwwGXL9+DdevX0NycjJSUpILXd+x+jslJRmXLl20OvFISTw8PG2MpouGduEwd3d3r/Y7ktkTw7cY\nQloq1H/+IY1s9+6B6sxpyzqztw9yHu8vbUru1AWm+g1k3YwsisDt24JVoMbHFyxfvSogN9d2fUFB\nZjz4oBmhoWaEhZkREiIiNFT6OzhYRHCwJ27d0jv4FRGRWq1GWFg4wsLCS9XebDYjPT3NKpALL9+5\nU/T2s2dPI6vQCXpK4uLiYmMUbTu869ULQW6uAHd3d2i17uU+F3dVxvDNl5UF9T/7LJuSVUcOQ8jb\n5CO6uSG3y2PI7dgFhk6dpWNrHfxhSk0FEhIUVqPXwiPY4nZg8vc3o2lTsyVQ88M1LMyM2rWlWV2I\nyPkpFAr4+PjCx8cXQL1S30+v1xcJ6sIj7Ltvv379Ok6dOlmm2jQaDdzd3eHu7gGtVpsXyh5511q4\nuxddzg/u/PvZauvMv3dX3/A1GqE6dMB6j+S8cyGLKhWMDz1sGdkaHnxImp26AmVmSuEaHy9YQjZ/\nOT5egfR02+Hq5SWdACI/WMPCCpZDQszg+QCIqCRarRZarRa1a9cp9X2MRiNSUlKKDe2srAwkJ6dC\np9PlXTKh1+uh0+mQlJQInU6H3ELnnr+/2u8O6sIhbyvIrf8TkL/s6+sLb2+f+66ptKpP+JrN1nsk\n//WnZY9kURBgbNYChg6dYOjUGblt28PeqZWTgyKbhfODNT5ewJ07tv8Hp9VKI9VHHpHCVBrBFmwa\n9va2a5lERPekUqkQEBCAgIAAm+tLs8OVwWCAXq+zBHTBcmbe33rLsvV66zDPb5OamorMzIxS/+59\nN4VCga+++haPPda9XPcvq6obvqJYsEfy73uK7pFcrz5yBg+R9kh+tCPEGn52eVq9Hjh/XoEzZxQ4\nezb/WonLlwWYzUVHry4uIkJCRDRvbiwSrKGh0vGu3MeBiKoatVoNb28fu442RVFEbm6uzXC2DvOi\n6wEgIqKh3Wq5lyoVvoqkRGlUm79HckK8ZZ2pVjCyhzyD3A6dYOjYGeYybGKxJTMTOHs2P2CVlqBN\nSBCKzKNao4YZbdqYUK+eFKjSCFbaRFyzplitZqMhIqoogiDA1dUVrq6uqGGnAVVFcerwFdJSof7j\nd2lT8u+/We+R7OuLnH4DpLDt1AWmevXLtUdyaipw5owS584VjGbPnlXg2rWiiVmzphkdOpgQEWFG\nRIQZDRtK1/7+FX4oNRERORHnC19RhNvC+cCOWPgdOFCwR7JWi9zHukt7JHfsJO2RXIYh5e3bQqHN\nxAWbjG/eLPoYwcFmdOlitISrdDHB19dur5KIiKow5wtfnQ7un3wIGI0wPPwIDB07S5cHH7rnHLai\nCNy8Kdz1e6x0sbXDU2ioGd27G/NGsdKItkEDM7y8KurFERFRdeB84evhgTv7j8M/LBBpetunXhNF\n6XSJ1qNY6XfZtDTrTc+CICI8XESbNgarzcX165vh7u6IF0RERNWN84UvADEgAHB3hzkzAwkJgtVe\nxfnLd09Fp1RKx8N26GC22lxcr56Zk58TEZFDOV34iiIwfbor/v0XOHXKA1lZ1iGrVouoX99cZKen\nBx4wc/o5IiKqFJwufPV6YMMGNbKzYQnZ/IBt2NCE8HDOnENERJWb08WUuztw4kQmAgM9kZzME/4T\nEZHzccrTO6jVDp/XgIiIyG6cMnyJiIicGcOXiIjIwRi+REREDsbwJSIicjCGLxERkYMxfImIiByM\n4UtERORgDF8iIiIHY/gSERE5GMOXiIjIwRi+REREDiaIoijKXQQREVF1wpEvERGRgzF8iYiIHIzh\nS0RE5GAMXyIiIgdj+BIRETkYw5eIiMjBnC5833//fTz99NMYOnQojh49Knc5VdqcOXPw9NNPY/Dg\nwfj555/lLqdKy87ORvfu3bFp0ya5S6mytmzZgieeeAJPPvkk4uLi5C6nStLpdBg/fjyioqIwdOhQ\n7N27V+6SKi2V3AWUxT///IMrV65gw4YNuHDhAqZMmYINGzbIXVaVtG/fPpw7dw4bNmxASkoKBg0a\nhJ49e8pdVpW1dOlSeHv/f3v398r6H8Bx/LkzubBxzDJaIblRSigXWHJBLlz7kRa3cqVc0FKUq7lS\nKAp/gLZwI0pZuZgr5UJRXGExy8evxgU6d6fOt9x8a3vbp9fjbrt61i5ee38+n7bfpjNsy7IslpaW\niEajpNNpFhYW6OjoMJ1lO5ubm1RXVzM+Ps7d3R3Dw8Ps7u6azvqRcmp84/E4nZ2dANTU1PD09MTr\n6ytut9twmf00NzdTX18PQFFREW9vb3x+fuJ0Og2X2c/l5SUXFxcagwyKx+O0tLTgdrtxu93Mzs6a\nTrIlj8fD+fk5AM/Pz3g8HsNFP1dOXXZOpVL/fJglJSXc398bLLIvp9NJQUEBAJFIhPb2dg1vhoTD\nYSYnJ01n2Nr19TXv7++MjIwwODhIPB43nWRLPT09JBIJurq6CAaDTExMmE76sXLq5Ptf+mXMzNvf\n3ycSibC+vm46xZa2trZoaGigoqLCdIrtPT4+sri4SCKRYGhoiIODAxwOh+ksW9ne3sbv97O2tsbZ\n2RmhUEjPMXwjp8bX5/ORSqX+vk4mk5SWlhossrfDw0OWl5dZXV2lsLDQdI4txWIxrq6uiMVi3N7e\nkp+fT3l5Oa2trabTbMXr9dLY2EheXh6VlZW4XC4eHh7wer2m02zl+PiYQCAAQG1tLclkUrervpFT\nl53b2trY29sD4PT0FJ/Pp/u9GfLy8sLc3BwrKysUFxebzrGt+fl5otEoGxsb9Pb2Mjo6quHNgEAg\nwNHREV9fX1iWRTqd1v3IDKiqquLk5ASAm5sbXC6XhvcbOXXybWpqoq6ujoGBARwOB9PT06aTbGtn\nZwfLshgbG/v7Xjgcxu/3G6wS+X/Kysro7u6mr68PgKmpKX79yqmzR07o7+8nFAoRDAb5+PhgZmbG\ndNKPpb8UFBERyTJ99RMREckyja+IiEiWaXxFRESyTOMrIiKSZRpfERGRLNP4ioiIZJnGV0REJMs0\nviIiIln2BzQKNGAGnBgwAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f7a18df6b50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Train our variables.\n", - "\n", - "# numpy is used for its asscalar() function.\n", - "import numpy as np\n", - "\n", - "num_training_steps = 10\n", - "\n", - "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n", - " loss_at_step = []\n", - " w_at_step = []\n", - " b_at_step = []\n", - " for step_num in range(num_training_steps):\n", - " loss_at_step.append(run_step(inputs, labels))\n", - " w, b = wb.variables\n", - " w_at_step.append(np.asscalar(w.numpy()))\n", - " b_at_step.append(np.asscalar(b.numpy()))\n", - "\n", - " print(w_at_step)\n", - " t = range(0, num_training_steps)\n", - " plt.plot(t, loss_at_step, 'k',\n", - " t, w_at_step, 'r',\n", - " t, [true_w] * num_training_steps, 'r--',\n", - " t, b_at_step, 'b',\n", - " t, [true_b] * num_training_steps, 'b--')\n", - " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n", - " plt.show()\n", - "\n", - "train_model(inputs, labels, wb, optimizer, num_training_steps)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Eager Execution Tutorial: Working with Gradients", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb deleted file mode 100644 index bfcc7feb075c403d024772e0d715339d58877a51..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "U9i2Dsh-ziXr" - }, - "source": [ - "# Eager Execution Tutorial: Importing Data\n", - "\n", - "This notebook demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n", - "\n", - "* Creating a `Dataset`.\n", - "* Iteration over a `Dataset` with eager execution enabled.\n", - "\n", - "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", - "\n", - "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", - "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", - "As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "z1JcS5iBXMRO" - }, - "source": [ - "# Setup: Enable eager execution\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RlIWhyeLoYnG" - }, - "outputs": [], - "source": [ - "# Import TensorFlow.\n", - "import tensorflow as tf\n", - "\n", - "# Enable eager execution\n", - "tf.enable_eager_execution()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "H9UySOPLXdaw" - }, - "source": [ - "# Step 1: Create a source `Dataset`\n", - "\n", - "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WPTUfGq6kJ5w" - }, - "outputs": [], - "source": [ - "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", - "\n", - "# Create a CSV file\n", - "import tempfile\n", - "_, filename = tempfile.mkstemp()\n", - "with open(filename, 'w') as f:\n", - " f.write(\"\"\"Line 1\n", - "Line 2\n", - "Line 3\n", - " \"\"\")\n", - "ds_file = tf.data.TextLineDataset(filename)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "twBfWd5xyu_d" - }, - "source": [ - "# Step 2: Apply transformations\n", - "\n", - "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "ngUe237Wt48W" - }, - "outputs": [], - "source": [ - "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", - "ds_file = ds_file.batch(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "IDY4WsYRhP81" - }, - "source": [ - "# Step 3: Iterate\n", - "\n", - "When eager execution is enabled `Dataset` objects support iteration.\n", - "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 388, - "status": "ok", - "timestamp": 1525154629129, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "lCUWzso6mbqR", - "outputId": "8e4b0298-d27d-4ac7-e26a-ef94af0594ec" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Elements of ds_tensors:\n", - "tf.Tensor([1 9], shape=(2,), dtype=int32)\n", - "tf.Tensor([16 25], shape=(2,), dtype=int32)\n", - "tf.Tensor([ 4 36], shape=(2,), dtype=int32)\n", - "\n", - "Elements in ds_file:\n", - "tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n", - "tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n" - ] - } - ], - "source": [ - "print('Elements of ds_tensors:')\n", - "for x in ds_tensors:\n", - " print(x)\n", - "\n", - "print('\\nElements in ds_file:')\n", - "for x in ds_file:\n", - " print(x)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Eager Execution Tutorial: Importing Data", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/README.md b/tensorflow/contrib/eager/python/examples/notebooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0d5ed848946d1eee643a57bf8c341520268c56b1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/README.md @@ -0,0 +1,11 @@ +## Research and experimentation + +Eager execution provides an imperative, define-by-run interface for advanced +operations. Write custom layers, forward passes, and training loops with auto +differentiation. Start with these notebooks, then read the +[eager execution guide](https://www.tensorflow.org/guide/eager). + +1. [Eager execution basics](./eager_basics.ipynb) +2. [Automatic differentiation and gradient tapes](./automatic_differentiation.ipynb) +3. [Custom training: basics](./custom_training.ipynb) +4. [Custom layers](./custom_layers.ipynb) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a18882fafa192bc4d4277d9d76fcd676b8295e04 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -0,0 +1,364 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "automatic_differentiation.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "id": "t09eeeR5prIJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "GCCk8_dHpuNf", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "xh8WkEwWpnm7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Automatic differentiation and gradient tape" + ] + }, + { + "metadata": { + "id": "idv0bPeCp325", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "vDJ4XzMqodTy", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." + ] + }, + { + "metadata": { + "id": "GQJysDM__Qb0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Setup\n" + ] + }, + { + "metadata": { + "id": "OiMPZStlibBv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "1CLWJl0QliB0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Derivatives of a function\n", + "\n", + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " + ] + }, + { + "metadata": { + "id": "9FViq92UX7P8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "from math import pi\n", + "\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "assert f(pi/2).numpy() == 1.0\n", + "\n", + "\n", + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() < 1e-7" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "v9fPs8RyopCf", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Higher-order gradients\n", + "\n", + "The same API can be used to differentiate as many times as you like:\n" + ] + }, + { + "metadata": { + "id": "3D0ZvnGYo0rW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", + "\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-39gouo7mtgu", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Gradient tapes\n", + "\n", + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" + ] + }, + { + "metadata": { + "id": "MH0UfjympWf7", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def f(x, y):\n", + " output = 1\n", + " for i in range(y):\n", + " output = tf.multiply(output, x)\n", + " return output\n", + "\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", + "\n", + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "aNmR5-jhpX2t", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", + "\n", + "For example:" + ] + }, + { + "metadata": { + "id": "bAFeIE8EuVIq", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "DK05KXrAAld3", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Higher-order gradients\n", + "\n", + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" + ] + }, + { + "metadata": { + "id": "cPQgthZ7ugRJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", + "\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", + "\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", + "\n", + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "4U1KKzUpNl58", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..54fbf2a7e18da0e8ec21ff6e01ea13b3a6a57ca4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb @@ -0,0 +1,399 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "custom_layers.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "metadata": { + "id": "tDnwEv8FtJm7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "JlknJBWQtKkI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "60RdWsg1tETW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Custom layers" + ] + }, + { + "metadata": { + "id": "BcJg7Enms86w", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "UEu3q4jmpKVT", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n" + ] + }, + { + "metadata": { + "id": "pwX7Fii1rwsJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "zSFfVVjkrrsI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "metadata": { + "id": "8PyXlPl-4TzQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Fn69xxPO5Psr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "metadata": { + "id": "E3XKNknP5Mhb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Wt_Nsv-L5t2s", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "6ilvKjz8_4MQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "O0kDbE54-5VS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." + ] + }, + { + "metadata": { + "id": "5Byl3n1k5kIy", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "tk8E2vY0-z4Z", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "metadata": { + "id": "Qhg4KlbKrs3G", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "metadata": { + "id": "N30DTXiRASlb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wYfucVw65PMj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "metadata": { + "id": "L9frk7Ur4uvJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "c5YwYcnuK-wc", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0a781d215308f04290aac2a74b5f0b1faf8b5406 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb @@ -0,0 +1,478 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Custom training: basics", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "id": "5rmpybwysXGV", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "m8y3rGtQsYP2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "hrXv0rU9sIma", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Custom training: basics" + ] + }, + { + "metadata": { + "id": "7S0BwJ_8sLu7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "k2o3TTG4TFpt", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "metadata": { + "id": "3LXMVuV0VhDr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Setup" + ] + }, + { + "metadata": { + "id": "PJ64L90aVir3", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tfe = tf.contrib.eager # Shorthand for some symbols\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "eMAWbDJFVmMk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "metadata": { + "id": "VkJwtLS_Jbn8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wfneTXy7JcUz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "metadata": { + "id": "itxmrMil6DQi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "v = tfe.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-paSaeq1JzwC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "metadata": { + "id": "BMiFcDzE7Qu3", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "metadata": { + "id": "gFzH64Jn9PIm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "metadata": { + "id": "_WRu7Pze7wk8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tfe.Variable(5.0)\n", + " self.b = tfe.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "xa6j_yXa-j79", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "metadata": { + "id": "Y0ysUFGY924U", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "qutT_fkl_CBc", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "metadata": { + "id": "gxPTb-kt_N5m", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-50nq-wPBsAW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "metadata": { + "id": "_eb83LtrB4nt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "sSDP-yeq_4jE", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "metadata": { + "id": "MBIACgdnA55X", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "RwWPaJryD2aN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "metadata": { + "id": "XdfkR223D9dW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vPnIVuaSJwWz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b37a18c9a6091c927767a814c1131ef5739c810b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb @@ -0,0 +1,491 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "eager_basics.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "id": "iPpI7RaYoZuE", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "hro2InpHobKk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "U9i2Dsh-ziXr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Eager execution basics" + ] + }, + { + "metadata": { + "id": "Hndw-YcxoOJK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "6sILUVbHoSgH", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This is an introductory tutorial for using TensorFlow. It will cover:\n", + "\n", + "* Importing required packages\n", + "* Creating and using Tensors\n", + "* Using GPU acceleration\n", + "* Datasets" + ] + }, + { + "metadata": { + "id": "z1JcS5iBXMRO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Import TensorFlow\n", + "\n", + "To get started, import the `tensorflow` module and enable eager execution.\n", + "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later." + ] + }, + { + "metadata": { + "id": "RlIWhyeLoYnG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "code" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "H9UySOPLXdaw", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Tensors\n", + "\n", + "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n" + ] + }, + { + "metadata": { + "id": "ngUe237Wt48W", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "code" + }, + "cell_type": "code", + "source": [ + "print(tf.add(1, 2))\n", + "print(tf.add([1, 2], [3, 4]))\n", + "print(tf.square(5))\n", + "print(tf.reduce_sum([1, 2, 3]))\n", + "print(tf.encode_base64(\"hello world\"))\n", + "\n", + "# Operator overloading is also supported\n", + "print(tf.square(2) + tf.square(3))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "IDY4WsYRhP81", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Each Tensor has a shape and a datatype" + ] + }, + { + "metadata": { + "id": "srYWH1MdJNG7", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "x = tf.matmul([[1]], [[2, 3]])\n", + "print(x.shape)\n", + "print(x.dtype)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "eBPw8e8vrsom", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", + "\n", + "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n", + "2. Tensors are immutable." + ] + }, + { + "metadata": { + "id": "Dwi1tdW3JBw6", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### NumPy Compatibility\n", + "\n", + "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n", + "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n", + "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n", + "\n", + "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n", + "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory." + ] + }, + { + "metadata": { + "id": "lCUWzso6mbqR", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "ndarray = np.ones([3, 3])\n", + "\n", + "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n", + "tensor = tf.multiply(ndarray, 42)\n", + "print(tensor)\n", + "\n", + "\n", + "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n", + "print(np.add(tensor, 1))\n", + "\n", + "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", + "print(tensor.numpy())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "PBNP8yTRfu_X", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## GPU acceleration\n", + "\n", + "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:" + ] + }, + { + "metadata": { + "id": "3Twf_Rw-gQFM", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "code" + }, + "cell_type": "code", + "source": [ + "x = tf.random_uniform([3, 3])\n", + "\n", + "print(\"Is there a GPU available: \"),\n", + "print(tf.test.is_gpu_available())\n", + "\n", + "print(\"Is the Tensor on GPU #0: \"),\n", + "print(x.device.endswith('GPU:0'))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vpgYzgVXW2Ud", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Device Names\n", + "\n", + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:` if the tensor is placed on the `N`-th tensor on the host." + ] + }, + { + "metadata": { + "id": "ZWZQCimzuqyP", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "\n", + "\n", + "### Explicit Device Placement\n", + "\n", + "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:" + ] + }, + { + "metadata": { + "id": "RjkNZTuauy-Q", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def time_matmul(x):\n", + " %timeit tf.matmul(x, x)\n", + "\n", + "# Force execution on CPU\n", + "print(\"On CPU:\")\n", + "with tf.device(\"CPU:0\"):\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"CPU:0\")\n", + " time_matmul(x)\n", + "\n", + "# Force execution on GPU #0 if available\n", + "if tf.test.is_gpu_available():\n", + " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"GPU:0\")\n", + " time_matmul(x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "o1K4dlhhHtQj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Datasets\n", + "\n", + "This section demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your model. It covers:\n", + "\n", + "* Creating a `Dataset`.\n", + "* Iteration over a `Dataset` with eager execution enabled.\n", + "\n", + "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", + "\n", + "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", + "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", + "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled." + ] + }, + { + "metadata": { + "id": "zI0fmOynH-Ne", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Create a source `Dataset`\n", + "\n", + "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information." + ] + }, + { + "metadata": { + "id": "F04fVOHQIBiG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", + "\n", + "# Create a CSV file\n", + "import tempfile\n", + "_, filename = tempfile.mkstemp()\n", + "\n", + "with open(filename, 'w') as f:\n", + " f.write(\"\"\"Line 1\n", + "Line 2\n", + "Line 3\n", + " \"\"\")\n", + "\n", + "ds_file = tf.data.TextLineDataset(filename)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vbxIhC-5IPdf", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Apply transformations\n", + "\n", + "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." + ] + }, + { + "metadata": { + "id": "uXSDZWE-ISsd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", + "\n", + "ds_file = ds_file.batch(2)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "A8X1GNfoIZKJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Iterate\n", + "\n", + "When eager execution is enabled `Dataset` objects support iteration.\n", + "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." + ] + }, + { + "metadata": { + "id": "ws-WKRk5Ic6-", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "print('Elements of ds_tensors:')\n", + "for x in ds_tensors:\n", + " print(x)\n", + "\n", + "print('\\nElements in ds_file:')\n", + "for x in ds_file:\n", + " print(x)" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 0c0e28dd95c68dc300384a128eb5aa2208f63a0d..68a84d5fbb4f13e4ebe0d71e3f5caebe97e2101c 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -51,5 +51,6 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "optonly", ], ) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index b8f352d5f5b72ffb8ae81a2bb72974c7fd65bd5a..b14ef1df8ff4c660b9b6f2abfd5df6572d10b1e8 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -49,15 +49,17 @@ def random_batch(batch_size, data_format): return images, one_hot -def train_one_step(model, images, labels, optimizer): - - with tfe.GradientTape() as tape: +def compute_gradients(model, images, labels): + with tf.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - grads = tape.gradient(loss, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) class ResNet50Test(tf.test.TestCase): @@ -114,7 +116,8 @@ class ResNet50Test(tf.test.TestCase): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) images, labels = random_batch(2, data_format) - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) self.assertEqual(320, len(model.variables)) tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) @@ -138,14 +141,16 @@ class ResNet50Test(tf.test.TestCase): # garbage to be collected. The hope is that this is a build-only effect, # and a subsequent training loop will create nothing which needs to be # collected. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() previous_gc_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL) for _ in range(2): # Run twice to ensure that garbage that is created on the first # iteration is no longer accessible. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() # There should be no garbage requiring collection. self.assertEqual(0, len(gc.garbage)) @@ -180,9 +185,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): return (16, 32, 64) if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': - # TODO(iga): Training fails with batch size of 16, probably because of - # no layout optimizations with op-by-op mode. Investigate more. - return (8,) + return (32,) return (16, 32) def _report(self, label, start, num_iters, device, batch_size, data_format): @@ -248,18 +251,21 @@ class ResNet50Benchmarks(tf.test.Benchmark): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size, data_format) - num_burn = 3 - num_iters = 10 model = resnet50.ResNet50(data_format) + optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients if defun: model.call = tfe.defun(model.call, compiled=compiled) - optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + num_burn = 3 + num_iters = 10 with tf.device(device): iterator = make_iterator((images, labels)) for _ in xrange(num_burn): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() @@ -268,7 +274,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): start = time.time() for _ in xrange(num_iters): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..81c9facfb5f00c45c8f26c1cd4284b98fb73dd23 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -0,0 +1,115 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "blocks", + srcs = ["blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "revnet", + srcs = ["revnet.py"], + srcs_version = "PY2AND3", + deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "large", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "blocks_test", + size = "large", + srcs = ["blocks_test.py"], + additional_deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) + +cuda_py_test( + name = "revnet_test", + size = "large", + srcs = ["revnet_test.py"], + additional_deps = [ + ":blocks_test", + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], + tags = [ + "no_pip", + "optonly", + ], +) + +# Training +py_library( + name = "cifar_input", + srcs = ["cifar_input.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar_tfrecords", + srcs = ["cifar_tfrecords.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "main", + srcs = ["main.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..21fc44febc8abdc30daad1b35d8434b083360bdf --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -0,0 +1,45 @@ +# RevNet with TensorFlow eager execution + +This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster. + +## Content + +- `revnet.py`: The RevNet model. +- `blocks.py`: The relevant reversible blocks. +- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. +- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. +- `config.py`: Configuration file for network architectures and training hyperparameters. +- `main.py`: Main training and evaluation script. +- `ops.py`: Auxiliary downsampling operation. + +## To run +- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- First run + +```bash +python cifar_tfrecords.py --data_dir ${PWD}/cifar +``` +to download the cifar dataset and convert them +to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. + +- To train a model run + +```bash +python main.py --data_dir ${PWD}/cifar +``` + +- Optional arguments for `main.py` include + - `train_dir`: Directory to store eventfiles and checkpoints. + - `restore`: Restore the latest checkpoint. + - `validate`: Use validation set for training monitoring. + - `manual_grad`: Use the manually defined gradient map given by the authors. + - `dataset`: Use either `cifar-10` or `cifar-100` + +## Performance +- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100. + +## Reference +The Reversible Residual Network: Backpropagation Without Storing Activations. +Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017. diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..306096e9f8c4da0ed7f893ae75067cd24e7274b1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -0,0 +1,357 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Building blocks with manual backward gradient computation. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops + + +class RevBlock(tf.keras.Model): + """Single reversible block containing several `_Residual` blocks. + + Each `_Residual` block in turn contains two _ResidualInner blocks, + corresponding to the `F`/`G` functions in the paper. + """ + + def __init__(self, + n_res, + filters, + strides, + input_shape, + batch_norm_first=False, + data_format="channels_first", + bottleneck=False, + fused=True, + dtype=tf.float32): + """Initialize RevBlock. + + Args: + n_res: number of residual blocks + filters: list/tuple of integers for output filter sizes of each residual + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(RevBlock, self).__init__() + self.blocks = tf.contrib.checkpoint.List() + for i in range(n_res): + curr_batch_norm_first = batch_norm_first and i == 0 + curr_strides = strides if i == 0 else (1, 1) + block = _Residual( + filters, + curr_strides, + input_shape, + batch_norm_first=curr_batch_norm_first, + data_format=data_format, + bottleneck=bottleneck, + fused=fused, + dtype=dtype) + self.blocks.append(block) + + if data_format == "channels_first": + input_shape = (filters, input_shape[1] // curr_strides[0], + input_shape[2] // curr_strides[1]) + else: + input_shape = (input_shape[0] // curr_strides[0], + input_shape[1] // curr_strides[1], filters) + + def call(self, h, training=True): + """Apply reversible block to inputs.""" + + for block in self.blocks: + h = block(h, training=training) + return h + + def backward_grads_and_vars(self, x, y, dy, training=True): + """Apply reversible block backward to outputs.""" + + grads_all = [] + vars_all = [] + + for i in reversed(range(len(self.blocks))): + block = self.blocks[i] + if i == 0: + # First block usually contains downsampling that can't be reversed + with tf.GradientTape() as tape: + x = tf.identity(x) + tape.watch(x) + y = block(x, training=training) + + grads_combined = tape.gradient( + y, [x] + block.trainable_variables, output_gradients=dy) + dy = grads_combined[0] + grads_all += grads_combined[1:] + vars_all += block.trainable_variables + else: + y, dy, grads, vars_ = block.backward_grads_and_vars( + y, dy, training=training) + grads_all += grads + vars_all += vars_ + + return dy, grads_all, vars_all + + +class _Residual(tf.keras.Model): + """Single residual block contained in a _RevBlock. Each `_Residual` object has + two _ResidualInner objects, corresponding to the `F` and `G` functions in the + paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC", + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + bottleneck=False, + fused=True, + dtype=tf.float32): + super(_Residual, self).__init__() + + self.filters = filters + self.strides = strides + self.axis = 1 if data_format == "channels_first" else 3 + if data_format == "channels_first": + f_input_shape = (input_shape[0] // 2,) + input_shape[1:] + g_input_shape = (filters // 2, input_shape[1] // strides[0], + input_shape[2] // strides[1]) + else: + f_input_shape = input_shape[:2] + (input_shape[2] // 2,) + g_input_shape = (input_shape[0] // strides[0], + input_shape[1] // strides[1], filters // 2) + + factory = _BottleneckResidualInner if bottleneck else _ResidualInner + self.f = factory( + filters=filters // 2, + strides=strides, + input_shape=f_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused, + dtype=dtype) + self.g = factory( + filters=filters // 2, + strides=(1, 1), + input_shape=g_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused, + dtype=dtype) + + def call(self, x, training=True, concat=True): + """Apply residual block to inputs.""" + + x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + f_x2 = self.f(x2, training=training) + x1_down = ops.downsample( + x1, self.filters // 2, self.strides, axis=self.axis) + x2_down = ops.downsample( + x2, self.filters // 2, self.strides, axis=self.axis) + y1 = f_x2 + x1_down + g_y1 = self.g(y1, training=training) + y2 = g_y1 + x2_down + if not concat: # For correct backward grads + return y1, y2 + + return tf.concat([y1, y2], axis=self.axis) + + def backward_grads_and_vars(self, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) + + with tf.GradientTape(persistent=True) as tape: + y = tf.identity(y) + tape.watch(y) + y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) + z1 = y1 + gz1 = self.g(z1, training=training) + x2 = y2 - gz1 + fx2 = self.f(x2, training=training) + x1 = z1 - fx2 + + grads_combined = tape.gradient( + gz1, [z1] + self.g.trainable_variables, output_gradients=dy2) + dz1 = dy1 + grads_combined[0] + dg = grads_combined[1:] + dx1 = dz1 + + grads_combined = tape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) + dx2 = dy2 + grads_combined[0] + df = grads_combined[1:] + + del tape + + grads = df + dg + vars_ = self.f.trainable_variables + self.g.trainable_variables + + x = tf.concat([x1, x2], axis=self.axis) + dx = tf.concat([dx1, dx2], axis=self.axis) + + return x, dx, grads, vars_ + + +def _BottleneckResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): + """Single bottleneck residual inner function contained in _Resdual. + + Corresponds to the `F`/`G` functions in the paper. + Suitable for training on ImageNet dataset. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + + Returns: + A keras model + """ + + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=1, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + model.add( + tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + model.add( + tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + return model + + +def _ResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): + """Single residual inner function contained in _ResdualBlock. + + Corresponds to the `F`/`G` functions in the paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + + Returns: + A keras model + """ + + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + model.add( + tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype)) + + return model diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d74785c8fe1c170ee95172974141c1cfe18b9502 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -0,0 +1,304 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +def compute_degree(g1, g2, eps=1e-7): + """Compute the degree between two vectors using their usual inner product.""" + + def _dot(u, v): + return tf.reduce_sum(u * v) + + g1_norm = tf.sqrt(_dot(g1, g1)) + g2_norm = tf.sqrt(_dot(g2, g2)) + if g1_norm.numpy() == 0 and g2_norm.numpy() == 0: + cosine = 1. - eps + else: + g1_norm = 1. if g1_norm.numpy() == 0 else g1_norm + g2_norm = 1. if g2_norm.numpy() == 0 else g2_norm + cosine = _dot(g1, g2) / g1_norm / g2_norm + # Restrict to arccos range + cosine = tf.minimum(tf.maximum(cosine, eps - 1.), 1. - eps) + degree = tf.acos(cosine) * 180. / 3.141592653589793 + + return degree + + +def _validate_block_call_channels_last(block_factory, test): + """Generic testing function for `channels_last` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + with tf.device("/cpu:0"): # NHWC format + input_shape = (8, 8, 128) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = block_factory( + filters=128, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 8, 8, 128)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory( + filters=128, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 4, 4, 128)) + test.assertNotAllClose(y_tr, y_ev) + + +def _validate_block_call_channels_first(block_factory, test): + """Generic testing function for `channels_first` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + if not tf.test.is_gpu_available(): + test.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = block_factory(filters=128, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 128, 8, 8)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory(filters=128, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 128, 4, 4)) + test.assertNotAllClose(y_tr, y_ev) + + +class RevBlockTest(tf.test.TestCase): + + def test_call_channels_first(self): + """Test `call` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = blocks.RevBlock( + n_res=3, filters=128, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 128, 8, 8)) + self.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = blocks.RevBlock( + n_res=3, filters=128, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, [16, 128, 4, 4]) + self.assertNotAllClose(y_tr, y_ev) + + def test_call_channels_last(self): + """Test `call` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (8, 8, 128) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 8, 8, 128)) + self.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 4, 4, 128)) + self.assertNotAllClose(y_tr, y_ev) + + def _check_grad_angle(self, grads, grads_true, atol=1e0): + """Check the angle between two list of vectors are all close.""" + for g1, g2 in zip(grads, grads_true): + degree = compute_degree(g1, g2) + self.assertLessEqual(degree, atol) + + def test_backward_grads_and_vars_channels_first(self): + """Test `backward` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + # Stride 1 + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + y = block(x, training=True) + # Compute grads from reconstruction + dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + # Compute true grads + grads = tape.gradient(y, [x] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + self._check_grad_angle(dx_true, dx) + self._check_grad_angle(dw_true, dw) + + # Stride 2 + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(2, 2), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + y = block(x, training=True) + # Compute grads from reconstruction + dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + # Compute true grads + grads = tape.gradient(y, [x] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + self._check_grad_angle(dx_true, dx) + self._check_grad_angle(dw_true, dw) + + +class _ResidualTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function. + + Varying downsampling and data format options. + """ + + _validate_block_call_channels_first(blocks._Residual, self) + _validate_block_call_channels_last(blocks._Residual, self) + + def test_backward_grads_and_vars_channels_first(self): + """Test `backward_grads` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + # Use double precision for testing + x_true = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + residual = blocks._Residual( + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + + with tf.GradientTape() as tape: + x_true = tf.identity(x_true) + tape.watch(x_true) + y = residual(x_true, training=True) + + # Gradients computed due to reversibility + x, dx, dw, vars_ = residual.backward_grads_and_vars( + y, dy=dy, training=True) + + # True gradients computed by the tape + grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + + self.assertAllClose(x_true, x) + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + + +class _ResidualInnerTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._ResidualInner, self) + _validate_block_call_channels_last(blocks._ResidualInner, self) + + +class _BottleneckResidualInner(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._BottleneckResidualInner, self) + _validate_block_call_channels_last(blocks._BottleneckResidualInner, self) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d4c35bfd21f9d651c4f059c019cf2e585da8b2 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -0,0 +1,116 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script for reading and loading CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +# Global constants describing the CIFAR data set. +IMAGE_HEIGHT = 32 +IMAGE_WIDTH = 32 +NUM_CHANNEL = 3 + + +def get_ds_from_tfrecords(data_dir, + split, + data_aug=True, + batch_size=100, + epochs=None, + shuffle=True, + data_format="channels_first", + num_parallel_calls=12, + prefetch=0, + div255=True, + dtype=tf.float32): + """Returns a tf.train.Dataset object from reading tfrecords. + + Args: + data_dir: Directory of tfrecords + split: "train", "validation", or "test" + data_aug: Apply data augmentation if True + batch_size: Batch size of dataset object + epochs: Number of epochs to repeat the dataset; default `None` means + repeating indefinitely + shuffle: Shuffle the dataset if True + data_format: `channels_first` or `channels_last` + num_parallel_calls: Number of threads for dataset preprocess + prefetch: Buffer size for prefetch + div255: Divide the images by 255 if True + dtype: Data type of images + Returns: + A tf.train.Dataset object + + Raises: + ValueError: Unknown split + """ + + if split not in ["train", "validation", "test", "train_all"]: + raise ValueError("Unknown split {}".format(split)) + + def _parser(serialized_example): + """Parses a single tf.Example into image and label tensors.""" + features = tf.parse_single_example( + serialized_example, + features={ + "image": tf.FixedLenFeature([], tf.string), + "label": tf.FixedLenFeature([], tf.int64), + }) + image = tf.decode_raw(features["image"], tf.uint8) + # Initially reshaping to [H, W, C] does not work + image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH]) + # This is needed for `tf.image.resize_image_with_crop_or_pad` + image = tf.transpose(image, [1, 2, 0]) + + image = tf.cast(image, dtype) + label = tf.cast(features["label"], tf.int32) + + if data_aug: + image = tf.image.resize_image_with_crop_or_pad(image, IMAGE_HEIGHT + 4, + IMAGE_WIDTH + 4) + image = tf.random_crop(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + image = tf.image.random_flip_left_right(image) + + if data_format == "channels_first": + image = tf.transpose(image, [2, 0, 1]) + + if div255: + image /= 255. + + return image, label + + filename = os.path.join(data_dir, split + ".tfrecords") + dataset = tf.data.TFRecordDataset(filename) + dataset = dataset.repeat(epochs) + dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls) + dataset = dataset.prefetch(prefetch) + + if shuffle: + # Find the right size according to the split + size = { + "train": 40000, + "validation": 10000, + "test": 10000, + "train_all": 50000 + }[split] + dataset = dataset.shuffle(size) + + dataset = dataset.batch(batch_size) + + return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..377844ad8fbca92629a4d71f5df2aab67b570c3c --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py @@ -0,0 +1,154 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Read CIFAR data from pickled numpy arrays and writes TFRecords. + +Generates tf.train.Example protos and writes them to TFRecord files from the +python version of the CIFAR dataset downloaded from +https://www.cs.toronto.edu/~kriz/cifar.html. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import tarfile + +from absl import flags +from six.moves import cPickle as pickle +from six.moves import urllib +import tensorflow as tf + +BASE_URL = 'https://www.cs.toronto.edu/~kriz/' +CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz'] +CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES] +CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100'] +EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python'] + + +def download_and_extract(data_dir, file_name, url): + """Download CIFAR if not already downloaded.""" + filepath = os.path.join(data_dir, file_name) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(data_dir): + tf.gfile.MakeDirs(data_dir) + + urllib.request.urlretrieve(url, filepath) + tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir) + return filepath + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _get_file_names(folder): + """Returns the file names expected to exist in the input_dir.""" + assert folder in ['cifar-10', 'cifar-100'] + + file_names = {} + if folder == 'cifar-10': + file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] + file_names['validation'] = ['data_batch_5'] + file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)] + file_names['test'] = ['test_batch'] + else: + file_names['train_all'] = ['train'] + file_names['test'] = ['test'] + # Split in `convert_to_tfrecord` function + file_names['train'] = ['train'] + file_names['validation'] = ['train'] + return file_names + + +def read_pickle_from_file(filename): + with tf.gfile.Open(filename, 'rb') as f: + if sys.version_info >= (3, 0): + data_dict = pickle.load(f, encoding='bytes') + else: + data_dict = pickle.load(f) + return data_dict + + +def convert_to_tfrecord(input_files, output_file, folder): + """Converts files with pickled data to TFRecords.""" + assert folder in ['cifar-10', 'cifar-100'] + + print('Generating %s' % output_file) + with tf.python_io.TFRecordWriter(output_file) as record_writer: + for input_file in input_files: + data_dict = read_pickle_from_file(input_file) + data = data_dict[b'data'] + try: + labels = data_dict[b'labels'] + except KeyError: + labels = data_dict[b'fine_labels'] + + if folder == 'cifar-100' and input_file.endswith('train.tfrecords'): + data = data[:40000] + labels = labels[:40000] + elif folder == 'cifar-100' and input_file.endswith( + 'validation.tfrecords'): + data = data[40000:] + labels = labels[40000:] + + num_entries_in_batch = len(labels) + + for i in range(num_entries_in_batch): + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'image': _bytes_feature(data[i].tobytes()), + 'label': _int64_feature(labels[i]) + })) + record_writer.write(example.SerializeToString()) + + +def main(_): + for file_name, url, folder, extract_folder in zip( + CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS, + EXTRACT_FOLDERS): + print('Download from {} and extract.'.format(url)) + data_dir = os.path.join(FLAGS.data_dir, folder) + download_and_extract(data_dir, file_name, url) + file_names = _get_file_names(folder) + input_dir = os.path.join(data_dir, extract_folder) + + for mode, files in file_names.items(): + input_files = [os.path.join(input_dir, f) for f in files] + output_file = os.path.join(data_dir, mode + '.tfrecords') + try: + os.remove(output_file) + except OSError: + pass + convert_to_tfrecord(input_files, output_file, folder) + + print('Done!') + + +if __name__ == '__main__': + FLAGS = flags.FLAGS + flags.DEFINE_string( + 'data_dir', + default=None, + help='Directory to download, extract and store TFRecords.') + + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3d93fa955a29718fdec52b04500c41f77351dd8d --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -0,0 +1,140 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports CIFAR-10, CIFAR-100, and ImageNet datasets. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +tfe = tf.contrib.eager + + +def get_hparams_cifar_38(): + """RevNet-38 configurations for CIFAR-10/CIFAR-100.""" + + config = tf.contrib.training.HParams() + config.add_hparam("init_filters", 32) + config.add_hparam("init_kernel", 3) + config.add_hparam("init_stride", 1) + config.add_hparam("n_classes", 10) + config.add_hparam("n_rev_blocks", 3) + config.add_hparam("n_res", [3, 3, 3]) + config.add_hparam("filters", [32, 64, 112]) + config.add_hparam("strides", [1, 2, 2]) + config.add_hparam("batch_size", 100) + config.add_hparam("bottleneck", False) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", False) + if tfe.num_gpus() > 0: + config.add_hparam("input_shape", (3, 32, 32)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (32, 32, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 2e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [40000, 60000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3]) + config.add_hparam("max_train_iter", 80000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("log_every", 500) + config.add_hparam("save_every", 500) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 1000) + config.add_hparam("div255", True) + # This is imprecise, when training with validation set, + # we only have 40k images in training data + config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + + return config + + +def get_hparams_cifar_110(): + config = get_hparams_cifar_38() + config.filters = [32, 64, 128] + config.n_res = [9, 9, 9] + + return config + + +def get_hparams_cifar_164(): + config = get_hparams_cifar_38() + config.filters = [32, 64, 128] + config.n_res = [9, 9, 9] + config.use_bottleneck = True + # Due to bottleneck residual blocks + filters = [f * 4 for f in config.filters] + config.filters = filters + + return config + + +def get_hparams_imagenet_56(): + """RevNet-56 configurations for ImageNet.""" + + config = tf.contrib.training.HParams() + config.add_hparam("init_filters", 128) + config.add_hparam("init_kernel", 7) + config.add_hparam("init_stride", 2) + config.add_hparam("n_classes", 1000) + config.add_hparam("n_rev_blocks", 4) + config.add_hparam("n_res", [2, 2, 2, 2]) + config.add_hparam("filters", [128, 256, 512, 832]) + config.add_hparam("strides", [1, 2, 2, 2]) + config.add_hparam("batch_size", 16) + config.add_hparam("bottleneck", True) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", True) + if tf.test.is_gpu_available(): + config.add_hparam("input_shape", (3, 224, 224)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (224, 224, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 1e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [160000, 320000, 480000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4]) + config.add_hparam("max_train_iter", 600000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("log_every", 50) + config.add_hparam("save_every", 50) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 1000) + config.add_hparam("div255", True) + # TODO(lxuechen): Update this according to ImageNet data + config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + # Due to bottleneck residual blocks + filters = [f * 4 for f in config.filters] + config.filters = filters + + return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f43b03f90ef6db01db1f85943e10ce8c9b582a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -0,0 +1,256 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Eager execution workflow with RevNet train on CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import flags +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import cifar_input +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +tfe = tf.contrib.eager + + +def main(_): + """Eager execution workflow with RevNet trained on CIFAR-10.""" + config = get_config() + ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config) + model = revnet.RevNet(config=config) + global_step = tf.train.get_or_create_global_step() # Ensure correct summary + global_step.assign(1) + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + checkpointer = tf.train.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=global_step) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" " + "with global_step: {}".format(latest_path, global_step.numpy())) + sys.stdout.flush() + + if FLAGS.manual_grad: + print("Using manual gradients.") + else: + print("Not using manual gradients.") + sys.stdout.flush() + + for x, y in ds_train: + train_one_iter(model, x, y, optimizer, global_step=global_step) + + if global_step.numpy() % config.log_every == 0: + it_train = ds_train_one_shot.make_one_shot_iterator() + it_test = ds_test.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) + acc_test, loss_test = evaluate(model, it_test) + + if FLAGS.validate: + it_validation = ds_validation.make_one_shot_iterator() + acc_validation, loss_validation = evaluate(model, it_validation) + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "validation set accuracy {:.4f}, loss {:4.f}" + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_validation, + loss_validation, acc_test, loss_test)) + else: + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_test, + loss_test)) + sys.stdout.flush() + + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Training accuracy", acc_train) + tf.contrib.summary.scalar("Test accuracy", acc_test) + tf.contrib.summary.scalar("Training loss", loss_train) + tf.contrib.summary.scalar("Test loss", loss_test) + if FLAGS.validate: + tf.contrib.summary.scalar("Validation accuracy", acc_validation) + tf.contrib.summary.scalar("Validation loss", loss_validation) + + if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" " + "with global_step: {}".format(saved_path, global_step.numpy())) + sys.stdout.flush() + + +def get_config(): + """Return configuration.""" + print("Config: {}".format(FLAGS.config)) + sys.stdout.flush() + config = { + "revnet-38": config_.get_hparams_cifar_38(), + "revnet-110": config_.get_hparams_cifar_110(), + "revnet-164": config_.get_hparams_cifar_164(), + }[FLAGS.config] + + if FLAGS.dataset == "cifar-100": + config.n_classes = 100 + + return config + + +def get_datasets(config): + """Return dataset.""" + if FLAGS.data_dir is None: + raise ValueError("No supplied data directory") + if not os.path.exists(FLAGS.data_dir): + raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) + if FLAGS.dataset not in ["cifar-10", "cifar-100"]: + raise ValueError("Unknown dataset {}".format(FLAGS.dataset)) + + print("Training on {} dataset.".format(FLAGS.dataset)) + sys.stdout.flush() + data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset) + if FLAGS.validate: + # 40k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="train", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + # 10k Training set + ds_validation = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="validation", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + else: + # 50k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="train_all", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + ds_validation = None + + # Always compute loss and accuracy on whole training and test set + ds_train_one_shot = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="train_all", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + + ds_test = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="test", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + + return ds_train, ds_train_one_shot, ds_validation, ds_test + + +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + if FLAGS.manual_grad: + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + else: # For correctness validation + with tf.GradientTape() as tape: + logits, _ = model(inputs, training=True) + loss = model.compute_loss(logits=logits, labels=labels) + tf.logging.info("Logits are placed on device: {}".format(logits.device)) + grads = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + + return loss.numpy() + + +def evaluate(model, iterator): + """Compute accuracy with the given dataset iterator.""" + mean_loss = tfe.metrics.Mean() + accuracy = tfe.metrics.Accuracy() + for x, y in iterator: + logits, _ = model(x, training=False) + loss = model.compute_loss(logits=logits, labels=y) + accuracy( + labels=tf.cast(y, tf.int64), + predictions=tf.argmax(logits, axis=1, output_type=tf.int64)) + mean_loss(loss) + + return accuracy.result().numpy(), mean_loss.result().numpy() + + +if __name__ == "__main__": + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_string( + "train_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_boolean( + "restore", + default=False, + help="[Optional] Restore the latest checkpoint from `train_dir` if True") + flags.DEFINE_boolean( + "validate", + default=False, + help="[Optional] Use the validation set or not for hyperparameter search") + flags.DEFINE_boolean( + "manual_grad", + default=False, + help="[Optional] Use manual gradient graph to save memory") + flags.DEFINE_string( + "dataset", + default="cifar-10", + help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") + flags.DEFINE_string( + "config", default="revnet-38", help="[Optional] Architecture of network.") + FLAGS = flags.FLAGS + tf.enable_eager_execution() + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed5d363e6c8bffd817357c006abee7ac0d1dbba --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Customized basic operations. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def downsample(x, filters, strides, axis=1): + """Downsample feature map with avg pooling, if filter size doesn't match.""" + + def pad_strides(strides, axis=1): + """Convert length 2 to length 4 strides. + + Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations + such as `tf.nn.avg_pool` use length 4 strides. + + Args: + strides: length 2 list/tuple strides for height and width + axis: integer specifying feature dimension according to data format + Returns: + length 4 strides padded with 1 on batch and channel dimension + """ + + assert len(strides) == 2 + + if axis == 1: + return [1, 1, strides[0], strides[1]] + return [1, strides[0], strides[1], 1] + + assert len(x.shape) == 4 and (axis == 1 or axis == 3) + + data_format = "NCHW" if axis == 1 else "NHWC" + strides_ = pad_strides(strides, axis=axis) + + if strides[0] > 1: + x = tf.nn.avg_pool( + x, strides_, strides_, padding="VALID", data_format=data_format) + + in_filter = x.shape[axis] + out_filter = filters + + if in_filter < out_filter: + pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2] + if axis == 1: + x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]]) + else: + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size]) + # In case `tape.gradient(x, [x])` produces a list of `None` + return x + 0. diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc2641faf5a5d26262de683e52e36b1f42b3a7b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic ops used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops +tfe = tf.contrib.eager + + +class OpsTest(tf.test.TestCase): + + def test_downsample(self): + """Test `possible_down_sample` function with mock object.""" + + batch_size = 100 + # NHWC format + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + # HW doesn't change but number of features increased + y = ops.downsample(x, filters=5, strides=(1, 1), axis=3) + self.assertEqual(y.shape, [batch_size, 32, 32, 5]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 3]) + # Number of feature increased and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 5]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + # HW doesn't change but feature map reduced + y = ops.downsample(x, filters=5, strides=(1, 1)) + self.assertEqual(y.shape, [batch_size, 5, 32, 32]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 3, 16, 16]) + # Both feature map and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 5, 16, 16]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d20fa729836b12036d5d54a9b5b0b68d719d2 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Code for main model. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import operator + +import six +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +class RevNet(tf.keras.Model): + """RevNet that depends on all the blocks.""" + + def __init__(self, config): + """Initialize RevNet with building blocks. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(RevNet, self).__init__() + self.axis = 1 if config.data_format == "channels_first" else 3 + self.config = config + + self._init_block = self._construct_init_block() + self._block_list = self._construct_intermediate_blocks() + self._final_block = self._construct_final_block() + + def _construct_init_block(self): + init_block = tf.keras.Sequential( + [ + tf.keras.layers.Conv2D( + filters=self.config.init_filters, + kernel_size=self.config.init_kernel, + strides=(self.config.init_stride, self.config.init_stride), + data_format=self.config.data_format, + use_bias=False, + padding="SAME", + input_shape=self.config.input_shape, + dtype=self.config.dtype), + tf.keras.layers.BatchNormalization( + axis=self.axis, + fused=self.config.fused, + dtype=self.config.dtype), + tf.keras.layers.Activation("relu"), + ], + name="init") + if self.config.init_max_pool: + init_block.add( + tf.keras.layers.MaxPooling2D( + pool_size=(3, 3), + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format, + dtype=self.config.dtype)) + return init_block + + def _construct_final_block(self): + f = self.config.filters[-1] # Number of filters + r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio + r *= self.config.init_stride + if self.config.init_max_pool: + r *= 2 + + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (f, w // r, h // r) + elif self.config.data_format == "channels_last": + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // r, h // r, f) + else: + raise ValueError("Data format should be either `channels_first`" + " or `channels_last`") + + final_block = tf.keras.Sequential( + [ + tf.keras.layers.BatchNormalization( + axis=self.axis, + input_shape=input_shape, + fused=self.config.fused, + dtype=self.config.dtype), + tf.keras.layers.Activation("relu"), + tf.keras.layers.GlobalAveragePooling2D( + data_format=self.config.data_format, dtype=self.config.dtype), + tf.keras.layers.Dense( + self.config.n_classes, dtype=self.config.dtype) + ], + name="final") + return final_block + + def _construct_intermediate_blocks(self): + # Precompute input shape after initial block + stride = self.config.init_stride + if self.config.init_max_pool: + stride *= 2 + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (self.config.init_filters, w // stride, h // stride) + else: + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // stride, h // stride, self.config.init_filters) + + # Aggregate intermediate blocks + block_list = tf.contrib.checkpoint.List() + for i in range(self.config.n_rev_blocks): + # RevBlock configurations + n_res = self.config.n_res[i] + filters = self.config.filters[i] + if filters % 2 != 0: + raise ValueError("Number of output filters must be even to ensure" + "correct partitioning of channels") + stride = self.config.strides[i] + strides = (self.config.strides[i], self.config.strides[i]) + + # Add block + rev_block = blocks.RevBlock( + n_res, + filters, + strides, + input_shape, + batch_norm_first=(i != 0), # Only skip on first block + data_format=self.config.data_format, + bottleneck=self.config.bottleneck, + fused=self.config.fused, + dtype=self.config.dtype) + block_list.append(rev_block) + + # Precompute input shape for the next block + if self.config.data_format == "channels_first": + w, h = input_shape[1], input_shape[2] + input_shape = (filters, w // stride, h // stride) + else: + w, h = input_shape[0], input_shape[1] + input_shape = (w // stride, h // stride, filters) + + return block_list + + def call(self, inputs, training=True): + """Forward pass.""" + + if training: + saved_hidden = [inputs] + + h = self._init_block(inputs, training=training) + if training: + saved_hidden.append(h) + + for block in self._block_list: + h = block(h, training=training) + if training: + saved_hidden.append(h) + + logits = self._final_block(h, training=training) + + return (logits, saved_hidden) if training else (logits, None) + + def compute_loss(self, logits, labels): + """Compute cross entropy loss.""" + + if self.config.dtype == tf.float32 or self.config.dtype == tf.float16: + cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + else: + # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel + # for float64, int32 pairs + labels = tf.one_hot( + labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype) + cross_ent = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + + return tf.reduce_mean(cross_ent) + + def compute_gradients(self, inputs, labels, training=True, l2_reg=True): + """Manually computes gradients. + + When eager execution is enabled, this method also SILENTLY updates the + running averages of batch normalization when `training` is set to True. + + Args: + inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + labels: One-hot labels for classification + training: Use the mini-batch stats in batch norm if set to True + l2_reg: Apply l2 regularization + + Returns: + list of tuples each being (grad, var) for optimizer to use + """ + + # Run forward pass to record hidden states; avoid updating running averages + vars_and_vals = self.get_moving_stats() + _, saved_hidden = self.call(inputs, training=training) + self.restore_moving_stats(vars_and_vals) + + grads_all = [] + vars_all = [] + + # Manually backprop through last block + x = saved_hidden[-1] + with tf.GradientTape() as tape: + x = tf.identity(x) + tape.watch(x) + # Running stats updated below + logits = self._final_block(x, training=training) + loss = self.compute_loss(logits, labels) + + grads_combined = tape.gradient(loss, + [x] + self._final_block.trainable_variables) + dy, grads_ = grads_combined[0], grads_combined[1:] + grads_all += grads_ + vars_all += self._final_block.trainable_variables + + # Manually backprop through intermediate blocks + for block in reversed(self._block_list): + y = saved_hidden.pop() + x = saved_hidden[-1] + dy, grads, vars_ = block.backward_grads_and_vars( + x, y, dy, training=training) + grads_all += grads + vars_all += vars_ + + # Manually backprop through first block + saved_hidden.pop() + x = saved_hidden.pop() + assert not saved_hidden # Cleared after backprop + + with tf.GradientTape() as tape: + x = tf.identity(x) + # Running stats updated below + y = self._init_block(x, training=training) + + grads_all += tape.gradient( + y, self._init_block.trainable_variables, output_gradients=dy) + vars_all += self._init_block.trainable_variables + + # Apply weight decay + if l2_reg: + grads_all = self._apply_weight_decay(grads_all, vars_all) + + return grads_all, vars_all, loss + + def _apply_weight_decay(self, grads, vars_): + """Update gradients to reflect weight decay.""" + # Don't decay bias + return [ + g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g + for g, v in zip(grads, vars_) + ] + + def get_moving_stats(self): + """Get moving averages of batch normalization. + + This is needed to avoid updating the running average twice in one iteration. + + Returns: + A dictionary mapping variables for batch normalization moving averages + to their current values. + """ + vars_and_vals = {} + + def _is_moving_var(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + + for v in filter(_is_moving_var, self.variables): + vars_and_vals[v] = v.read_value() + + return vars_and_vals + + def restore_moving_stats(self, vars_and_vals): + """Restore moving averages of batch normalization. + + This is needed to avoid updating the running average twice in one iteration. + + Args: + vars_and_vals: The dictionary mapping variables to their previous values. + """ + for var_, val in six.iteritems(vars_and_vals): + var_.assign(val) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ac4b67c926951672996df5564b9b57def0ea13 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -0,0 +1,332 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks_test +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +from tensorflow.python.client import device_lib +tfe = tf.contrib.eager + + +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + return loss + + +class RevNetTest(tf.test.TestCase): + + def setUp(self): + super(RevNetTest, self).setUp() + config = config_.get_hparams_cifar_38() + # Reconstruction could cause numerical error, use double precision for tests + config.dtype = tf.float64 + config.fused = False # Fused batch norm does not support tf.float64 + shape = (config.batch_size,) + config.input_shape + self.model = revnet.RevNet(config=config) + self.x = tf.random_normal(shape=shape, dtype=tf.float64) + self.t = tf.random_uniform( + shape=[config.batch_size], + minval=0, + maxval=config.n_classes, + dtype=tf.int64) + self.config = config + + def tearDown(self): + del self.model + del self.x + del self.t + del self.config + super(RevNetTest, self).tearDown() + + def test_call(self): + """Test `call` function.""" + + y, _ = self.model(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def _check_grad_angle_combined(self, grads, grads_true): + """Verify that the reconstructed gradients has correct direction. + + Due to numerical imprecision, the magnitude may be slightly different. + Yet according to the paper, the angle should be roughly the same. + + Args: + grads: list of gradients from reconstruction + grads_true: list of true gradients + """ + + def _combine(gs): + return [tf.reshape(g, [-1]) for g in gs] + + g1_all = tf.concat(_combine(grads), axis=0) + g2_all = tf.concat(_combine(grads_true), axis=0) + + self.assertEqual(len(g1_all.shape), 1) + self.assertEqual(len(g2_all.shape), 1) + + degree = blocks_test.compute_degree(g1_all, g2_all) + self.assertLessEqual(degree, 1e0) + + def test_compute_gradients(self): + """Test `compute_gradients` function.""" + self.model(self.x, training=False) # Initialize model + grads, vars_, loss = self.model.compute_gradients( + inputs=self.x, labels=self.t, training=True, l2_reg=True) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + self.assertEqual(grad.shape, var.shape) + + # Compare against the true gradient computed by the tape + with tf.GradientTape() as tape: + logits, _ = self.model(self.x, training=True) + loss_true = self.model.compute_loss(logits=logits, labels=self.t) + grads_true = tape.gradient(loss_true, vars_) + self.assertAllClose(loss, loss_true) + self.assertAllClose(grads, grads_true, rtol=1e-4, atol=1e-4) + self._check_grad_angle_combined(grads, grads_true) + + def test_call_defun(self): + """Test `call` function with defun.""" + y, _ = tfe.defun(self.model.call)(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def test_compute_gradients_defun(self): + """Test `compute_gradients` function with defun.""" + compute_gradients = tfe.defun(self.model.compute_gradients) + grads, vars_, _ = compute_gradients(self.x, self.t, training=True) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + if grad is not None: + self.assertEqual(grad.shape, var.shape) + + def test_training_graph(self): + """Test model training in graph mode.""" + with tf.Graph().as_default(): + config = config_.get_hparams_cifar_38() + x = tf.random_normal( + shape=(self.config.batch_size,) + self.config.input_shape) + t = tf.random_uniform( + shape=(self.config.batch_size,), + minval=0, + maxval=self.config.n_classes, + dtype=tf.int32) + global_step = tfe.Variable(0., trainable=False) + model = revnet.RevNet(config=config) + model(x) + updates = model.get_updates_for(x) + + x_ = tf.identity(x) + grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True) + optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) + with tf.control_dependencies(updates): + train_op = optimizer.apply_gradients( + zip(grads_all, vars_all), global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(1): + sess.run(train_op) + + +# Benchmark related +def device_and_data_format(): + return ("/gpu:0", + "channels_first") if tf.test.is_gpu_available() else ("/cpu:0", + "channels_last") + + +def random_batch(batch_size, config): + shape = (batch_size,) + config.input_shape + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32) + + return images, labels + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class RevNetBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for RevNet.""" + + def _train_batch_sizes(self): + """Shamelessly copied from `resnet50_test.py`. + + Note: This is targeted towards ImageNet. CIFAR-10 should allow more + aggressive batch sizes. + + Returns: + A tuple of possible batch sizes + """ + for device in device_lib.list_local_devices(): + if tf.DeviceSpec.from_string(device.name).device_type == "GPU": + if "K20" in device.physical_device_desc: + return (16,) + if "P100" in device.physical_device_desc: + return (16, 32, 64) + if tf.DeviceSpec.from_string(device.name).device_type == "TPU": + return (32,) + return (16, 32) + + def _force_device_sync(self): + """Shamelessly copied from `resnet50_test.py`.""" + tf.constant(1.).cpu() + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = tf.DeviceSpec.from_string(device).device_type.lower() + name = "%s_%s_batch_%d_%s" % (label, dev, batch_size, data_format) + extras = {"examples_per_sec": batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _benchmark_eager_apply(self, + label, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + model = revnet.RevNet(config=config) + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 10 + with tf.device(device): + images, _ = random_batch(batch_size, config) + for _ in range(num_burn): + model(images, training=False) + if execution_mode: + tfe.async_wait() + gc.collect() + start = time.time() + for _ in range(num_iters): + model(images, training=False) + if execution_mode: + tfe.async_wait() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply( + "eager_apply_sync", device_and_data_format(), defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + "eager_apply_async", + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_call_defun(self): + self._benchmark_eager_apply( + "eager_apply_with_defun", device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, config) + model = revnet.RevNet(config=config) + optimizer = tf.train.GradientDescentOptimizer(0.1) + if defun: + model.call = tfe.defun(model.call) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in range(num_burn): + (images, labels) = iterator.next() + train_one_iter(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in range(num_iters): + (images, labels) = iterator.next() + train_one_iter(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train( + "eager_train_sync", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + "eager_train_async", + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_train_defun(self): + self._benchmark_eager_train( + "eager_train", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device("/device:CPU:0"): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + "eager_train_dataset_with_defun", + make_iterator, + device_and_data_format(), + defun=True) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 492adbe1d80941f9df96d6636e4933d11239408e..5ee2176154ec7011dcb3d7b384a86213e778014f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -152,7 +152,7 @@ class RNNColorbot(tf.keras.Model): self.label_dimension = label_dimension self.keep_prob = keep_prob - self.cells = self._add_cells( + self.cells = tf.contrib.checkpoint.List( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") @@ -204,14 +204,6 @@ class RNNColorbot(tf.keras.Model): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of layers.Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - def loss(labels, predictions): """Computes mean squared loss.""" diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py index 75b342ba78bd5de5c2827296f6fba01ffa86d560..b7d8395e277b526ba40ccafa323ba453a8667b62 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -67,5 +67,5 @@ class RNNColorbotTest(tf.test.TestCase): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index be5d60449d7e08c99cc28e76befce56f468c77fd..c2340a293a80924f2dfa90e2fb23134b0f1feb6b 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -50,7 +50,7 @@ class RNN(tf.keras.Model): def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - self.cells = self._add_cells([ + self.cells = tf.contrib.checkpoint.List([ tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) for _ in range(num_layers) ]) @@ -74,14 +74,6 @@ class RNN(tf.keras.Model): # tuple (output, output_states). return [input_seq] - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - class Embedding(layers.Layer): """An Embedding layer.""" @@ -304,7 +296,7 @@ def test_model(use_cudnn_rnn): def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() if not FLAGS.data_path: raise ValueError("Must specify --data-path") diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b470a41d815ce650731680065cc7341f844e3fdc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "sagan", + srcs = ["sagan.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "small", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "sagan_test", + size = "large", + srcs = ["sagan_test.py"], + additional_deps = [ + ":config", + ":sagan", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1967bbd867447d9deaf9a7cb3b22a38889276a50 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/config.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports default 128x128 ImageNet. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +tfe = tf.contrib.eager + + +def get_hparams_imagenet(): + """Configurations to train SAGAN on 128x128 ImageNet dataset.""" + config = tf.contrib.training.HParams() + if tf.test.is_gpu_available(): + config.add_hparam("image_shape", (3, 128, 128)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (512, 4, 4)) + else: + config.add_hparam("image_shape", (128, 128, 3)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (4, 4, 512)) + + config.add_hparam("latent_dim", 128) + config.add_hparam("update_g_once_every", 1) + config.add_hparam("batch_size", 64) + config.add_hparam("d_init_filters", 32) + config.add_hparam("num_upsamples", 5) + # (512, 4, 4) -> (3, 128, 128) + return config + + +def get_hparams_mock(): + """Configurations of smaller networks for testing.""" + config = tf.contrib.training.HParams() + if tf.test.is_gpu_available(): + config.add_hparam("image_shape", (3, 16, 16)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (32, 2, 2)) + else: + config.add_hparam("image_shape", (16, 16, 3)) + config.add_hparam("data_format", "channels_last") + config.add_hparam("g_init_shape", (2, 2, 32)) + + config.add_hparam("latent_dim", 16) + config.add_hparam("update_g_once_every", 1) + config.add_hparam("batch_size", 2) + config.add_hparam("d_init_filters", 4) + config.add_hparam("num_upsamples", 3) + # (32, 2, 2) -> (3, 16, 16) + return config diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9a03cab1d12fc16baa7343f72ac58ccd39f698bc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/ops.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Auxiliary operations. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def flatten_hw(x, data_format="channels_first"): + """Flatten the input tensor across height and width dimensions.""" + if data_format == "channels_last": + x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` + + old_shape = tf.shape(x) + new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]] + + return tf.reshape(x, new_shape) + + +def broaden_hw(x, h, w, c, data_format="channels_first"): + """Broaden dimension so that output has height and width.""" + if data_format == "channels_first": + shape = [-1, c, h, w] + else: + shape = [-1, h, w, c] + + return tf.reshape(x, shape) + + +class BroadenHW(tf.keras.layers.Layer): + """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`.""" + + def __init__(self, h, w, c, data_format="channels_first"): + super(BroadenHW, self).__init__() + self.h = h + self.w = w + self.c = c + self.data_format = data_format + + def call(self, x): + return broaden_hw( + x, h=self.h, w=self.w, c=self.c, data_format=self.data_format) + + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + if self.data_format == "channels_first": + output_shape = (input_shape[0], self.c, self.h, self.w) + else: + output_shape = (input_shape[0], self.h, self.w, self.c) + + return tf.TensorShape(output_shape) diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3454985904215b59d27fc4b76ccb4a8c2c2eff00 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py @@ -0,0 +1,59 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for auxiliary operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import ops + + +class OpsTest(tf.test.TestCase): + + def test_flatten_hw(self): + """Test `flatten_hw` function with mock object.""" + + batch_size = 1 + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=(batch_size, 3, 4, 4)) + y = ops.flatten_hw(x, data_format="channels_first") + self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) + + # NHWC format + x = tf.random_normal(shape=(batch_size, 4, 4, 3)) + y = ops.flatten_hw(x, data_format="channels_last") + self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) + + def test_broaden_hw(self): + """Test `broaden_hw` function with mock object.""" + + batch_size = 1 + # NHWC format + x = tf.random_normal(shape=[batch_size, 4 * 4 * 16]) + y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last") + self.assertEqual(y.shape, (batch_size, 4, 4, 16)) + + # Default NCHW format + if tf.test.is_gpu_available(): + y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first") + self.assertEqual(y.shape, (batch_size, 16, 4, 4)) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py new file mode 100644 index 0000000000000000000000000000000000000000..561be36c911d7145e2d4a5ed12eccd8ceb054f45 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py @@ -0,0 +1,232 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Code for main model. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import ops +tfe = tf.contrib.eager + + +class SelfAttentionModule(tf.keras.Model): + """Self-attention module composed of convolutional layers.""" + + def __init__(self, + attention_features, + original_features, + data_format="channels_first"): + """Initialize the module. + + Args: + attention_features: Number of filters for the attention computation. + original_features: Number of filters of the original Tensor. + data_format: Either 'channels_first' or 'channels_last' + """ + super(SelfAttentionModule, self).__init__() + self.data_format = data_format + # Matrix multiplication implemented as 2D Convolution + self.f = tf.keras.layers.Conv2D( + filters=attention_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.g = tf.keras.layers.Conv2D( + filters=attention_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.h = tf.keras.layers.Conv2D( + filters=original_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.scale = tfe.Variable(0., trainable=True) + + def call(self, x): + f = self.f(x) + g = self.g(x) + h = self.h(x) + + f_flatten = ops.flatten_hw(f, data_format=self.data_format) + g_flatten = ops.flatten_hw(g, data_format=self.data_format) + h_flatten = ops.flatten_hw(h, data_format=self.data_format) + + s = tf.matmul(g_flatten, f_flatten, transpose_b=True) + b = tf.nn.softmax(s, axis=-1) + o = tf.matmul(b, h_flatten) + y = self.scale * tf.reshape(o, tf.shape(x)) + x + + return y + + def compute_output_shape(self, input_shape): + return input_shape + + +class SAGAN(tf.contrib.checkpoint.Checkpointable): + """Self-attention generative adversarial network.""" + + def __init__(self, config): + """Initialize the model. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(SAGAN, self).__init__() + self.config = config + self.generator = self._construct_generator() + self.discriminator = self._construct_discriminator() + + def _construct_generator(self): + """Construct generator.""" + # TODO(lxuechen): Add spectral normalization for WGAN + axis = 1 if self.config.data_format == "channels_first" else 3 + + generator = tf.keras.Sequential() + generator.add( + tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,))) + generator.add( + tf.keras.layers.Dense( + units=np.prod(self.config.g_init_shape), activation=tf.nn.relu)) + + if self.config.data_format == "channels_first": + c, h, w = self.config.g_init_shape + else: + h, w, c = self.config.g_init_shape + + # Reshape to NHWC/NCHW + generator.add( + ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format)) + + filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)] + filters_list[-1] = 3 # Standard RGB images + + for filters in filters_list[:len(filters_list) // 2]: + generator.add( + tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=4, + strides=(2, 2), + use_bias=False, + padding="SAME", + data_format=self.config.data_format)) + generator.add(tf.keras.layers.BatchNormalization(axis=axis)) + generator.add(tf.keras.layers.Activation("relu")) + + # pylint: disable=undefined-loop-variable + generator.add( + SelfAttentionModule( + original_features=filters, + attention_features=filters // 8, + data_format=self.config.data_format)) + # pylint: enable=undefined-loop-variable + + for filters in filters_list[len(filters_list) // 2:]: + generator.add( + tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=4, + strides=(2, 2), + use_bias=False, + padding="SAME", + data_format=self.config.data_format)) + if filters == 3: + # Assume Image rescaled to [-1, 1] + generator.add(tf.keras.layers.Activation("tanh")) + else: + generator.add(tf.keras.layers.BatchNormalization(axis=axis)) + generator.add(tf.keras.layers.Activation("relu")) + + return generator + + def _construct_discriminator(self): + """Construct discriminator.""" + # TODO(lxuechen): Add spectral normalization for WGAN + discriminator = tf.keras.Sequential() + discriminator.add( + tf.keras.layers.InputLayer(input_shape=self.config.image_shape)) + + filters_list = [ + self.config.d_init_filters * 2**p + for p in range(self.config.num_upsamples) + ] + + for filters in filters_list[:(len(filters_list) + 1) // 2]: + discriminator.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=4, + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) + + # pylint: disable=undefined-loop-variable + discriminator.add( + SelfAttentionModule( + original_features=filters, + attention_features=filters // 8, + data_format=self.config.data_format)) + # pylint: enable=undefined-loop-variable + + for filters in filters_list[(len(filters_list) + 1) // 2:]: + discriminator.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=4, + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) + + discriminator.add(tf.keras.layers.Flatten()) + discriminator.add(tf.keras.layers.Dense(units=1)) + + return discriminator + + def compute_loss_and_grads(self, real_images, noise, training=True): + """Compute loss and gradients for both generator and discriminator.""" + # TODO(lxuechen): Add gradient penalty for discriminator + with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: + real_logits = self.discriminator(real_images, training=training) + + fake_images = self.generator.call(noise, training=training) + fake_logits = self.discriminator.call(fake_images) + + g_loss = self.compute_g_loss(fake_logits) + d_loss = self.compute_d_loss(fake_logits, real_logits) + + g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables) + d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables) + + return g_loss, d_loss, g_grads, d_grads + + def compute_g_loss(self, fake_logits): + return -tf.reduce_mean(fake_logits) # Hinge loss + + def compute_d_loss(self, fake_logits, real_logits): + # Hinge loss + real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits)) + fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits)) + return real_loss + fake_loss diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18345945108111b57c5401c26b7dca0bfc8f8316 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for self-attention generative adversarial network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import config as config_ +from tensorflow.contrib.eager.python.examples.sagan import sagan +tfe = tf.contrib.eager + + +class SAGANTest(tf.test.TestCase): + + def setUp(self): + super(SAGANTest, self).setUp() + config = config_.get_hparams_mock() + self.noise_shape = (config.batch_size, config.latent_dim) + self.logits_shape = (config.batch_size, 1) + self.images_shape = (config.batch_size,) + config.image_shape + + self.model = sagan.SAGAN(config=config) + self.noise = tf.random_normal(shape=self.noise_shape) + self.real_images = tf.random_normal(shape=self.images_shape) + self.config = config + + def tearDown(self): + del self.model + del self.noise + del self.real_images + super(SAGANTest, self).tearDown() + + def test_generator_call(self): + """Test `generator.__call__` function.""" + fake_images = self.model.generator(self.noise, training=False) + self.assertEqual(fake_images.shape, self.images_shape) + + def test_generator_call_defun(self): + """Test `generator.__call__` function with defun.""" + call_ = tfe.defun(self.model.generator.__call__) + fake_images = call_(self.noise, training=False) + self.assertEqual(fake_images.shape, self.images_shape) + + def test_discriminator_call(self): + """Test `discriminator.__call__` function.""" + real_logits = self.model.discriminator(self.real_images) + self.assertEqual(real_logits.shape, self.logits_shape) + + def test_discriminator_call_defun(self): + """Test `discriminator.__call__` function with defun.""" + call_ = tfe.defun(self.model.discriminator.__call__) + real_logits = call_(self.real_images) + self.assertEqual(real_logits.shape, self.logits_shape) + + def test_compute_loss_and_grads(self): + """Test `compute_loss_and_grads` function.""" + g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads( + self.real_images, self.noise, training=False) + self.assertEqual(g_loss.shape, ()) + self.assertEqual(d_loss.shape, ()) + self.assertTrue(isinstance(g_grads, list)) + self.assertTrue(isinstance(d_grads, list)) + g_vars = self.model.generator.trainable_variables + d_vars = self.model.discriminator.trainable_variables + + self.assertEqual(len(g_grads), len(g_vars)) + self.assertEqual(len(d_grads), len(d_vars)) + + def test_compute_loss_and_grads_defun(self): + """Test `compute_loss_and_grads` function with defun.""" + compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads) + g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads( + self.real_images, self.noise, training=False) + self.assertEqual(g_loss.shape, ()) + self.assertEqual(d_loss.shape, ()) + self.assertTrue(isinstance(g_grads, list)) + self.assertTrue(isinstance(d_grads, list)) + g_vars = self.model.generator.trainable_variables + d_vars = self.model.discriminator.trainable_variables + + self.assertEqual(len(g_grads), len(g_vars)) + self.assertEqual(len(d_grads), len(d_vars)) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py index 4661dafbed19c753da71b749d176c78bd25de1e2..d4b8c8941ec411912f3089315d038fc4bcd049ae 100644 --- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py @@ -38,20 +38,17 @@ class ScanBenchmark(tf.test.Benchmark): iters=n, wall_time=wall_time) + def benchmarkScan16000(self): + self.runScan(16000) + def benchmarkScan32000(self): self.runScan(32000) - def benchmarkScan1M(self): - self.runScan(1000000) - - def benchmarkScan2M(self): - self.runScan(2000000) - - def benchmarkScan4M(self): - self.runScan(4000000) + def benchmarkScan64000(self): + self.runScan(64000) - def benchmarkScan8M(self): - self.runScan(8000000) + def benchmarkScan128000(self): + self.runScan(128000) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py index b8c7cf1fe5bcb7f25b2fa72bdd40ed625de87931..a02fc24c79dae6c2565db8b138b1d7391d169ed8 100644 --- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py +++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py @@ -36,21 +36,19 @@ class ScanBenchmark(tf.test.Benchmark): iters=n, wall_time=wall_time) - def benchmarkScan2000(self): - self.runScan(2000) - - def benchmarkScan4000(self): - self.runScan(4000) - - def benchmarkScan8000(self): - self.runScan(8000) - def benchmarkScan16000(self): self.runScan(16000) def benchmarkScan32000(self): self.runScan(32000) + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + + if __name__ == '__main__': tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3e7abe952d63610b14967d41be0a36430fcd29c6 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb @@ -0,0 +1,282 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "TFE Workshop: control flow", + "version": "0.3.2", + "provenance": [], + "include_colab_link": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/664b2f8700485ff6801f4d26293bd567/tfe-workshop-control-flow.ipynb)" + ] + }, + { + "metadata": { + "id": "9BpQzh9BvJlj", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "0b336886-8204-4815-89fa-5291a49d5784" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "tf.enable_eager_execution()" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "0roIB19GvOjI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Eager execution basics\n", + "\n", + "When eager execution is enabled TensorFlow immediately executes operations, and Tensors are always available. " + ] + }, + { + "metadata": { + "id": "jeO8F-V-vN24", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "aeb3bdec-50b7-440d-93d8-5a171f091081" + }, + "cell_type": "code", + "source": [ + "t = tf.constant([[1, 2], [3, 4]])\n", + "t" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + } + ] + }, + { + "metadata": { + "id": "Y17RwSFxvlDL", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "cfcc10c7-707b-4997-99b3-a5f382c5166b" + }, + "cell_type": "code", + "source": [ + "tf.matmul(t, t)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "metadata": { + "id": "Dab1bS3TvmRE", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "8a624f3d-a658-4359-c586-1c5f6bf4c8b7" + }, + "cell_type": "code", + "source": [ + "# It's also possible to have Python control flow which depends on the value of tensors.\n", + "if t[0, 0] > 0.5:\n", + " print(\"T is bigger\")\n", + "else:\n", + " print(\"T is smaller\")" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "T is bigger\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "dPgptJcGwIon", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "c4f27f2b-0848-4475-dde5-2534dac65a5c" + }, + "cell_type": "code", + "source": [ + "# Tensors are also usable as numpy arrays\n", + "np.prod(t)" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "24" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "metadata": { + "id": "p3DTfQXnwXzj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Exercise\n", + "\n", + "The algorithm for bisecting line search is a pretty simple way to find a zero of a continuous scalar function in an interval [a,b] where f(a) and f(b) have different signs. Simply evaluate f((a+b)/2), and narrow the interval by replacing either a or b with (a+b)/2 such that the function when applied on the boundary of the interval still has different signs.\n", + "\n", + "Implement a python function `bisecting_line_search(f, a, b, epsilon)` which returns a value such that `tf.abs(f(value)) < epsilon`.\n", + "\n", + "One thing to keep in mind: python's `==` opertor is not overloaded on Tensors, so you need to use `tf.equal` to compare for equality." + ] + }, + { + "metadata": { + "id": "6eq0YuI6ykm5", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Example test harness to get you going\n", + "\n", + "def test_f(x):\n", + " return x - 0.1234\n", + "def bisecting_line_search(f, a, b, epsilon):\n", + " # Return x such that f(x) <= epsilon.\n", + " pass\n", + "a = tf.constant(0.0)\n", + "b = tf.constant(1.0)\n", + "epsilon = tf.constant(0.001)\n", + "x = bisecting_line_search(test_f, a, b, epsilon)\n", + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "LcMmEfd_xvej", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, + "outputId": "f402aa50-8ce3-4416-f755-8bbcd1af7809" + }, + "cell_type": "code", + "source": [ + "#@title Double-click to see the solution\n", + "\n", + "def bisecting_line_search(f, a, b, epsilon):\n", + " f_a = f(a)\n", + " f_b = f(b)\n", + " probe = (a + b) / 2\n", + " f_probe = f(probe)\n", + " while tf.abs(f_probe) > epsilon:\n", + " if tf.equal(tf.sign(f_probe), tf.sign(f_a)):\n", + " a = probe\n", + " f_a = f_probe\n", + " else:\n", + " b = probe\n", + " f_b = f_probe\n", + " probe = (a + b) / 2\n", + " f_probe = f(probe)\n", + " print(\"new probe\", probe)\n", + " return probe\n", + "\n", + "bisecting_line_search(test_f, 0., 1., 0.001)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "('new probe', 0.25)\n", + "('new probe', 0.125)\n", + "('new probe', 0.0625)\n", + "('new probe', 0.09375)\n", + "('new probe', 0.109375)\n", + "('new probe', 0.1171875)\n", + "('new probe', 0.12109375)\n", + "('new probe', 0.123046875)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.123046875" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4f1410e00bb986f68f3c4c8494aa97bf66284510 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb @@ -0,0 +1,1018 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "TFE Workshop: Models.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/5cfcffd408bd5103f5ae747bc97ab0b5/tfe-workshop-models.ipynb)" + ] + }, + { + "metadata": { + "id": "BMxv1O6Q0SJL", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "8be9c556-ac7f-4142-e35e-19dc2b097121" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "lE1vJhxp0WR9", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Variables\n", + "\n", + "TensorFlow variables are useful to store the state in your program. They are integrated with other parts of the API (taking gradients, checkpointing, graph functions)." + ] + }, + { + "metadata": { + "id": "C4ztQNgc0VpW", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "8b63ae1f-2670-49c0-a31b-8cf7fc4194a1" + }, + "cell_type": "code", + "source": [ + "# Creating variables\n", + "v = tfe.Variable(1.0)\n", + "v" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + } + ] + }, + { + "metadata": { + "id": "H0daItGg1IAp", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "e47d5aab-16a1-4e29-c27d-7fbc0b94b5d3" + }, + "cell_type": "code", + "source": [ + "v.assign_add(1.0)\n", + "v" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "metadata": { + "id": "BJvBzcIG1hyK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "metadata": { + "id": "iSQTS3QW1YQQ", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "c5d8aa10-dcad-44f7-f0eb-0faf5249fd7e" + }, + "cell_type": "code", + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))\n" + ], + "execution_count": 4, + "outputs": [] + }, + { + "metadata": { + "id": "nRuUogoS1liV", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "c352ce79-d519-45e4-a12e-1eaba76871a2" + }, + "cell_type": "code", + "source": [ + "layer(tf.zeros([2, 2]))" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 5 + } + ] + }, + { + "metadata": { + "id": "JH4Kf4ka1mht", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + }, + "outputId": "c34e2378-f83d-42c5-d30a-ebe55620368a" + }, + "cell_type": "code", + "source": [ + "layer.variables" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "metadata": { + "id": "DSI4NF0_1vn-", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "metadata": { + "id": "hMgDBftJ12Bp", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model.\n" + ] + }, + { + "metadata": { + "id": "K3gVY6gj1nbe", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 190 + }, + "outputId": "6e9be0c4-960e-46c2-cdd9-7e94ad09d46b" + }, + "cell_type": "code", + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]\n", + "\n", + " [[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block/conv2d/kernel:0', 'resnet_identity_block/conv2d/bias:0', 'resnet_identity_block/batch_normalization/gamma:0', 'resnet_identity_block/batch_normalization/beta:0', 'resnet_identity_block/conv2d_1/kernel:0', 'resnet_identity_block/conv2d_1/bias:0', 'resnet_identity_block/batch_normalization_1/gamma:0', 'resnet_identity_block/batch_normalization_1/beta:0', 'resnet_identity_block/conv2d_2/kernel:0', 'resnet_identity_block/conv2d_2/bias:0', 'resnet_identity_block/batch_normalization_2/gamma:0', 'resnet_identity_block/batch_normalization_2/beta:0', 'resnet_identity_block/batch_normalization/moving_mean:0', 'resnet_identity_block/batch_normalization/moving_variance:0', 'resnet_identity_block/batch_normalization_1/moving_mean:0', 'resnet_identity_block/batch_normalization_1/moving_variance:0', 'resnet_identity_block/batch_normalization_2/moving_mean:0', 'resnet_identity_block/batch_normalization_2/moving_variance:0']\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "LPXhHUIc1-sO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "metadata": { + "id": "5pXgzNAU17xk", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 173 + }, + "outputId": "03b7eaf8-9b35-482b-bcf0-a99af6c2c6a4" + }, + "cell_type": "code", + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))\n" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "metadata": { + "id": "MZrns6p22GEQ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Exercise!\n", + "\n", + "Make a simple convolutional neural network model, useful for things such as MNIST which don't need too many parameters. A sequence of two or three convolutions with small output channels (say, 32 and 64) plus one or two fully connected layers is probably enough.\n", + "\n", + "The input shape should be [batch_size, 28, 28, 1]." + ] + }, + { + "metadata": { + "id": "8CAUa3KNN916", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "97c0ff3c-c962-4c13-eee8-406101465761" + }, + "cell_type": "code", + "source": [ + "# TODO: Implement a convolutional model as described above, and assign it to\n", + "# model.\n", + "model = tf.keras.Sequential([\n", + " \n", + "])" + ], + "execution_count": 9, + "outputs": [] + }, + { + "metadata": { + "id": "vLDDduR32E82", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "09bb1d43-b4c6-44b5-916e-0d2903d10cf4" + }, + "cell_type": "code", + "source": [ + "#@title Click to see the answer\n", + "\n", + "max_pool = tf.keras.layers.MaxPooling2D(\n", + " (2, 2), (2, 2), padding='same')\n", + " # The model consists of a sequential chain of layers, so tf.keras.Sequential\n", + " # (a subclass of tf.keras.Model) makes for a compact description.\n", + "model = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.Conv2D(\n", + " 32,\n", + " 5,\n", + " padding='same',\n", + " activation=tf.nn.relu),\n", + " max_pool,\n", + " tf.keras.layers.Conv2D(\n", + " 64,\n", + " 5,\n", + " padding='same',\n", + " activation=tf.nn.relu),\n", + " max_pool,\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(1024, activation=tf.nn.relu),\n", + " tf.keras.layers.Dropout(0.4),\n", + " tf.keras.layers.Dense(10)\n", + " ])\n", + "\n", + "model(tf.zeros([1, 28, 28, 1]))" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 10 + } + ] + }, + { + "metadata": { + "id": "H_CKVBroik4M", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Stop here for now" + ] + }, + { + "metadata": { + "id": "_yRwuE6MMmzC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Training\n", + "\n", + "When eager execution is enabled, you can write Pythonic training loops. Simply\n", + "\n", + "1. load your data into a `tf.data.Dataset`, which lets you construct functional pipelines for processing, shuffling, and batching your data,\n", + "2. iterate over the dataset using a Python `for` loop, and\n", + "3. perform an optimization step in the body of your `for` loop.\n", + "\n", + "This workflow is exemplified in the following exercise." + ] + }, + { + "metadata": { + "id": "gj0-EkTc_Xt1", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "\n", + "\n", + "## Exercise!\n", + "\n", + "In this exercise, you'll train the convolutional model you implemented for the previous exericse on the MNIST dataset. " + ] + }, + { + "metadata": { + "id": "WOGm9HHn_byR", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "bbccc7ad-33cd-446e-bcda-f358c7547e1b" + }, + "cell_type": "code", + "source": [ + "#@title Utilities for downloading MNIST data (double-click to show code)\n", + "import gzip\n", + "import os\n", + "import tempfile\n", + "from six.moves import urllib\n", + "import shutil\n", + "\n", + "import numpy as np\n", + "\n", + "def read32(bytestream):\n", + " \"\"\"Read 4 bytes from bytestream as an unsigned 32-bit integer.\"\"\"\n", + " dt = np.dtype(np.uint32).newbyteorder('>')\n", + " return np.frombuffer(bytestream.read(4), dtype=dt)[0]\n", + "\n", + "\n", + "def check_image_file_header(filename):\n", + " \"\"\"Validate that filename corresponds to images for the MNIST dataset.\"\"\"\n", + " with tf.gfile.Open(filename, 'rb') as f:\n", + " magic = read32(f)\n", + " read32(f) # num_images, unused\n", + " rows = read32(f)\n", + " cols = read32(f)\n", + " if magic != 2051:\n", + " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n", + " f.name))\n", + " if rows != 28 or cols != 28:\n", + " raise ValueError(\n", + " 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %\n", + " (f.name, rows, cols))\n", + "\n", + "\n", + "def check_labels_file_header(filename):\n", + " \"\"\"Validate that filename corresponds to labels for the MNIST dataset.\"\"\"\n", + " with tf.gfile.Open(filename, 'rb') as f:\n", + " magic = read32(f)\n", + " read32(f) # num_items, unused\n", + " if magic != 2049:\n", + " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n", + " f.name))\n", + " \n", + "def download(directory, filename):\n", + " \"\"\"Download (and unzip) a file from the MNIST dataset if not already done.\"\"\"\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " # CVDF mirror of http://yann.lecun.com/exdb/mnist/\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " _, zipped_filepath = tempfile.mkstemp(suffix='.gz')\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, \\\n", + " tf.gfile.Open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " \"\"\"Download and parse MNIST dataset.\"\"\"\n", + "\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " check_image_file_header(images_file)\n", + " check_labels_file_header(labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [28, 28, 1])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]\n", + " label = tf.reshape(label, []) # label is a scalar\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def get_training_data(directory):\n", + " \"\"\"tf.data.Dataset object for MNIST training data.\"\"\"\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte').take(1024)\n", + "\n", + "def get_test_data(directory):\n", + " \"\"\"tf.data.Dataset object for MNIST test data.\"\"\"\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ], + "execution_count": 11, + "outputs": [] + }, + { + "metadata": { + "id": "4ejmJ2dv_f0R", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "outputId": "274c0381-e505-4e69-f910-3def6f8572a7" + }, + "cell_type": "code", + "source": [ + "# Don't forget to run the cell above!\n", + "training_data = get_training_data(\"/tmp/mnist/train\")\n", + "test_data = get_test_data(\"/tmp/mnist/test\")" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/tmp4ull1xwa.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/tmp1eikhj1v.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/tmpcp8xah9c.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/tmpqww_1e74.gz\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "TANpFS6GKLMC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Fill in the implementation of `train_one_epoch` below and run the cell to train your model. " + ] + }, + { + "metadata": { + "id": "btKL0Ss9_rmC", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "56858516-86fc-424a-f00d-6f088f98bf9b" + }, + "cell_type": "code", + "source": [ + "EPOCHS = 5\n", + "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n", + "\n", + "def loss_fn(logits, labels):\n", + " return tf.reduce_mean(\n", + " tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=tf.squeeze(logits), labels=labels))\n", + "\n", + "def train_one_epoch(model, training_data, optimizer):\n", + " # TODO: Implement an optimization step and return the average loss.\n", + " #\n", + " # Hint: Use `tf.GradientTape` to compute the gradient of the loss, and use\n", + " # `optimizer.apply_gradients` to update the model's variables, which are\n", + " # accessible as `model.variables`\n", + " average_loss = tfe.metrics.Mean('loss')\n", + " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n", + " pass\n", + " return average_loss.result()\n", + "\n", + "for epoch in range(EPOCHS):\n", + " loss = train_one_epoch(model, training_data, optimizer)\n", + " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Average loss after epoch 0: 2.2847\n", + "Average loss after epoch 1: 2.2305\n", + "Average loss after epoch 2: 2.1334\n", + "Average loss after epoch 3: 1.9115\n", + "Average loss after epoch 4: 1.4285\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "yAOFupJN_htg", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "67e711e4-76c9-4e3f-bb49-a14955dba03a" + }, + "cell_type": "code", + "source": [ + "#@title Double-click to see a solution.\n", + "EPOCHS = 5\n", + "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n", + "\n", + "def _loss_fn(logits, labels):\n", + " return tf.reduce_mean(\n", + " tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=tf.squeeze(logits), labels=labels))\n", + "\n", + "def _train_one_epoch(model, training_data):\n", + " average_loss = tfe.metrics.Mean(\"loss\")\n", + " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n", + " with tf.GradientTape() as tape:\n", + " logits = model(images, training=True)\n", + " loss = _loss_fn(logits, labels)\n", + " average_loss(loss)\n", + " gradients = tape.gradient(loss, model.variables)\n", + " optimizer.apply_gradients(zip(gradients, model.variables))\n", + " return average_loss.result()\n", + " \n", + "for epoch in range(EPOCHS):\n", + " loss = _train_one_epoch(model, training_data)\n", + " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Average loss after epoch 0: 1.0563\n", + "Average loss after epoch 1: 0.8013\n", + "Average loss after epoch 2: 0.6306\n", + "Average loss after epoch 3: 0.5543\n", + "Average loss after epoch 4: 0.5037\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "uDy1DrYA_2Jz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Run the below cell to qualitatively evaluate your model. Note how eager execution interoperates seamlessly with `matplotlib`." + ] + }, + { + "metadata": { + "id": "vR7rMtpu_3nB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1752 + }, + "outputId": "b212aefa-f4b3-425c-f34d-2491429fa521" + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "sampled_data = test_data.batch(1).shuffle(buffer_size=10000).take(5)\n", + "for image, label in sampled_data:\n", + " plt.figure()\n", + " plt.imshow(tf.reshape(image, (28, 28)))\n", + " plt.show()\n", + " logits = model(image, training=False)\n", + " prediction = tf.argmax(logits, axis=1, output_type=tf.int64)\n", + " print(\"Prediction: %d\" % prediction)" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEwpJREFUeJzt3X1Ilff/x/HXmScxV2GZOmLVohXK\nKmLQjbUsy+pbI7rbaEm1IFhRSU1aE+kO3LqxCGrBMlsNkq0zZIM2Cu1mUTg1itXQbVnBQqKZNtcN\n2d3J3x9ffpLrNN/ndM65jn6fj7/m5cfrvI9XPHedc7zOcTU3NzcLAPCvXnJ6AABoD4glABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAG7kB/cOPGjbpw4YJcLpdyc3M1ZMiQYM4FABEloFieOXNGV69e\nlcfj0ZUrV5SbmyuPxxPs2QAgYgT0MLy8vFwZGRmSpP79++vWrVu6e/duUAcDgEgSUCwbGhrUvXv3\nlq979Oih+vr6oA0FAJEmKC/w8F4cADq6gGKZmJiohoaGlq9v3LihhISEoA0FAJEmoFiOHj1aJSUl\nkqTq6molJiaqS5cuQR0MACJJQK+Gv/nmm3rjjTf03nvvyeVyaf369cGeCwAiios3/wWAtnEFDwAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMHA7\nPQAQiAcPHpjX3rlzx+f2nj17qqGhodW2kydPmvb566+/mm//xx9/NK+13r4kjRgx4pltFRUVGjly\nZKttP/30k3mfL73E+dPz8JsBAIOAziwrKyu1YsUKDRgwQJI0cOBArV27NqiDAUAkCfhh+PDhw7Vz\n585gzgIAEYuH4QBgEHAsL1++rCVLlmju3LkqKysL5kwAEHFczc3Nzf7+UF1dnc6dO6cpU6aotrZW\nCxYsUGlpqaKjo0MxIwA4LqDnLJOSkjR16lRJUp8+fdSzZ0/V1dWpd+/eQR0OeB7+dIg/HQq3gH4z\nhw4d0hdffCFJqq+v182bN5WUlBTUwQAgkgR0Zjl+/HitWrVKx48f16NHj7RhwwYeggPo0AKKZZcu\nXbR79+5gzwIAESugF3gAf1RVVZnXfvfdd6Z1hw8fNu/zzJkzPrd7vV5FRUWZ99Me+LpPDx8+NP98\nR/t9BBPP5gKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAM+3RGtPO/qV5fL\n1ep7BQUF5n1mZWWZ1z558sS8NhRcLpdpnT9vZebPJYT9+vUzry0pKfG5/Y8//mj1NW+7Fhz8FgHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgCt40MrBgwd9bp87d26r7y1btsy8z1de\necW89q233jKte//99837/Dfff/99q68TExNNP/fqq6+ab8Of+x8MvXv3Duvt/a/gzBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4mp/3CVXoMB49emRe+/rrr/vcfvXqVfXt\n27fl68zMTPM+P/74Y/PauLg481ognDizBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABny6YztVX19vXjthwgTz2oEDB5q+l5eXZ96n223/Z/b48WPTuuvXr5v3efz4cZ/bFy5c\nqC+//NK8n0CNHTvWvLZfv34hnAQvwnRmWVNTo4yMDBUVFUn67z/U+fPnKzMzUytWrNDDhw9DOiQA\nOK3NWN67d095eXlKTU1t2bZz505lZmbqq6++Ut++fVVcXBzSIQHAaW3GMjo6WoWFha0+fL6ysrLl\noV16errKy8tDNyEARIA2n0xyu93PPOfU1NSk6OhoSVJ8fLxfz58BQHv0wi/w8HaYzkhISDCv/eWX\nX4Jym0ePHg3Kfv6N9cWg3r17m/e5cOHCgL4HPC2gWMbGxur+/fuKiYlRXV1dq4foCI9QvRqelJTk\nc/vRo0c1ceLElq+PHDli3ievhvNqeEcQ0N9Zjho1SiUlJZKk0tJSjRkzJqhDAUCkafN/+VVVVdqy\nZYuuXbsmt9utkpISbdu2TTk5OfJ4POrVq5dmzJgRjlkBwDFtxnLQoEE6cODAM9v3798fkoEAIBLx\ngWXt1A8//GBeO3v2bPPa572Ik5aWplOnTrV8ff78efM+J02aZF5rnfX333837/N5vF6voqKiAvrZ\nd99917x20KBB5rWrVq0yr42JiTGvxYvj2nAAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGDA5Y7tlD+X23377bcvfHv/vDTQn7cS8+ft1NLS0kzr/Ln/o0aN8rk9OTn5mcsmO3Xq\nZNrn7du3zbc/YsQI89q9e/ea1y5YsMC8Fi+OM0sAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyI\nJQAYEEsAMCCWAGDQ5kfhIjItXrzYvHb06NHmtRcvXnzu9z744IOW//bnUruhQ4ea11ovN3S7g/NP\nNzk5OaCfe/qTLtvi9XrNa/351E4udwwvziwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIAreNqpjIyMkKz9N59//nlQ9tMRPHjwwOkREGacWQKAAbEEAANiCQAGxBIADIglABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADEyxrKmpUUZGhoqKiiRJOTk5mjZtmubPn6/58+fr5MmT\noZwRABzX5rsO3bt3T3l5eUpNTW21PTs7W+np6SEbDAAiSZtnltHR0SosLFRiYmI45gGAiNTmmaXb\n7Zbb/eyyoqIi7d+/X/Hx8Vq7dq169OgRkgGBSDRx4kTzWq/XG8JJEC4Bvfnv9OnTFRcXp5SUFO3Z\ns0e7du3SunXrgj0bELGOHj1qXvuf//zHvHb27Nnmtd988415LV5cQK+Gp6amKiUlRZI0fvx41dTU\nBHUoAIg0AcUyKytLtbW1kqTKykoNGDAgqEMBQKRp82F4VVWVtmzZomvXrsntdqukpETz5s3TypUr\n1blzZ8XGxmrTpk3hmBUAHNNmLAcNGqQDBw48s33y5MkhGQgAIhGf7ggEgAsx/vdwuSMAGBBLADAg\nlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgckcgAKdPnw7JfqdNmxaS/eLFcWYJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAZcwQM85dSpU6Z1P//8s3mfL7/8snntuHHj\nzGsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAAy53RIf3999/+9we\nFxf3zPcyMjJM+/R6vebbP3jwoHlt7969zWsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAAy53DIMnT56Y1+bm5prWbdiwwbzPmJgY89r24u7du+a1b7/9ts/tZWVlz3zP\nehnjO++8Y7792bNnm9cicplimZ+fr3Pnzunx48davHixBg8erNWrV8vr9SohIUFbt25VdHR0qGcF\nAMe0GcuKigpdunRJHo9HjY2NmjlzplJTU5WZmakpU6Zo+/btKi4uVmZmZjjmBQBHtPmc5bBhw7Rj\nxw5JUrdu3dTU1KTKykpNmDBBkpSenq7y8vLQTgkADmszllFRUYqNjZUkFRcXKy0tTU1NTS0Pu+Pj\n41VfXx/aKQHAYeYXeI4dO6bi4mLt27dPkyZNatne3NwcksE6kpdesv/RwebNm0M4ScfRpUsX89qy\nsrKAvgc8zRTL06dPa/fu3dq7d6+6du2q2NhY3b9/XzExMaqrq1NiYmKo52zXeDU8+Px5NXzy5Mk+\nt5eVlWn06NGttlVUVJj26c+r4V9//bV5rT//Y0V4tXlk7ty5o/z8fBUUFCguLk6SNGrUKJWUlEiS\nSktLNWbMmNBOCQAOa/PM8vDhw2psbNTKlStbtm3evFlr1qyRx+NRr169NGPGjJAOCQBOazOWc+bM\n0Zw5c57Zvn///pAMBACRyNXMKzQh58+HW1n/uP/TTz817zM7Ozvotx8qv/32m2nd0qVLzfs8deqU\nz+1er1dRUVHm/TyturravDY5OTmg20Bk4dlkADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgwOWOYeDP5Y4JCQmmdbdu3TLvc+LEiea148aN87k9Jycn4PfavH//vnntJ598Ylrn\nzz/bbt26+dze2Nio7t27t9p28eJF0z6tx0mSXC6XeS0iF2eWAGBALAHAgFgCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgMsdI0xxcbFp3bJly8z7bGhoCHScFi/ySYj++Oflh88zefJk8z4/\n+ugjn9uHDh2q8+fPP7MN8IUzSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAw4Aqe\ndqqmpsa8Njs727z2yJEjPre/yBU8q1evNq8dPHiwaV1mZmZAswCB4swSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYcLkjABi4LYvy8/N17tw5PX78WIsXL9aJEydUXV2tuLg4\nSdKiRYs0bty4UM4JAI5qM5YVFRW6dOmSPB6PGhsbNXPmTI0cOVLZ2dlKT08Px4wA4Lg2Yzls2DAN\nGTJEktStWzc1NTXJ6/WGfDAAiCR+PWfp8Xh09uxZRUVFqb6+Xo8ePVJ8fLzWrl2rHj16hHJOAHCU\nOZbHjh1TQUGB9u3bp6qqKsXFxSklJUV79uzRn3/+qXXr1oV6VgBwjOlPh06fPq3du3ersLBQXbt2\nVWpqqlJSUiRJ48eP9+uNaAGgPWozlnfu3FF+fr4KCgpaXv3OyspSbW2tJKmyslIDBgwI7ZQA4LA2\nX+A5fPiwGhsbtXLlypZts2bN0sqVK9W5c2fFxsZq06ZNIR0SAJzGH6UDgAGXOwKAAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4nbjRjRs3\n6sKFC3K5XMrNzdWQIUOcGCOoKisrtWLFCg0YMECSNHDgQK1du9bhqQJXU1OjpUuXauHChZo3b56u\nX7+u1atXy+v1KiEhQVu3blV0dLTTY/rln/cpJydH1dXViouLkyQtWrRI48aNc3ZIP+Xn5+vcuXN6\n/PixFi9erMGDB7f74yQ9e79OnDjh+LEKeyzPnDmjq1evyuPx6MqVK8rNzZXH4wn3GCExfPhw7dy5\n0+kxXti9e/eUl5en1NTUlm07d+5UZmampkyZou3bt6u4uFiZmZkOTukfX/dJkrKzs5Wenu7QVC+m\noqJCly5dksfjUWNjo2bOnKnU1NR2fZwk3/dr5MiRjh+rsD8MLy8vV0ZGhiSpf//+unXrlu7evRvu\nMfAvoqOjVVhYqMTExJZtlZWVmjBhgiQpPT1d5eXlTo0XEF/3qb0bNmyYduzYIUnq1q2bmpqa2v1x\nknzfL6/X6/BUDsSyoaFB3bt3b/m6R48eqq+vD/cYIXH58mUtWbJEc+fOVVlZmdPjBMztdismJqbV\ntqamppaHc/Hx8e3umPm6T5JUVFSkBQsW6MMPP9Rff/3lwGSBi4qKUmxsrCSpuLhYaWlp7f44Sb7v\nV1RUlOPHypHnLJ/W3Nzs9AhB8dprr2n58uWaMmWKamtrtWDBApWWlrbL54va0lGO2fTp0xUXF6eU\nlBTt2bNHu3bt0rp165wey2/Hjh1TcXGx9u3bp0mTJrVsb+/H6en7VVVV5fixCvuZZWJiohoaGlq+\nvnHjhhISEsI9RtAlJSVp6tSpcrlc6tOnj3r27Km6ujqnxwqa2NhY3b9/X5JUV1fXIR7OpqamKiUl\nRZI0fvx41dTUODyR/06fPq3du3ersLBQXbt27TDH6Z/3KxKOVdhjOXr0aJWUlEiSqqurlZiYqC5d\nuoR7jKA7dOiQvvjiC0lSfX29bt68qaSkJIenCp5Ro0a1HLfS0lKNGTPG4YleXFZWlmprayX99znZ\n//9Lhvbizp07ys/PV0FBQcurxB3hOPm6X5FwrFzNDpyrb9u2TWfPnpXL5dL69euVnJwc7hGC7u7d\nu1q1apVu376tR48eafny5Ro7dqzTYwWkqqpKW7Zs0bVr1+R2u5WUlKRt27YpJydHDx48UK9evbRp\n0yZ16tTJ6VHNfN2nefPmac+ePercubNiY2O1adMmxcfHOz2qmcfj0WeffaZ+/fq1bNu8ebPWrFnT\nbo+T5Pt+zZo1S0VFRY4eK0diCQDtDVfwAIABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwOD/\nAKCzFeFbFn4BAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 5\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEQ1JREFUeJzt3W9Ilff/x/HXSSd2VmKaRwiqjTBy\nq9gfap2iliaFQfRvsCXW1rpRRJGTCJG0MSHLIpbF8M9qN3L7cjZvNQiOVAQt7LQcBLqB1Y0QaXYs\naUa2mZ3fjS9ff7Vcvj2ec65jez7ueZ1P57wPlzy7Li8vjysUCoUEAHihcU4PAABjAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcAgMdx/uH//fl27dk0ul0ulpaWaO3duJOcCgLgSViyvXLmiW7du\nyefz6ebNmyotLZXP54v0bAAQN8I6DW9ublZeXp4kacaMGbp//74ePHgQ0cEAIJ6EFcvu7m5NmjRp\n8Ou0tDQFg8GIDQUA8SYiF3j4WxwAXnZhxdLj8ai7u3vw6zt37igjIyNiQwFAvAkrlosWLZLf75ck\ntbW1yePxaMKECREdDADiSVhXw9955x29+eab+uijj+RyubRv375IzwUAccXFH/8FgOFxBw8AGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMEp0eAIgnP/30k2nd+vXrzc+Zl5dnXvvtt9+a1yK2OLIEAANiCQAGxBIADIgl\nABgQSwAwIJYAYEAsAcCAWAKAAbEEAAPu4AGecuzYMdO6YDBofk6XyxXuOIgjHFkCgEFYR5aBQEC7\ndu1SVlaWJGnmzJkqKyuL6GAAEE/CPg2fP3++qqurIzkLAMQtTsMBwCDsWN64cUPbtm3Thg0bdOnS\npUjOBABxxxUKhUIj/UddXV1qaWlRfn6+Ojo6tGnTJjU1NSkpKSkaMwKA48L6mWVmZqZWrlwpSZo2\nbZomT56srq4uTZ06NaLDAbH24Ycfmtb98MMP5ucsKCgwr21oaDCvRWyFdRp++vRpnThxQtJ/f9/s\n7t27yszMjOhgABBPwjqyzM3N1e7du3Xu3Dn19/fr888/5xQcwEstrFhOmDBBNTU1kZ4FAOIWtzsC\nT7lw4ULEn3PVqlURf07EHr9nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDLjdES89v98/5PYVK1Y899hIPrXRqre3N+LPidjjyBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLiDB2NSKBQyr21oaBhy+4oVK/7xsUh6++23o/4aiD6OLAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgIErNJL7xoA40dnZaV47derUIbc/efJE48aFd7zw7rvv\nmtf+/PPPYb0G4gtHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBPd8SY\nVFlZ6ejrb9682dHXR+yZjizb29uVl5c3+LGht2/f1saNG1VQUKBdu3bpr7/+iuqQAOC0YWP58OFD\nVVRUyOv1Dm6rrq5WQUGBvvvuO02fPl2NjY1RHRIAnDZsLJOSklRfXy+PxzO4LRAIaNmyZZKknJwc\nNTc3R29CAIgDw/7MMjExUYmJzy7r6+tTUlKSJCk9PV3BYDA60wFAnBj1BR7+HCaccPz48YisffLk\nSSTGwb9AWLF0u9169OiRkpOT1dXV9cwpOhALO3bsMK/96quvhtw+mj/+O5JYb9++PazXQHwJ6ztl\n4cKF8vv9kqSmpiYtXrw4okMBQLwZ9siytbVVBw8eVGdnpxITE+X3+3X48GGVlJTI5/NpypQpWrNm\nTSxmBQDHDBvL2bNn69SpU89t/+abb6IyEADEI+7gQVyxXnCJ1oeAWX/+XlhYGJXXR/zi3nAAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDA7Y6IKxUVFaZ10brd8dVXXzWt6+3t\nNT9nSkpKuOMgjnBkCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADLjdEXHl\nyy+/dPT1BwYGTOv8fr/5OT/99NNwx0Ec4cgSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAy4gwdR99tvv5nXjuSDwKzcbrf5sV9++cX0nGlpaaOaCWMPR5YAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYEAsAcCA2x0RFusHe0kj+xCyJ0+ehDPOC507d878GLcx4p9wZAkA\nBqZYtre3Ky8vTw0NDZKkkpISrVq1Shs3btTGjRt14cKFaM4IAI4b9jT84cOHqqiokNfrfWZ7cXGx\ncnJyojYYAMSTYY8sk5KSVF9fL4/HE4t5ACAuuUKhUMiy8NixY5o0aZIKCwtVUlKiYDCo/v5+paen\nq6ysjB+MA3iphXU1fPXq1UpNTVV2drbq6up0/PhxlZeXR3o2xLGRXA3fvn27eW19fX0447xQc3Pz\nkNvfe+89BQKB57YBQwnrarjX61V2drYkKTc3V+3t7REdCgDiTVix3Llzpzo6OiRJgUBAWVlZER0K\nAOLNsKfhra2tOnjwoDo7O5WYmCi/36/CwkIVFRVp/PjxcrvdqqysjMWsAOCYYWM5e/ZsnTp16rnt\nK1asiMpAABCPzFfDgafdu3fPvHby5MkRf/0PPvjAvPY///nPkNsTEhKeu1CVkJAwqrnw8uJ2RwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMCnO+IZ//TpiuPGjXvmsc2bN0fl\n9V0ul2ndF198YX7OF93CyO2NsOLIEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\nuIMHz/jf58H/3fTp05957Mcff4zK6xcWFprWzZo1KyqvD/wTjiwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABtzviGRcuXBhy+8cff/zMY6FQKCqvX15eHpXnBUaLI0sAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDgCkXrvjXEjV9//dW8ds6cOUNuHxgY\nUEJCwuDXI/m2Wb9+vXmtz+czrRs3jv/nEVume8OrqqrU0tKix48fa+vWrZozZ4727NmjgYEBZWRk\n6NChQ0pKSor2rADgmGFjefnyZV2/fl0+n089PT1au3atvF6vCgoKlJ+fryNHjqixsVEFBQWxmBcA\nHDHsucy8efN09OhRSVJKSor6+voUCAS0bNkySVJOTo6am5ujOyUAOGzYWCYkJMjtdkuSGhsbtWTJ\nEvX19Q2edqenpysYDEZ3SgBwmPnvWZ49e1aNjY06efKkli9fPrid60Px74033jCvHRgYCOsx4GVn\niuXFixdVU1Ojr7/+WhMnTpTb7dajR4+UnJysrq4ueTyeaM+JUeBqODB6w37H9fb2qqqqSrW1tUpN\nTZUkLVy4UH6/X5LU1NSkxYsXR3dKAHDYsEeWZ86cUU9Pj4qKiga3HThwQHv37pXP59OUKVO0Zs2a\nqA4JAE7jl9L/BTgNB0aPDyz7F7AGSHpxBJ9+LCUlxfycJ06cMK8lgohXfGcCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADbnf8F7hx44Z5rfV2x+TkZPNzjuTWSCBecWQJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMuN3xX6C4uNi89vvvv//HxxIT///b\n5a233hrVTMBYw5ElABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4Qi/6hCoAgCSO\nLAHAhFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAAD06c7VlVVqaWlRY8f\nP9bWrVt1/vx5tbW1KTU1VZK0ZcsWLV26NJpzAoCjho3l5cuXdf36dfl8PvX09Gjt2rVasGCBiouL\nlZOTE4sZAcBxw8Zy3rx5mjt3riQpJSVFfX19GhgYiPpgABBPRvQn2nw+n65evaqEhAQFg0H19/cr\nPT1dZWVlSktLi+acAOAocyzPnj2r2tpanTx5Uq2trUpNTVV2drbq6ur0+++/q7y8PNqzAoBjTFfD\nL168qJqaGtXX12vixInyer3Kzs6WJOXm5qq9vT2qQwKA04aNZW9vr6qqqlRbWzt49Xvnzp3q6OiQ\nJAUCAWVlZUV3SgBw2LAXeM6cOaOenh4VFRUNblu3bp2Kioo0fvx4ud1uVVZWRnVIAHAan8EDAAbc\nwQMABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGCQ6MSL7t+/X9euXZPL5VJpaanmzp3rxBgRFQgEtGvXLmVlZUmSZs6cqbKyMoenCl97e7u2\nb9+uTz75RIWFhbp9+7b27NmjgYEBZWRk6NChQ0pKSnJ6zBH5+3sqKSlRW1ubUlNTJUlbtmzR0qVL\nnR1yhKqqqtTS0qLHjx9r69atmjNnzpjfT9Lz7+v8+fOO76uYx/LKlSu6deuWfD6fbt68qdLSUvl8\nvliPERXz589XdXW102OM2sOHD1VRUSGv1zu4rbq6WgUFBcrPz9eRI0fU2NiogoICB6ccmaHekyQV\nFxcrJyfHoalG5/Lly7p+/bp8Pp96enq0du1aeb3eMb2fpKHf14IFCxzfVzE/DW9ublZeXp4kacaM\nGbp//74ePHgQ6zHwAklJSaqvr5fH4xncFggEtGzZMklSTk6OmpubnRovLEO9p7Fu3rx5Onr0qCQp\nJSVFfX19Y34/SUO/r4GBAYenciCW3d3dmjRp0uDXaWlpCgaDsR4jKm7cuKFt27Zpw4YNunTpktPj\nhC0xMVHJycnPbOvr6xs8nUtPTx9z+2yo9yRJDQ0N2rRpkz777DPdu3fPgcnCl5CQILfbLUlqbGzU\nkiVLxvx+koZ+XwkJCY7vK0d+Zvm0UCjk9AgR8dprr2nHjh3Kz89XR0eHNm3apKampjH586LhvCz7\nbPXq1UpNTVV2drbq6up0/PhxlZeXOz3WiJ09e1aNjY06efKkli9fPrh9rO+np99Xa2ur4/sq5keW\nHo9H3d3dg1/fuXNHGRkZsR4j4jIzM7Vy5Uq5XC5NmzZNkydPVldXl9NjRYzb7dajR48kSV1dXS/F\n6azX61V2drYkKTc3V+3t7Q5PNHIXL15UTU2N6uvrNXHixJdmP/39fcXDvop5LBctWiS/3y9Jamtr\nk8fj0YQJE2I9RsSdPn1aJ06ckCQFg0HdvXtXmZmZDk8VOQsXLhzcb01NTVq8eLHDE43ezp071dHR\nIem/P5P9328yjBW9vb2qqqpSbW3t4FXil2E/DfW+4mFfuUIOHKsfPnxYV69elcvl0r59+zRr1qxY\njxBxDx480O7du/XHH3+ov79fO3bs0Pvvv+/0WGFpbW3VwYMH1dnZqcTERGVmZurw4cMqKSnRn3/+\nqSlTpqiyslKvvPKK06OaDfWeCgsLVVdXp/Hjx8vtdquyslLp6elOj2rm8/l07Ngxvf7664PbDhw4\noL17947Z/SQN/b7WrVunhoYGR/eVI7EEgLGGO3gAwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAg\nlgBg8H/nb4OLnfGqVAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 1\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE1ZJREFUeJzt3X1olfX/x/HXccc1DyrLuY1GaRGL\nRqZSaE7zZmqKgnhDsVwqkYGRE29QW8tp4M102solNJ03fzSqgyPoBmFDIlg1Jw0xNsrZDbKGranD\nG5x3x33/+NF+rp153js751znrOfjv13n43Xex4NPrrPL61yujo6ODgEA7muA0wMAQCwglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAMiCUAGLiD/YM7duzQ6dOn5XK5lJ+fr9GjR4dyLgCIKkHF8uTJkzp3\n7py8Xq9+++035efny+v1hno2AIgaQX0Mr6mp0cyZMyVJjz/+uC5fvqxr166FdDAAiCZBxfLChQt6\n8MEHO38eNmyYWltbQzYUAESbkJzg4bs4APR3QcUyJSVFFy5c6Pz577//VnJycsiGAoBoE1QsJ02a\npMrKSklSQ0ODUlJSNHjw4JAOBgDRJKiz4c8884yeeuopvfzyy3K5XNqyZUuo5wKAqOLiy38BIDCu\n4AEAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJ\nAAbEEgAMiCUAGLiD+UO1tbVavXq10tPTJUlPPPGECgoKQjoYAESToGIpSePHj1dJSUkoZwGAqMXH\ncAAwCDqWv/76q9544w0tXrxY33//fShnAoCo4+ro6Ojo7R9qaWlRXV2d5syZo6amJi1btkxVVVWK\nj48Px4wA4LigjixTU1M1d+5cuVwujRgxQsOHD1dLS0uoZwOAqBFULL/88ksdOnRIktTa2qqLFy8q\nNTU1pIMBQDQJ6mP4tWvXtH79el25ckW3b99Wbm6upk6dGo75ACAqBBVLAPivCfr/WQL90alTp0zr\nSktLzfssKysLdpz78nec09HRIZfL1WVbbm6ueZ+9+b/T/36e/o7/ZwkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAy4Nhz93tmzZ/1uT09P7/bY4sWLTfu0XhYZaT6fT3FxcUH/\n+Vu3bpnX9uV5YhFHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgwA3LEHa9uUjs\nzJkzpnXz588377Opqcnv9uvXr2vMmDFdtt28edO8Xyu32/7PrKCgwLw2Pj7e7/bCwsIuPz/77LPm\nfQ4YwPFTT/ibAQADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtywDEG5ffu2\nee1bb71lXrt3795gxgmKv5t7PfTQQ6Y/u3r1avPzLF++3Lz2yJEj5rW5ubndtj3wwAPdLtl84IEH\nzPtEzziyBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzdEV3cvXvX7/YB\nAwZ0eSwvL8+8z0hewujPokWLzI999NFHpn16PB7z8y9evNi89uuvvzavbW5u7ratuLhYb7/9drdt\n6DvTkWVjY6Nmzpyp8vJySdL58+e1dOlS5eTkaPXq1bp161ZYhwQApwWM5fXr17V161ZlZmZ2bisp\nKVFOTo4++eQTjRw5UhUVFWEdEgCcFjCW8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpqa8E0IAFEg\n4O8s3W633O6uy9rb2xUfHy9JSkpKUmtra3imA4Ao0ecTPHwdZv8yYEDPHzbufey9994z77M3ayPt\n6NGjYX+OL774IuzPcS9O6IRHULH0eDy6ceOGEhIS1NLS0uUjOmKb9Wz4hg0bzPv84IMP+jxXX/R0\nNvzo0aN66aWXumyLpbPh/r6AuLi4WOvWreu2DX0X1P+znDhxoiorKyVJVVVVmjx5ckiHAoBoE/DI\nsr6+Xrt27VJzc7PcbrcqKyu1Z88e5eXlyev1Ki0tTQsWLIjErADgmICxHDVqlD7++ONu23tzrxAA\niHVcwfMf8Ndff5nXzpo1y+/2n376SWPHju38uaGhoc9z+TN06FDTutLSUvM+X3zxxR4f++yzz7r8\nfL8TXPf69NNPzc/fm99D9kZaWlqvtqNvuDYcAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYuDr4QsqYdPXqVfPaUaNGmdf++eeffrf7fD7FxcWZ93Ovf75V3+LQoUOmdY888khQ\nswRivUVKdnZ2WJ7/ny/Vtjh16lS3bU8++aR++eWXbtvQdxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA+7uGKPKy8vNa3u6hLEvlixZYl67Z88e89rk5GTTupaWFvM+X3/9\ndb/bv/rqK82bN6/LtsrKSvN+w6E3d43s6TJGLm8MD44sAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCAG5ZFmbt375rWvfDCC+Z9fvvtt+a1Pd0wq729XYMGDer8ubGx0bzPtLQ089qf\nf/7ZtG7Dhg3mfVZVVfnd3pebsIXLjRs3zGsHDhwYxknwbxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA25YFmWsV5/25hLG3vD5fKbHiouLzfv8448/zGu/+uor89pYsWDB\nAvPaaLv8Ev+PI0sAMDDFsrGxUTNnzuy8/WpeXp7mzZunpUuXaunSpWE7ygGAaBHwY/j169e1detW\nZWZmdtm+bt06ZWVlhW0wAIgmAY8s4+PjVVZWppSUlEjMAwBRKeCRpdvtltvdfVl5ebmOHDmipKQk\nFRQUaNiwYWEZ8L/G+gv++52ICZdbt25F/DnDzYm/R8SmoM6Gz58/X4mJicrIyNCBAwe0b98+bd68\nOdSz/SdZ//H29CW9fdVTrG/dutXlOVeuXGneZ7SeDY/Ul//25mz40aNHzWsHDOD8bCQF9bedmZmp\njIwMSdL06dN79a3ZABCLgorlqlWr1NTUJEmqra1Venp6SIcCgGgT8GN4fX29du3apebmZrndblVW\nVmrJkiVas2aNBg0aJI/Ho8LCwkjMCgCOCRjLUaNG6eOPP+62ffbs2WEZCACiEZc7ogvr5Y4lJSWR\nGKdf6M0JHk7aRC/eGQAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMDljlHG\nernbsWPHzPvszeV24fiC3958MfT69etN6/Lz84MdJyS2bdtmXvvKK6+EcRJECkeWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAFTxRxuVymdb15u6ap06dMq+9dOlSj49VV1eb93Ov\nsWPHmtfW1dUF9RyhMmbMGNO6lStXmvfJTcj6B95FADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg4Oro6Ohwegj0b21tbea1kyZNMq07c+ZMsON08vl8iouL67Lthx9+MP3Z5557\nrs/Pj9jCkSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDg7o4Iu5MnT5rX\nhuIyxn/Ly8szPzZ+/PiQPz/6B1Msi4qKVFdXpzt37mjFihV6+umntXHjRvl8PiUnJ2v37t2Kj48P\n96wA4JiAsTxx4oTOnj0rr9ertrY2LVy4UJmZmcrJydGcOXNUXFysiooK5eTkRGJeAHBEwN9Zjhs3\nTnv37pUkDR06VO3t7aqtrdWMGTMkSVlZWaqpqQnvlADgsICxjIuLk8fjkSRVVFRoypQpam9v7/zY\nnZSUpNbW1vBOCQAOM5/gOX78uCoqKnT48GHNmjWrcztfh4lAZs+ebV7r8/nCOEl327dvj+jzIXaZ\nYlldXa3S0lIdPHhQQ4YMkcfj0Y0bN5SQkKCWlhalpKSEe07EsMrKSvPauXPnhvz5ezobvn37dr3z\nzjtdtm3bts20T5fL1ee5EFsCfgy/evWqioqKtH//fiUmJkqSJk6c2PkPoKqqSpMnTw7vlADgsIBH\nlseOHVNbW5vWrFnTuW3nzp3atGmTvF6v0tLStGDBgrAOCQBOCxjL7OxsZWdnd9t+5MiRsAwEANGI\nG5YhKL25CVlGRoZ5bTj+Z8Xvv//ud/vIkSN17ty5btsAf7g2HAAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGHDDMgSlrKzMvDYclzDm5uaa16alpQX1GHAvjiwBwIBYAoABsQQA\nA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABlzuiizt37vjd7na7uzz2+eefh+X5V61aZVr3\n/vvvm/fpcrl6fGzgwIHm/eC/jSNLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADBw\ndXR0dDg9BKLHd99953f7888/3+WxqVOnmvf58MMPm9eeOXPGtC4hIcG8TyAUOLIEAANiCQAGxBIA\nDIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG3LAMXQwZMiSox+5ny5Yt5rVcxohoZYplUVGR\n6urqdOfOHa1YsULffPONGhoalJiYKElavny5pk2bFs45AcBRAWN54sQJnT17Vl6vV21tbVq4cKEm\nTJigdevWKSsrKxIzAoDjAsZy3LhxGj16tCRp6NCham9vl8/nC/tgABBNAp7giYuLk8fjkSRVVFRo\nypQpiouLU3l5uZYtW6a1a9fq0qVLYR8UAJxk/j7L48ePa//+/Tp8+LDq6+uVmJiojIwMHThwQH/9\n9Zc2b94c7lkBwDGmEzzV1dUqLS3VwYMHNWTIEGVmZnY+Nn36dL377rvhmg8Rdvr0ab/bx4wZ0+Wx\nZ555xrzPsrIy89rXXnvNvBaIpIAfw69evaqioiLt37+/8+z3qlWr1NTUJEmqra1Venp6eKcEAIcF\nPLI8duyY2tratGbNms5tixYt0po1azRo0CB5PB4VFhaGdUgAcFrAWGZnZys7O7vb9oULF4ZlIACI\nRlzuCAAG3N0RAAw4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQA\nA2IJAAbEEgAM3E486Y4dO3T69Gm5XC7l5+dr9OjRTowRUrW1tVq9erXS09MlSU888YQKCgocnip4\njY2NevPNN/Xqq69qyZIlOn/+vDZu3Cifz6fk5GTt3r1b8fHxTo/ZK/9+TXl5eWpoaFBiYqIkafny\n5Zo2bZqzQ/ZSUVGR6urqdOfOHa1YsUJPP/10zL9PUvfX9c033zj+XkU8lidPntS5c+fk9Xr122+/\nKT8/X16vN9JjhMX48eNVUlLi9Bh9dv36dW3dulWZmZmd20pKSpSTk6M5c+aouLhYFRUVysnJcXDK\n3vH3miRp3bp1ysrKcmiqvjlx4oTOnj0rr9ertrY2LVy4UJmZmTH9Pkn+X9eECRMcf68i/jG8pqZG\nM2fOlCQ9/vjjunz5sq5duxbpMXAf8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpoap8YLir/XFOvG\njRunvXv3SpKGDh2q9vb2mH+fJP+vy+fzOTyVA7G8cOGCHnzwwc6fhw0bptbW1kiPERa//vqr3njj\nDS1evFjff/+90+MEze12KyEhocu29vb2zo9zSUlJMfee+XtNklReXq5ly5Zp7dq1unTpkgOTBS8u\nLk4ej0eSVFFRoSlTpsT8+yT5f11xcXGOv1eO/M7yXh0dHU6PEBKPPvqocnNzNWfOHDU1NWnZsmWq\nqqqKyd8XBdJf3rP58+crMTFRGRkZOnDggPbt26fNmzc7PVavHT9+XBUVFTp8+LBmzZrVuT3W36d7\nX1d9fb3j71XEjyxTUlJ04cKFzp///vtvJScnR3qMkEtNTdXcuXPlcrk0YsQIDR8+XC0tLU6PFTIe\nj0c3btyQJLW0tPSLj7OZmZnKyMiQJE2fPl2NjY0OT9R71dXVKi0tVVlZmYYMGdJv3qd/v65oeK8i\nHstJkyapsrJSktTQ0KCUlBQNHjw40mOE3JdffqlDhw5JklpbW3Xx4kWlpqY6PFXoTJw4sfN9q6qq\n0uTJkx2eqO9WrVqlpqYmSf/3O9l//idDrLh69aqKioq0f//+zrPE/eF98ve6ouG9cnU4cKy+Z88e\n/fjjj3K5XNqyZYuefPLJSI8QcteuXdP69et15coV3b59W7m5uZo6darTYwWlvr5eu3btUnNzs9xu\nt1JTU7Vnzx7l5eXp5s2bSktLU2FhoQYOHOj0qGb+XtOSJUt04MABDRo0SB6PR4WFhUpKSnJ6VDOv\n16sPP/xQjz32WOe2nTt3atOmTTH7Pkn+X9eiRYtUXl7u6HvlSCwBINZwBQ8AGBBLADAglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAM/gepgR0uaefKmwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 4\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEelJREFUeJzt3W9MlfX/x/HXEWJyhg5BIG1ZfR0u\nKr3hhopOE2Q23FxiN0xCdNmGa5pG6hhTtNn8g85NtI0/aS1Z29moG96wILM2dYDKDRu0hrpyzCkC\nkUocDeH8brQfk8R4czyH64DPx624+Hid99nFnl2H61wHl8/n8wkA8J/GOD0AAIwExBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAINwf//h7t27denSJblcLhUUFGjGjBmBnAsAQopfsTx//ryuXbsm\nj8ejq1evqqCgQB6PJ9CzAUDI8OtleE1NjdLT0yVJU6dO1e3bt9XZ2RnQwQAglPgVy7a2Nk2YMKHv\n65iYGLW2tgZsKAAINQG5wMNncQAY7fyKZXx8vNra2vq+vnXrluLi4gI2FACEGr9iOW/ePFVVVUmS\nGhsbFR8fr6ioqIAOBgChxK+r4TNnztSrr76qt99+Wy6XSzt27Aj0XAAQUlx8+C8ADI47eADAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAz8+lO4AJz3yy+/PLLtlVdeeWT777//bt7ne++9Z147f/58\n0zqPx2PeZyjjzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4fD6fz+kh\ngNHsr7/+Mq+tr683r33rrbce2dba2qq4uLh+29rb2837XL16tXntp59+alrndrvN+wxlnFkCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF/sAzww/37981rMzMzzWtPnTplXvu4O2O8\nXm+/rysrK837XLJkiXnt2LFjzWtHA84sAcDArzPLuro6bdy4UYmJiZKkadOmafv27QEdDABCid8v\nw2fNmqXi4uJAzgIAIYuX4QBg4Hcsr1y5onXr1mnlypU6d+5cIGcCgJDj1+dZtrS0qL6+XhkZGWpu\nblZOTo6qq6sVERERjBkBwHF+/c4yISGh7y0GU6ZM0cSJE9XS0qLnn38+oMMBoWoobx1aunSpee2T\nvnWos7NTUVFR/bZ9+eWX5n3y1qHH8+tl+IkTJ3T06FFJ/3wyc3t7uxISEgI6GACEEr/OLNPS0rR5\n82b98MMP6u7u1s6dO3kJDmBU8yuWUVFRKikpCfQsABCyuN0ReIj1vcNbtmwx77O7u9u8dii/9//x\nxx8H3P7zzz/3+/p///ufeZ94PN5nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG\nxBIADPz6PEvAaT09Pea1x48fH3D7mjVr9MUXX/Tblpuba9pnb2+v+fE/+eQT89qcnBzz2kmTJpnX\n4slxZgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzBgxHpcXflDGT16tUDbu/t\n7dWYMf6dL+zcudO8trCw0K/HQGjhzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDIglABhwuyNCSnFxsWndRx99ZN7n4/642UC3O77zzjumff77D539l7CwMPNahC7OLAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAG3OyLovF6vee2kSZNM6+7cuePvOH0Gut2x\npqbG9G9nz579xI+PkcV0ZtnU1KT09HRVVFRIkm7cuKFVq1YpKytLGzdu1N9//x3UIQHAaYPGsqur\nS7t27VJKSkrftuLiYmVlZemrr77SCy+8oMrKyqAOCQBOGzSWERERKi8vV3x8fN+2uro6LVq0SJKU\nmppqfukCACNV+KALwsMVHt5/mdfrVUREhCQpNjZWra2twZkOAELEoLEcDNeHMJjIyEjz2j///DOI\nkzyqt7d3WB8PI5dfsXS73bp3757Gjh2rlpaWfi/RgX/jajhGA7/eZzl37lxVVVVJkqqrqzV//vyA\nDgUAoWbQM8uGhgbt27dP169fV3h4uKqqqnTgwAHl5+fL4/Fo8uTJWrZs2XDMCgCO4U3pCDpehmM0\neOILPHg6ffvtt+a1hw4dMq8NRASfRElJiWkdsXz6cG84ABgQSwAwIJYAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAw4HZH+MV6W6D0zydTWU2ZMsW07v79++Z9trS0mNcCj8OZJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMOB2R/Rz4cKFAbcnJyf3+15tbW1QHv/77783\nrRvKX4FMTk72dxygD2eWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAHTzoZ8GC\nBQNu93q9/b43lD8YNhTWP1jm9XqD8vjA43BmCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLjd8Slw5MgR89r/uo3R31scZ8yYYV7rcrn8eoxAuXnzpmldV1eXeZ9ut9vfcRBC\nOLMEAANTLJuampSenq6KigpJUn5+vpYuXapVq1Zp1apV+umnn4I5IwA4btCX4V1dXdq1a5dSUlL6\nbc/Ly1NqamrQBgOAUDLomWVERITKy8sVHx8/HPMAQEhy+Xw+n2Xh4cOHNWHCBGVnZys/P1+tra3q\n7u5WbGystm/frpiYmGDPCgCO8etq+Jtvvqno6GglJSWprKxMR44cUWFhYaBnQ4AM5Wr4Bx98MOD2\n3t5ejRnj3/XAoVwNP3/+vGndUK5GP+5/5AM9pzfeeMO0z6+//tr8+FwNHx38+ulPSUlRUlKSJCkt\nLU1NTU0BHQoAQo1fsdywYYOam5slSXV1dUpMTAzoUAAQagZ9Gd7Q0KB9+/bp+vXrCg8PV1VVlbKz\ns7Vp0yZFRkbK7XZrz549wzErADhm0Fi+9tprOn78+CPbrb/bAYDRgNsdnwLt7e2OPv6WLVvMayMi\nIkzrhnKBZyiqqqpM63799VfzPmfOnOnvOAgh3O4IAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMuN0RfomNjTWvTU5ODvjjnz17NuD7lNT30YODee6554Ly+AhdnFkCgAGxBAAD\nYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF38MAv48ePN6999tlnA/74FRUVAd+nJM2aNcu0\nLiEhISiPj9DFmSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgdkf45bff\nfjOv/eabb8xrs7OzTet6e3vN+/T5fH59D3gYZ5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCA2x0RdO+++25Q1lq5XC6/vgc8zBTLoqIi1dfX68GDB8rNzdX06dO1detW9fT0\nKC4uTvv371dERESwZwUAxwway9raWl2+fFkej0cdHR3KzMxUSkqKsrKylJGRoYMHD6qyslJZWVnD\nMS8AOGLQ31kmJyfr0KFDkqTx48fL6/Wqrq5OixYtkiSlpqaqpqYmuFMCgMMGjWVYWJjcbrckqbKy\nUgsWLJDX6+172R0bG6vW1tbgTgkADjNf4Dl16pQqKyt17NgxLV68uG87nwcY+nbs2BGQtUP5DMmR\nYjQ+JwSHKZZnzpxRSUmJPvvsM40bN05ut1v37t3T2LFj1dLSovj4+GDPiSfw8ccfP/Ha3t5ejRkz\nut5pNtBzWr16tenffv7558EYCSFs0J/+u3fvqqioSKWlpYqOjpYkzZ07V1VVVZKk6upqzZ8/P7hT\nAoDDBj2zPHnypDo6OrRp06a+bXv37tW2bdvk8Xg0efJkLVu2LKhDAoDTBo3lihUrtGLFike28zIE\nwNOEO3ieAnl5eea1Fy5ceOz3lixZ0vffZ8+eNe/zzp075rVAqBpdv7EHgCAhlgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg4PLxgZTww3fffWde+/Btkk543I+4z+d75A+W1dbWmvY5\ne/bsJ54LIwtnlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIDbHQHAgDNL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAg3DLoqKiItXX1+vBgwfKzc3V6dOn1djYqOjoaEnS\n2rVrtXDhwmDOCQCOGjSWtbW1unz5sjwejzo6OpSZmak5c+YoLy9PqampwzEjADhu0FgmJydrxowZ\nkqTx48fL6/Wqp6cn6IMBQChx+Xw+n3Wxx+PRxYsXFRYWptbWVnV3dys2Nlbbt29XTExMMOcEAEeZ\nY3nq1CmVlpbq2LFjamhoUHR0tJKSklRWVqabN2+qsLAw2LMCgGNMV8PPnDmjkpISlZeXa9y4cUpJ\nSVFSUpIkKS0tTU1NTUEdEgCcNmgs7969q6KiIpWWlvZd/d6wYYOam5slSXV1dUpMTAzulADgsEEv\n8Jw8eVIdHR3atGlT37bly5dr06ZNioyMlNvt1p49e4I6JAA4bUgXeADgacUdPABgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbhTjzo7t27\ndenSJblcLhUUFGjGjBlOjBFQdXV12rhxoxITEyVJ06ZN0/bt2x2eyn9NTU16//33tWbNGmVnZ+vG\njRvaunWrenp6FBcXp/379ysiIsLpMYfk388pPz9fjY2Nio6OliStXbtWCxcudHbIISoqKlJ9fb0e\nPHig3NxcTZ8+fcQfJ+nR53X69GnHj9Wwx/L8+fO6du2aPB6Prl69qoKCAnk8nuEeIyhmzZql4uJi\np8d4Yl1dXdq1a5dSUlL6thUXFysrK0sZGRk6ePCgKisrlZWV5eCUQzPQc5KkvLw8paamOjTVk6mt\nrdXly5fl8XjU0dGhzMxMpaSkjOjjJA38vObMmeP4sRr2l+E1NTVKT0+XJE2dOlW3b99WZ2fncI+B\n/xAREaHy8nLFx8f3baurq9OiRYskSampqaqpqXFqPL8M9JxGuuTkZB06dEiSNH78eHm93hF/nKSB\nn1dPT4/DUzkQy7a2Nk2YMKHv65iYGLW2tg73GEFx5coVrVu3TitXrtS5c+ecHsdv4eHhGjt2bL9t\nXq+37+VcbGzsiDtmAz0nSaqoqFBOTo4+/PBD/fHHHw5M5r+wsDC53W5JUmVlpRYsWDDij5M08PMK\nCwtz/Fg58jvLh/l8PqdHCIgXX3xR69evV0ZGhpqbm5WTk6Pq6uoR+fuiwYyWY/bmm28qOjpaSUlJ\nKisr05EjR1RYWOj0WEN26tQpVVZW6tixY1q8eHHf9pF+nB5+Xg0NDY4fq2E/s4yPj1dbW1vf17du\n3VJcXNxwjxFwCQkJWrJkiVwul6ZMmaKJEyeqpaXF6bECxu126969e5KklpaWUfFyNiUlRUlJSZKk\ntLQ0NTU1OTzR0J05c0YlJSUqLy/XuHHjRs1x+vfzCoVjNeyxnDdvnqqqqiRJjY2Nio+PV1RU1HCP\nEXAnTpzQ0aNHJUmtra1qb29XQkKCw1MFzty5c/uOW3V1tebPn+/wRE9uw4YNam5ulvTP72T//50M\nI8Xdu3dVVFSk0tLSvqvEo+E4DfS8QuFYuXwOnKsfOHBAFy9elMvl0o4dO/Tyyy8P9wgB19nZqc2b\nN+vOnTvq7u7W+vXr9frrrzs9ll8aGhq0b98+Xb9+XeHh4UpISNCBAweUn5+v+/fva/LkydqzZ4+e\neeYZp0c1G+g5ZWdnq6ysTJGRkXK73dqzZ49iY2OdHtXM4/Ho8OHDeumll/q27d27V9u2bRuxx0ka\n+HktX75cFRUVjh4rR2IJACMNd/AAgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA4P8ALqDX\nN3rmU3AAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 1\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEqVJREFUeJzt3W9Ilff/x/HX+eWkpMQ0dQRrZdgm\nq24Miiz6Y0nrFKPVjZqiMgiW/SMX0ZxlDYJMiyALZrnqRlKc4a1u5B9cjIWZUbDA7ljWQqJMm1iR\nbSbne2P8/H7NY77P8Ryvoz0f97y8us777BpPrnMuP+e4vF6vVwCAd/o/pwcAgNGAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYBAR6D88dOiQbt++LZfLpYKCAs2dOzeYcwFAWAkoljdu3NDDhw/l\n8XjU0tKigoICeTyeYM8GAGEjoJfhDQ0NSk9PlyTNnDlTXV1devnyZVAHA4BwElAsOzo6NHny5L6f\nY2Nj1d7eHrShACDcBOUGD5/FAWCsCyiWCQkJ6ujo6Pv56dOnio+PD9pQABBuAorlokWLVFNTI0m6\nc+eOEhISNHHixKAOBgDhJKC74Z9//rk+++wzff3113K5XDpw4ECw5wKAsOLiw38BYGis4AEAA2IJ\nAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGBBLADAI6KtwgVC5ePGiab+9e/eaj/ngwQOf271er1wul/k4gWpp\naTHvm5SUFMJJMBxcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAOWOyIg\n9+/fD8lxMzMzTfutWrXKfMzBljv6MmPGjKAf88mTJ+Z9We4YvriyBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAgFgCgAGxBAADVvAgIOnp6eZ9/VntYrV06VLzvh6PZ9DfdXV19fs5OjradMwtW7aY\nH3/27NnmfRG+uLIEAIOAriwbGxu1c+dOJScnS5JmzZqlwsLCoA4GAOEk4Jfh8+fPV2lpaTBnAYCw\nxctwADAIOJb37t1Tbm6uMjIyVF9fH8yZACDsuLxer9fff9TW1qZbt27J7XartbVVOTk5qq2tVWRk\nZChmBADHBfSeZWJiolavXi1JmjZtmqZMmaK2tjZ99NFHQR0O4cufD6kNxZ8OFRUVmffdunWrz+3R\n0dF6/vz5gG0W/vzpUHFxsXlf6+Nj5AX0MvzSpUs6c+aMJKm9vV3Pnj1TYmJiUAcDgHAS0JXl8uXL\ntXv3bv3666/q6enRjz/+yEtwAGNaQLGcOHGiysrKgj0LAIStgG7wYHR5+325d9m4caPP7VVVVXK7\n3X0/V1dXD3suX6zvRebn54fk8YHB8HeWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgOWO7wF/Pk5ssDX/Xq9XLpcroMf35+PUWMaIcMWVJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYsIJnlLp27Zp530WLFg378d5ewXPhwgXzv83IyBj24wNO48oSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYRDg9APp7/vy5ab9gLGH0JTc31/Q7ljDi\nfcOVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMODbHcOM2+027VddXW0+\n5qpVq8z7ejwen9ujo6P7LcWMjo42HxMYC0xXls3NzUpPT1dFRYUk6fHjx8rOzlZmZqZ27typf/75\nJ6RDAoDThozlq1evdPDgQaWmpvZtKy0tVWZmpi5cuKCPP/5YlZWVIR0SAJw2ZCwjIyNVXl6uhISE\nvm2NjY1asWKFJCktLU0NDQ2hmxAAwsCQH9EWERGhiIj+u3V3dysyMlKSFBcXp/b29tBMBwBhYtif\nZ8n9oeCqqqpyeoRBcVMH77OAYhkVFaXXr19r/Pjxamtr6/cSHcPD3XAgPAX0d5YLFy5UTU2NJKm2\ntlaLFy8O6lAAEG6GvLJsampScXGxHj16pIiICNXU1Ojo0aPKz8+Xx+PR1KlT9dVXX43ErADgmCFj\nOXv2bJ0/f37A9nPnzoVkIAAIR6zgGQH379837ztz5sygP35LS4t536SkpKA/PjAWsDYcAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYDPvzLDG0I0eOBP2Yubm55n1ZwggMH1eW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgOWOI6Cmpibox8zOzg76Mceq\nwb5dMykpacDvrEtT//zzT/PjT58+3byvP/+vfPLJJwO2VVVVye1299uWk5NjPuaaNWvM+0ZHR5v3\nHQu4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA5fX6/U6PcRY588Xhj148MC0\nX0tLS0ge32kXL1407bd3717zMQf7b+r1euVyuczHGQ2G+5xWrVpl3tfj8Zj2GysrfbiyBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABix3HAFbtmwx71tWVmbabzSdtlAs9wyG\n4SwN9GdZYHV1dUCPEYiRXMJpXXI7mpbbvgtXlgBgYIplc3Oz0tPTVVFRIUnKz8/Xl19+qezsbGVn\nZ+u3334L5YwA4Lghvzf81atXOnjwoFJTU/tt37Vrl9LS0kI2GACEkyGvLCMjI1VeXq6EhISRmAcA\nwpL5Bs+JEyc0efJkZWVlKT8/X+3t7erp6VFcXJwKCwsVGxsb6lkBwDFDvgz3Ze3atYqJiVFKSopO\nnz6tkydPav/+/cGebczgbjh3w0cKd8NDJ6C74ampqUpJSZEkLV++XM3NzUEdCgDCTUCx3LFjh1pb\nWyVJjY2NSk5ODupQABBuhnwZ3tTUpOLiYj169EgRERGqqalRVlaW8vLyNGHCBEVFRamoqGgkZgUA\nxwwZy9mzZ+v8+fMDtn/xxRchGQgAwlFAN3gAt9tt3jcUN238eTWzYcOGQX/39k2KKVOmBDzTYEL1\n7YbPnz/3ub2rq6vfz99//735mNYbjJK0bds2035VVVXmY4YzljsCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADljuOUteuXTPvu3DhwmEfd+HChf1+F6rPaKyvrzft589zepfR\n/FmLgy2jfHv7Tz/9ZD6mP8sd3zdcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nASt4RkBxcbF535qaGtN+WVlZ5mP+8ccf5n19fZOn9O+KmcF+NxR/vlwsWCtz8F/+rPbyR2FhYUiO\nG664sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAYur9frdXoI/Jd1adqi\nRYtCPEl/Xq9XLpcroH/b1dVl3newL+HCQBcvXhywLSMjY8D2zMxM8zEvXLhg3nfNmjWm/cbKOeXK\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGLDccZTy5xv7grE0cjjLHf35\ndseHDx+a9svOzjYf88MPP/S5PSkpSffv3++37ZdffjEdc8mSJebH98fBgwfN+1ZXVw/YNpzzJEn1\n9fXmfd+3b+I0fRVuSUmJbt26pTdv3mjz5s2aM2eO9uzZo97eXsXHx+vIkSOKjIwM9awA4JghY3n9\n+nXdvXtXHo9HnZ2dWrdunVJTU5WZmSm3261jx46psrLSr8X6ADDaDPme5bx583T8+HFJ/356SHd3\ntxobG7VixQpJUlpamhoaGkI7JQA4bMhYjhs3TlFRUZKkyspKLVmyRN3d3X0vu+Pi4tTe3h7aKQHA\nYab3LCWprq5OlZWVOnv2rFauXNm3nftDzvDnzfVgnaOxeK6TkpL6/Zyfn+/QJP+qqqoa9jHG4nkK\nB6ZYXr16VWVlZfr55581adIkRUVF6fXr1xo/frza2tqUkJAQ6jnxFu6Gczecu+Eja8iX4S9evFBJ\nSYlOnTqlmJgYSf/+R6qpqZEk1dbWavHixaGdEgAcNuSV5eXLl9XZ2am8vLy+bYcPH9a+ffvk8Xg0\ndepUffXVVyEdEgCcNmQsN27cqI0bNw7Yfu7cuZAMBADhiBU874G335d7l23btvncXlVVJbfb3fez\nr/fLRpvhvr/ntBkzZgzYdv/+/QE3rerq6szHnDJlinnfsfJFZFasDQcAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAYsd0RA/PmIuPPnz5v3tX702u+//24+5g8//OBzu6/ljr6W\nEPry7bffmh9/w4YN5n398fayRoQWV5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYA\nYEAsAcCA5Y4AYMCVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgEGEZaeSkhLdunVLb9680ebN\nm3XlyhXduXNHMTExkqRNmzZp2bJloZwTABw1ZCyvX7+uu3fvyuPxqLOzU+vWrdOCBQu0a9cupaWl\njcSMAOC4IWM5b948zZ07V5IUHR2t7u5u9fb2hnwwAAgnLq/X67Xu7PF4dPPmTY0bN07t7e3q6elR\nXFycCgsLFRsbG8o5AcBR5ljW1dXp1KlTOnv2rJqamhQTE6OUlBSdPn1aT5480f79+0M9KwA4xnQ3\n/OrVqyorK1N5ebkmTZqk1NRUpaSkSJKWL1+u5ubmkA4JAE4bMpYvXrxQSUmJTp061Xf3e8eOHWpt\nbZUkNTY2Kjk5ObRTAoDDhrzBc/nyZXV2diovL69v2/r165WXl6cJEyYoKipKRUVFIR0SAJzm1w0e\nAHhfsYIHAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAIMKJBz106JBu374tl8ulgoICzZ0714kxgqqxsVE7d+5UcnKyJGnWrFkqLCx0eKrA\nNTc3a+vWrfrmm2+UlZWlx48fa8+ePert7VV8fLyOHDmiyMhIp8f0y9vPKT8/X3fu3FFMTIwkadOm\nTVq2bJmzQ/qppKREt27d0ps3b7R582bNmTNn1J8naeDzunLliuPnasRjeePGDT18+FAej0ctLS0q\nKCiQx+MZ6TFCYv78+SotLXV6jGF79eqVDh48qNTU1L5tpaWlyszMlNvt1rFjx1RZWanMzEwHp/SP\nr+ckSbt27VJaWppDUw3P9evXdffuXXk8HnV2dmrdunVKTU0d1edJ8v28FixY4Pi5GvGX4Q0NDUpP\nT5ckzZw5U11dXXr58uVIj4F3iIyMVHl5uRISEvq2NTY2asWKFZKktLQ0NTQ0ODVeQHw9p9Fu3rx5\nOn78uCQpOjpa3d3do/48Sb6fV29vr8NTORDLjo4OTZ48ue/n2NhYtbe3j/QYIXHv3j3l5uYqIyND\n9fX1To8TsIiICI0fP77ftu7u7r6Xc3FxcaPunPl6TpJUUVGhnJwcfffdd/rrr78cmCxw48aNU1RU\nlCSpsrJSS5YsGfXnSfL9vMaNG+f4uXLkPcv/5fV6nR4hKKZPn67t27fL7XartbVVOTk5qq2tHZXv\nFw1lrJyztWvXKiYmRikpKTp9+rROnjyp/fv3Oz2W3+rq6lRZWamzZ89q5cqVfdtH+3n63+fV1NTk\n+Lka8SvLhIQEdXR09P389OlTxcfHj/QYQZeYmKjVq1fL5XJp2rRpmjJlitra2pweK2iioqL0+vVr\nSVJbW9uYeDmbmpqqlJQUSdLy5cvV3Nzs8ET+u3r1qsrKylReXq5JkyaNmfP09vMKh3M14rFctGiR\nampqJEl37txRQkKCJk6cONJjBN2lS5d05swZSVJ7e7uePXumxMREh6cKnoULF/adt9raWi1evNjh\niYZvx44dam1tlfTve7L//5cMo8WLFy9UUlKiU6dO9d0lHgvnydfzCodz5fI6cK1+9OhR3bx5Uy6X\nSwcOHNCnn3460iME3cuXL7V79249f/5cPT092r59u5YuXer0WAFpampScXGxHj16pIiICCUmJuro\n0aPKz8/X33//ralTp6qoqEgffPCB06Oa+XpOWVlZOn36tCZMmKCoqCgVFRUpLi7O6VHNPB6PTpw4\noRkzZvRtO3z4sPbt2zdqz5Pk+3mtX79eFRUVjp4rR2IJAKMNK3gAwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg8B9OkjtgR8VvdgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 6\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "4SJizeJtNaAs", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Profiling\n", + "\n", + "If you want to drill down into the performance characteristics of your code, you can use native Python profilers like [`cProfile`](https://docs.python.org/3/library/profile.html). In the next exercise, you'll do just that." + ] + }, + { + "metadata": { + "id": "_2v0QnG8__PJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Exercise!\n", + "\n", + "This exercise does not require coding. If you have not completed the training exercise, replace `train_one_epoch` below with `_train_one_epoch`.\n", + "\n", + "Run the below cell and inspect the printed profiles. What parts of the code appear to be hotspots or\n", + "bottlenecks? How does sorting the profile by total time compare to sorting it\n", + "by cumulative time?\n", + "\n" + ] + }, + { + "metadata": { + "id": "IFypaYbG_9fB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 714 + }, + "outputId": "d9c3596b-a165-4edd-fc6b-53ccd0d01d19" + }, + "cell_type": "code", + "source": [ + "import cProfile\n", + "import pstats\n", + "\n", + "cProfile.run(\"train_one_epoch(model, training_data, optimizer)\", \"training_profile\")\n", + "\n", + "stats = pstats.Stats(\"training_profile\").strip_dirs().sort_stats(\"tottime\")\n", + "stats.print_stats(10)\n", + "\n", + "stats.sort_stats(\"cumtime\").print_stats(10)" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Thu Jun 7 12:25:04 2018 training_profile\n", + "\n", + " 92209 function calls (91817 primitive calls) in 3.446 seconds\n", + "\n", + " Ordered by: internal time\n", + " List reduced from 672 to 10 due to restriction <10>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n", + " 83 0.753 0.009 0.753 0.009 {built-in method _pywrap_tensorflow_internal.TFE_Py_Execute}\n", + " 16 0.006 0.000 1.019 0.064 network.py:736(_run_internal_graph)\n", + " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n", + " 2321 0.004 0.000 0.007 0.000 abc.py:178(__instancecheck__)\n", + " 288 0.004 0.000 0.009 0.000 inspect.py:2092(_signature_from_function)\n", + " 878 0.004 0.000 0.005 0.000 ops.py:5936(__enter__)\n", + " 288 0.004 0.000 0.016 0.000 inspect.py:1079(getfullargspec)\n", + " 11006 0.003 0.000 0.005 0.000 {built-in method builtins.isinstance}\n", + " 768 0.003 0.000 0.008 0.000 {built-in method _pywrap_tensorflow_internal.Flatten}\n", + "\n", + "\n", + "Thu Jun 7 12:25:04 2018 training_profile\n", + "\n", + " 92209 function calls (91817 primitive calls) in 3.446 seconds\n", + "\n", + " Ordered by: cumulative time\n", + " List reduced from 672 to 10 due to restriction <10>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1 0.000 0.000 3.446 3.446 {built-in method builtins.exec}\n", + " 1 0.000 0.000 3.446 3.446 :1()\n", + " 1 0.001 0.001 3.446 3.446 :9(train_one_epoch)\n", + " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n", + " 16 0.000 0.000 2.255 0.141 backprop.py:739(gradient)\n", + " 16 0.000 0.000 2.253 0.141 imperative_grad.py:31(imperative_grad)\n", + " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n", + " 400 0.002 0.000 2.246 0.006 backprop.py:145(grad_fn)\n", + " 400 0.002 0.000 2.239 0.006 backprop.py:95(_magic_gradient_function)\n", + " 32 0.001 0.000 1.601 0.050 nn_grad.py:497(_Conv2DGrad)\n", + "\n", + "\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 17 + } + ] + }, + { + "metadata": { + "id": "8ixpnyCNNTI4", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..64d19ec5c9bfccd07eabb21ce8fbb62b21f23efa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb @@ -0,0 +1,443 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Debugging \"graph-first\" models with eager execution", + "version": "0.3.2", + "provenance": [], + "include_colab_link": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/9568ab40f6ed6f9a3ba4736f6aef6127/debugging-graph-first-models-with-eager-execution.ipynb)" + ] + }, + { + "metadata": { + "id": "mm-t0GuIu1Dt", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This colab uses eager execution and the Python debugger to modify the execution of a translation model. This combination lets you quickly explore counterfactuals when researching and designing modifications to a model.\n", + "\n", + "The model, Transformer from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), was originally written with graph building in mind. Executing it eagerly can still be helpful!" + ] + }, + { + "metadata": { + "id": "gxb1DvIDg4sv", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "#@title License (double click to show)\n", + "# Copyright 2018 The TensorFlow Authors.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Gx3HA9N1ui64", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "f6986f34-f3e1-44e1-c902-2eb33081acad" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "import pdb\n", + "tfe = tf.contrib.eager\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "3LkOm2ct-Lmc", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "2edc74d9-6bc0-4e78-ab4e-83bf96099ef4" + }, + "cell_type": "code", + "source": [ + "!pip install -q -U tensor2tensor\n", + "from tensor2tensor.models import transformer" + ], + "execution_count": 2, + "outputs": [] + }, + { + "metadata": { + "id": "1Z3oMsqV0zB6", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, + "outputId": "0a8186ee-c688-457f-c9f6-9a6c1477a93b" + }, + "cell_type": "code", + "source": [ + "#@title Create a tensor2tensor translation model, fetch a checkpoint (double click to show)\n", + "from tensor2tensor import problems\n", + "from tensor2tensor.utils import trainer_lib\n", + "from tensor2tensor.utils import registry\n", + "\n", + "import numpy as np\n", + "import os\n", + "\n", + "# Setup some directories\n", + "data_dir = os.path.expanduser(\"~/t2t/data\")\n", + "tmp_dir = os.path.expanduser(\"~/t2t/tmp\")\n", + "train_dir = os.path.expanduser(\"~/t2t/train\")\n", + "checkpoint_dir = os.path.expanduser(\"~/t2t/checkpoints\")\n", + "tf.gfile.MakeDirs(data_dir)\n", + "tf.gfile.MakeDirs(tmp_dir)\n", + "tf.gfile.MakeDirs(train_dir)\n", + "tf.gfile.MakeDirs(checkpoint_dir)\n", + "gs_data_dir = \"gs://tensor2tensor-data\"\n", + "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"\n", + "\n", + "# Fetch the problem\n", + "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n", + "\n", + "# Copy the vocab file locally so we can encode inputs and decode model outputs\n", + "# All vocabs are stored on GCS\n", + "vocab_name = \"vocab.ende.32768\"\n", + "vocab_file = os.path.join(gs_data_dir, vocab_name)\n", + "!gsutil cp {vocab_file} {data_dir}\n", + "\n", + "# Get the encoders from the problem\n", + "encoders = ende_problem.feature_encoders(data_dir)\n", + "\n", + "# Setup helper functions for encoding and decoding\n", + "def encode(input_str, output_str=None):\n", + " \"\"\"Input str to features dict, ready for inference\"\"\"\n", + " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n", + " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n", + " return {\"inputs\": batch_inputs}\n", + "\n", + "def decode(integers):\n", + " \"\"\"List of ints to str\"\"\"\n", + " integers = list(np.squeeze(integers))\n", + " if 1 in integers:\n", + " integers = integers[:integers.index(1)]\n", + " return encoders[\"inputs\"].decode(np.squeeze(integers))\n", + "\n", + "# Copy the pretrained checkpoint locally\n", + "ckpt_name = \"transformer_ende_test\"\n", + "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n", + "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n", + "checkpoint_path = tf.train.latest_checkpoint(\n", + " os.path.join(checkpoint_dir, ckpt_name))\n", + "\n", + "# Create hparams and the model\n", + "model_name = \"transformer\"\n", + "hparams_set = \"transformer_base\"\n", + "\n", + "hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n", + "\n", + "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n", + "# Layer and so subsequent instantiations will have different variable scopes\n", + "# that will not match the checkpoint.\n", + "translate_model = registry.model(model_name)(hparams, tf.estimator.ModeKeys.EVAL)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Copying gs://tensor2tensor-data/vocab.ende.32768...\n", + "/ [1 files][316.4 KiB/316.4 KiB] \n", + "Operation completed over 1 objects/316.4 KiB. \n", + "INFO:tensorflow:Setting T2TModel mode to 'eval'\n", + "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "4IblPXLGjuCl", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We've created a Transformer model and fetched an existing training checkpoint. It hasn't created variables yet, and we want to load them from the checkpoint before they're used (restore-on-create) so the first run of the model outputs the correct value. The `tfe.restore_variables_on_create` API looks up variables by name on creation and restores their values." + ] + }, + { + "metadata": { + "id": "o3MWxcAqJoqG", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "outputId": "fbc1b1bf-ffbe-4621-b3cb-5eb855fec3a8" + }, + "cell_type": "code", + "source": [ + "with tfe.restore_variables_on_create(checkpoint_path):\n", + " model_output = translate_model.infer(encode(\"Eager execution\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "INFO:tensorflow:Greedy Decoding\n", + "Hinrichtung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "xk5HV9Hhu9zO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Using global variable names can get somewhat fragile, so for new code we recommend the object-based `tf.keras.Model.save_weights` or `tf.train.Checkpoint`. However, these require some small code changes to work with existing graph building code.\n", + "\n", + "The Transformer model translates \"Eager execution\" in English to \"Hinrichtung\" in German, which refers to capital punishment rather than getting things done. Transformer first encodes the English, then decodes to German. We'll add a debugging hook at the start of the decode phase (once the encodings have been finalized) and see if we can correct the translation." + ] + }, + { + "metadata": { + "id": "GUGwbYvXZ9-7", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "previous_fast_decode = transformer.fast_decode\n", + "def debug_fn(*args, **kwargs):\n", + " pdb.set_trace()\n", + " return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "transformer.fast_decode = debug_fn # Add our debugging hook to Transformer" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "f61HlvECxJn0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now that we've \"monkey patched\" the model, we'll drop into a debugger just before decoding starts. In most cases it'd be simpler to add the `pdb.set_trace()` call to the code directly, but in this case we're working with prepackaged library code.\n", + "\n", + "First, let's find an encoding which represents the correct sense of \"execution\". Then we'll patch part of that encoding into the encoding of \"Eager execution\" to fix the translation. Feel free to poke around with the debugger (e.g. print a Tensor's value), but your main task is to save the encodings by assigning them to an attribute of the function:\n", + "\n", + "```\n", + "(running the next cell drops you into a pdb shell)\n", + "step\n", + "fast_decode.previous_encoding = encoder_output\n", + "continue\n", + "\n", + "```\n", + "\n", + "You can type `next` (or `n`) a few times before `continue` to watch the decoding ops run." + ] + }, + { + "metadata": { + "id": "dX4CPOGSpZrb", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "6de38c31-836f-40ef-b701-e42908172619" + }, + "cell_type": "code", + "source": [ + "model_output = translate_model.infer(encode(\"Immediate running\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "> (4)debug_fn()\n", + "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "(Pdb) step\n", + "--Call--\n", + "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n", + "-> def fast_decode(encoder_output,\n", + "(Pdb) fast_decode.previous_encoding = encoder_output\n", + "(Pdb) continue\n", + "Sofortige Durchführung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "-ZEZciV4FpLo", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now we have an encoding saved which gets the correct sense for \"execution\"." + ] + }, + { + "metadata": { + "id": "QeC_oDVqHD_v", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "253c9af1-003e-46bd-8bf5-db968cf6a8cf" + }, + "cell_type": "code", + "source": [ + "# Assumes you followed the pdb instructions above!\n", + "transformer.fast_decode.previous_encoding" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "metadata": { + "id": "bC9JjeDcHEav", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's replace part of the encoding for \"Eager execution\" with the encoding of \"Immediate running\".\n", + "\n", + "Again we'll drop into a pdb shell. This time we'll run some TensorFlow operations to patch the encodings while the model is running.\n", + "\n", + "```\n", + "(running the next cell again drops you into a pdb shell)\n", + "step\n", + "encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n", + "continue\n", + "```" + ] + }, + { + "metadata": { + "id": "t2as_Kn1h65G", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "5b4e546e-3bb4-4761-c545-467b631e3ffe" + }, + "cell_type": "code", + "source": [ + "model_output = translate_model.infer(encode(\"Eager execution\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "text": [ + "> (4)debug_fn()\n", + "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "(Pdb) step\n", + "--Call--\n", + "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n", + "-> def fast_decode(encoder_output,\n", + "(Pdb) encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n", + "(Pdb) continue\n", + "sofortige Ausführung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "rK6tYZ23I2cm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We get a different decoding, with the correct sense of \"execution\". Likely we're keeping just the encoding of \"tion\" from \"Eager execution\", so no great breakthrough in translation modeling.\n", + "\n", + "Similarly it's possible to modify attention vectors, or change words during decoding to help debug a beam search." + ] + }, + { + "metadata": { + "id": "Nb-4ipYNRWxA", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This colab was adapted from the [Tensor2Tensor colab](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb). Credit to Ankur Taly for its concept." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 2d2aba6908b168e0bf63f4706b6344cbb4ca82bd..23f33d0230b0b9fa906636a9df4e046c6873d90b 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -4,8 +4,8 @@ Eager execution is a feature that makes TensorFlow execute operations immediately: concrete values are returned, instead of creating a computational graph that is executed later. -A user guide is available: https://www.tensorflow.org/programmers_guide/eager -([source file](../../../../docs_src/programmers_guide/eager.md)) +A user guide is available: https://www.tensorflow.org/guide/eager +([source file](../../../../docs_src/guide/eager.md)) We welcome feedback through [GitHub issues](https://github.com/tensorflow/tensorflow/labels/comp:eager). diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py index 3e3100427376ddd480b50d967cf53e7831aaefb2..04b7b1165e19612be2fa878f83effbe814fc5c46 100644 --- a/tensorflow/contrib/eager/python/metrics.py +++ b/tensorflow/contrib/eager/python/metrics.py @@ -22,5 +22,6 @@ from __future__ import print_function from tensorflow.contrib.eager.python.metrics_impl import * from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['Accuracy', 'Mean', 'Metric'] +_allowed_symbols = ['Accuracy', 'Mean', 'Metric', 'CategoricalAccuracy', + 'BinaryAccuracy', 'SparseAccuracy'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 1ae6415d5ecb03ef97cdf734c808e3f728dafcb0..efa6ba062631500bd7cd16620ebec23d15b93b62 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -344,9 +345,14 @@ class Mean(Metric): class Accuracy(Mean): - """Calculates how often `predictions` matches `labels`.""" + """Calculates how often `predictions` matches `labels`. + Attributes: + name: name of the accuracy object + dtype: data type of the tensor + """ def __init__(self, name=None, dtype=dtypes.float64): + """Inits Accuracy class with name and dtype.""" super(Accuracy, self).__init__(name=name, dtype=dtype) def call(self, labels, predictions, weights=None): @@ -367,9 +373,155 @@ class Accuracy(Mean): Returns: The arguments, for easy chaining. """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) if weights is None: return labels, predictions return labels, predictions, weights + + +class CategoricalAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with `tf.keras.losses.categorical_crossentropy`, + `tf.nn.softmax_cross_entropy_with_logits_v2`, + `tf.losses.softmax_cross_entropy`. + + Attributes: + name: name of the accuracy object. + dtype: data type of tensor. + """ + + def __init__(self, name=None, dtype=dtypes.float64): + """Inits CategoricalAccuracy with name and dtype.""" + super(CategoricalAccuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape. + As argmax is being done here, labels and predictions type + can be different. + + Args: + labels: One-hot Tensor. + predictions: Tensor with the logits or probabilities for each example. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") + labels = math_ops.argmax(labels, axis=-1) + predictions = math_ops.argmax(predictions, axis=-1) + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(CategoricalAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights + + +class BinaryAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with `tf.keras.losses.binary_crossentropy`, + `tf.losses.sigmoid_cross_entropy`, + `tf.nn.sigmoid_cross_entropy_with_logits`. + If there is more than one label, this will become multi-label classification. + + Attributes: + name: name of the accuracy object. + threshold: Used for rounding off the predictions. + If the predictions are, + 1. probabilities then set the threshold to 0.5. + 2. logits then set the threshold to 0. + You can set the threshold appropriately, + to trade off with precision and recall. + dtype: data type of tensor. + """ + + def __init__(self, threshold, name=None, dtype=dtypes.float64): + """Inits BinaryAccuracy with name, threshold and dtype.""" + + super(BinaryAccuracy, self).__init__(name=name, dtype=dtype) + self.threshold = threshold + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape and type. + + Args: + labels: Binary Tensor(containing 0 or 1). + predictions: Tensor with probabilities or logits. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") + predictions = ops.convert_to_tensor(predictions) + predictions = predictions > self.threshold + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(BinaryAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights + + +class SparseAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with + `tf.keras.losses.sparse_categorical_crossentropy`, + `tf.nn.sparse_softmax_cross_entropy_with_logits`, + `tf.losses.sparse_softmax_cross_entropy`. + + Attributes: + name: name of the accuracy object + dtype: data type of tensor. + """ + + def __init__(self, name=None, dtype=dtypes.float64): + """Inits SparseAccuracy with name and dtype.""" + + super(SparseAccuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape except the + predictions must have one additional trailing dimension equal to the + number of classes(you want to predict). + + Type of labels and predictions can be different. + + Args: + labels: Tensor of shape (batch_size, ) containing integers + predictions: Tensor with the logits or probabilities for each example. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions)[0], + message="First axis of labels and predictions is unequal") + predictions = math_ops.argmax(predictions, axis=-1) + labels = math_ops.cast(labels, dtypes.int64) + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(SparseAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index aad672344313f6238ec2646cd0c09760c6f5e3fe..20d938d492bf78fab852c638ba675d7ee6ed9073 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -117,6 +118,44 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testCategoricalAccuracy(self): + m = metrics.CategoricalAccuracy() + m([[1, 0, 0, 0], [0, 1, 0, 0]], + [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 1/2 correct + m([[0, 0, 0, 1]], [[0.25, 0.95, 0.25, 0.0]]) # 0/1 correct + m([[1, 0, 0, 0], [0, 1, 0, 0]], + [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct + self.assertEqual(2.0/5, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testBinaryAccuracy(self): + m = metrics.BinaryAccuracy(threshold=0) + # as threshold is 0 hence the predictions are logits + m([[0, 0, 0, 0]], + [[-4.2, 4.5, 1.2, -1.1]]) # 2/4 correct + m([[0, 1]], [[-5.3, 11.65]]) # 2/2 correct + m([[0, 1], [1, 1]], + [[-5.3, 11.65], [-10.32, 56.38]]) # 3/4 correct + self.assertEqual(7.0/10, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testSparseAccuracy(self): + m = metrics.SparseAccuracy() + m([0, 2], + [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 2/2 correct + m([1], [[0.25, 0.95, 0.25, 0.0]]) # 1/1 correct + m([0, 3], [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct + self.assertEqual(4.0/5, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testAccuracyDifferentShapes(self): + m = metrics.Accuracy() + with self.assertRaises(errors.InvalidArgumentError): + m([[0], [0]], [0, 1]) + def testWeightedAccuracy(self): m = metrics.Accuracy() # 1 correct, total weight of 2 @@ -146,8 +185,6 @@ class MetricsTest(test.TestCase): self.assertAllEqual(2.0, m2.result()) def testNamesWithSpaces(self): - # Verify two metrics with the same class and name don't - # accidentally share state. m1 = metrics.Mean("has space") m1(0) self.assertEqual(m1.name, "has space") @@ -169,7 +206,7 @@ class MetricsTest(test.TestCase): sess.run(accumulate, feed_dict={p: 7}) self.assertAllEqual(m.result().eval(), 7) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphAndEagerTensor(self): m = metrics.Mean() inputs = ops.convert_to_tensor([1.0, 2.0]) @@ -186,8 +223,8 @@ class MetricsTest(test.TestCase): self.assertEqual(self.evaluate(value), 2.5) def testTwoMeansGraph(self): - # Verify two metrics with the same class and name don't - # accidentally share state. + # Verify two metrics with the same name in the same graph raises a + # ValueError. with context.graph_mode(): m1 = metrics.Mean() m1(0) @@ -217,7 +254,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(m2.result().eval(), 2.0) self.assertAllEqual(m1.result().eval(), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 9af50ee1464c7c52922ed08ef09b1ef098af7bb4..f801d9a47b2f831a48d9b6335c69612c1356d800 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -24,7 +24,7 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index c92bd15b253b67a3301cd562046a4467e1bf877d..240f213c602395b8589d39c3ecd90b602ffa9848 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -126,7 +126,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( @@ -138,7 +138,7 @@ class NetworkTest(test.TestCase): self._save_modify_load_network_built(net, global_step=10) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) @@ -149,7 +149,7 @@ class NetworkTest(test.TestCase): self.assertIn("abcd-4242", save_path) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() net1 = MyNetwork() @@ -166,7 +166,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkMatchesLayerVariableNames(self): zero = constant_op.constant([[0.]]) layer_one = core.Dense(1, use_bias=False) @@ -193,7 +193,7 @@ class NetworkTest(test.TestCase): self.assertEqual("two_layer_net/" + layer_two.variables[0].name, net.second.variables[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadIntoUnbuiltSharedLayer(self): class Owner(network.Network): @@ -272,7 +272,7 @@ class NetworkTest(test.TestCase): network.restore_network_checkpoint( load_into, save_path, map_func=_restore_map_func) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRestoreIntoSubNetwork(self): class Parent(network.Network): @@ -327,7 +327,7 @@ class NetworkTest(test.TestCase): # The checkpoint is incompatible. network.restore_network_checkpoint(save_into_parent, checkpoint) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCustomMapCollisionErrors(self): class Parent(network.Network): @@ -372,7 +372,7 @@ class NetworkTest(test.TestCase): network.restore_network_checkpoint( loader, checkpoint, map_func=lambda n: "foo") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) @@ -571,7 +571,7 @@ class NetworkTest(test.TestCase): expected_start="my_network_1/dense/", actual=outside_net_after.trainable_weights[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVariableScopeStripping(self): with variable_scope.variable_scope("scope1"): with variable_scope.variable_scope("scope2"): @@ -596,7 +596,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerNamesRespected(self): class ParentNetwork(network.Network): @@ -677,7 +677,7 @@ class NetworkTest(test.TestCase): self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableAnonymous(self): # The case where no explicit names are specified. We make up unique names, @@ -721,7 +721,7 @@ class NetworkTest(test.TestCase): self.assertEqual("my_network", net2.first.name) self.assertEqual("my_network_1", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicit(self): # We have explicit network names and everything is globally unique. @@ -750,7 +750,7 @@ class NetworkTest(test.TestCase): self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerNetworkNameInteractions(self): # Same base name as core.Dense; Networks and non-Network Layers with the @@ -801,7 +801,7 @@ class NetworkTest(test.TestCase): actual=net.trainable_weights[4].name) self.assertEqual("mixed_layer_network", net.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitCollisions(self): # We have explicit network names and they are unique within the layer @@ -831,7 +831,7 @@ class NetworkTest(test.TestCase): self.assertEqual("nonunique_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitWithAnonymousParent(self): # A parent network is instantiated multiple times with explicitly named @@ -873,7 +873,7 @@ class NetworkTest(test.TestCase): self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitSameLayerCollisions(self): # We have explicit network names and they are _not_ unique within the layer @@ -891,7 +891,7 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "nonunique_name"): ParentNetwork() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAnonymousVariableSharing(self): # Two "owned" Networks @@ -989,7 +989,7 @@ class NetworkTest(test.TestCase): self.assertEqual("my_network", net4.first.name) self.assertEqual("my_network", net4.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRecursiveLayerRenaming(self): core.Dense(1) # Under default Layer naming, would change subsequent names. @@ -1041,7 +1041,7 @@ class NetworkTest(test.TestCase): self.assertEqual("dense", net.second.first.name) self.assertEqual("dense_1", net.second.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallInDifferentOrderThanConstruct(self): shared_network = MyNetwork() @@ -1091,7 +1091,7 @@ class NetworkTest(test.TestCase): self.assertTrue(net2.first is net1.first) self.assertEqual("my_network", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerCallInDifferentOrderThanConstruct(self): # Same idea as testCallInDifferentOrderThanConstruct, but this time with a # non-Network Layer shared between two Networks rather than a @@ -1144,7 +1144,7 @@ class NetworkTest(test.TestCase): self.assertTrue(net2.first is net1.first) self.assertEqual("dense", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerAlreadyBuilt(self): one = constant_op.constant([[1.]]) core.Dense(1, use_bias=False) # pre-built layers use global naming diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 4032e755f6e7dea9dcb42587f14e8386e5db2338..90a3711475719a7f991473c6c9067da1e76ab9f2 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -60,15 +60,9 @@ class SaverTest(test.TestCase): def testSameNameNoClobbering(self): with ops.device(self._dev()): - # Note that this test purposefully uses Graphs rather than - # IsolateTest. Users are more likely to accidentally create the same - # variable name this way. - first_graph = ops.Graph() - with first_graph.as_default(): - v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') - saver = _saver.Saver([v1_first_graph, v1_second_graph]) + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1, v2]) ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) @@ -126,12 +120,11 @@ class SaverTest(test.TestCase): saver = _saver.Saver([v1]) saver.save(ckpt_prefix) - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - with _saver.restore_variables_on_create(ckpt_prefix): - # Value is from checkpoint, but not from argument. - ret, _ = model(2.0) - self.assertEqual(ret.numpy(), 1.0) + saver = _saver.Saver([v1]) + with _saver.restore_variables_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret, _ = model(2.0) + self.assertEqual(ret.numpy(), 1.0) def testRestoreNotFound(self): with ops.device(self._dev()): @@ -184,17 +177,17 @@ class SaverTest(test.TestCase): 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 - with ops.Graph().as_default(): - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + ops.reset_default_graph() + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 5826700c73e255198e9a6974ca240ba55e438a26..ca6430253b67d825290b6a376ba3f29b3ae67577 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -68,6 +68,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@async_clear_error @@run_test_in_graph_and_eager_modes +@@run_all_tests_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@ -115,12 +116,13 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.tracking import Checkpointable from tensorflow.python.training.checkpointable.util import CheckpointableSaver from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index df08dc2be650376dd4248feece539dddee7e9557..30d297a5fb2dd2f844093d790d051a79105984dd 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -21,6 +21,7 @@ py_library( ":export", ":extenders", ":head", + ":hooks", ":linear", ":logit_fns", ":multi_head", @@ -116,7 +117,7 @@ py_library( py_test( name = "dnn_test", - size = "small", + size = "medium", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", tags = [ @@ -311,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", @@ -321,6 +323,37 @@ py_test( ], ) +py_library( + name = "hooks", + srcs = [ + "python/estimator/hooks.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "hooks_test", + size = "medium", + srcs = ["python/estimator/hooks_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":hooks", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "linear", srcs = ["python/estimator/linear.py"], diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 32a0f2545dd0ea97b8578f558ad8869199ca7d81..788ac5ca7046d6dd30a3d5520b243944532622fa 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -26,6 +26,7 @@ from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.hooks import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * @@ -40,6 +41,7 @@ _allowed_symbols = [ 'binary_classification_head', 'clip_gradients_by_norm', 'forward_features', + 'InMemoryEvaluatorHook', 'logistic_regression_head', 'multi_class_head', 'multi_head', diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index bd641014e9eec6623d66574bccd08ff03ebc28ac..43bfcffd790e7b3c716c3f70820851a8819af225 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -49,7 +49,8 @@ class _BoostedTreesEstimator(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False): """Initializes a `BoostedTreesEstimator` instance. Args: @@ -82,17 +83,30 @@ class _BoostedTreesEstimator(estimator.Estimator): considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + """ # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( - features, labels, mode, head, feature_columns, tree_hparams, - n_batches_per_layer, config) + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer, + config=config) super(_BoostedTreesEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) @@ -114,7 +128,8 @@ def boosted_trees_classifier_train_in_memory( tree_complexity=0., min_node_weight=0., config=None, - train_hooks=None): + train_hooks=None, + center_bias=False): """Trains a boosted tree classifier with in memory dataset. Example: @@ -186,7 +201,13 @@ def boosted_trees_classifier_train_in_memory( considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. - train_hooks: a list of Hook instances to be passed to estimator.train(). + train_hooks: a list of Hook instances to be passed to estimator.train() + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -207,7 +228,7 @@ def boosted_trees_classifier_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -247,7 +268,8 @@ def boosted_trees_regressor_train_in_memory( tree_complexity=0., min_node_weight=0., config=None, - train_hooks=None): + train_hooks=None, + center_bias=False): """Trains a boosted tree regressor with in memory dataset. Example: @@ -313,6 +335,12 @@ def boosted_trees_regressor_train_in_memory( (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. train_hooks: a list of Hook instances to be passed to estimator.train(). + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -332,7 +360,7 @@ def boosted_trees_regressor_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index 76cbefe5e94502188388df6fc2816d130ac896d5..999c2aa5e28242f996e12da3807a74c6acf31df9 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -115,6 +115,27 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): eval_res = est.evaluate(input_fn=input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 1.008551) + def testTrainAndEvaluateEstimatorWithCenterBias(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + center_bias=True) + + # It will stop after 11 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # 10 steps for training and 2 step for bias centering. + self._assert_checkpoint( + est.model_dir, global_step=12, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 0.614642) + def testInferEstimator(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( @@ -139,6 +160,33 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], [pred['predictions'] for pred in predictions]) + def testInferEstimatorWithCenterBias(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + center_bias=True, + head=self._head) + + # It will stop after 6 steps because of the max depth and num trees (5 for + # training and 2 for bias centering). + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=7, finalized_trees=1, attempted_layers=5) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + + self.assertAllClose( + [[1.634501], [1.325703], [1.187431], [2.019683], [2.832683]], + [pred['predictions'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( @@ -159,14 +207,40 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithCenterBias(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + center_bias=True) + # It will stop after 5 steps + 3 for bias, because of the max depth and num + # trees. + self._assert_checkpoint( + est.model_dir, global_step=8, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): train_input_fn = _make_train_input_fn_dataset(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, feature_columns=self._feature_columns, - n_trees=1, max_depth=5) + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index 7ff25b95c079c7e06d29e874bcaa0d2c13e7167e..9efa8f474d865a36788cba40a15404bf0b30a17e 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -53,6 +53,25 @@ class DNNEstimator(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + + # Or estimator with warm-starting from a previous checkpoint. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + warm_start_from="/path/to/checkpoint/dir") + # Input builders def input_fn_train: # returns x, y pass @@ -92,7 +111,9 @@ class DNNEstimator(estimator.Estimator): activation_fn=nn.relu, dropout=None, input_layer_partitioner=None, - config=None): + config=None, + warm_start_from=None, + batch_norm=False): """Initializes a `DNNEstimator` instance. Args: @@ -107,8 +128,9 @@ class DNNEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -116,6 +138,12 @@ class DNNEstimator(estimator.Estimator): input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. + warm_start_from: A string filepath to a checkpoint to warm-start from, or + a `WarmStartSettings` object to fully configure warm-starting. If the + string filepath is provided instead of a `WarmStartSettings`, then all + weights are warm-started, and it is assumed that vocabularies and Tensor + names are unchanged. + batch_norm: Whether to use batch normalization after each hidden layer. """ def _model_fn(features, labels, mode, config): return dnn_lib._dnn_model_fn( # pylint: disable=protected-access @@ -129,6 +157,8 @@ class DNNEstimator(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index ccaf1128bf23af734f7a5722a4dd8c1f0304fab7..2eef60c39f54bfb464b7da0eb57a47e9eee9b800 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -53,12 +53,19 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_hidden_units=[1000, 500, 100], dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -103,7 +110,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_activation_fn=nn.relu, dnn_dropout=None, input_layer_partitioner=None, - config=None): + config=None, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedEstimator instance. Args: @@ -116,12 +124,16 @@ class DNNLinearCombinedEstimator(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -131,6 +143,11 @@ class DNNLinearCombinedEstimator(estimator.Estimator): input_layer_partitioner: Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: RunConfig object to configure the runtime settings. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see @{tf.feature_column.linear_model$linear_model}. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -158,7 +175,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index dd009a6753f3231638f93e50fc8f19eae8820139..51b9ce7005cec3910ba73db62a674e4628ca30a2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -100,7 +100,8 @@ def _linear_only_estimator_fn( weight_column=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( weight_column=weight_column, label_dimension=label_dimension, @@ -110,7 +111,8 @@ def _linear_only_estimator_fn( linear_feature_columns=feature_columns, linear_optimizer=optimizer, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyEstimatorEvaluateTest( diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 75e3107670d658e55ce23d983e47311f1c180104..050b0428bf7b685229e12561cfb0682d931299d2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache -def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): +def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( @@ -48,6 +48,12 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): *args, **kwargs) +def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Returns a DNNEstimator that uses multi_class_head.""" + return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes), + *args, **kwargs) + + class DNNEstimatorEvaluateTest( dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): @@ -75,6 +81,15 @@ class DNNEstimatorTrainTest( self, _dnn_estimator_fn) +class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNWarmStartingTest.__init__( + self, _dnn_estimator_classifier_fn, _dnn_estimator_fn) + + class DNNEstimatorIntegrationTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 8b97f86db19a1bc2d9f17c9935e6678844daf177..c9d86ef4ab89950b0c7b0414ba60d9e0a1cbe476 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -529,11 +529,13 @@ def multi_label_head(n_classes, applications, the shape is `[batch_size, n_classes]`. Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` - must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a + multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. @@ -845,6 +847,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index d6c158608b5c564f24bc90583084306aa7084742..7b884402d4650636bc9fe053994246aabb9c312d 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -567,6 +568,33 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_label_vocabulary_with_multi_hot_input(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['class0', 'class1']) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels_multi_hot, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_thresholds(self): n_classes = 2 thresholds = [0.25, 0.5, 0.75] @@ -989,6 +1017,34 @@ class MultiLabelHead(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd6aa442f82bad2d4714dbcdc85b20b34773068 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Some useful session run hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training + + +# pylint: disable=protected-access +class InMemoryEvaluatorHook(training.SessionRunHook): + """Hook to run evaluation in training without a checkpoint. + + Example: + + ```python + def train_input_fn(): + ... + return train_dataset + + def eval_input_fn(): + ... + return eval_dataset + + estimator = tf.estimator.DNNClassifier(...) + + evaluator = tf.contrib.estimator.InMemoryEvaluatorHook( + estimator, eval_input_fn) + estimator.train(train_input_fn, hooks=[evaluator]) + ``` + + Current limitations of this approach are: + * It doesn't support multi-node distributed mode. + * It doesn't support saveable objects other than variables (such as boosted + tree support) + * It doesn't support custom saver logic (such as ExponentialMovingAverage + support) + + """ + + def __init__(self, + estimator, + input_fn, + steps=None, + hooks=None, + name=None, + every_n_iter=100): + """Initializes a `InMemoryEvaluatorHook`. + + Args: + estimator: A `tf.estimator.Estimator` instance to call evaluate. + input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A + function that constructs the input data for evaluation. + See @{$premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + steps: Equivalent to the `steps` arg to `estimator.evaluate`. Number of + steps for which to evaluate model. If `None`, evaluates until `input_fn` + raises an end-of-input exception. + hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of + `SessionRunHook` subclass instances. Used for callbacks inside the + evaluation call. + name: Equivalent to the `name` arg to `estimator.evaluate`. Name of the + evaluation if user needs to run multiple evaluations on different data + sets, such as on training data vs test data. Metrics for different + evaluations are saved in separate folders, and appear separately in + tensorboard. + every_n_iter: `int`, runs the evaluator once every N training iteration. + + Raises: + ValueError: if `every_n_iter` is non-positive or it's not a single machine + training + """ + if every_n_iter is None or every_n_iter <= 0: + raise ValueError('invalid every_n_iter=%s.' % every_n_iter) + if (estimator.config.num_ps_replicas > 0 or + estimator.config.num_worker_replicas > 1): + raise ValueError( + 'InMemoryEvaluator supports only single machine (aka Local) setting.') + self._estimator = estimator + self._input_fn = input_fn + self._steps = steps + self._name = name + self._every_n_iter = every_n_iter + self._eval_dir = os.path.join(self._estimator.model_dir, 'eval' + if not name else 'eval_' + name) + + self._graph = None + self._hooks = estimator_lib._check_hooks_type(hooks) + self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps)) + self._timer = training.SecondOrStepTimer(every_steps=every_n_iter) + + def begin(self): + """Build eval graph and restoring op.""" + self._timer.reset() + self._iter_count = 0 + self._graph = ops.Graph() + with self._graph.as_default(): + (self._scaffold, self._update_op, self._eval_dict, + self._all_hooks) = self._estimator._evaluate_build_graph( + self._input_fn, self._hooks, checkpoint_path=None) + + if self._scaffold.saver is not None: + raise ValueError('InMemoryEvaluator does not support custom saver') + if self._scaffold.init_fn is not None: + raise ValueError('InMemoryEvaluator does not support custom init_fn') + + self._var_name_to_eval_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + self._var_name_to_placeholder = { + v.name: array_ops.placeholder(v.dtype) + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + + def after_create_session(self, session, coord): # pylint: disable=unused-argument + """Does first run which shows the eval metrics before training.""" + if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS): + raise ValueError( + 'InMemoryEvaluator does not support saveables other than global ' + 'variables.') + self._var_name_to_train_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set( + self._var_name_to_train_var.keys()) + # Filter training var names that are not exist in evaluation + self._var_name_to_train_var = { + v_name: self._var_name_to_train_var[v_name] + for v_name in var_names_to_transfer + } + # Filter eval var names that are not exist in training + self._var_name_to_eval_var = { + v_name: self._var_name_to_eval_var[v_name] + for v_name in var_names_to_transfer + } + + with self._graph.as_default(): + self._var_feed_op = control_flow_ops.group([ + state_ops.assign(self._var_name_to_eval_var[v_name], + self._var_name_to_placeholder[v_name]) + for v_name in var_names_to_transfer + ]) + + self._evaluate(session) + + def _evaluate(self, train_session): + var_name_to_value = train_session.run(self._var_name_to_train_var) + placeholder_to_value = { + self._var_name_to_placeholder[v_name]: var_name_to_value[v_name] + for v_name in var_name_to_value + } + + def feed_variables(scaffold, session): + del scaffold + session.run(self._var_feed_op, feed_dict=placeholder_to_value) + + scaffold = training.Scaffold( + init_fn=feed_variables, copy_from_scaffold=self._scaffold) + + with self._graph.as_default(): + return self._estimator._evaluate_run( + checkpoint_path=None, + scaffold=scaffold, + update_op=self._update_op, + eval_dict=self._eval_dict, + all_hooks=self._all_hooks, + output_dir=self._eval_dir) + + self._timer.update_last_triggered_step(self._iter_count) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Runs evaluator.""" + self._iter_count += 1 + if self._timer.should_trigger_for_step(self._iter_count): + self._evaluate(run_context.session) + + def end(self, session): # pylint: disable=unused-argument + """Runs evaluator for final model.""" + self._evaluate(session) + + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae971852ee6dffb6174fc243686721c30ef685 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os + +from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import training + + +def summary_step_keyword_to_value_mapping(dir_): + writer_cache.FileWriterCache.clear() + + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + step_keyword_to_value = {} + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step not in step_keyword_to_value: + step_keyword_to_value[last_event.step] = {} + if last_event.summary is not None: + for value in last_event.summary.value: + step_keyword_to_value[last_event.step][value.tag] = value.simple_value + + return step_keyword_to_value + + +def get_summary_value(dir_, step, keyword): + """Get summary value for given step and keyword.""" + + writer_cache.FileWriterCache.clear() + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + print('XXX', event_paths) + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step == step and last_event.summary is not None: + for value in last_event.summary.value: + if keyword in value.tag: + return value.simple_value + return None + + +class InMemoryEvaluatorHookTest(test.TestCase): + + def test_runs_eval_metrics(self): + + def model_fn(features, labels, mode): + _ = labels + if estimator_lib.ModeKeys.TRAIN == mode: + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + return estimator_lib.EstimatorSpec( + mode, loss=constant_op.constant(3.), train_op=train_op) + if estimator_lib.ModeKeys.EVAL == mode: + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(5.), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # 4.5 = sum(range(10))/10 + # before training + self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) + # intervals (every_n_iter=4) + self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features']) + self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) + # end + self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + + def test_uses_latest_variable_value(self): + + def model_fn(features, labels, mode): + _ = labels + step = training.get_global_step() + w = variable_scope.get_variable( + 'w', + shape=[], + initializer=init_ops.zeros_initializer(), + dtype=dtypes.int64) + if estimator_lib.ModeKeys.TRAIN == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + step_inc = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([step_inc]): + assign_w_to_step_plus_2 = w.assign(step + 2) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + train_op=assign_w_to_step_plus_2) + if estimator_lib.ModeKeys.EVAL == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + loss = constant_op.constant(5.) + return estimator_lib.EstimatorSpec( + mode, + loss=loss, + # w is constant in each step, so the mean. + # w = 0 if step==0 else step+2 + eval_metric_ops={'mean_of_const': metrics_lib.mean(w)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # w = 0 if step==0 else step+2 + self.assertEqual(0, step_keyword_to_value[0]['mean_of_const']) + self.assertEqual(6, step_keyword_to_value[4]['mean_of_const']) + self.assertEqual(12, step_keyword_to_value[10]['mean_of_const']) + + def test_dnn_classifier(self): + embedding = feature_column_lib.embedding_column( + feature_column_lib.categorical_column_with_vocabulary_list( + 'wire_cast', ['kima', 'omar', 'stringer']), 8) + dnn = estimator_lib.DNNClassifier( + feature_columns=[embedding], hidden_units=[3, 1]) + + def train_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['omar'], ['kima']] + }, [[0], [1]])).repeat(3) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['stringer'], ['kima']] + }, [[0], [1]])).repeat(2) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + dnn, eval_input_fn, name='in-memory') + dnn.train(train_input_fn, hooks=[evaluator]) + self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory'))) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + dnn.eval_dir('in-memory')) + + final_metrics = dnn.evaluate(eval_input_fn) + step = final_metrics[ops.GraphKeys.GLOBAL_STEP] + for summary_tag in final_metrics: + if summary_tag == ops.GraphKeys.GLOBAL_STEP: + continue + self.assertEqual(final_metrics[summary_tag], + step_keyword_to_value[step][summary_tag]) + + def test_raise_error_with_multi_worker(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_ps(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1'], + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_custom_saver_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(saver=training.Saver()), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom saver'): + evaluator.begin() + + def test_raise_error_with_custom_init_fn_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + + def init_fn(scaffold, session): + _, _ = scaffold, session + + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_fn=init_fn), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'): + evaluator.begin() + + def test_raise_error_with_saveables_other_than_global_variables(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + w = variables.Variable( + initial_value=[0.], + trainable=False, + collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) + init_op = control_flow_ops.group( + [w.initializer, training.get_global_step().initializer]) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_op=init_op), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support saveables'): + estimator.train(input_fn, hooks=[evaluator]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py index 3bf4abe83d54504d55de73b63f369cceaf149dd2..62a37abefb1f6ed291df1df3da6de35bfd2b6c52 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear.py +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -39,6 +39,18 @@ class LinearEstimator(estimator.Estimator): feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator using the FTRL optimizer with regularization. estimator = LinearEstimator( head=tf.contrib.estimator.multi_label_head(n_classes=3), @@ -87,7 +99,8 @@ class LinearEstimator(estimator.Estimator): model_dir=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): """Initializes a `LinearEstimator` instance. Args: @@ -99,10 +112,16 @@ class LinearEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + @{tf.feature_column.linear_model$linear_model}. """ def _model_fn(features, labels, mode, config): return linear_lib._linear_model_fn( # pylint: disable=protected-access @@ -113,6 +132,7 @@ class LinearEstimator(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc index bb9b835889b1b5e36d6f470b51834d4c6bb3d493..7fcae5ad8e1536530e2d039e1d14df4e192c4fa3 100644 --- a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc +++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc @@ -62,10 +62,11 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { public: explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->MatchSignature( - {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL}, - {DT_FLOAT, DT_FLOAT})); + OP_REQUIRES_OK(context, + context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, + DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL}, + {DT_FLOAT, DT_FLOAT})); } void Compute(OpKernelContext* context) override { @@ -75,8 +76,9 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const Tensor& input_weights = context->input(3); const Tensor& input_indices = context->input(4); const Tensor& input_values = context->input(5); - const Tensor& input_block_size = context->input(6); - const Tensor& input_is_transpose = context->input(7); + const Tensor& entry_weights = context->input(6); + const Tensor& input_block_size = context->input(7); + const Tensor& input_is_transpose = context->input(8); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()), InvalidArgument("Input factors should be a matrix.")); @@ -89,13 +91,33 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { InvalidArgument("Input input_weights should be a vector.")); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), InvalidArgument("Input input_indices should be a matrix.")); + OP_REQUIRES( + context, input_indices.dim_size(1) == 2, + InvalidArgument("Input input_indices should have shape (?, 2).")); OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), InvalidArgument("Input input_values should be a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()), + InvalidArgument("Input entry_weights should be a vector")); + OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0), + InvalidArgument("Input input_values' length should match the " + "first dimension of Input input_indices ")); OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()), InvalidArgument("Input input_block_size should be a scalar.")); OP_REQUIRES( context, TensorShapeUtils::IsScalar(input_is_transpose.shape()), InvalidArgument("Input input_is_transpose should be a scalar.")); + OP_REQUIRES( + context, + ((input_weights.dim_size(0) > 0 && + factor_weights.dim_size(0) == factors.dim_size(0) && + entry_weights.dim_size(0) == 0) || + (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 && + entry_weights.dim_size(0) == input_indices.dim_size(0))), + InvalidArgument("To specify the weights for observed entries, either " + "(1) entry_weights must be set or (2) input_weights " + "and factor_weights must be set, but not both.")); + // TODO(yifanchen): Deprecate the support of input_weights and + // factor_weights. const int64 factor_dim = factors.dim_size(1); const int64 factors_size = factors.dim_size(0); @@ -105,6 +127,7 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const auto& input_weights_vec = input_weights.vec(); const float w_0 = unobserved_weights.scalar()(); const auto& input_values_vec = input_values.vec(); + const auto& entry_weights_vec = entry_weights.vec(); ConstEigenMatrixFloatMap factors_mat(factors.matrix().data(), factor_dim, factors_size); @@ -134,6 +157,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { return is_transpose ? indices_mat(0, i) : indices_mat(1, i); }; + const bool use_entry_weights = entry_weights_vec.size() > 0; + // TODO(rmlarsen): In principle, we should be using the SparseTensor class // and machinery for iterating over groups, but the fact that class // SparseTensor makes a complete copy of the matrix makes me reluctant to @@ -195,6 +220,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { // map using the hash of the thread id as the key. // // TODO(jpoulson): Switch to try_emplace once C++17 is supported + // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be + // consolidated into just one. map_mutex.lock(); const auto key_count = factor_batch_map.count(id_hash); map_mutex.unlock(); @@ -213,6 +240,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { CHECK_LE(shard.second, perm.size()); CHECK_LE(shard.first, shard.second); const int64 input_index = get_input_index(perm[shard.first]); + const float input_weight = + use_entry_weights ? 1.0 : input_weights_vec(input_index); // Accumulate the rhs and lhs terms in the normal equations // for the non-zero elements in the row or column of the sparse matrix // corresponding to input_index. @@ -228,7 +257,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const int64 factor_index = get_factor_index(i); const float input_value = input_values_vec(i); const float weight = - input_weights_vec(input_index) * factor_weights_vec(factor_index); + use_entry_weights ? entry_weights_vec(i) + : input_weight * factor_weights_vec(factor_index); CHECK_GE(weight, 0); factor_batch.col(num_batched) = factors_mat.col(factor_index) * std::sqrt(weight); diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc index 11ea36946e92769cd6901eb998a20148250ef7ce..1d31bd38c824f24e9a70c0f69da129f5ddc18985 100644 --- a/tensorflow/contrib/factorization/ops/factorization_ops.cc +++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc @@ -25,20 +25,33 @@ REGISTER_OP("WALSComputePartialLhsAndRhs") .Input("input_weights: float32") .Input("input_indices: int64") .Input("input_values: float32") + .Input("entry_weights: float32") .Input("input_block_size: int64") .Input("input_is_transpose: bool") .Output("partial_lhs: float32") .Output("partial_rhs: float32") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"( -Computes the partial left-hand side and right-hand side of WALS update. +Computes the partial left-hand side and right-hand side of WALS update. For +observed entry input_indices[i]=[m, n] with value input_values[i]=v, the weight +should be specified either through (1) entry_weights[i] or (2) through +input_weights[m] * factor_weights[n] (if input_is_transpose is false) or +input_weights[n] * factor_weights[m] (if input_is_transpose is true). Note it is +not allowed to have both (1) and (2) specified at the same time: when one +approach is used, the input tensors related to the other approach must be kept +completely empty. factors: Matrix of size m * k. -factor_weights: Vector of size m. Corresponds to column weights +factor_weights: Vector of size m. Corresponds to column weights. Should be empty + if entry_weights is used. unobserved_weights: Scalar. Weight for unobserved input entries. -input_weights: Vector of size n. Corresponds to row weights. +input_weights: Vector of size n. Corresponds to row weights. Should be empty if + entry_weights is used. input_indices: Indices for the input SparseTensor. input_values: Values for the input SparseTensor. +entry_weights: If not empty, this must be same length as input_vaues and is used + as the per-entry non-zero weight. If this is used, input_weights and + factor_weights must be empty. input_block_size: Scalar. Number of rows spanned by input. input_is_transpose: If true, logically transposes the input for processing. partial_lhs: 3-D tensor with size input_block_size x k x k. diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py index ba30fd997700f461b6afffa13cf371c598d3332e..6c2f1d46084d701beac1e3a99e3ad66bae57eda5 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py @@ -55,7 +55,41 @@ class WalsSolverOpsTest(test.TestCase): rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( self._column_factors, self._column_weights, self._unobserved_weights, self._row_weights, sparse_block.indices, sparse_block.values, - sparse_block.dense_shape[0], False) + [], + input_block_size=sparse_block.dense_shape[0], + input_is_transpose=False) + self.assertAllClose(lhs_tensor.eval(), [[ + [0.014800, 0.017000, 0.019200], + [0.017000, 0.019600, 0.022200], + [0.019200, 0.022200, 0.025200], + ], [ + [0.0064000, 0.0080000, 0.0096000], + [0.0080000, 0.0100000, 0.0120000], + [0.0096000, 0.0120000, 0.0144000], + ], [ + [0.0099000, 0.0126000, 0.0153000], + [0.0126000, 0.0162000, 0.0198000], + [0.0153000, 0.0198000, 0.0243000], + ], [ + [0.058800, 0.067200, 0.075600], + [0.067200, 0.076800, 0.086400], + [0.075600, 0.086400, 0.097200], + ]]) + self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700], + [0.061600, 0.077000, 0.092400], + [0.160400, 0.220000, 0.279600], + [0.492800, 0.563200, 0.633600]]) + + def testWalsSolverLhsEntryWeights(self): + sparse_block = SparseBlock3x3() + with self.test_session(): + [lhs_tensor, + rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( + self._column_factors, [], self._unobserved_weights, + [], sparse_block.indices, sparse_block.values, + [0.01, 0.03, 0.04, 0.03, 0.06, 0.12], + input_block_size=sparse_block.dense_shape[0], + input_is_transpose=False) self.assertAllClose(lhs_tensor.eval(), [[ [0.014800, 0.017000, 0.019200], [0.017000, 0.019600, 0.022200], diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 5cef4068ed119d5dbccd585c5b4e5e28840d2cc7..7ab70fbcfd7324961b61526a08daab7e393630e9 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -197,7 +197,8 @@ class WALSModel(object): row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -239,6 +240,8 @@ class WALSModel(object): weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -251,25 +254,46 @@ class WALSModel(object): regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") - self._row_update_prep_gramian = self._prepare_gramian( - self._col_factors, self._col_gramian) - self._col_update_prep_gramian = self._prepare_gramian( - self._row_factors, self._row_gramian) - self._create_transient_vars() + with ops.name_scope("row_prepare_gramian"): + self._row_update_prep_gramian = self._prepare_gramian( + self._col_factors, self._col_gramian) + with ops.name_scope("col_prepare_gramian"): + self._col_update_prep_gramian = self._prepare_gramian( + self._row_factors, self._row_gramian) + with ops.name_scope("transient_vars"): + self._create_transient_vars() @property def row_factors(self): @@ -919,6 +943,7 @@ class WALSModel(object): row_weights_slice, new_sp_input.indices, new_sp_input.values, + [], num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 555beddeaab419bcb23d06f960d370b706d744c8..05bcdac2caa77062f9a8a44a948d2897b439ea1f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -95,7 +95,7 @@ def sequence_input_layer( Raises: ValueError: If any of the `feature_columns` is the wrong type. """ - feature_columns = fc._clean_feature_columns(feature_columns) + feature_columns = fc._normalize_feature_columns(feature_columns) for c in feature_columns: if not isinstance(c, fc._SequenceDenseColumn): raise ValueError( @@ -346,7 +346,8 @@ def sequence_numeric_column( key, shape=(1,), default_value=0., - dtype=dtypes.float32): + dtype=dtypes.float32, + normalizer_fn=None): """Returns a feature column that represents sequences of numeric data. Example: @@ -370,6 +371,12 @@ def sequence_numeric_column( default_value: A single value compatible with `dtype` that is used for padding the sparse data into a dense `Tensor`. dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. Returns: A `_SequenceNumericColumn`. @@ -383,12 +390,16 @@ def sequence_numeric_column( if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) return _SequenceNumericColumn( key, shape=shape, default_value=default_value, - dtype=dtype) + dtype=dtype, + normalizer_fn=normalizer_fn) def _assert_all_equal_and_return(tensors, name=None): @@ -407,7 +418,7 @@ class _SequenceNumericColumn( fc._SequenceDenseColumn, collections.namedtuple( '_SequenceNumericColumn', - ['key', 'shape', 'default_value', 'dtype'])): + ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])): """Represents sequences of numeric data.""" @property @@ -419,7 +430,10 @@ class _SequenceNumericColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - return inputs.get(self.key) + input_tensor = inputs.get(self.key) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor @property def _variable_shape(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 88f5d535162939e063eb1e7f43d495137c5adef4..45d7b740462ca21139e2e93e34b43668f1e08a94 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session @@ -109,7 +110,7 @@ class SequenceInputLayerTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) def test_embedding_column_with_non_sequence_categorical(self): - """Tests that error is raised for non-sequence categorical column.""" + """Tests that error is raised for non-sequence embedding column.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] @@ -131,6 +132,107 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[embedding_column_a]) + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + def test_indicator_column(self): vocabulary_size_a = 3 sparse_input_a = sparse_tensor.SparseTensorValue( @@ -577,6 +679,182 @@ class SequenceEmbeddingColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + class SequenceIndicatorColumnTest(test.TestCase): def test_get_sequence_dense_tensor(self): @@ -670,6 +948,7 @@ class SequenceNumericColumnTest(test.TestCase): self.assertEqual((1,), a.shape) self.assertEqual(0., a.default_value) self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) def test_shape_saved_as_tuple(self): a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) @@ -688,6 +967,10 @@ class SequenceNumericColumnTest(test.TestCase): ValueError, 'dtype must be convertible to float'): sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] @@ -708,6 +991,41 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_shape(self): """Tests get_sequence_dense_tensor with shape !=(1,).""" sparse_input = sparse_tensor.SparseTensorValue( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index daba965a98893b992abdc598ec713f13020d6e91..484ffee3e7afe55c63cab2a463454353b2663e18 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -28,7 +28,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio -from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 020b5c99c61019254bef0b1dff6bc5901c92758a..b1b5126d9e9e5196a1733b80e0778e53cef7f774 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py -from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 10d1ecc738de6777784200ba934a521dff592e28..dc49383c5c300e82839c478e097074b3e8776b3b 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -119,14 +119,13 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec -from tensorflow.python.ops.array_ops import broadcast_to from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest', 'broadcast_to'] +_allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index df7d7e9dae80722569efccbc9cc0d1b75e90cf03..34fd5018af125335845540dedfdffc984ba02313 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging class CriticalSectionTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateCriticalSection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @@ -53,7 +53,7 @@ class CriticalSectionTest(test.TestCase): self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCriticalSectionWithControlFlow(self): for outer_cond in [False, True]: for inner_cond in [False, True]: @@ -109,7 +109,7 @@ class CriticalSectionTest(test.TestCase): with self.assertRaisesOpError("Error"): self.evaluate(r) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateCriticalSectionFnReturnsOp(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @@ -332,7 +332,7 @@ class CriticalSectionTest(test.TestCase): self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInsideFunction(self): cs = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(1) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 40ae01bfcce1dde580e6a5f6d9c8ec1aa1abb83f..e8e318001972934c7d2154bc14744823a3ba09f9 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -712,7 +712,8 @@ class VariableDeviceChooser(object): num_tasks=0, job_name='ps', device_type='CPU', - device_index=0): + device_index=0, + replica=None): """Initialize VariableDeviceChooser. Usage: @@ -733,12 +734,15 @@ class VariableDeviceChooser(object): self._job_name = job_name self._device_type = device_type self._device_index = device_index + self._replica = replica self._num_tasks = num_tasks self._next_task_id = 0 def __call__(self, op): - device_spec = tf_device.DeviceSpec(device_type=self._device_type, - device_index=self._device_index) + device_spec = tf_device.DeviceSpec( + replica=self._replica, + device_type=self._device_type, + device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 37ea6eb12aba7d25656f19cbbc86475c1228d916..7e0c7dbec1d9266b53a169fe83b88d1e3af77d04 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -506,6 +506,35 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0') self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableWithVariableDeviceChooserWithReplica(self): + + with ops.Graph().as_default(): + device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2) + with arg_scope([variables_lib2.variable], device=device_fn): + a = variables_lib2.variable('a', []) + b = variables_lib2.variable('b', []) + c = variables_lib2.variable('c', [], device='cpu:12') + d = variables_lib2.variable('d', []) + with ops.device('cpu:99'): + e_init = constant_op.constant(12) + e = variables_lib2.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(a.initial_value.op.colocation_groups(), + a.op.colocation_groups()) + self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertEqual(b.initial_value.op.colocation_groups(), + b.op.colocation_groups()) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertEqual(c.initial_value.op.colocation_groups(), + c.op.colocation_groups()) + self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(d.initial_value.op.colocation_groups(), + d.op.colocation_groups()) + self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableGPUPlacement(self): with ops.Graph().as_default(): @@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) init_value0 = 10.0 init_value1 = 20.0 @@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase): # Tests restoring PartitionedVariables and tests using a dictionary # of lists as the assign_from_checkpoint() var_list param. def testLoadPartitionedVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_partitioned_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables')) init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]]) init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case. @@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase): partitioner = partitioned_variables.variable_axis_size_partitioner(2) var0 = variables_lib2.variable( 'var0', shape=init_value0.shape, partitioner=partitioner) - var0full = variables_lib2.variable( - 'var0full', shape=init_value0.shape) + var0full = variables_lib2.variable('var0full', shape=init_value0.shape) var1 = variables_lib2.variable( 'var1', shape=init_value1.shape, partitioner=partitioner) # Convert var0 and var1 into a list of underlying variables. vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase): # Request and test the variable values. PartitionedVariables can't # be evaled so we wrap them in an identity. - self.assertTrue(np.array_equal( - init_value0, array_ops.identity(var0).eval())) - self.assertTrue(np.array_equal( - init_value0, var0full.eval())) - self.assertTrue(np.array_equal( - init_value1, array_ops.identity(var1).eval())) + self.assertTrue( + np.array_equal(init_value0, + array_ops.identity(var0).eval())) + self.assertTrue(np.array_equal(init_value0, var0full.eval())) + self.assertTrue( + np.array_equal(init_value1, + array_ops.identity(var1).eval())) def testRaisesValueErrorIfAVariableIsntFound(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'raises_value_error_if_var_isnt_found')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'raises_value_error_if_var_isnt_found')) init_value0 = 10.0 init_value1 = 20.0 @@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase): variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) def testInitFromCheckpointWithScopes(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'init_from_checkpoint_with_scopes')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'init_from_checkpoint_with_scopes')) init_value0 = np.asarray( [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1)) @@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=init_value1.shape) vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_existing_vars_no_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'load_existing_vars_no_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testLoadExistingVariablesDifferentShapeAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), - 'load_existing_variables_different_shape_allow_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join( + self.get_temp_dir(), + 'load_existing_variables_different_shape_allow_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testNotFoundError(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'not_found_error')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'not_found_error')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testMissingVariablesList(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_list')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testMissingVariablesDict(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_dict')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase): def testZeroInitializer(self): for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): for use_init in (False, True): - self._testZeroInitializer( - [10, 20], array_ops.ones( - [10, 20], dtype=dtype), use_init) + self._testZeroInitializer([10, 20], array_ops.ones( + [10, 20], dtype=dtype), use_init) class ZeroVarInitializerOpTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 0eb6889db1fae1c74aeb4392441b308392b091a5..0f0813c07f8bd330b089780064e02f8dfe7d49f6 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -75,6 +75,7 @@ tf_kernel_library( "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, ) @@ -94,6 +95,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_cuda//cuda:cudnn_header", ], ) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 3d0ed899322c26bf4ae428930899d7a5885e9f21..4d62ac65ff619f98a18387058fdc8a0eade0d8f8 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -289,8 +289,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): conv = tensors[i] value = values[i] ref_value = ref_values[i] - print("expected = ", ref_value) - print("actual = ", value) + tf_logging.info("expected = ", ref_value) + tf_logging.info("actual = ", value) tol = 1e-5 if value.dtype == np.float16: tol = 1e-3 @@ -831,7 +831,8 @@ class FusedConvInt8Tests(test.TestCase): vertical_stride, padding_type) output_width = CalculateConvolvedOutputDim(input_width, filter_width, horizontal_stride, padding_type) - print("output_height=", output_height, ", output_width=", output_width) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -866,8 +867,8 @@ class FusedConvInt8Tests(test.TestCase): with self.test_session(use_gpu=True) as sess: actual_y, expected_y = sess.run([actual, expected]) - print("actual_y = ", actual_y) - print("expected_y = ", expected_y) + tf_logging.info("actual_y = ", actual_y) + tf_logging.info("expected_y = ", expected_y) self.assertTrue(np.array_equal(actual_y, expected_y)) def testFusedConvInt8(self): diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index ff903a78cc36c1965b7655aa902501b1943637a8..d1441e1eb2aae0fb7d1771110f969bf727ebbb14 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -24,6 +24,7 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import metrics as metrics_lib @@ -102,9 +103,20 @@ class GANHead(head._Head): # pylint: disable=protected-access name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + + if not callable(generator_loss_fn): + raise TypeError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise TypeError('discriminator_loss_fn must be callable.') + if not use_loss_summaries in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + if name is not None and not isinstance(name, str): + raise TypeError('name must be string.') + if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() - # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( @@ -182,7 +194,10 @@ class GANHead(head._Head): # pylint: disable=protected-access if mode == model_fn_lib.ModeKeys.PREDICT: return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.PREDICT, - predictions=gan_model.generated_data) + predictions=gan_model.generated_data, + export_outputs={ + 'predict': export_output.PredictOutput(gan_model.generated_data) + }) elif mode == model_fn_lib.ModeKeys.EVAL: gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 6587f1fc600b94d27f7c12b44ca2136d0be5a8c5..5309d87765694fa476dae006105e842420a7c437 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import training +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument return math_ops.reduce_sum(gan_model.discriminator_real_outputs - @@ -71,13 +74,15 @@ class GANHeadTest(test.TestCase): return {} def _test_modes_helper(self, mode): - self.gan_head.create_estimator_spec( + return self.gan_head.create_estimator_spec( features=None, mode=mode, logits=get_gan_model()) def test_modes_predict(self): - self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'), + spec.export_outputs.keys()) def test_modes_eval(self): self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 81e70ae30a4c72dbcedd1aabfe758ecca4c8b366..1435e19109ca2f3bbd6ce70e6e5f26a92dfc2713 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -34,8 +34,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/process_state.h" #endif // GOOGLE_CUDA #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/core/status.h" @@ -274,7 +275,7 @@ Status GdrMemoryManager::Init() { Allocator* allocators[] = { #if GOOGLE_CUDA - ProcessState::singleton()->GetCUDAHostAllocator(0), + GPUProcessState::singleton()->GetCUDAHostAllocator(0), ProcessState::singleton()->GetCPUAllocator(0), #endif // GOOGLE_CUDA cpu_allocator(), @@ -308,7 +309,8 @@ Status GdrMemoryManager::Init() { if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1; - ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA @@ -430,7 +432,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( #if GOOGLE_CUDA if (!on_host) { - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); GPUUtil::CopyGPUTensorToCPU( device, device_context, &tensor, host_copy, @@ -532,7 +534,7 @@ void GdrMemoryManager::TensorFromTransportOptions( Tensor host_copy; #if GOOGLE_CUDA if (mr == nullptr && !on_host) { - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); buffer = DMAHelper::buffer(&host_copy); addr = buffer->data(); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index 1f9dd0decb84cf9b7b703f18c061d3c0c7a1cb25..9025c992a4467f521d6d8d514e6a5e92f5492947 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -57,7 +57,7 @@ Status GdrServer::Init() { new GdrWorker(env, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR( - GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); + GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func)); return remote_memory_manager_->Init(); } diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 592d37b432ee605d74162e0b8ec6ccdf426c45d1..026a3d1200033400472c4fd763a244c04b284a9b 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -492,7 +489,7 @@ class Transformer(object): t_ = info.transformed_ts[t] consumer_op_ = info.transformed_ops[consumer_op] t_index_ = list(consumer_op_.inputs).index(tmp_t_) - consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access def _connect_control_inputs(self, info): """Connect the previously copied ops.""" diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c index 6a5d982dc8514d69277b8f042ac1256e28715d9e..2e5c84704f8464ab46d740ea3c1eef0548826e8d 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c @@ -19,7 +19,7 @@ limitations under the License. #include "hexagon_controller.h" -#include +#include #include #include "adspmsgd.h" diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index c2e32da133b32c8fe169302668031af8bace2c22..022e17d13963a14f81d76e683d13060d1f3f8a7e 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -35,6 +35,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; template struct FillProjectiveTransform; template struct FillProjectiveTransform; template struct FillProjectiveTransform; +template struct FillProjectiveTransform; template struct FillProjectiveTransform; template struct FillProjectiveTransform; @@ -99,6 +100,7 @@ class ImageProjectiveTransform : public OpKernel { TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index ad501330617be89c87a0e94ab6e8773a6e1eecf6..209aa24548443bb10c13cd506b8c93c23cfff4a4 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -21,6 +21,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" @@ -58,6 +59,11 @@ class ProjectiveGenerator { ? transforms_.data() : &transforms_.data()[transforms_.dimension(1) * coords[0]]; float projection = transform[6] * output_x + transform[7] * output_y + 1.f; + if (projection == 0) { + // Return the fill value (0) for infinite coordinates, + // which are outside the input image + return T(0); + } const float input_x = (transform[0] * output_x + transform[1] * output_y + transform[2]) / projection; @@ -105,21 +111,21 @@ class ProjectiveGenerator { // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) const float value_yfloor = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) const float value_yceil = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index ebdcaea7abae2a967786831b62b331897aa3f6a3..e59f1bf8443732a4b84fe7461439e3d0ee7dd158 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -29,7 +29,7 @@ using shape_inference::ShapeHandle; REGISTER_OP("ImageProjectiveTransform") .Input("images: dtype") .Input("transforms: float32") - .Attr("dtype: {uint8, int32, int64, float32, float64}") + .Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("interpolation: string") .Output("transformed_images: dtype") .SetShapeFn([](InferenceContext* c) { diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index b50177ae5651fbc15f292e11031411c2074357ec..62a22dcf3411fb160b3c432bbdd67303697f7262 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -30,7 +30,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest _DTYPES = set( - [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) + [dtypes.uint8, dtypes.int32, dtypes.int64, + dtypes.float16, dtypes.float32, dtypes.float64]) class ImageOpsTest(test_util.TensorFlowTestCase): @@ -127,6 +128,23 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [0, 1, 0, 1], [0, 1, 1, 1]]) + def test_extreme_projective_transform(self): + for dtype in _DTYPES: + with self.test_session(): + image = constant_op.constant( + [[1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1]], dtype=dtype) + transformation = constant_op.constant([1, 0, 0, 0, 1, 0, -1, 0], + dtypes.float32) + image_transformed = image_ops.transform(image, transformation) + self.assertAllEqual(image_transformed.eval(), + [[1, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]]) + def test_bilinear(self): with self.test_session(): image = constant_op.constant( diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index cd984c80543886be1f682933e2e003bd3374e425..86b0ffe9a0f2236d5ac7d5f846e7b5d2615c9b09 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -33,7 +33,8 @@ _image_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_image_ops.so")) _IMAGE_DTYPES = set( - [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) + [dtypes.uint8, dtypes.int32, dtypes.int64, + dtypes.float16, dtypes.float32, dtypes.float64]) ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index b4a99867ed46897f60be3f230838c3f576d5455e..61f78febfc07bb4e677259366a81c16b2b585244 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops @@ -279,13 +278,27 @@ def _assert_increasing(t): return ops.control_dependencies([assert_increasing]) -def _check_input_types(t, y0): +def _check_input_types(y0, t, dt=None): if not (y0.dtype.is_floating or y0.dtype.is_complex): raise TypeError('`y0` must have a floating point or complex floating ' 'point dtype') if not t.dtype.is_floating: raise TypeError('`t` must have a floating point dtype') + if dt is not None and not dt.dtype.is_floating: + raise TypeError('`dt` must have a floating point dtype') + + +def _check_input_sizes(t, dt): + if len(t.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if len(dt.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if t.get_shape()[0] != dt.get_shape()[0] + 1: + raise ValueError('t and dt have incompatible lengths, must be N and N-1') + def _dopri5(func, y0, @@ -510,7 +523,7 @@ def odeint(func, # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - _check_input_types(t, y0) + _check_input_types(y0, t) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') @@ -530,24 +543,74 @@ def odeint(func, class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): """Base class for fixed-grid ODE integrators.""" - def integrate(self, evol_func, y0, time_grid): - time_delta_grid = time_grid[1:] - time_grid[:-1] - - scan_func = self._make_scan_func(evol_func) + def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): + """Returns integrated values of differential equation on the `time grid`. + + Numerically integrates differential equation defined via time derivative + evaluator `evol_func` using fixed time steps specified in dt_grid. + + Args: + evol_func: Callable, evaluates time derivative of y at a given time. + y0: N-D Tensor holds initial values of the solution. + time_grid: 1-D Tensor holding the time points at which the solution + will be recorded, must have a floating dtype. + dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid + intervals. Must be a floating dtype and have one less element than that + of the time_grid. + steps_on_intervals: 1-D Tensor of integer dtype, must have the same size + as dt_grid. Specifies number of steps needed for every interval. Assumes + steps_on_intervals * dt_grid == time intervals. + + Returns: + (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + """ - y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), - y0) - return array_ops.concat([[y0], y_grid], axis=0) + iteration_func = self._make_iteration_func(evol_func, dt_grid) + integrate_interval = self._make_interval_integrator(iteration_func, + steps_on_intervals) - def _make_scan_func(self, evol_func): + num_times = array_ops.size(time_grid) + current_time = time_grid[0] + solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times) + solution_array = solution_array.write(0, y0) - def scan_func(y, t_and_dt): - t, dt = t_and_dt + solution_array, _, _, _ = control_flow_ops.while_loop( + lambda _, __, ___, i: i < num_times, + integrate_interval, + (solution_array, y0, current_time, 1) + ) + solution_array = solution_array.stack() + solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape())) + return solution_array + + def _make_iteration_func(self, evol_func, dt_grid): + """Returns a function that builds operations of a single time step.""" + + def iteration_func(y, t, dt_step, interval_step): + """Performs a single time step advance.""" + dt = dt_grid[interval_step - 1] dy = self._step_func(evol_func, t, dt, y) dy = math_ops.cast(dy, dtype=y.dtype) - return y + dy + return y + dy, t + dt, dt_step + 1, interval_step + + return iteration_func + + def _make_interval_integrator(self, iteration_func, interval_sizes): + """Returns a function that builds operations for interval integration.""" - return scan_func + def integrate_interval(solution_array, y, t, interval_num): + """Integrates y with fixed time step on interval `interval_num`.""" + y, t, _, _ = control_flow_ops.while_loop( + lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1], + iteration_func, + (y, t, 0, interval_num) + ) + return solution_array.write(interval_num, y), y, t, interval_num + 1 + + return integrate_interval @abc.abstractmethod def _step_func(self, evol_func, t, dt, y): @@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing midpoint scheme.""" def _step_func(self, evol_func, t, dt, y): dt_cast = math_ops.cast(dt, y.dtype) @@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator): class _RK4FixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing RK4 scheme.""" def _step_func(self, evol_func, t, dt, y): k1 = evol_func(y, t) @@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator): return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) -def odeint_fixed(func, y0, t, method='rk4', name=None): +def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None): """ODE integration on a fixed grid (with no step size control). Useful in certain scenarios to avoid the overhead of adaptive step size @@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. + dt: 0-D or 1-D Tensor providing time step suggestion to be used on time + integration intervals in `t`. 1-D Tensor should provide values + for all intervals, must have 1 less element than that of `t`. + If given a 0-D Tensor, the value is interpreted as time step suggestion + same for all intervals. If passed None, then time step is set to be the + t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by + insuring an integer number of steps per interval, potentially reducing the + time step. method: One of 'midpoint' or 'rk4'. name: Optional name for the resulting operation. @@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): Raises: ValueError: Upon caller errors. """ - with ops.name_scope(name, 'odeint_fixed', [y0, t]): + with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]): t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') y0 = ops.convert_to_tensor(y0, name='y0') - _check_input_types(t, y0) + + intervals = t[1:] - t[:-1] + if dt is None: + dt = intervals + dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt') + + steps_on_intervals = math_ops.ceil(intervals / dt) + dt = intervals / steps_on_intervals + steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32) + + _check_input_types(y0, t, dt) + _check_input_sizes(t, dt) with _assert_increasing(t): with ops.name_scope(method): if method == 'midpoint': - return _MidpointFixedGridIntegrator().integrate(func, y0, t) + return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) elif method == 'rk4': - return _RK4FixedGridIntegrator().integrate(func, y0, t) + return _RK4FixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) else: raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 3ec01212d25ca8dc6e13f340177a5e85138868d5..c7b4e2faa84e1a87cb1904b22eb0008ab1ee4be6 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase): class OdeIntFixedTest(test.TestCase): - def _test_integrate_sine(self, method): + def _test_integrate_sine(self, method, t, dt=None): def evol_func(y, t): del t return array_ops.stack([y[1], -y[0]]) y0 = [0., 1.] - time_grid = np.linspace(0., 10., 200) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2) - def _test_integrate_gaussian(self, method): + def _test_integrate_gaussian(self, method, t, dt=None): def evol_func(y, t): return -math_ops.cast(t, dtype=y.dtype) * y[0] y0 = [1.] - time_grid = np.linspace(0., 2., 100) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_integrate_sine_all(self, method): + uniform_time_grid = np.linspace(0., 10., 200) + non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0]) + uniform_dt = 0.02 + non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03]) + self._test_integrate_sine(method, uniform_time_grid) + self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt) + + def _test_integrate_gaussian_all(self, method): + uniform_time_grid = np.linspace(0., 2., 100) + non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0]) + uniform_dt = 0.01 + non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03]) + self._test_integrate_gaussian(method, uniform_time_grid) + self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt) def _test_everything(self, method): - self._test_integrate_sine(method) - self._test_integrate_gaussian(method) + self._test_integrate_sine_all(method) + self._test_integrate_gaussian_all(method) def test_midpoint(self): self._test_everything('midpoint') @@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase): def test_rk4(self): self._test_everything('rk4') + def test_dt_size_exceptions(self): + times = np.linspace(0., 2., 100) + dt = np.ones(99) * 0.01 + dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03]) + dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0) + times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0) + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_length) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_dim) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times_wrong_dim, dt) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index a4cd4a2cc4b99b5906185bd2b942ed15c1ddf5e4..2638b25ec424b5b4ef556ff769e94e64da32fec2 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel { eof_(eof), timeout_(timeout) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Kafka")})); @@ -81,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "KafkaDatasetOp::Dataset"; } + string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index d04838c218d6643a703723a1d163c88547c14da7..3f0184276f6b903be63f7b35459e4ad57044eb2c 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh +from tensorflow.python.keras.activations import elu +from tensorflow.python.keras.activations import hard_sigmoid +from tensorflow.python.keras.activations import linear +from tensorflow.python.keras.activations import relu +from tensorflow.python.keras.activations import selu +from tensorflow.python.keras.activations import sigmoid +from tensorflow.python.keras.activations import softmax +from tensorflow.python.keras.activations import softplus +from tensorflow.python.keras.activations import softsign +from tensorflow.python.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get +from tensorflow.python.keras.activations import deserialize +from tensorflow.python.keras.activations import serialize +from tensorflow.python.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index abf8393ae45d71dc0cb746706abb72f77b82d199..6dfb5cab17c088bfab8ed806adeabd793ced4d12 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index b809e91193b459a46906443796344c092e1d2a6b..67306cc51e1927cfbc2db424b1f4165dabfa22f9 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index 530805d150bfe32c5b81d7d7d3f92e203b83b602..a25ff48b593a9a9ea56fd427a932bb64c10f7b7b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 118361604bbc7e0a88ed34243c0d5ea98856a301..4964b1b7deb56fe0025e9a8d8cb45d18e0209fea 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index cda52628f3c10d65fdbe70b2f86cc12c771870a9..afb3abebdd6735e6f17bc94c1fcd15a31b74f983 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c..2e3335d02aff0fff805fc2dac614b14e0593d40d 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception +from tensorflow.python.keras.applications.xception import decode_predictions +from tensorflow.python.keras.applications.xception import preprocess_input +from tensorflow.python.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index 10ef5a75852deb6595bced2703d7c5f29b0efac3..a755364014206e92289eec0b9c8e510251862e0e 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like +from tensorflow.python.keras.backend import abs +from tensorflow.python.keras.backend import all +from tensorflow.python.keras.backend import any +from tensorflow.python.keras.backend import arange +from tensorflow.python.keras.backend import argmax +from tensorflow.python.keras.backend import argmin +from tensorflow.python.keras.backend import backend +from tensorflow.python.keras.backend import batch_dot +from tensorflow.python.keras.backend import batch_flatten +from tensorflow.python.keras.backend import batch_get_value +from tensorflow.python.keras.backend import batch_normalization +from tensorflow.python.keras.backend import batch_set_value +from tensorflow.python.keras.backend import bias_add +from tensorflow.python.keras.backend import binary_crossentropy +from tensorflow.python.keras.backend import cast +from tensorflow.python.keras.backend import cast_to_floatx +from tensorflow.python.keras.backend import categorical_crossentropy +from tensorflow.python.keras.backend import clear_session +from tensorflow.python.keras.backend import clip +from tensorflow.python.keras.backend import concatenate +from tensorflow.python.keras.backend import constant +from tensorflow.python.keras.backend import conv1d +from tensorflow.python.keras.backend import conv2d +from tensorflow.python.keras.backend import conv2d_transpose +from tensorflow.python.keras.backend import conv3d +from tensorflow.python.keras.backend import cos +from tensorflow.python.keras.backend import count_params +from tensorflow.python.keras.backend import ctc_batch_cost +from tensorflow.python.keras.backend import ctc_decode +from tensorflow.python.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras.backend import dot +from tensorflow.python.keras.backend import dropout +from tensorflow.python.keras.backend import dtype +from tensorflow.python.keras.backend import elu +from tensorflow.python.keras.backend import epsilon +from tensorflow.python.keras.backend import equal +from tensorflow.python.keras.backend import eval +from tensorflow.python.keras.backend import exp +from tensorflow.python.keras.backend import expand_dims +from tensorflow.python.keras.backend import eye +from tensorflow.python.keras.backend import flatten +from tensorflow.python.keras.backend import floatx +from tensorflow.python.keras.backend import foldl +from tensorflow.python.keras.backend import foldr +from tensorflow.python.keras.backend import function +from tensorflow.python.keras.backend import gather +from tensorflow.python.keras.backend import get_session +from tensorflow.python.keras.backend import get_uid +from tensorflow.python.keras.backend import get_value +from tensorflow.python.keras.backend import gradients +from tensorflow.python.keras.backend import greater +from tensorflow.python.keras.backend import greater_equal +from tensorflow.python.keras.backend import hard_sigmoid +from tensorflow.python.keras.backend import image_data_format +from tensorflow.python.keras.backend import in_test_phase +from tensorflow.python.keras.backend import in_top_k +from tensorflow.python.keras.backend import in_train_phase +from tensorflow.python.keras.backend import int_shape +from tensorflow.python.keras.backend import is_sparse +from tensorflow.python.keras.backend import l2_normalize +from tensorflow.python.keras.backend import learning_phase +from tensorflow.python.keras.backend import less +from tensorflow.python.keras.backend import less_equal +from tensorflow.python.keras.backend import log +from tensorflow.python.keras.backend import manual_variable_initialization +from tensorflow.python.keras.backend import map_fn +from tensorflow.python.keras.backend import max +from tensorflow.python.keras.backend import maximum +from tensorflow.python.keras.backend import mean +from tensorflow.python.keras.backend import min +from tensorflow.python.keras.backend import minimum +from tensorflow.python.keras.backend import moving_average_update +from tensorflow.python.keras.backend import name_scope +from tensorflow.python.keras.backend import ndim +from tensorflow.python.keras.backend import normalize_batch_in_training +from tensorflow.python.keras.backend import not_equal +from tensorflow.python.keras.backend import one_hot +from tensorflow.python.keras.backend import ones +from tensorflow.python.keras.backend import ones_like +from tensorflow.python.keras.backend import permute_dimensions +from tensorflow.python.keras.backend import placeholder +from tensorflow.python.keras.backend import pool2d +from tensorflow.python.keras.backend import pool3d +from tensorflow.python.keras.backend import pow +from tensorflow.python.keras.backend import print_tensor +from tensorflow.python.keras.backend import prod +from tensorflow.python.keras.backend import random_binomial +from tensorflow.python.keras.backend import random_normal +from tensorflow.python.keras.backend import random_normal_variable +from tensorflow.python.keras.backend import random_uniform +from tensorflow.python.keras.backend import random_uniform_variable +from tensorflow.python.keras.backend import relu +from tensorflow.python.keras.backend import repeat +from tensorflow.python.keras.backend import repeat_elements +from tensorflow.python.keras.backend import reset_uids +from tensorflow.python.keras.backend import reshape +from tensorflow.python.keras.backend import resize_images +from tensorflow.python.keras.backend import resize_volumes +from tensorflow.python.keras.backend import reverse +from tensorflow.python.keras.backend import rnn +from tensorflow.python.keras.backend import round +from tensorflow.python.keras.backend import separable_conv2d +from tensorflow.python.keras.backend import set_epsilon +from tensorflow.python.keras.backend import set_floatx +from tensorflow.python.keras.backend import set_image_data_format +from tensorflow.python.keras.backend import set_learning_phase +from tensorflow.python.keras.backend import set_session +from tensorflow.python.keras.backend import set_value +from tensorflow.python.keras.backend import shape +from tensorflow.python.keras.backend import sigmoid +from tensorflow.python.keras.backend import sign +from tensorflow.python.keras.backend import sin +from tensorflow.python.keras.backend import softmax +from tensorflow.python.keras.backend import softplus +from tensorflow.python.keras.backend import softsign +from tensorflow.python.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras.backend import spatial_2d_padding +from tensorflow.python.keras.backend import spatial_3d_padding +from tensorflow.python.keras.backend import sqrt +from tensorflow.python.keras.backend import square +from tensorflow.python.keras.backend import squeeze +from tensorflow.python.keras.backend import stack +from tensorflow.python.keras.backend import std +from tensorflow.python.keras.backend import stop_gradient +from tensorflow.python.keras.backend import sum +from tensorflow.python.keras.backend import switch +from tensorflow.python.keras.backend import tanh +from tensorflow.python.keras.backend import temporal_padding +from tensorflow.python.keras.backend import to_dense +from tensorflow.python.keras.backend import transpose +from tensorflow.python.keras.backend import truncated_normal +from tensorflow.python.keras.backend import update +from tensorflow.python.keras.backend import update_add +from tensorflow.python.keras.backend import update_sub +from tensorflow.python.keras.backend import var +from tensorflow.python.keras.backend import variable +from tensorflow.python.keras.backend import zeros +from tensorflow.python.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2d884790ddb9ccf49649c6af4cfd40cddbc38cb3..10e05f2969bc404d4cf3a9b7a999510cd40e3c17 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras.callbacks import BaseLogger +from tensorflow.python.keras.callbacks import Callback +from tensorflow.python.keras.callbacks import CSVLogger +from tensorflow.python.keras.callbacks import EarlyStopping +from tensorflow.python.keras.callbacks import History +from tensorflow.python.keras.callbacks import LambdaCallback +from tensorflow.python.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import ProgbarLogger +from tensorflow.python.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras.callbacks import RemoteMonitor +from tensorflow.python.keras.callbacks import TensorBoard +from tensorflow.python.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 152606d8ebbcadf57d971d508e15283da65e4aa3..08debf974ec3a36174c353ecaf9e425a9afc3f36 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm +from tensorflow.python.keras.constraints import Constraint +from tensorflow.python.keras.constraints import max_norm +from tensorflow.python.keras.constraints import MaxNorm +from tensorflow.python.keras.constraints import min_max_norm +from tensorflow.python.keras.constraints import MinMaxNorm +from tensorflow.python.keras.constraints import non_neg +from tensorflow.python.keras.constraints import NonNeg +from tensorflow.python.keras.constraints import unit_norm +from tensorflow.python.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get +from tensorflow.python.keras.constraints import deserialize +from tensorflow.python.keras.constraints import serialize +from tensorflow.python.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index b5371a03fd5f5755ba8844415276113c565f52db..a5a6fdab445d2d5328f203b6a704f89e9bb4ce67 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.python.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index 68d3eb789ea2c410095c0c75e0b79a9b07d209a3..e74e5f347df2eeb626cd781c54c9a7b76561d4e9 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data +from tensorflow.python.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index ca93742673341660ba69712feb59c5dd32ea3252..8f5753a6360dfbddb5678c4f2c02adff86b5f0cb 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.python.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index 1c6396d2d32b88eaa900a5af4e62c7484fceab63..bd6ec4b8dfb0344ad0b89956939607ef51bb0889 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data +from tensorflow.python.keras.datasets.imdb import get_word_index +from tensorflow.python.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 364255f3387b59a419c010db9b93cdfbcba36186..f61145655bd5d98965e15fecd387d538e9bc642b 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data +from tensorflow.python.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index bb6791a344ad0c372ac60cd4a332f5632841dd46..ade31f4ea9c33204a4350e6bc3a5a2469e54fd61 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data +from tensorflow.python.keras.datasets.reuters import get_word_index +from tensorflow.python.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7..c6bdc4f0dac3f446238dc4cbc72fe4be278a5ff6 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros +from tensorflow.python.keras.initializers import Constant +from tensorflow.python.keras.initializers import Identity +from tensorflow.python.keras.initializers import Initializer +from tensorflow.python.keras.initializers import Ones +from tensorflow.python.keras.initializers import Orthogonal +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.initializers import RandomUniform +from tensorflow.python.keras.initializers import TruncatedNormal +from tensorflow.python.keras.initializers import VarianceScaling +from tensorflow.python.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform +from tensorflow.python.keras.initializers import glorot_normal +from tensorflow.python.keras.initializers import glorot_uniform +from tensorflow.python.keras.initializers import he_normal +from tensorflow.python.keras.initializers import he_uniform +from tensorflow.python.keras.initializers import lecun_normal +from tensorflow.python.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get +from tensorflow.python.keras.initializers import deserialize +from tensorflow.python.keras.initializers import serialize +from tensorflow.python.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index acf0a5e1799b7c57dfd82861c9ccc1f132c34375..3327a9f9a613bfb56e6a25af0fe1c0ca18609035 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 66721b694f5fd5fae7ca521ff56d4c6c6bce79b5..c4476a7bbd5056fa898468a46031bf3d8b1e44cf 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get +from tensorflow.python.keras.losses import deserialize +from tensorflow.python.keras.losses import serialize +from tensorflow.python.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 59faf037bce0f087d244a2faaeb52713bdc3b772..7317fdb52c5b79e787a49d71be49f5261d6b1fff 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras.metrics import binary_accuracy +from tensorflow.python.keras.metrics import binary_crossentropy +from tensorflow.python.keras.metrics import categorical_accuracy +from tensorflow.python.keras.metrics import categorical_crossentropy +from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import hinge +from tensorflow.python.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras.metrics import mean_absolute_error +from tensorflow.python.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras.metrics import mean_squared_error +from tensorflow.python.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras.metrics import poisson +from tensorflow.python.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras.metrics import squared_hinge +from tensorflow.python.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get +from tensorflow.python.keras.metrics import deserialize +from tensorflow.python.keras.metrics import serialize +from tensorflow.python.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac..3a196984cd88cb60fbc2a9db306ce8fecf0febc0 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras.models import load_model +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.models import model_from_config +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.keras.models import model_from_yaml +from tensorflow.python.keras.models import save_model +from tensorflow.python.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index 44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44..4849a06747958ab41b8b6309fa848aff3da3f633 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras.optimizers import Adadelta +from tensorflow.python.keras.optimizers import Adagrad +from tensorflow.python.keras.optimizers import Adam +from tensorflow.python.keras.optimizers import Adamax +from tensorflow.python.keras.optimizers import Nadam +from tensorflow.python.keras.optimizers import Optimizer +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get +from tensorflow.python.keras.optimizers import deserialize +from tensorflow.python.keras.optimizers import serialize +from tensorflow.python.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index b96e7675527041d3952b049f5f431d3df36eea4c..1f9e82b41bf09b235e93fa512a50ea4c3047c01b 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom +from tensorflow.python.keras.preprocessing.image import apply_transform +from tensorflow.python.keras.preprocessing.image import array_to_img +from tensorflow.python.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras.preprocessing.image import flip_axis +from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras.preprocessing.image import img_to_array +from tensorflow.python.keras.preprocessing.image import Iterator +from tensorflow.python.keras.preprocessing.image import load_img +from tensorflow.python.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras.preprocessing.image import random_rotation +from tensorflow.python.keras.preprocessing.image import random_shear +from tensorflow.python.keras.preprocessing.image import random_shift +from tensorflow.python.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 112f6af5e588bcb2e85fdbecea86f402742d44e7..9a93b6fb57ff5aaab25f2b606249a6022814b5e4 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index 5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f..86386a9b6762d1c5cb3915ace64686cc25367e0f 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras.preprocessing.text import one_hot +from tensorflow.python.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index 3e707ccab577b5e28febd83d91f84d7b1c0d5d82..d668e39c09ca28239e56763f111fb01939bedc69 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 +from tensorflow.python.keras.regularizers import l1 +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get +from tensorflow.python.keras.regularizers import deserialize +from tensorflow.python.keras.regularizers import serialize +from tensorflow.python.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index a7c2179fe7ad434356921a5fb8709aa5b1f33498..47cd01b924fb43e8a83836c58f8ced61e9e88268 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index a46f859273ea0117e29a403057f9f81bc758dd52..c4b7aa765c26bafbfcfe45df02e58d1cf1064b4b 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 762a2f0b57e95e2fef3dd177070701afb410e93a..102626925db560e47cdc73eb1e25e08836cb4fba 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,5 +1,10 @@ # K-FAC: Kronecker-Factored Approximate Curvature +# WARNING: +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== +# ==== + **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. When applied to feedforward and convolutional neural networks, K-FAC can converge `>3.5x` diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index b7f63d8d94a7a427eb57afefeda3939f0c530f8e..03b9da793307b966632789fd11162306e6cd19f9 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -107,6 +109,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) # Parameters to be passed to the Fisher estimator: self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay diff --git a/tensorflow/contrib/kinesis/BUILD b/tensorflow/contrib/kinesis/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..25443d0ad47aa7d503f905eb34000488b62f22c6 --- /dev/null +++ b/tensorflow/contrib/kinesis/BUILD @@ -0,0 +1,113 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +py_library( + name = "kinesis", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/kinesis_dataset_ops.cc", + ], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/platform/s3:aws_crypto", + "//third_party/eigen3", + "@aws", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/kinesis_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":kinesis_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"], +) + +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "kinesis_op_loader", + srcs = ["python/ops/kinesis_op_loader.py"], + dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +tf_py_test( + name = "kinesis_test", + srcs = ["python/kernel_tests/kinesis_test.py"], + additional_deps = [ + ":kinesis", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "no_windows", + "notap", + ], +) diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/__init__.py b/tensorflow/contrib/kinesis/__init__.py similarity index 69% rename from tensorflow/python/keras/_impl/keras/preprocessing/__init__.py rename to tensorflow/contrib/kinesis/__init__.py index 2ca48cdbf9c066194f4f4ed448fd621167db7ba9..3824b8ae7532ab97a5ebf01ab66ece6476c87d42 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/__init__.py +++ b/tensorflow/contrib/kinesis/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Data preprocessing module. +"""Kinesis Dataset. + +@@KinesisDataset """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing import image -from tensorflow.python.keras._impl.keras.preprocessing import sequence -from tensorflow.python.keras._impl.keras.preprocessing import text +from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "KinesisDataset", +] +remove_undocumented(__name__) diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..3212279c4c50efb92acc712b82cb3e1a22c76870 --- /dev/null +++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc @@ -0,0 +1,359 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/platform/s3/aws_crypto.h" + +namespace tensorflow { +namespace { + +Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() { + static Aws::Client::ClientConfiguration config; + const char* endpoint = getenv("KINESIS_ENDPOINT"); + if (endpoint) { + config.endpointOverride = Aws::String(endpoint); + } + const char* region = getenv("AWS_REGION"); + if (region) { + config.region = Aws::String(region); + } else { + // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG + // is set with a truthy value. + const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); + string load_config = + load_config_env ? str_util::Lowercase(load_config_env) : ""; + if (load_config == "true" || load_config == "1") { + Aws::String config_file; + // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + // Load the configuration. If successful, get the region. + // If the load is not successful, then generate a warning. + if (loader.Load()) { + auto profiles = loader.GetProfiles(); + if (!profiles["default"].GetRegion().empty()) { + config.region = profiles["default"].GetRegion(); + } + } else { + LOG(WARNING) << "Failed to load the profile in " << config_file << "."; + } + } + } + const char* use_https = getenv("KINESIS_USE_HTTPS"); + if (use_https) { + if (use_https[0] == '0') { + config.scheme = Aws::Http::Scheme::HTTP; + } else { + config.scheme = Aws::Http::Scheme::HTTPS; + } + } + const char* verify_ssl = getenv("KINESIS_VERIFY_SSL"); + if (verify_ssl) { + if (verify_ssl[0] == '0') { + config.verifySSL = false; + } else { + config.verifySSL = true; + } + } + const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC"); + if (connect_timeout) { + int64 timeout; + + if (strings::safe_strto64(connect_timeout, &timeout)) { + config.connectTimeoutMs = timeout; + } + } + const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC"); + if (request_timeout) { + int64 timeout; + + if (strings::safe_strto64(request_timeout, &timeout)) { + config.requestTimeoutMs = timeout; + } + } + + return &config; +} + +Aws::Client::ClientConfiguration& GetDefaultClientConfig() { + static Aws::Client::ClientConfiguration* config = + InitializeDefaultClientConfig(); + return *config; +} + +static mutex mu(LINKER_INITIALIZED); +static unsigned count(0); +void AwsInitAPI() { + mutex_lock lock(mu); + count++; + if (count == 1) { + Aws::SDKOptions options; + options.cryptoOptions.sha256Factory_create_fn = []() { + return Aws::MakeShared(AWSCryptoAllocationTag); + }; + options.cryptoOptions.sha256HMACFactory_create_fn = []() { + return Aws::MakeShared(AWSCryptoAllocationTag); + }; + Aws::InitAPI(options); + } +} +void AwsShutdownAPI() { + mutex_lock lock(mu); + count--; + if (count == 0) { + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + } +} +void ShutdownClient(Aws::Kinesis::KinesisClient* client) { + if (client != nullptr) { + delete client; + AwsShutdownAPI(); + } +} +} +class KinesisDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + std::string stream = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "stream", &stream)); + std::string shard = ""; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "shard", &shard)); + bool read_indefinitely = true; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "read_indefinitely", + &read_indefinitely)); + int64 interval = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "interval", &interval)); + OP_REQUIRES(ctx, (interval > 0), + errors::InvalidArgument( + "Interval value should be large than 0, got ", interval)); + *output = new Dataset(ctx, stream, shard, read_indefinitely, interval); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const string& stream, const string& shard, + const bool read_indefinitely, const int64 interval) + : GraphDatasetBase(ctx), + stream_(stream), + shard_(shard), + read_indefinitely_(read_indefinitely), + interval_(interval) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Kinesis")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { return "KinesisDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* stream = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream)); + Node* shard = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard)); + Node* read_indefinitely = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely)); + Node* interval = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {stream, shard, read_indefinitely, interval}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + client_(nullptr, ShutdownClient) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (iterator_ == "") { + TF_RETURN_IF_ERROR(SetupStreamsLocked()); + } + do { + Aws::Kinesis::Model::GetRecordsRequest request; + auto outcome = client_->GetRecords( + request.WithShardIterator(iterator_).WithLimit(1)); + if (!outcome.IsSuccess()) { + return errors::Unknown(outcome.GetError().GetExceptionName(), ": ", + outcome.GetError().GetMessage()); + } + if (outcome.GetResult().GetRecords().size() == 0) { + // If no records were returned then nothing is available at the + // moment. + if (!dataset()->read_indefinitely_) { + *end_of_sequence = true; + return Status::OK(); + } + // Continue the loop after a period of time. + ctx->env()->SleepForMicroseconds(dataset()->interval_); + continue; + } + if (outcome.GetResult().GetRecords().size() != 1) { + return errors::Unknown("invalid number of records ", + outcome.GetResult().GetRecords().size(), + " returned"); + } + + iterator_ = outcome.GetResult().GetNextShardIterator(); + + const auto& data = outcome.GetResult().GetRecords()[0].GetData(); + StringPiece value( + reinterpret_cast(data.GetUnderlyingData()), + data.GetLength()); + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar()() = std::string(value); + out_tensors->emplace_back(std::move(value_tensor)); + + *end_of_sequence = false; + return Status::OK(); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + + private: + // Sets up Kinesis streams to read from. + Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + AwsInitAPI(); + client_.reset( + new Aws::Kinesis::KinesisClient(GetDefaultClientConfig())); + + Aws::Kinesis::Model::DescribeStreamRequest request; + auto outcome = client_->DescribeStream( + request.WithStreamName(dataset()->stream_.c_str())); + if (!outcome.IsSuccess()) { + return errors::Unknown(outcome.GetError().GetExceptionName(), ": ", + outcome.GetError().GetMessage()); + } + Aws::String shard; + Aws::String sequence; + if (dataset()->shard_ == "") { + if (outcome.GetResult().GetStreamDescription().GetShards().size() != + 1) { + return errors::InvalidArgument( + "shard has to be provided unless the stream only have one " + "shard, there are ", + outcome.GetResult().GetStreamDescription().GetShards().size(), + " shards in stream ", dataset()->stream_); + } + shard = outcome.GetResult() + .GetStreamDescription() + .GetShards()[0] + .GetShardId(); + sequence = outcome.GetResult() + .GetStreamDescription() + .GetShards()[0] + .GetSequenceNumberRange() + .GetStartingSequenceNumber(); + } else { + for (const auto& entry : + outcome.GetResult().GetStreamDescription().GetShards()) { + if (entry.GetShardId() == dataset()->shard_.c_str()) { + shard = entry.GetShardId(); + sequence = + entry.GetSequenceNumberRange().GetStartingSequenceNumber(); + break; + } + } + if (shard == "") { + return errors::InvalidArgument("no shard ", dataset()->shard_, + " in stream ", dataset()->stream_); + } + } + + Aws::Kinesis::Model::GetShardIteratorRequest iterator_request; + auto iterator_outcome = client_->GetShardIterator( + iterator_request.WithStreamName(dataset()->stream_.c_str()) + .WithShardId(shard) + .WithShardIteratorType( + Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER) + .WithStartingSequenceNumber(sequence)); + if (!iterator_outcome.IsSuccess()) { + return errors::Unknown(iterator_outcome.GetError().GetExceptionName(), + ": ", + iterator_outcome.GetError().GetMessage()); + } + iterator_ = iterator_outcome.GetResult().GetShardIterator(); + return Status::OK(); + } + + mutex mu_; + Aws::String iterator_ GUARDED_BY(mu_); + std::unique_ptr + client_ GUARDED_BY(mu_); + }; + + const std::string stream_; + const std::string shard_; + const bool read_indefinitely_; + const int64 interval_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU), + KinesisDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kinesis/ops/dataset_ops.cc b/tensorflow/contrib/kinesis/ops/dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..54204513cf22519ecfb5fa45748250ee0f4aac7a --- /dev/null +++ b/tensorflow/contrib/kinesis/ops/dataset_ops.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KinesisDataset") + .Input("stream: string") + .Input("shard: string") + .Input("read_indefinitely: bool") + .Input("interval: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kinesis topics. + +stream: A `tf.string` tensor containing the name of the stream. +shard: A `tf.string` tensor containing the id of the shard. +read_indefinitely: If `True`, the Kinesis dataset will keep retry + again on `EOF` after the `interval` period. If `False`, then + the dataset will stop on `EOF`. The default value is `True`. +interval: The interval for the Kinesis Client to wait before + it tries to get records again (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7289b45c50fa92455b4c317b8a039ca414fa585e --- /dev/null +++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KinesisDataset. +NOTE: boto3 is needed and the test has to be invoked manually: +``` +$ bazel test -s --verbose_failures --config=opt \ + --action_env=AWS_ACCESS_KEY_ID=XXXXXX \ + --action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \ + //tensorflow/contrib/kinesis:kinesis_test +``` +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import boto3 + +from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KinesisDatasetTest(test.TestCase): + + def testKinesisDatasetOneShard(self): + client = boto3.client('kinesis', region_name='us-east-1') + + # Setup the Kinesis with 1 shard. + stream_name = "tf_kinesis_test_1" + client.create_stream(StreamName=stream_name, ShardCount=1) + # Wait until stream exists, default is 10 * 18 seconds. + client.get_waiter('stream_exists').wait(StreamName=stream_name) + for i in range(10): + data = "D" + str(i) + client.put_record( + StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i)) + + stream = array_ops.placeholder(dtypes.string, shape=[]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kinesis_dataset_ops.KinesisDataset( + stream, read_indefinitely=False).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from shard 0 of stream 1. + sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1}) + for i in range(10): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + client.delete_stream(StreamName=stream_name) + # Wait until stream deleted, default is 10 * 18 seconds. + client.get_waiter('stream_not_exists').wait(StreamName=stream_name) + + def testKinesisDatasetTwoShards(self): + client = boto3.client('kinesis', region_name='us-east-1') + + # Setup the Kinesis with 2 shards. + stream_name = "tf_kinesis_test_2" + client.create_stream(StreamName=stream_name, ShardCount=2) + # Wait until stream exists, default is 10 * 18 seconds. + client.get_waiter('stream_exists').wait(StreamName=stream_name) + + for i in range(10): + data = "D" + str(i) + client.put_record( + StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i)) + response = client.describe_stream(StreamName=stream_name) + shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"] + shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"] + + stream = array_ops.placeholder(dtypes.string, shape=[]) + shard = array_ops.placeholder(dtypes.string, shape=[]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kinesis_dataset_ops.KinesisDataset( + stream, shard, read_indefinitely=False).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + data = list() + with self.test_session() as sess: + # Basic test: read from shard 0 of stream 2. + sess.run( + init_op, feed_dict={ + stream: stream_name, shard: shard_id_0, num_epochs: 1}) + with self.assertRaises(errors.OutOfRangeError): + # Use range(11) to guarantee the OutOfRangeError. + for i in range(11): + data.append(sess.run(get_next)) + + # Basic test: read from shard 1 of stream 2. + sess.run( + init_op, feed_dict={ + stream: stream_name, shard: shard_id_1, num_epochs: 1}) + with self.assertRaises(errors.OutOfRangeError): + # Use range(11) to guarantee the OutOfRangeError. + for i in range(11): + data.append(sess.run(get_next)) + + data.sort() + self.assertEqual(data, ["D" + str(i) for i in range(10)]) + + client.delete_stream(StreamName=stream_name) + # Wait until stream deleted, default is 10 * 18 seconds. + client.get_waiter('stream_not_exists').wait(StreamName=stream_name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2df95ba4f20ec5fa58ff13530096e6e065f4fe --- /dev/null +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kinesis Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import +from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops +from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KinesisDataset(Dataset): + """A Kinesis Dataset that consumes the message. + + Kinesis is a managed service provided by AWS for data streaming. + This dataset reads messages from Kinesis with each message presented + as a `tf.string`. + + For example, we can construct and use the KinesisDataset as follows: + ```python + dataset = tf.contrib.kinesis.KinesisDataset( + "kinesis_stream_name", read_indefinitely=False) + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + ``` + + Since Kinesis is a data streaming service, data may not be available + at the time it is being read. The argument `read_indefinitely` is + used to control the behavior in this situation. If `read_indefinitely` + is `True`, then `KinesisDataset` will keep retrying to retrieve data + from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError` + is returned immediately instead. + """ + + def __init__(self, + stream, + shard="", + read_indefinitely=True, + interval=100000): + """Create a KinesisDataset. + + Args: + stream: A `tf.string` tensor containing the name of the stream. + shard: A `tf.string` tensor containing the id of the shard. + read_indefinitely: If `True`, the Kinesis dataset will keep retry + again on `EOF` after the `interval` period. If `False`, then + the dataset will stop on `EOF`. The default value is `True`. + interval: The interval for the Kinesis Client to wait before + it tries to get records again (in millisecond). + """ + super(KinesisDataset, self).__init__() + self._stream = ops.convert_to_tensor( + stream, dtype=dtypes.string, name="stream") + self._shard = ops.convert_to_tensor( + shard, dtype=dtypes.string, name="shard") + self._read_indefinitely = ops.convert_to_tensor( + read_indefinitely, dtype=dtypes.bool, name="read_indefinitely") + self._interval = ops.convert_to_tensor( + interval, dtype=dtypes.int64, name="interval") + + def _as_variant_tensor(self): + return gen_dataset_ops.kinesis_dataset( + self._stream, self._shard, self._read_indefinitely, self._interval) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string diff --git a/tensorflow/python/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py similarity index 69% rename from tensorflow/python/keras/datasets/boston_housing/__init__.py rename to tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py index b5371a03fd5f5755ba8844415276113c565f52db..c9ce9f3646200a777cdbdf34b37626154ca730bb 100644 --- a/tensorflow/python/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Boston housing price regression dataset.""" - +"""Python helper for loading kinesis ops and kernels.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader -del absolute_import -del division -del print_function +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 3ba1026383ef146adb32197ae41b5c251155bf46..2ede5daee74223e812cc29e9708b1989b698fb4e 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -652,7 +652,8 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor) + map_op = functional_ops.map_fn( + tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) return core.identity(map_lt, name=scope) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 00f03a111ae8be7f49761ef5fb5a82810bcca182..bc3359693562deb1229a78a2db5c256c76f7fd8d 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide. @@avg_pool2d @@avg_pool3d @@batch_norm +@@convolution +@@convolution1d @@convolution2d @@convolution3d @@conv2d_in_plane diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 49c3faf3b7f5eaa3b1542a1fdddcfaff99737a24..60e1d85ea9c08a51763fdaf08853f8d9b67347e5 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -458,7 +458,7 @@ def scattered_embedding_lookup_sparse(params, return embeddings -def embedding_lookup_unique(params, ids, name=None): +def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): """Version of embedding_lookup that avoids duplicate lookups. This can save communication in the case of repeated ids. @@ -470,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None): `PartitionedVariable`. Shape `[index, d1, d2, ...]`. ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for this operation (optional). Returns: @@ -485,7 +488,8 @@ def embedding_lookup_unique(params, ids, name=None): ids_flat = array_ops.reshape( ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) - unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) + unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, + partition_strategy) embeds_flat = array_ops.gather(unique_embeddings, idx) embed_shape = array_ops.concat( [shape, array_ops.shape(unique_embeddings)[1:]], 0) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index dd2395f8c9748dadbecfe47df5511874d5f848ea..7ede193029d2d95fa4953b4c417a1e86ebb4a42e 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import itertools import math -import sys import numpy as np diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 06060b99e7e58787994f20f037ffa451abbc7459..a85cff4f7098e9a5eedca1b0c8c0cb42e172d90a 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -683,11 +683,12 @@ def parse_feature_columns_from_sequence_examples( the serialized proto. Returns: - A tuple consisting of: - context_features: a dict mapping `FeatureColumns` from - `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. - sequence_features: a dict mapping `FeatureColumns` from - `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + A tuple consisting of (context_features, sequence_features) + + * context_features: a dict mapping `FeatureColumns` from + `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + * sequence_features: a dict mapping `FeatureColumns` from + `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. """ # Sequence example parsing requires a single (scalar) example. try: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 2f3e57653c5d6d949c4dcc91635690322b7f90c4..beeabd6b65631cad88efd10d5faee1917e162e41 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -57,10 +57,10 @@ from tensorflow.python.training import moving_averages __all__ = [ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', - 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', + 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', + 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', @@ -2022,6 +2022,7 @@ class GDN(base.Layer): def beta_initializer(shape, dtype=None, partition_info=None): del partition_info # unused + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal) def gamma_initializer(shape, dtype=None, partition_info=None): @@ -2029,6 +2030,7 @@ class GDN(base.Layer): assert len(shape) == 2 assert shape[0] == shape[1] eye = linalg_ops.eye(shape[0], dtype=dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) beta = self.add_variable( @@ -2662,6 +2664,7 @@ def separable_convolution2d( normalizer_fn=None, normalizer_params=None, weights_initializer=initializers.xavier_initializer(), + pointwise_initializer=None, weights_regularizer=None, biases_initializer=init_ops.zeros_initializer(), biases_regularizer=None, @@ -2703,7 +2706,9 @@ def separable_convolution2d( `biases_regularizer` are ignored and `biases` are not created nor added. default set to None for no normalizer function normalizer_params: Normalization function parameters. - weights_initializer: An initializer for the weights. + weights_initializer: An initializer for the depthwise weights. + pointwise_initializer: An initializer for the pointwise weights. + default set to None, means use weights_initializer. weights_regularizer: Optional regularizer for the weights. biases_initializer: An initializer for the biases. If None skip biases. biases_regularizer: Optional regularizer for the biases. @@ -2735,6 +2740,9 @@ def separable_convolution2d( custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) + if pointwise_initializer is None: + pointwise_initializer = weights_initializer + df = ('channels_first' if data_format and data_format.startswith('NC') else 'channels_last') if num_outputs is not None: @@ -2750,7 +2758,7 @@ def separable_convolution2d( depth_multiplier=depth_multiplier, use_bias=not normalizer_fn and biases_initializer, depthwise_initializer=weights_initializer, - pointwise_initializer=weights_initializer, + pointwise_initializer=pointwise_initializer, bias_initializer=biases_initializer, depthwise_regularizer=weights_regularizer, pointwise_regularizer=weights_regularizer, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 56e9194cebbe46907707f7ac0996f9a56fb53c0f..c5c7269b1f15849956e90654e3bcf8ab0eebc393 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1312,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) + def testConv1dShape(self): + width = 7 + with self.test_session(): + images = random_ops.random_uniform((5, width, 3), seed=1) + output = layers_lib.convolution1d(images, 32, 3) + self.assertEqual(output.op.name, 'Conv/Relu') + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + + def testConvInferSpatialDims(self): + depth, height, width = 7, 9, 11 + with self.test_session(): + images = np.random.uniform(size=(5, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3]) + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3]) + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) + images = np.random.uniform(size=(5, depth, height, width, + 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 32]) + class DenseToSparseTest(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 8ed9f446bcd5f222f486e43125dafc595852e5ce..0e35b1aa8bf682c1b4f7e8d974d3e8fad69e33cb 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect __all__ = ["rev_block", "RevBlock", "recompute_grad"] @@ -449,6 +450,15 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): `variable_scope(name, use_resource=True), which are the default in Eager mode and when running on TPU. + Warning: Because the function will be called again on the backwards pass, the + user should be careful to not use ops in their function that mutate state or + have randomness (for example, batch normalization or dropout). If the function + does have such operations, it is recommended that the function take the + `is_recomputing` keyword argument which will be `False` on the forward pass + and `True` on the backwards pass so that it can disable state changes when + `is_recomputing=True` (for example, not updating the moving averages in batch + normalization). + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -482,6 +492,7 @@ def _is_on_tpu(): def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args for arg in args: if not isinstance(arg, framework_ops.Tensor): raise ValueError("All inputs to function must be Tensors") @@ -496,7 +507,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: - outputs = fn(*args) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = False + outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass @@ -516,7 +530,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: - outputs = fn(*inputs) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 997f53b9e1bbf9ac151cadd4a9f8e79c2e0ebca2..bc09ba8d439808c1582f207a99504012afcf33a6 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -21,9 +21,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import rev_block_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import convolutional from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -342,6 +344,34 @@ class RecomputeTest(test.TestCase): for grad in grads: self.assertTrue(grad is not None) + def testWithIsRecomputeKwarg(self): + + kwarg_values = [] + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs, is_recomputing=False): + kwarg_values.append(is_recomputing) + out = core_layers.dense(inputs, 2) + out = normalization_layers.batch_normalization(out, training=True) + if is_recomputing: + # Ensure that the updates are not duplicated by popping off the latest + # 2 additions. + update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) + update_ops.pop() + update_ops.pop() + return out + + x = array_ops.ones((2, 4), dtypes.float32) + with variable_scope.variable_scope("layer1", use_resource=True): + y = layer_with_recompute(x) + loss = math_ops.reduce_sum(y) + tvars = variables.trainable_variables() + gradients_impl.gradients(loss, [x] + tvars) + + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + self.assertEqual(2, len(update_ops)) + self.assertEqual([False, True], kwarg_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 0fdbe8f6308e30db2043c400f37d7dcb6058d1f2..b56a88659bbd4467600788fc8e3e9dbf38ce8244 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -284,6 +284,7 @@ py_test( tags = [ "manual", "noasan", # times out + "optonly", # test is flaky without optimization. ], deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 70b70af98c51dcb991c19152607272673953ee2a..e100bc7a1e7be4896e9ab1c965775b5185b38897 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -31,7 +31,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -51,6 +50,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook from tensorflow.python.training import training as train +from tensorflow.python.training import training_util # The default learning rate of 0.2 is a historical artifact of the initial @@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), + name_or_scope=parent_scope, + partitioner=optimizer.partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 0a863f0e20c05d3372ffd8f7677cd518390ecc9d..597ca4e86dbf66c86182f14a2a364b662d52fb0a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test from tensorflow.python.training import ftrl from tensorflow.python.training import input as input_lib @@ -966,6 +967,63 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearClassifier with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + classifier = linear.LinearClassifier( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + print('all scores = {}'.format(scores)) + self.assertGreater(scores['accuracy'], 0.9) + def testEval(self): """Tests that eval produces correct metrics. """ @@ -1540,6 +1598,60 @@ class LinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearRegressor with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([0.6, 0.8, 0.3]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', symmetric_l2_regularization=1.0, + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + regressor = linear.LinearRegressor( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """Tests LinearClassifier with SDCAOptimizer and sparse features.""" diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 541da9061732ad271f6d5456446a9c30b81e58dd..f8a3709ee57a32734afa7ac8133271c75d152b2c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -505,7 +505,7 @@ class Experiment(object): eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( - eval_result, checkpoint_path=previous_path if eval_result else None)): + eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index d10927a0cdd5c67c8d2a8e569153235ee175ec4d..fb16c94c29660e2777942ea9cf30da51dbf90571 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase): noop_hook = _NoopHook() def _predicate_fn(eval_result, checkpoint_path): - self.assertEqual(not eval_result, + self.assertEqual(eval_result is None, checkpoint_path is None) return est.eval_count < 3 # pylint: disable=cell-var-from-loop diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index b5741967ab52568725d7c9f03a0cc0b0f63f7459..ef0e08a777779e04f70d11fe83280ccaf1c178fd 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -35,6 +35,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import googletest @@ -132,15 +134,22 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): return examples_dict, variables_dict -def make_variable_dict(max_age, max_gender): +def make_variable_dict(max_age, max_gender, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + partitioner = None + if partitioned: + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, + axis=0) + with variable_scope.variable_scope( + name_or_scope='variables', + partitioner=partitioner): + age_weights = variables_lib.Variable( + array_ops.zeros( + [max_age + 1], dtype=dtypes.float32)) + gender_weights = variables_lib.Variable( + array_ops.zeros( + [max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -265,6 +274,54 @@ class SdcaWithLogisticLossTest(SdcaModelTest): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testPartitionedPrimals(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0], + 'gender': [0] + }, 0), + make_example_proto({ + 'age': [1], + 'gender': [1] + }, 1), + ] + example_weights = [1.0, 1.0] + for num_shards in _SHARD_NUMBERS: + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1, partitioned=True) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + num_table_shards=num_shards, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + unregularized_loss = lr.unregularized_loss(examples) + loss = lr.regularized_loss(examples) + predictions = lr.predictions(examples) + self.assertAllClose(0.693147, unregularized_loss.eval()) + self.assertAllClose(0.693147, loss.eval()) + train_op = lr.minimize() + for _ in range(_MAX_ITERATIONS): + train_op.run() + lr.update_weights(train_op).run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.411608 is the unregularized_loss at that optimum. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) + self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose( + 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSparseRandom(self): dim = 20 num_examples = 1000 @@ -320,7 +377,10 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op.run() def testDistributedSimple(self): - # Setup test data + # Distributed SDCA may not converge if the workers update concurrently the + # same example. In this test the examples are partitioned across workers. + # The examples are the same for all workers, just the example_ids are + # different. example_protos = [ make_example_proto({ 'age': [0], @@ -332,13 +392,19 @@ class SdcaWithLogisticLossTest(SdcaModelTest): }, 1), ] example_weights = [1.0, 1.0] + examples = make_example_dict(example_protos, example_weights) + example_ids = array_ops.placeholder( + dtypes.string, shape=(len(example_weights),)) + examples['example_ids'] = example_ids + variables = make_variable_dict(1, 1) for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): - examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) options = dict( - symmetric_l2_regularization=1, + # Keep the same solution as for TestSimple: since the number of + # examples is multplied by num_loss_partitions, multiply also + # L2 by the same value. + symmetric_l2_regularization=num_loss_partitions, symmetric_l1_regularization=0, loss_type='logistic_loss', num_table_shards=num_shards, @@ -354,32 +420,30 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def minimize(): + def minimize(worker_id): with self._single_threaded_test_session(): + feed_dict = {example_ids: [ + str(i + worker_id*len(example_weights)) for i in range( + len(example_weights))]} for _ in range(_MAX_ITERATIONS): - train_op.run() # pylint: disable=cell-var-from-loop + train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop threads = [] - for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=minimize)) + for worker_id in range(num_loss_partitions): + threads.append(threading.Thread(target=minimize, args=(worker_id,))) threads[-1].start() for t in threads: t.join() - lr.update_weights(train_op).run() - - # The high tolerance in unregularized_loss comparisons is due to the - # fact that it's possible to trade off unregularized_loss vs. - # regularization and still have a sum that is quite close to the - # optimal regularized_loss value. SDCA's duality gap only ensures - # that the regularized_loss is within 0.01 of optimal. - # 0.525457 is the optimal regularized_loss. - # 0.411608 is the unregularized_loss at that optimum. - self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) - self.assertAllClose(0.525457, loss.eval(), atol=0.01) + lr.update_weights(train_op).run(feed_dict={ + example_ids: [str(i) for i in range(len(example_weights))]}) + + # Test only the unregularized loss because the optimal value of the + # regularized loss depends on num_loss_partitions. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) - self.assertTrue(lr.approximate_duality_gap().eval() < 0.02) + self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02) def testSimpleNoL2(self): # Same as test above (so comments from above apply) but without an L2. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f980746a19fb8e0a02b9d023c127da7ab33e457f..0047d5753a773ce814d685f89da9ae6b04d21cb6 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,12 +22,14 @@ import collections from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -43,9 +45,6 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - This class currently only supports a single machine (multi-threaded) - implementation. We expect the weights and duals to fit in a single machine. - Loss functions supported: * Binary logistic loss @@ -182,18 +181,41 @@ class SdcaModel(object): # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic. def _create_slots(self): - # Make internal variables which have the updates before applying L1 - # regularization. + """Make unshrinked internal variables (slots).""" + # Unshrinked variables have the updates before applying L1 regularization. + # Each unshrinked slot variable is either a `Variable` or list of + # `Variable`, depending on the value of its corresponding primary variable. + # We avoid using `PartitionedVariable` for the unshrinked slots since we do + # not need any of the extra information. self._slots = collections.defaultdict(list) for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is - # fixed - self._slots['unshrinked_' + name].append( - var_ops.Variable( - array_ops.zeros_like(var.initialized_value(), dtypes.float32), - name=var.op.name + '_unshrinked/SDCAOptimizer')) + # Our primary variable may be either a PartitionedVariable, or a list + # of Variables (each representing a partition). + if (isinstance(var, var_ops.PartitionedVariable) or + isinstance(var, list)): + var_list = [] + # pylint: disable=protected-access + for v in var: + with ops.colocate_with(v): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 + # is fixed. + slot_var = var_ops.Variable( + initial_value=array_ops.zeros_like(v.initialized_value(), + dtypes.float32), + name=v.op.name + '_unshrinked/SDCAOptimizer') + var_list.append(slot_var) + self._slots['unshrinked_' + name].append(var_list) + # pylint: enable=protected-access + else: + with ops.device(var.device): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is + # fixed. + self._slots['unshrinked_' + name].append( + var_ops.Variable( + array_ops.zeros_like(var.initialized_value(), + dtypes.float32), + name=var.op.name + '_unshrinked/SDCAOptimizer')) def _assertSpecified(self, items, check_in): for x in items: @@ -205,16 +227,25 @@ class SdcaModel(object): if not isinstance(check_in[x], list): raise ValueError(x + ' must be a list.') + def _var_to_list(self, var): + """Wraps var in a list if it is not a list or PartitionedVariable.""" + if not (isinstance(var, list) or + isinstance(var, var_ops.PartitionedVariable)): + var = [var] + return var + def _l1_loss(self): """Computes the (un-normalized) l1 loss of the model.""" with name_scope('sdca/l1_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.abs(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append( + math_ops.reduce_sum( + math_ops.abs(math_ops.cast(weights, dtypes.float64)))) # SDCA L1 regularization cost is: l1 * sum(|weights|) return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) @@ -223,17 +254,37 @@ class SdcaModel(object): with name_scope('sdca/l2_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.square(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast( + weights, dtypes.float64)))) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" - return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list] + # input_list can be a list of Variables (that are implicitly partitioned), + # in which case the underlying logic in internal_convert_to_tensor will not + # concatenate the partitions together. This method takes care of the + # concatenating (we only allow partitioning on the first axis). + output_list = [] + for x in input_list: + tensor_to_convert = x + if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable): + # We only allow for partitioning on the first axis. + tensor_to_convert = array_ops.concat(x, axis=0) + output_list.append(internal_convert_to_tensor( + tensor_to_convert, as_ref=as_ref)) + return output_list + + def _get_first_dimension_size_statically(self, w, num_partitions): + """Compute the static size of the first dimension for a sharded variable.""" + dim_0_size = w[0].get_shape()[0] + for p in range(1, num_partitions): + dim_0_size += w[p].get_shape()[0] + return dim_0_size def _linear_predictions(self, examples): """Returns predictions of the form w*x.""" @@ -286,6 +337,28 @@ class SdcaModel(object): result = math_ops.sigmoid(result) return result + def _get_partitioned_update_ops(self, + v_num, + num_partitions_by_var, + p_assignments_by_var, + gather_ids_by_var, + weights, + full_update, + p_assignments, + num_partitions): + """Get updates for partitioned variables.""" + num_partitions = num_partitions_by_var[v_num] + p_assignments = p_assignments_by_var[v_num] + gather_ids = gather_ids_by_var[v_num] + updates = data_flow_ops.dynamic_partition( + full_update, p_assignments, num_partitions) + update_ops = [] + for p in range(num_partitions): + with ops.colocate_with(weights[p]): + result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) + update_ops.append(result) + return update_ops + def minimize(self, global_step=None, name=None): """Add operations to train a linear model by minimizing the loss function. @@ -318,18 +391,89 @@ class SdcaModel(object): # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. - weights_tensor = self._convert_n_to_tensor(self._slots[ - 'unshrinked_sparse_features_weights']) sparse_weights = [] sparse_indices = [] - for w, i in zip(weights_tensor, sparse_feature_indices): - # Find the feature ids to lookup in the variables. - with ops.device(w.device): - sparse_indices.append( - math_ops.cast( - array_ops.unique(math_ops.cast(i, dtypes.int32))[0], - dtypes.int64)) - sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) + # If we have partitioned variables, keep a few lists of Tensors around + # that we need for the assign_add after the op call to + # gen_sdca_ops.sdca_optimizer(). + num_partitions_by_var = [] + p_assignments_by_var = [] + gather_ids_by_var = [] + for w, i in zip(self._slots['unshrinked_sparse_features_weights'], + sparse_feature_indices): + # Append the sparse_indices (in full-variable space). + sparse_idx = math_ops.cast( + array_ops.unique(math_ops.cast(i, dtypes.int32))[0], + dtypes.int64) + sparse_indices.append(sparse_idx) + if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable): + num_partitions = len(w) + flat_ids = array_ops.reshape(sparse_idx, [-1]) + # We use div partitioning, which is easiest to support downstream. + # Compute num_total_ids as the sum of dim-0 of w, then assign + # to partitions based on a constant number of ids per partition. + # Optimize if we already know the full shape statically. + dim_0_size = self._get_first_dimension_size_statically( + w, num_partitions) + + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, + flat_ids.dtype) + else: + dim_0_sizes = [] + for p in range(num_partitions): + if w[p].get_shape()[0].value is not None: + dim_0_sizes.append(w[p].get_shape()[0].value) + else: + with ops.colocate_with(w[p]): + dim_0_sizes.append(array_ops.shape(w[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // num_partitions + extras = num_total_ids % num_partitions + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + new_ids = array_ops.where(p_assignments < extras, + flat_ids % (ids_per_partition + 1), + (flat_ids - extras) % ids_per_partition) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into num_partitions + # separate lists. + gather_ids = data_flow_ops.dynamic_partition(new_ids, + p_assignments, + num_partitions) + # Append these to the lists for use in the later update. + num_partitions_by_var.append(num_partitions) + p_assignments_by_var.append(p_assignments) + gather_ids_by_var.append(gather_ids) + + # Gather the weights from each partition. + partition_gathered_weights = [] + for p in range(num_partitions): + with ops.colocate_with(w[p]): + partition_gathered_weights.append( + array_ops.gather(w[p], gather_ids[p])) + + # Stitch the weights back together in the same order they were before + # we dynamic_partitioned them. + condition_indices = data_flow_ops.dynamic_partition( + math_ops.range(array_ops.shape(new_ids)[0]), + p_assignments, num_partitions) + batch_gathered_weights = data_flow_ops.dynamic_stitch( + condition_indices, partition_gathered_weights) + else: + w_as_tensor = internal_convert_to_tensor(w) + with ops.device(w_as_tensor.device): + batch_gathered_weights = array_ops.gather( + w_as_tensor, sparse_idx) + sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( @@ -355,12 +499,25 @@ class SdcaModel(object): with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] # Update the weights before the proximal step. - for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'], - sparse_indices, sfw): - update_ops.append(state_ops.scatter_add(w, i, u)) + for v_num, (w, i, u) in enumerate( + zip(self._slots['unshrinked_sparse_features_weights'], + sparse_indices, sfw)): + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + update_ops += self._get_partitioned_update_ops( + v_num, num_partitions_by_var, p_assignments_by_var, + gather_ids_by_var, w, u, p_assignments, num_partitions) + else: + update_ops.append(state_ops.scatter_add(w, i, u)) for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): - update_ops.append(w.assign_add(u)) - + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + split_updates = array_ops.split( + u, num_or_size_splits=[v.shape.as_list()[0] for v in w]) + for v, split_update in zip(w, split_updates): + update_ops.append(state_ops.assign_add(v, split_update)) + else: + update_ops.append(state_ops.assign_add(w, u)) if not global_step: return control_flow_ops.group(*update_ops) with ops.control_dependencies(update_ops): @@ -385,21 +542,22 @@ class SdcaModel(object): for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): - update_ops.append(var.assign(slot_var)) + for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)): + update_ops.append(v.assign(sv)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # pylint: disable=protected-access - update_ops.append( - gen_sdca_ops.sdca_shrink_l1( - self._convert_n_to_tensor( - [var], as_ref=True), - l1=self._symmetric_l1_regularization(), - l2=self._symmetric_l2_regularization())) + for v in self._var_to_list(var): + with ops.device(v.device): + # pylint: disable=protected-access + update_ops.append( + gen_sdca_ops.sdca_shrink_l1( + self._convert_n_to_tensor([v], as_ref=True), + l1=self._symmetric_l1_regularization(), + l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops) def approximate_duality_gap(self): diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index d4e54c82f988e0adcd16aad29702ee9f8b16aea3..200e7de6b95f17672c6ef51f887b15f9d185f775 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): num_loss_partitions = params["num_loss_partitions"] weight_column_name = params["weight_column_name"] update_weights_hook = params.get("update_weights_hook", None) + partitioner = params["partitioner"] loss_type = None if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access @@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None): example_id_column=example_id_column, num_loss_partitions=n_loss_partitions, symmetric_l1_regularization=l1_regularization, - symmetric_l2_regularization=l2_regularization) + symmetric_l2_regularization=l2_regularization, + partitioner=partitioner) parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), name_or_scope=parent_scope, + partitioner=partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -213,7 +216,8 @@ class _SDCAEstimator(estimator.Estimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `_SDCAEstimator` estimator object. Args: @@ -241,6 +245,8 @@ class _SDCAEstimator(estimator.Estimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `_SDCAEstimator` estimator. @@ -267,6 +273,7 @@ class _SDCAEstimator(estimator.Estimator): "l2_regularization": l2_regularization, "weight_column_name": weight_column_name, "update_weights_hook": _SdcaUpdateWeightsHook(), + "partitioner": partitioner, } super(_SDCAEstimator, self).__init__( @@ -336,7 +343,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALogisticClassifier` object. Args: @@ -361,6 +369,8 @@ class SDCALogisticClassifier(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALogisiticClassifier` estimator. @@ -376,7 +386,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_classes(self, input_fn=None): """Runs inference to determine the predicted class. @@ -463,7 +474,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALinearRegressor` estimator object. @@ -489,6 +501,8 @@ class SDCALinearRegressor(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALinearRegressor` estimator. @@ -503,7 +517,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_scores(self, input_fn): """Returns predicted scores for given features. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index bed3d5139fcbf9d9e8b85605c752736f26af6793..647667188238dc18b137eaad98356a79b3a549b4 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.linear_optimizer.python import sdca_estimator from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test @@ -273,6 +274,47 @@ class SDCALogisticClassifierTest(test.TestCase): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testPartitionedMixedFeatures(self): + """Tests SDCALogisticClassifier with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([900.0, 700.0, 600.0]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(metrics['accuracy'], 0.9) + class SDCALinearRegressorTest(test.TestCase): @@ -350,6 +392,48 @@ class SDCALinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testMixedFeaturesArbitraryWeightsPartitioned(self): + """Tests SDCALinearRegressor works with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + regressor = sdca_estimator.SDCALinearRegressor( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + l2_regularization=1.0, + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """SDCALinearRegressor works with sparse features and L1 regularization.""" diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 12039ecc6f357af07e0c2a08e17d46396f3ad386..9872c6f97c879d8994b6c26e65df33e368a0603e 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -64,7 +64,8 @@ class SDCAOptimizer(object): of workers running the train steps. It defaults to 1 (single machine). `num_table_shards` defines the number of shards for the internal state table, typically set to match the number of parameter servers for large - data sets. + data sets. You can also specify a `partitioner` object to partition the primal + weights during training (`div` partitioning strategy will be used). """ def __init__(self, @@ -73,13 +74,15 @@ class SDCAOptimizer(object): num_table_shards=None, symmetric_l1_regularization=0.0, symmetric_l2_regularization=1.0, - adaptive=True): + adaptive=True, + partitioner=None): self._example_id_column = example_id_column self._num_loss_partitions = num_loss_partitions self._num_table_shards = num_table_shards self._symmetric_l1_regularization = symmetric_l1_regularization self._symmetric_l2_regularization = symmetric_l2_regularization self._adaptive = adaptive + self._partitioner = partitioner def get_name(self): return 'SDCAOptimizer' @@ -108,6 +111,10 @@ class SDCAOptimizer(object): def adaptive(self): return self._adaptive + @property + def partitioner(self): + return self._partitioner + def get_train_step(self, columns_to_variables, weight_column_name, loss_type, features, targets, global_step): """Returns the training operation of an SdcaModel optimizer.""" @@ -175,10 +182,12 @@ class SDCAOptimizer(object): sparse_feature_column = _dense_tensor_to_sparse_feature_column( dense_bucket_tensor) sparse_feature_with_values.append(sparse_feature_column) - # For bucketized columns, the variables list contains exactly one - # element. - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) elif isinstance( column, ( @@ -226,8 +235,12 @@ class SDCAOptimizer(object): array_ops.shape(ids)[0]), [-1]) sparse_feature_with_values.append( SparseFeatureColumn(example_ids_filtered, reproject_ids, weights)) - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) else: raise ValueError('SDCAOptimizer does not support column type %s.' % type(column).__name__) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 55b984f260ec49ab9b52be6402885a46226cba70..73f5c1448d91c573efed34c6aaaf5c28feac6555 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -90,6 +90,16 @@ cc_library( deps = [":context"], ) +cc_library( + name = "kernel_api", + hdrs = [ + "builtin_op_data.h", + "builtin_ops.h", + "context.h", + "context_util.h", + ], +) + exports_files(["builtin_ops.h"]) cc_library( @@ -118,6 +128,7 @@ cc_library( hdrs = [ "allocation.h", "context.h", + "context_util.h", "error_reporter.h", "graph_info.h", "interpreter.h", @@ -174,6 +185,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index cc8a8035d1dadeec98886ba1dae4cdf403f26de4..a616138d3321d43f66a2b430f7df609a13b9caf6 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -17,7 +17,29 @@ else endif endif -ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) +HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Self-hosting +TARGET_ARCH := ${HOST_ARCH} + +# Cross compiling +ifeq ($(CROSS),rpi) + TARGET_ARCH := armv7l + TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf- +endif + +ifeq ($(CROSS),riscv) + TARGET_ARCH := riscv + TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf- +endif +ifeq ($(CROSS),stm32f7) + TARGET_ARCH := armf7 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +endif +ifeq ($(CROSS),stm32f1) + TARGET_ARCH := armm1 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +endif # Where compiled objects are stored. OBJDIR := $(MAKEFILE_DIR)/gen/obj/ @@ -25,11 +47,46 @@ BINDIR := $(MAKEFILE_DIR)/gen/bin/ LIBDIR := $(MAKEFILE_DIR)/gen/lib/ GENDIR := $(MAKEFILE_DIR)/gen/obj/ +LIBS := +ifeq ($(TARGET_ARCH),x86_64) + CXXFLAGS += -fPIC -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -pthread # -msse4.2 +endif + +ifeq ($(TARGET_ARCH),armv7l) + CXXFLAGS += -mfpu=neon -pthread -fPIC + LIBS += -ldl +endif + +ifeq ($(TARGET_ARCH),riscv) +# CXXFLAGS += -march=gap8 + CXXFLAGS += -DTFLITE_MCU + LIBS += -ldl + BUILD_TYPE := micro +endif + +ifeq ($(TARGET_ARCH),armf7) + CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -DTFLITE_MCU + CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections + CXXFLAGS += -funsigned-char -MMD + CXXFLAGS += -mcpu=cortex-m7 -mthumb -mfpu=fpv5-sp-d16 -mfloat-abi=softfp + CXXFLAGS += '-std=gnu++11' '-fno-rtti' '-Wvla' '-c' '-Wall' '-Wextra' '-Wno-unused-parameter' '-Wno-missing-field-initializers' '-fmessage-length=0' '-fno-exceptions' '-fno-builtin' '-ffunction-sections' '-fdata-sections' '-funsigned-char' '-MMD' '-fno-delete-null-pointer-checks' '-fomit-frame-pointer' '-Os' + LIBS += -ldl + BUILD_TYPE := micro +endif +ifeq ($(TARGET_ARCH),armm1) + CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -mcpu=cortex-m1 -mthumb -DTFLITE_MCU + CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections + CXXFLAGS += -funsigned-char -MMD + LIBS += -ldl +endif + # Settings for the host compiler. -CXX := $(CC_PREFIX)gcc -CXXFLAGS := --std=c++11 -O3 -DNDEBUG -CC := $(CC_PREFIX)gcc -CCFLAGS := -O3 -DNDEBUG +CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++ +CXXFLAGS += --std=c++11 -O3 -DNDEBUG +CCFLAGS := ${CXXFLAGS} +CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar +CFLAGS := LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -48,7 +105,7 @@ INCLUDES := \ # override local versions in the source tree. INCLUDES += -I/usr/local/include -LIBS := \ +LIBS += \ -lstdc++ \ -lpthread \ -lm \ @@ -70,6 +127,12 @@ LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. MINIMAL_PATH := $(BINDIR)minimal +# Benchmark static library and binary +BENCHMARK_LIB_NAME := benchmark-lib.a +BENCHMARK_BINARY_NAME := benchmark_model +BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) +BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) + MINIMAL_SRCS := \ tensorflow/contrib/lite/examples/minimal/minimal.cc MINIMAL_OBJS := $(addprefix $(OBJDIR), \ @@ -78,19 +141,29 @@ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. +PROFILER_SRCS := \ + tensorflow/contrib/lite/profiling/time.cc +PROFILE_SUMMARIZER_SRCS := \ + tensorflow/contrib/lite/profiling/profile_summarizer.cc \ + tensorflow/core/util/stats_calculator.cc + CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ +$(wildcard tensorflow/contrib/lite/*.c) +ifneq ($(BUILD_TYPE),micro) +CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ -$(wildcard tensorflow/contrib/lite/*.c) \ +$(PROFILER_SRCS) \ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ $(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ $(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) +endif # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ @@ -100,6 +173,11 @@ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ $(MINIMAL_SRCS) +ifeq ($(BUILD_TYPE),micro) +CORE_CC_EXCLUDE_SRCS += \ +tensorflow/contrib/lite/model.cc \ +tensorflow/contrib/lite/nnapi_delegate.cc +endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # File names of the intermediate files target compilation generates. @@ -107,18 +185,33 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) LIB_OBJS := $(TF_LITE_CC_OBJS) +# Benchmark sources +BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark +BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \ + $(PROFILE_SUMMARIZER_SRCS) + +BENCHMARK_SRCS := $(filter-out \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \ + $(BENCHMARK_ALL_SRCS)) + +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.c @mkdir -p $(dir $@) $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(MINIMAL_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) + +# The target that's compiled for micro-controllers +micro: $(LIB_PATH) # Gathers together all the objects we've compiled into a single '.a' archive. $(LIB_PATH): $(LIB_OBJS) @@ -131,6 +224,21 @@ $(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS) + +benchmark_lib: $(BENCHMARK_LIB) +$(info $(BENCHMARK_BINARY)) +$(BENCHMARK_BINARY) : $(BENCHMARK_LIB) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_BINARY) \ + $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS) + +benchmark: $(BENCHMARK_BINARY) + # Gets rid of all generated files. clean: rm -rf $(MAKEFILE_DIR)/gen diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index a4772731ecda92431c412672610a39c188dabf27..c42622ff02fc2837b61b35f19e834276c0518d1e 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#ifndef TFLITE_MCU #include +#endif #include #include #include @@ -27,10 +29,13 @@ limitations under the License. #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" +#ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" +#endif namespace tflite { +#ifndef TFLITE_MCU MMAPAllocation::MMAPAllocation(const char* filename, ErrorReporter* error_reporter) : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { @@ -111,6 +116,7 @@ MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, buffer_ = ptr; buffer_size_bytes_ = num_bytes; } +#endif MemoryAllocation::~MemoryAllocation() {} diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 4f836d367747e06de682b5764206d33f6e2fb983..4257e754ad5c30e17ec8ba8d5c6e69b5c5bcd728 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -31,16 +31,17 @@ struct AllocationInfo { // The tensor index to be allocated or deallocated. int tensor; // Whether to allocate or deallocate - enum { ALLOC, DEALLOC } type; + enum Type { ALLOC, DEALLOC } type; }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, - std::unique_ptr graph_info) + std::unique_ptr graph_info, + bool preserve_inputs) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment) {} - + persistent_arena_(kDefaultArenaAlignment), + preserve_inputs_(preserve_inputs) {} ArenaPlanner::~ArenaPlanner() {} int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { @@ -67,6 +68,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Keeps track of references to each tensor. std::vector refcounts(graph_info_->num_tensors(), 0); + // `allocated` and `deallocated` are technically list of boolean values. + // We're saving the compiled binary size by using `vector`. + std::vector allocated(graph_info_->num_tensors(), false); + std::vector deallocated(graph_info_->num_tensors(), false); + + auto allocate = [this, &allocated, &deallocated](int node, + int tensor) -> TfLiteStatus { + if (allocated[tensor]) { + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC}); + allocated[tensor] = true; + return kTfLiteOk; + }; + + auto deallocate = [this, &allocated, &deallocated]( + int node, int tensor) -> TfLiteStatus { + if (!allocated[tensor]) { + // Do not enqueue a DEALLOC if the tensor is never allocated. + // This happened with the constant tensors. + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC}); + return kTfLiteOk; + }; // There will be an entry in alloc_queue_ for the allocation of each tensor // and another for their deallocation. @@ -79,6 +107,32 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } + // Variable tensors should are also never overwritten and need to be alive all + // the time. + for (int tensor_index : graph_info_->variables()) { + refcounts[tensor_index]++; + } + + // Queue all graph inputs for allocation. If preserve_inputs_ is true, make + // sure they never be overwritten. + for (int tensor_index : graph_info_->inputs()) { + if (tensor_index != kOptionalTensor) { + if (preserve_inputs_) { + refcounts[tensor_index]++; + } + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + + // Queue all graph variable tensors for allocation. + for (int tensor_index : graph_info_->variables()) { + if (tensor_index != kOptionalTensor) { + // Increase the reference count for input tensors by one, so it will + // never be deallocated. + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + // Count references to node input tensors. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -94,10 +148,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Queue all graph inputs for allocation. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { - alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } - // Go through the graph in execution order. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -106,7 +159,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { TfLiteIntArray* node_outputs = node.outputs; for (int j = 0; j < node_outputs->size; ++j) { int tensor_index = node_outputs->data[j]; - alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(i, tensor_index)); } // Then update the ref-counts of the node's inputs, and if necessary queue @@ -117,7 +170,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { if (tensor_index != kOptionalTensor) { refcounts[tensor_index]--; if (refcounts[tensor_index] == 0) { - alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC}); + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); } } } diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index e9d0fbc5a9b5aec06e28da8757466b25f40da2f5..1d84950e91bc48fd1c1a7e5b2d9063e20dea0718 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -43,8 +43,11 @@ struct AllocationInfo; class ArenaPlanner : public MemoryPlanner { public: // Ownership of 'context' is not taken and it must remain util the - // ArenaPlanner is destroyed. - ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info); + // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the + // graph will not share memory with any other tensor, effectively preserving + // them until the end of inference. + ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info, + bool preserve_inputs); ~ArenaPlanner() override; ArenaPlanner(const ArenaPlanner&) = delete; ArenaPlanner& operator=(const ArenaPlanner&) = delete; @@ -100,6 +103,8 @@ class ArenaPlanner : public MemoryPlanner { // Raw memory buffer that is allocated for persistent tensors that are // declared as kTfLiteArenaRwPersistent. SimpleMemoryArena persistent_arena_; + + bool preserve_inputs_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441..f5bd1932f976f5c7d0f0d14bbaf9ca3807dfd3b0 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -100,12 +100,18 @@ class TestGraph { std::vector* tensors() { return &tensors_; } const std::vector& inputs() { return inputs_; } const std::vector& outputs() { return outputs_; } + const std::vector& variables() { return variables_; } + + void SetVariables(const std::vector& variables) { + variables_ = variables; + } private: std::vector nodes_; std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // The GraphInfo for a TestGraph. @@ -123,6 +129,9 @@ class TestGraphInfo : public GraphInfo { } const std::vector& inputs() const override { return graph_->inputs(); } const std::vector& outputs() const override { return graph_->outputs(); } + const std::vector& variables() const override { + return graph_->variables(); + } private: TestGraph* graph_; @@ -142,11 +151,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) { class ArenaPlannerTest : public ::testing::Test { protected: - void SetGraph(TestGraph* graph) { + void SetGraph(TestGraph* graph, bool preserve_inputs = false) { graph_ = graph; context_.ReportError = ReportError; planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr(new TestGraphInfo(graph)))); + &context_, std::unique_ptr(new TestGraphInfo(graph)), + preserve_inputs)); CHECK(planner_->ResetAllocations() == kTfLiteOk); CHECK(planner_->PlanAllocations() == kTfLiteOk); } @@ -209,11 +219,8 @@ TEST_F(ArenaPlannerTest, ZeroSizedTensors) { TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); (*graph.tensors())[1].bytes = 0; SetGraph(&graph); - // TODO(ahentz): this is currently broken because the arena finds two - // allocations with the same offset and returns an error. - ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); - // EXPECT_EQ(GetOffset(1), 0); - // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + ASSERT_EQ(planner_->ExecuteAllocations(0, 10), kTfLiteOk); + EXPECT_EQ((*graph_->tensors())[1].data.raw, nullptr); } TEST_F(ArenaPlannerTest, SimpleGraph) { @@ -237,6 +244,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) { EXPECT_EQ(GetOffset(3), 0); } +TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph, /*preserve_inputs=*/true); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + // Because we are keeping the inputs alive until the end (due to + // preserve_inputs=true), the output tensor will not be able to use that + // space. It will end up using the same are as tensor #2. + EXPECT_EQ(GetOffset(3), GetOffsetAfter(1)); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) { TestGraph graph({0, 1}, { @@ -309,13 +340,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) { { /* in, out, tmp */ {{0, 1}, {2}, {}}, // First op - {{2, 0}, {4}, {5}}, // Second op, with temporary + {{2, 0}, {4}, {5}}, // Second op, with persistent {{4, -1}, {3}, {}} // Third op, with optional }, {3}); // Make #1 persistent so it goes into its own arena. (*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent; + // The only use case for kTfLiteArenaRwPersistent is variable tensor now. + graph.SetVariables({1}); SetGraph(&graph); Execute(0, 10); diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 85216776823eab2ab3ac2a3bc666f21e312acc6c..5543acc1f5dabaa8a54ec4d1f2027bc66a00f6db 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,4 +1,8 @@ """Generate Flatbuffer binary from json.""" +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) def tflite_copts(): """Defines compile time flags.""" @@ -185,32 +189,114 @@ def json_to_tflite(name, src, out): tools = [flatc], ) -def gen_zipped_test_files(name, files): +# This is the master list of generated examples that will be made into tests. A +# function called make_XXX_tests() must also appear in generate_examples.py. +# Disable a test by commenting it out. If you do, add a link to a bug or issue. +def generated_test_models(): + return [ + "add", + "arg_max", + "avg_pool", + "batch_to_space_nd", + "concat", + "constant", + "control_dep", + "conv", + "depthwiseconv", + "div", + "equal", + "exp", + "expand_dims", + "floor", + "fully_connected", + "fused_batch_norm", + "gather", + "global_batch_norm", + "greater", + "greater_equal", + "sum", + "l2norm", + "l2_pool", + "less", + "less_equal", + "local_response_norm", + "log_softmax", + "log", + "lstm", + "max_pool", + "maximum", + "mean", + "minimum", + "mul", + "neg", + "not_equal", + "pad", + "padv2", + # "prelu", + "pow", + "relu", + "relu1", + "relu6", + "reshape", + "resize_bilinear", + "rsqrt", + "shape", + "sigmoid", + "sin", + "slice", + "softmax", + "space_to_batch_nd", + "space_to_depth", + "sparse_to_dense", + "split", + "sqrt", + "squeeze", + "strided_slice", + "strided_slice_1d_exhaustive", + "sub", + "tile", + "topk", + "transpose", + "transpose_conv", + "where", + ] + +def gen_zip_test(name, test_name, **kwargs): + """Generate a zipped-example test and its dependent zip files. + + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) + +def gen_zipped_test_file(name, file): """Generate a zip file of tests by using :generate_examples. Args: - name: Name of output. We will produce "`name`_files" as a target. - files: A list of zip file basenames. + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" """ toco = "//tensorflow/contrib/lite/toco:toco" - out_files = [] - for f in files: - out_file = name + "/" + f - out_files.append(out_file) - native.genrule( - name = name + "_" + f + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + f + " $(@D)"), - outs = [out_file], - tools = [ - ":generate_examples", - toco, - ], - ) + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) native.filegroup( name = name, - srcs = out_files, + srcs = [file], ) def gen_selected_ops(name, model): diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75..e9531aef19f04adf719156aa3e874dc5ce6e2b04 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,22 +19,23 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a +# Build library for supported architectures and packs them in a fat binary. +make_library() { + for arch in x86_64 i386 armv7 armv7s arm64 + do + make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \ + -j 8 \ + $SCRIPT_DIR/gen/lib/ios_${arch}/${1} + done + lipo \ + tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \ + tensorflow/contrib/lite/gen/lib/ios_i386/${1} \ + tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \ + tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \ + tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \ + -create \ + -output tensorflow/contrib/lite/gen/lib/${1} +} -lipo \ -tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ --create \ --output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a +make_library libtensorflow-lite.a +make_library benchmark-lib.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 8660c653ae4c0c69e4f5ad8fae739c8c1db7414c..cda889bf502a535eac4249bbae645359cdb2135d 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -92,8 +92,17 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + typedef struct { + // Parameters for FullyConnected version 1 or above. TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; } TfLiteFullyConnectedParams; typedef enum { @@ -148,10 +157,20 @@ typedef struct { float beta; } TfLiteLocalResponseNormParams; +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + typedef struct { + // Parameters for LSTM version 1. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; typedef struct { @@ -205,7 +224,7 @@ typedef struct { typedef struct { bool keep_dims; -} TfLiteMeanParams; +} TfLiteReducerParams; typedef struct { int num_splits; @@ -236,6 +255,14 @@ typedef struct { int stride_height; } TfLiteTransposeConvParams; +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 7e285186f45a61a451fd7328b061e16059049ea5..a44e9182302d19acd1e1c183ed388531eec11d93 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -93,10 +93,20 @@ typedef enum { kTfLiteBuiltinSlice = 65, kTfLiteBuiltinSin = 66, kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, + kTfLiteBuiltinTile = 69, + kTfLiteBuiltinExpandDims = 70, + kTfLiteBuiltinEqual = 71, + kTfLiteBuiltinNotEqual = 72, + kTfLiteBuiltinLog = 73, + kTfLiteBuiltinSum = 74, + kTfLiteBuiltinSqrt = 75, + kTfLiteBuiltinRsqrt = 76, + kTfLiteBuiltinShape = 77, + kTfLiteBuiltinPow = 78, } TfLiteBuiltinOperator; #ifdef __cplusplus } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index 5c6f5e72a47180cd98be46f60cfa8eaf28197806..7f2aa316f4a9a265b14a216a6ffa53c7f0757426 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) { void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor) { + const void* allocation, bool is_variable, TfLiteTensor* tensor) { TfLiteTensorFree(tensor); tensor->type = type; tensor->name = name; @@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, tensor->bytes = size; tensor->allocation_type = allocation_type; tensor->allocation = allocation; + tensor->is_variable = is_variable; } void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 4eb66cc225eb04923be9aaa445a335ad822c8a6f..1ff8843fa78f48fc74b4d7e7d0cc4ae2a0d255af 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -39,6 +39,26 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteMaxExternalContexts = 2 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; typedef struct _TfLiteDelegate TfLiteDelegate; @@ -138,6 +158,8 @@ typedef enum { kTfLiteInt64 = 4, kTfLiteString = 5, kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -148,7 +170,7 @@ typedef struct { int32_t zero_point; } TfLiteQuantizationParams; -// A union of points that points to memory for a given tensor. +// A union of pointers that points to memory for a given tensor. typedef union { int* i32; int64_t* i64; @@ -157,6 +179,8 @@ typedef union { const char* raw_const; uint8_t* uint8; bool* b; + int16_t* i16; + _Complex float* c64; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -223,6 +247,9 @@ typedef struct { // delegate buffer. // WARNING: This is an // experimental interface that is subject to change. bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; } TfLiteTensor; // Free data memory of tensor `t`; @@ -235,9 +262,11 @@ void TfLiteTensorFree(TfLiteTensor* t); void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor); + const void* allocation, bool is_variable, + TfLiteTensor* tensor); -// Resize the allocated data of a (dynamic) tensor. +// Resize the allocated data of a (dynamic) tensor. Tensors with allocation +// types other than kTfLiteDynamic will be ignored. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); // A structure representing an instance of a node. @@ -330,10 +359,15 @@ typedef struct TfLiteContext { // eigen. int recommended_num_threads; - // TODO(ahentz): we should create a more general mechanism for this sort of - // library-global objects. - void* gemm_context; - void* eigen_context; + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); } TfLiteContext; typedef struct _TfLiteRegistration { @@ -368,6 +402,14 @@ typedef struct _TfLiteRegistration { // Returns kTfLiteOk on success. TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like // NN API. diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h new file mode 100644 index 0000000000000000000000000000000000000000..abe802e34214caf4d5063da827b3aca4a82aa56d --- /dev/null +++ b/tensorflow/contrib/lite/context_util.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides a few C++ helpers that are useful for manipulating C structures +// in C++. +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + size_t size() const { return end() - begin(); } + + private: + const TfLiteIntArray* int_array_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..35a8f6ca4166e373ea1a0af5d4a013327b30d2b6 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = [ + "//visibility:public", +]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "nnapi_delegate", + srcs = ["nnapi_delegate.cc"], + hdrs = ["nnapi_delegate.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "nnapi_delegate_test", + size = "small", + srcs = ["nnapi_delegate_test.cc"], + deps = [ + ":nnapi_delegate", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd798c209e5112235cf6e351e231d4096006a8b0 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -0,0 +1,678 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +#ifdef __ANDROID__ +#include +#endif + +namespace tflite { +namespace { + +// TODO(b/80621585): Consider printing error string, but don't for now to +// minimize binary size. +#define CHECK_NN(context, code) \ + if (code != ANEURALNETWORKS_NO_ERROR) { \ + context->ReportError(context, "NN API returned error (%d).\n", code); \ + return kTfLiteError; \ + } + +namespace { +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return std::numeric_limits::max(); + } + } + return atoi(sdkVersion); + } +#endif // __ANDROID__ + return 0; +} + +constexpr int32_t kMinSdkVersionForNNAPI = 27; +constexpr int32_t kMinSdkVersionForNNAPI11 = 28; +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); + +} // namespace + +// RAII NN API Model Destructor for use with std::unique_ptr +struct NNFreeModel { + void operator()(ANeuralNetworksModel* model) { + ANeuralNetworksModel_free(model); + } +}; +// RAII NN API Compilation Destructor for use with std::unique_ptr +struct NNFreeCompilation { + void operator()(ANeuralNetworksCompilation* model) { + ANeuralNetworksCompilation_free(model); + } +}; + +// Track tensor indices to NN API tensor indices mapping. +class OperandMapping { + public: + // Given a TFLite index return the ANN index. If it doesn't exist + // return -1. + int lite_index_to_ann(int index) const { + if (index < lite_tensor_to_ann_tensor_.size()) + return lite_tensor_to_ann_tensor_[index]; + else + return -1; + } + + // NN API uses non tensor operands instead of structs. This creates one + // and returns the index. It uses a std::vector and resizes it as needed + // keeping -1 to unmapped values. Intermediate tensors likely will not + // be mapped. + int add_new_non_tensor_operand() { return next_ann_tensor_index_++; } + + // Add a new mapping from `tflite_index` and return the NN API tensor index. + int add_new_ann_tensor_index(int tflite_index) { + if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { + lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1); + } + int new_tensor_index = next_ann_tensor_index_++; + lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; + return new_tensor_index; + } + + private: + // Next index of ann tensor + int next_ann_tensor_index_ = 0; + + // Mapping from lite index. Use a std::vector for speed and code size + // rather than a map. + std::vector lite_tensor_to_ann_tensor_; +}; + +// Abstract builder for building an op in the NN API graph. This handles +// the disparity between TFLite and NN API operand types. NN API has singular +// operands for both tensors and parameters, and TFLite separates the two. +class NNAPIOpBuilder { + public: + NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + ANeuralNetworksModel* nn_model) + : context_(context), + operand_mapping_(tensor_mapping), + nn_model_(nn_model) {} + + TfLiteStatus AddScalarInt32Operand(int32_t value) { + return AddScalarOperand(value, ANEURALNETWORKS_INT32); + } + + TfLiteStatus AddScalarFloat32Operand(float value) { + return AddScalarOperand(value, ANEURALNETWORKS_FLOAT32); + } + + TfLiteStatus AddVectorInt32Operand(const int32_t* values, + uint32_t num_values) { + return AddVectorOperand(values, num_values, + ANEURALNETWORKS_TENSOR_INT32); + } + + TfLiteStatus AddPoolingParams(void* data) { + auto builtin = reinterpret_cast(data); + AddScalarInt32Operand(builtin->padding); + AddScalarInt32Operand(builtin->stride_width); + AddScalarInt32Operand(builtin->stride_height); + AddScalarInt32Operand(builtin->filter_width); + AddScalarInt32Operand(builtin->filter_height); + AddScalarInt32Operand(builtin->activation); + return kTfLiteOk; + } + + TfLiteStatus AddTensorInput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + TfLiteStatus AddTensorOutput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_outputs_.push_back(ann_index); + return kTfLiteOk; + } + + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. + // This returns the NN API tensor index corresponding to the created tensor. + // If another caller previously created a NN API tensor for `tensor_index` + // then the existing one is returned. + TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); + if (ann_tensor_index != -1) { + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + // Allocate a new tensor index + ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index); + + // Parameters needed for new type. + int32_t nn_type = 0; + float scale = 0.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + *ann_tensor_index_out = -1; + return kTfLiteOk; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + default: + context_->ReportError(context_, "Logic error in NN API Delegate.\n"); + return kTfLiteError; + } + + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + if (tensor->allocation_type == kTfLiteMmapRo) { + // TODO(b/80630405): Use NNAPIAllocation. + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_tensor_index, tensor->data.raw, + tensor->bytes)); + } + + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + CHECK_NN(context_, ANeuralNetworksModel_addOperation( + nn_model_, type, + static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_inputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + template + TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { + ANeuralNetworksOperandType operand_type{.type = nn_type}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(T))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + template + TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values, + int32_t nn_type) { + ANeuralNetworksOperandType operand_type{ + .type = nn_type, .dimensionCount = 1, .dimensions = &num_values}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, + ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, values, sizeof(T) * num_values)); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + // TfLiteContext for error handling. Must be named context for macros to + // work. + TfLiteContext* context_; + + // Tracks relationship between indices + OperandMapping* operand_mapping_; + + // The model + ANeuralNetworksModel* nn_model_; + + // Inputs and outputs for the current op. These are augmented in the sense + // that NN API uses operands for all arguments, not just tensors, unlike + // TensorFlow lite. + std::vector augmented_inputs_; + std::vector augmented_outputs_; +}; + +// The kernel that represents the subgraph of TF Lite being run on NN API. +class NNAPIDelegateKernel { + public: + NNAPIDelegateKernel() = default; + + typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, + NNAPIOpBuilder* builder, + TfLiteNode* node); + + // Return a function that knows how to translate a node into its operands + // when called. You can use this function to see if a node is supported + // (i.e. that MappingFn is not nullptr). + MappingFn Map(TfLiteContext* context, int builtin_code, int version, + TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMul: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_MUL; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinAveragePool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMaxPool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_MAX_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinL2Pool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_L2_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConv2d: + if (version == 1) { + auto builtin = + reinterpret_cast(node->builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node->inputs->size != 3) { + // NNAPI does not support dilated Conv2D. + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDepthwiseConv2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->depth_multiplier); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DEPTHWISE_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFullyConnected: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_FULLY_CONNECTED; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSoftmax: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_SOFTMAX; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReshape: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RESHAPE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSqueeze: + // Squeeze requires NNAPI1.1. + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + // Note that we add the squeeze dimensions even if the dimensions + // were unspecified (empty), as NNAPI requires the operand. + builder->AddVectorInt32Operand( + builtin->squeeze_dims, + static_cast(builtin->num_squeeze_dims)); + return ANEURALNETWORKS_SQUEEZE; + }; + } else { + return nullptr; + } + break; + default: + return nullptr; + } + } + + // Initialize the kernel (a NN model). + TfLiteStatus Init(TfLiteContext* context, + const TfLiteDelegateParams* params) { + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + nodes_.push_back(node_index); + } + + if (!nn_model_) { + ANeuralNetworksModel* model; + CHECK_NN(context, ANeuralNetworksModel_create(&model)); + nn_model_.reset(model); + + TF_LITE_ENSURE_STATUS( + BuildGraph(context, params->input_tensors, params->output_tensors)); + } + + if (!nn_compilation_) { + ANeuralNetworksCompilation* compilation; + CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation)); + nn_compilation_.reset(compilation); + } + return kTfLiteOk; + } + + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + + // Set the input tensor buffers. Note: we access tflite tensors using + // absolute indices but NN api indices inputs by relative indices. + int relative_input_index = 0; + for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[absolute_input_index]; + // TODO(miaowang): make sure the delegation works with dequantized weights + // as intermediate tensors. + if (tensor->allocation_type != kTfLiteMmapRo) { + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } + } + + // Set the output tensor buffers. + int relative_output_index = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; + } + // Invoke ANN in blocking fashion. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(context, ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + + return kTfLiteOk; + } + + private: + // ANN API state. + std::unique_ptr nn_model_; + std::unique_ptr + nn_compilation_; + // Node indices that this delegate is responsible for. Indices here + // indexes into the nodes array in the TfLiteContext. + std::vector nodes_; + // Track indices we use + OperandMapping operand_mapping_; + + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { + // The operand builder allows creating a single op. We create it at this + // reduced power position rather than in the for loop to avoid reallocating + // the vectors. + NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); + // Add Tensors + // allocate outside to avoid realloc + for (auto node_index : nodes_) { + // Obtain the op and registration. + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + // Map inputs to NN API tensor indices. + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } + // Get op type and operands + int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( + context, &builder, node); + // Map outputs to NN API tensor indices. + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + } + + builder.FinalizeAddOperation(nn_op_type); + } + return kTfLiteOk; + } + + TfLiteStatus BuildGraph(TfLiteContext* context, + const TfLiteIntArray* input_tensors, + const TfLiteIntArray* output_tensors) { + // Build the ops and tensors. + TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context)); + // Map input and output tensor indices to ANN + std::vector inputs; + inputs.reserve(input_tensors->size); + std::vector outputs; + outputs.reserve(output_tensors->size); + // Make the TensorFlow lite inputs and outputs to ann_indices. + for (int i : TfLiteIntArrayView(input_tensors)) { + // Constant tensors are not NNAPI inputs. + if (context->tensors[i].allocation_type != kTfLiteMmapRo) { + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + } + } + for (int i : TfLiteIntArrayView(output_tensors)) + outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + // Tell ANN to declare inputs/outputs + CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_.get(), inputs.size(), inputs.data(), + outputs.size(), outputs.data())); + // Finalize the model + CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + + return kTfLiteOk; + } +}; + +} // namespace + +// Return a NN API Delegate struct that can check for support of ops. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = { + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Do not check nodes_ if NN API is unavailable. + if (kAndroidSdkVersion < kMinSdkVersionForNNAPI || !NNAPIExists()) { + return kTfLiteOk; + } + + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int total_supported_nodes = 0; + + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + NNAPIDelegateKernel dummy_kernel; + if (dummy_kernel.Map(context, registration->builtin_code, + registration->version, node)) { + supported_nodes.push_back(node_index); + } + total_supported_nodes += 1; + } + // Put the size at the beginning of the array. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API subgraphs) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .builtin_code = kTfLiteBuiltinDelegate, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent subgraph a new nnapi_delegate_kernel. + context->ReplaceSubgraphsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), + delegate); + return kTfLiteOk; + }}; + + return &delegate; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..44cca2fd285370d700525f98ba33c861fb97be1e --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Return a delegate that can be used to use the NN API. +// e.g. +// NnApiDelegate* delegate = NnApiDelegate(); +// interpreter->ModifyGraphWithDelegate(&delegate); +// NnApiDelegate() returns a singleton, so you should not free this +// pointer or worry about its lifetime. +TfLiteDelegate* NnApiDelegate(); +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aad10c9ce730a2e90481a123a1e3e323cfb2bd42 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -0,0 +1,676 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +// TODO(b/110368244): figure out how to share the existing tests in kernels/ but +// with the delegation on. Also, add more unit tests to improve code coverage. + +class FloatAddOpModel : public SingleOpModel { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +// Do a test with the NN API using no activation. +TEST(NNAPIDelegate, AddWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Do a test with the NN api with relu. +TEST(NNAPIDelegate, AddWithRelu) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); +} + +class FloatMulOpModel : public SingleOpModel { + public: + FloatMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, MulWithNoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +class FloatPoolingOpModel : public SingleOpModel { + public: + FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, AveragePoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(NNAPIDelegate, MaxPoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(NNAPIDelegate, L2PoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +class BaseConvolutionOpModel : public SingleOpModel { + public: + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE, + int dilation_width_factor = 1, int dilation_height_factor = 1) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions( + builder_, padding, stride_width, stride_height, activation, + dilation_width_factor, dilation_height_factor) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(NNAPIDelegate, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(NNAPIDelegate, Conv2DWithNoActivation) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +class DepthwiseConvolutionOpModel : public SingleOpModel { + public: + DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class FloatFullyConnectedOpModel : public SingleOpModel { + public: + FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, FullyConnectedSimpleTest) { + FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(NNAPIDelegate, SoftmaxSimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + new_shape_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape, {static_cast(new_shape.size())}}); + PopulateTensor(new_shape_, new_shape); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, ReshapeSimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +class SqueezeOpModel : public SingleOpModel { + public: + SqueezeOpModel(const TensorData& input, const TensorData& output, + std::initializer_list axis) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions, + CreateSqueezeOptions(builder_, builder_.CreateVector(axis)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, SqueezeSimpleTest) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}}, + {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(NNAPIDelegate, SqueezeWithAxisTest) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}}, + {2}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 436c3e1d4cad5e6ee355d7e9cf8ee7da1a8385ce..840015a7fad173dbd2ea353786871dd4e89abb98 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,9 +30,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 57000072561303e8457f61b1ebe95d382fc01f10..4d2437e7d3714e1b8b427b0c6197b295c0355b07 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -24,28 +26,29 @@ cc_library( android_binary( name = "tflite_demo", srcs = glob([ - "src/**/*.java", + "app/src/main/java/**/*.java", ]), # Package assets from assets dir as well as all model targets. # Remove undesired models (and corresponding Activities in source) # to reduce APK size. assets = [ - "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", - "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", - "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", - "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + "@tflite_mobilenet_ssd_quant//:detect.tflite", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt", ], assets_dir = "", custom_package = "org.tensorflow.lite.demo", inline_constants = 1, - manifest = "AndroidManifest.xml", + manifest = "app/src/main/AndroidManifest.xml", nocompress_extensions = [ ".tflite", ], - resource_files = glob(["res/**"]), + resource_files = glob(["app/src/main/res/**"]), tags = [ "manual", "notap", @@ -55,31 +58,3 @@ android_binary( "//tensorflow/contrib/lite/java:tensorflowlite", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - "gradleBuild/**", - "libs/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -filegroup( - name = "java_files", - srcs = glob(["src/**/*.java"]), -) - -filegroup( - name = "resource_files", - srcs = glob(["res/**"]), -) - -exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/android.iml b/tensorflow/contrib/lite/examples/android/android.iml new file mode 100644 index 0000000000000000000000000000000000000000..f0a5ac2bf4cdfb7c98f5704310fbf2f16e9065a2 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/android.iml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..1ffb9dd377730bb3dc872cbf1548fa29ffaa0949 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/build.gradle @@ -0,0 +1,60 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion '26.0.2' + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +// import DownloadModels task +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +apply from: "download-models.gradle" + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle new file mode 100644 index 0000000000000000000000000000000000000000..c100e37c16f38a65f7b1f64a3f6e3eaa1477e8eb --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle @@ -0,0 +1,74 @@ +/* + * download-models.gradle + * Downloads model files from ${MODEL_URL} into application's asset folder + * Input: + * project.ext.TMP_DIR: absolute path to hold downloaded zip files + * project.ext.ASSET_DIR: absolute path to save unzipped model files + * Output: + * 3 model files will be downloaded into given folder of ext.ASSET_DIR + */ +// hard coded model files +// LINT.IfChange + +def models = ['conv_actions_tflite.zip', + 'mobilenet_ssd_tflite_v1.zip', + 'mobilenet_v1_224_android_quant_2017_11_08.zip', + 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip'] +// LINT.ThenChange(//tensorflow/contrib/lite/examples/android/BUILD) + +// Root URL for model archives +def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'de.undercouch:gradle-download-task:3.2.0' + } +} + +import de.undercouch.gradle.tasks.download.Download +task downloadFile(type: Download){ + for (f in models) { + def modelUrl = MODEL_URL + "/" + f + println "Downloading ${f} from ${modelUrl}" + src modelUrl + } + + dest new File(project.ext.TMP_DIR) + overwrite true +} + +task extractModels(type: Copy) { + for (f in models) { + def localFile = f.split("/")[-1] + from zipTree(project.ext.TMP_DIR + '/' + localFile) + } + + into file(project.ext.ASSET_DIR) + fileMode 0644 + exclude '**/LICENSE' + + def needDownload = false + for (f in models) { + def localFile = f.split("/")[-1] + if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) { + needDownload = true + } + } + + if (needDownload) { + dependsOn downloadFile + } +} + +tasks.whenTaskAdded { task -> + if (task.name == 'assembleDebug') { + task.dependsOn 'extractModels' + } + if (task.name == 'assembleRelease') { + task.dependsOn 'extractModels' + } +} + diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/AndroidManifest.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/BUILD rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/box_priors.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java similarity index 96% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java index de997e454a1e33254cb7c2c932ca79d0072539fa..87160f6b3fb8c0d24e5df131d9becbb3eb6e2980 100644 --- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java @@ -1,5 +1,5 @@ /* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,10 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable // Configuration values for the prepackaged SSD model. private static final int TF_OD_API_INPUT_SIZE = 300; - private static final String TF_OD_API_MODEL_FILE = "mobilenet_ssd.tflite"; + private static final boolean TF_OD_API_IS_QUANTIZED = true; + private static final String TF_OD_API_MODEL_FILE = "detect.tflite"; private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt"; - + // Which detection model to use: by default uses Tensorflow Object Detection API frozen // checkpoints. private enum DetectorMode { @@ -107,7 +108,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable try { detector = TFLiteObjectDetectionAPIModel.create( - getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); + getAssets(), + TF_OD_API_MODEL_FILE, + TF_OD_API_LABELS_FILE, + TF_OD_API_INPUT_SIZE, + TF_OD_API_IS_QUANTIZED); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java similarity index 50% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java index bfb4a0a04bc90566736864bf62340d1032961858..9eb21de9d03e387d3c25b38171e154a358dc81ce 100644 --- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java @@ -25,15 +25,14 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.PriorityQueue; -import java.util.StringTokenizer; import java.util.Vector; import org.tensorflow.demo.env.Logger; import org.tensorflow.lite.Interpreter; @@ -46,32 +45,35 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private static final Logger LOGGER = new Logger(); // Only return this many results. - private static final int NUM_RESULTS = 1917; - private static final int NUM_CLASSES = 91; - - private static final float Y_SCALE = 10.0f; - private static final float X_SCALE = 10.0f; - private static final float H_SCALE = 5.0f; - private static final float W_SCALE = 5.0f; - + private static final int NUM_DETECTIONS = 10; + private boolean isModelQuantized; + // Float model + private static final float IMAGE_MEAN = 128.0f; + private static final float IMAGE_STD = 128.0f; + // Number of threads in the java app + private static final int NUM_THREADS = 4; // Config values. private int inputSize; - - private final float[][] boxPriors = new float[4][NUM_RESULTS]; - // Pre-allocated buffers. private Vector labels = new Vector(); private int[] intValues; + // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4] + // contains the location of detected boxes private float[][][] outputLocations; - private float[][][] outputClasses; - - float[][][][] img; + // outputClasses: array of shape [Batchsize, NUM_DETECTIONS] + // contains the classes of detected boxes + private float[][] outputClasses; + // outputScores: array of shape [Batchsize, NUM_DETECTIONS] + // contains the scores of detected boxes + private float[][] outputScores; + // numDetections: array of shape [Batchsize] + // contains the number of detected boxes + private float[] numDetections; + + private ByteBuffer imgData; private Interpreter tfLite; - private float expit(final float x) { - return (float) (1. / (1. + Math.exp(-x))); - } /** Memory-map the model file in Assets. */ private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) @@ -84,77 +86,24 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } - private void loadCoderOptions( - final AssetManager assetManager, final String locationFilename, final float[][] boxPriors) - throws IOException { - // Try to be intelligent about opening from assets or sdcard depending on prefix. - final String assetPrefix = "file:///android_asset/"; - InputStream is; - if (locationFilename.startsWith(assetPrefix)) { - is = assetManager.open(locationFilename.split(assetPrefix, -1)[1]); - } else { - is = new FileInputStream(locationFilename); - } - - final BufferedReader reader = new BufferedReader(new InputStreamReader(is)); - - for (int lineNum = 0; lineNum < 4; ++lineNum) { - String line = reader.readLine(); - final StringTokenizer st = new StringTokenizer(line, ", "); - int priorIndex = 0; - while (st.hasMoreTokens()) { - final String token = st.nextToken(); - try { - final float number = Float.parseFloat(token); - boxPriors[lineNum][priorIndex++] = number; - } catch (final NumberFormatException e) { - // Silently ignore. - } - } - if (priorIndex != NUM_RESULTS) { - throw new RuntimeException( - "BoxPrior length mismatch: " + priorIndex + " vs " + NUM_RESULTS); - } - } - - LOGGER.i("Loaded box priors!"); - } - - void decodeCenterSizeBoxes(float[][][] predictions) { - for (int i = 0; i < NUM_RESULTS; ++i) { - float ycenter = predictions[0][i][0] / Y_SCALE * boxPriors[2][i] + boxPriors[0][i]; - float xcenter = predictions[0][i][1] / X_SCALE * boxPriors[3][i] + boxPriors[1][i]; - float h = (float) Math.exp(predictions[0][i][2] / H_SCALE) * boxPriors[2][i]; - float w = (float) Math.exp(predictions[0][i][3] / W_SCALE) * boxPriors[3][i]; - - float ymin = ycenter - h / 2.f; - float xmin = xcenter - w / 2.f; - float ymax = ycenter + h / 2.f; - float xmax = xcenter + w / 2.f; - - predictions[0][i][0] = ymin; - predictions[0][i][1] = xmin; - predictions[0][i][2] = ymax; - predictions[0][i][3] = xmax; - } - } - /** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. + * @param inputSize The size of image input + * @param isQuantized Boolean representing model is quantized or not */ public static Classifier create( final AssetManager assetManager, final String modelFilename, final String labelFilename, - final int inputSize) throws IOException { + final int inputSize, + final boolean isQuantized) + throws IOException { final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel(); - d.loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", d.boxPriors); - InputStream labelsInput = null; String actualFilename = labelFilename.split("file:///android_asset/")[1]; labelsInput = assetManager.open(actualFilename); @@ -175,12 +124,23 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { throw new RuntimeException(e); } + d.isModelQuantized = isQuantized; // Pre-allocate buffers. - d.img = new float[1][inputSize][inputSize][3]; - + int numBytesPerChannel; + if (isQuantized) { + numBytesPerChannel = 1; // Quantized + } else { + numBytesPerChannel = 4; // Floating point + } + d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); + d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; - d.outputLocations = new float[1][NUM_RESULTS][4]; - d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; + + d.tfLite.setNumThreads(NUM_THREADS); + d.outputLocations = new float[1][NUM_DETECTIONS][4]; + d.outputClasses = new float[1][NUM_DETECTIONS]; + d.outputScores = new float[1][NUM_DETECTIONS]; + d.numDetections = new float[1]; return d; } @@ -196,25 +156,37 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + imgData.rewind(); for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { - int pixel = intValues[j * inputSize + i]; - img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f; - img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f; - img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f; + int pixelValue = intValues[i * inputSize + j]; + if (isModelQuantized) { + // Quantized model + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); + } else { // Float model + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + } } } Trace.endSection(); // preprocessBitmap // Copy the input data into TensorFlow. Trace.beginSection("feed"); - outputLocations = new float[1][NUM_RESULTS][4]; - outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; + outputLocations = new float[1][NUM_DETECTIONS][4]; + outputClasses = new float[1][NUM_DETECTIONS]; + outputScores = new float[1][NUM_DETECTIONS]; + numDetections = new float[1]; - Object[] inputArray = {img}; + Object[] inputArray = {imgData}; Map outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); + outputMap.put(2, outputScores); + outputMap.put(3, numDetections); Trace.endSection(); // Run the inference call. @@ -222,56 +194,26 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { tfLite.runForMultipleInputsOutputs(inputArray, outputMap); Trace.endSection(); - decodeCenterSizeBoxes(outputLocations); - - // Find the best detections. - final PriorityQueue pq = - new PriorityQueue( - 1, - new Comparator() { - @Override - public int compare(final Recognition lhs, final Recognition rhs) { - // Intentionally reversed to put high confidence at the head of the queue. - return Float.compare(rhs.getConfidence(), lhs.getConfidence()); - } - }); - - // Scale them back to the input size. - for (int i = 0; i < NUM_RESULTS; ++i) { - float topClassScore = -1000f; - int topClassScoreIndex = -1; - - // Skip the first catch-all class. - for (int j = 1; j < NUM_CLASSES; ++j) { - float score = expit(outputClasses[0][i][j]); - - if (score > topClassScore) { - topClassScoreIndex = j; - topClassScore = score; - } - } - - if (topClassScore > 0.001f) { - final RectF detection = - new RectF( - outputLocations[0][i][1] * inputSize, - outputLocations[0][i][0] * inputSize, - outputLocations[0][i][3] * inputSize, - outputLocations[0][i][2] * inputSize); - - pq.add( - new Recognition( - "" + i, - labels.get(topClassScoreIndex), - outputClasses[0][i][topClassScoreIndex], - detection)); - } - } - - final ArrayList recognitions = new ArrayList(); - for (int i = 0; i < Math.min(pq.size(), 10); ++i) { - Recognition recog = pq.poll(); - recognitions.add(recog); + // Show the best detections. + // after scaling them back to the input size. + final ArrayList recognitions = new ArrayList<>(NUM_DETECTIONS); + for (int i = 0; i < NUM_DETECTIONS; ++i) { + final RectF detection = + new RectF( + outputLocations[0][i][1] * inputSize, + outputLocations[0][i][0] * inputSize, + outputLocations[0][i][3] * inputSize, + outputLocations[0][i][2] * inputSize); + // SSD Mobilenet V1 Model assumes class 0 is background class + // in label file and class labels start from 1 to number_of_classes+1, + // while outputClasses correspond to class index from 0 to number_of_classes + int labelOffset = 1; + recognitions.add( + new Recognition( + "" + i, + labels.get((int) outputClasses[0][i] + labelOffset), + outputScores[0][i], + detection)); } Trace.endSection(); // "recognizeImage" return recognitions; diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable/border.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/attrs.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/attrs.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/base-strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/base-strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle index 0d4de358156a5d139e35cc542b8d36ab24e763b9..a47fa4bbf6730c7d1269737564381c8464224713 100644 --- a/tensorflow/contrib/lite/examples/android/build.gradle +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -1,52 +1,23 @@ -apply plugin: 'com.android.application' +// Top-level build file where you can add configuration options common to all sub-projects/modules. -android { - compileSdkVersion 26 - buildToolsVersion "26.0.1" - defaultConfig { - applicationId "org.tensorflow.lite.demo" - minSdkVersion 15 - targetSdkVersion 26 - versionCode 1 - versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - - // Remove this block. - jackOptions { - enabled true - } - } - lintOptions { - abortOnError false - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' - } - } - aaptOptions { - noCompress "tflite" +buildscript { + repositories { + jcenter() } + dependencies { + classpath 'com.android.tools.build:gradle:3.0.1' - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files } } -repositories { - maven { - url 'https://google.bintray.com/tensorflow' +allprojects { + repositories { + jcenter() } } -dependencies { - compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { - exclude group: 'com.android.support', module: 'support-annotations' - }) - compile 'org.tensorflow:tensorflow-lite:+' - - testCompile 'junit:junit:4.12' +task clean(type: Delete) { + delete rootProject.buildDir } diff --git a/tensorflow/contrib/lite/examples/android/settings.gradle b/tensorflow/contrib/lite/examples/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index 59b575ab6eec50b90768a26f8b2075a70c98bfbb..d74e275f0439b1ce56b29e0eadff5f211f6a4faa 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -25,8 +25,8 @@ #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" #define LOG(x) std::cerr diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 32da7f7e4fce5cafc3c4746e5847315172542fc9..0ab7aa25d0b4e6d2c02e61ec1d82b85258b3dfbc 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -24,8 +24,8 @@ #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" #include "ios_image_load.h" diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 9322e186a280e932a2441ab16ac8579d9ab67ee2..c61445114ecc6dfbe4f2b6ab666b28a8aa746be3 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -53,19 +53,18 @@ cc_library( ], ) -# TODO(ahentz): Test disabled as it has a memory leek from read_bmp -# cc_test( -# name = "label_image_test", -# srcs = [ -# "get_top_n.h", -# "get_top_n_impl.h", -# "label_image_test.cc", -# ], -# data = [ -# "testdata/grace_hopper.bmp", -# ], -# deps = [ -# ":bitmap_helpers", -# "//testing/base/public:gunit", -# ], -# ) +cc_test( + name = "label_image_test", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image_test.cc", + ], + data = [ + "testdata/grace_hopper.bmp", + ], + deps = [ + ":bitmap_helpers", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc index 0b38cd38c83927c65d251b9356301b6bef7521f2..2735d1f5ea4e2a104f71a3a6f874d9acb2f48142 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -28,8 +28,9 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, - int width, int height, int channels, bool top_down) { +std::vector decode_bmp(const uint8_t* input, int row_size, int width, + int height, int channels, bool top_down) { + std::vector output(height * width * channels); for (int i = 0; i < height; i++) { int src_pos; int dst_pos; @@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, } } } - return output; } -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s) { +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s) { int begin, end; std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); @@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, if (s->verbose) LOG(INFO) << "len: " << len << "\n"; - const uint8_t* img_bytes = new uint8_t[len]; + std::vector img_bytes(len); file.seekg(0, std::ios::beg); - file.read((char*)img_bytes, len); + file.read(reinterpret_cast(img_bytes.data()), len); const int32_t header_size = - *(reinterpret_cast(img_bytes + 10)); - *width = *(reinterpret_cast(img_bytes + 18)); - *height = *(reinterpret_cast(img_bytes + 22)); - const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *(reinterpret_cast(img_bytes.data() + 10)); + *width = *(reinterpret_cast(img_bytes.data() + 18)); + *height = *(reinterpret_cast(img_bytes.data() + 22)); + const int32_t bpp = + *(reinterpret_cast(img_bytes.data() + 28)); *channels = bpp / 8; if (s->verbose) @@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, bool top_down = (*height < 0); // Decode image, allocating tensor once the image size is known - uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; const uint8_t* bmp_pixels = &img_bytes[header_size]; - return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), - *channels, top_down); + return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels, + top_down); } } // namespace label_image diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 97343dde6b31694e5b2de20b35a7083fb8fe4a0e..5fc75b1f7274c14d49e4a26d6ce4902c037afa6b 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -22,8 +22,8 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s); +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s); template void resize(T* out, uint8_t* in, int image_height, int image_width, diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 966fcd2a31fd4d4ff2c3e91633550a8effa81ee8..86d7d1cc4a625243791d5e7d5b746526a58efb6d 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -138,8 +138,8 @@ void RunInference(Settings* s) { int image_width = 224; int image_height = 224; int image_channels = 3; - uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, - &image_channels, s); + std::vector in = read_bmp(s->input_bmp_name, &image_width, + &image_height, &image_channels, s); int input = interpreter->inputs()[0]; if (s->verbose) LOG(INFO) << "input: " << input << "\n"; @@ -168,12 +168,12 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - resize(interpreter->typed_tensor(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); + resize(interpreter->typed_tensor(input), in.data(), + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; case kTfLiteUInt8: - resize(interpreter->typed_tensor(input), in, + resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, wanted_width, wanted_channels, s); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc index ce35483f76e8f40ced79e1ee30774c62d0eba94e..de7de21f7741d3d46cb96e793e8bc4bfb21384fe 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -27,20 +27,20 @@ namespace label_image { TEST(LabelImageTest, GraceHopper) { std::string lena_file = - "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + "tensorflow/contrib/lite/examples/label_image/testdata/" + "grace_hopper.bmp"; int height, width, channels; Settings s; - uint8_t *data; - - data = read_bmp(lena_file, &width, &height, &channels, &s); + std::vector input = + read_bmp(lena_file, &width, &height, &channels, &s); ASSERT_EQ(height, 606); ASSERT_EQ(width, 517); ASSERT_EQ(channels, 3); - uint8_t *out = new uint8_t[606 * 517 * 3]; - downsize(out, data, 606, 517, 3, 214, 214, 3, &s); - ASSERT_EQ(out[0], 0x15); - ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); + std::vector output(606 * 517 * 3); + resize(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(output[0], 0x15); + ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11); } TEST(LabelImageTest, GetTopN) { diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/contrib/lite/examples/minimal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b403628d6c457ce3fb67eac3675fd7bb9187deab --- /dev/null +++ b/tensorflow/contrib/lite/examples/minimal/BUILD @@ -0,0 +1,27 @@ +# Description: +# TensorFlow Lite minimal example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") + +tf_cc_binary( + name = "minimal", + srcs = [ + "minimal.cc", + ], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc index 106e3b027055b67092f653c6bcdc4827b56bdbaa..8b65cde7b79fde19280ad778ea874c64b01d169a 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/model.h" +#include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" -#include +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" // This is an example that is minimal to read a model // from disk and perform inference. There is no data being loaded @@ -29,23 +30,22 @@ limitations under the License. using namespace tflite; -#define TFLITE_MINIMAL_CHECK(x) \ - if(!(x)) { \ - fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ - exit(1); \ +#define TFLITE_MINIMAL_CHECK(x) \ + if (!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ } - -int main(int argc, char *argv[]) { +int main(int argc, char* argv[]) { if(argc != 2) { - fprintf(stderr, "Usage: %s \n"); + fprintf(stderr, "minimal \n"); return 1; } const char* filename = argv[1]; // Load model - std::unique_ptr model - = tflite::FlatBufferModel::BuildFromFile(filename); + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(filename); TFLITE_MINIMAL_CHECK(model != nullptr); // Build the interpreter @@ -57,12 +57,16 @@ int main(int argc, char *argv[]) { // Allocate tensor buffers. TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + printf("=== Pre-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Fill input buffers // TODO(user): Insert code to fill input tensors // Run inference TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + printf("\n\n=== Post-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Read output buffers // TODO(user): Insert getting data out code. diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index 50cc146a87ee9ab94aea6a92fb2fb5c531f83369..a591a353dd8f0ac94ecaa3f12e1aa1c57566ef69 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -7,6 +7,9 @@ no surprise that the APIs try to avoid unnecessary copies at the expense of convenience. Similarly, consistency with TensorFlow APIs was not an explicit goal and some variance is to be expected. +There is also a Python API for TensorFlow Lite described +[here](../toco/g3doc/python_api.md#interpreter). + ## C++ In order to run the inference model in TensorFlow Lite, one has to load the diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md new file mode 100644 index 0000000000000000000000000000000000000000..bd2f797e6c5b05f52bec9fc34f1b8011aca70330 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -0,0 +1,206 @@ +# TensorFlow Lite Ops Versioning + +This document describes TensorFlow Lite's op versioning schema. Op +versioning enables developers to add new functionalities and parameters into +existing ops. In addition, it guarantees the following: + +* Backward compatibility: New TensorFlow Lite implementation should + handle an old model file. +* Forward compatibility: Old TensorFlow Lite implementation should + handle a new model file produced by new version of TOCO, as long as no new + features are used. +* Forward in-compatibility detection: If an old TensorFlow Lite implementation + reads a new model that contains a new version of an op which isn't + supported, it should report the error. + +## Example: Adding Dilation into Convolution + +The remainder of this document explains op versioning in TFLite by showing how +to add dilation parameters to the convolution operation. + +Knowledge of dilation is not required to understand this document. Note that: + +* 2 new integer parameters will be added: `dilation_width_factor` and + `dilation_height_factor`. +* Old convolution kernels that don't support dilation are equivalent to + setting the dilation factors to 1. + +### Change FlatBuffer Schema + +To add new parameters into an op, change the options table in +`lite/schema/schema.fbs`. + +For example, the options table of convolution looks like this: + +``` +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} +``` + +When adding new parameters: + +* Add comments indicating which parameters are supported by which version. +* When the new implementation gets the default values for newly added + parameters, it should work exactly the same as the old implementation. + +The table will be like this after the new parameters are added: + +``` +table Conv2DOptions { + // Parameters supported by version 1: + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + + // Parameters supported by version 2: + dilation_width_factor:int = 1; + dilation_height_factor:int = 1; +} +``` + +### Change C Structures and Kernel Implementation + +In TensorFlow Lite, the kernel implementation is decoupled from +FlatBuffer definition. The kernels read the parameter from C structures defined +in `lite/builtin_op_data.h`. + +The original convolution parameter is as follows: + +``` +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; +``` + +As with the FlatBuffer schema, add comments indicating which parameters are +supported starting from which version. The result is seen below: + +``` +typedef struct { + // Parameters supported by version 1: TfLitePadding padding; int + stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters supported by version 2: + int dilation_width_factor; + int dilation_height_factor; +} TfLiteConvParams; +``` + +Please also change the kernel implementation to read the newly added parameters +from the C structures. The details are omitted here. + +### Change the FlatBuffer Reading Code + +The logic to read FlatBuffer and produce C structure is in `lite/model.cc`. + +Update the file to handle the new parameters, as shown below: + +``` +case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_width_factor(); + params->dilation_height_factor = conv_params->dilation_height_factor(); + } + *builtin_data = reinterpret_cast(params); + break; +} +``` + +It's not required to check the op version here. When the new implementation +reads an old model file where dilation factors are missing, it will use 1 as +the default value, and the new kernel will work consistently with the old +kernel. + +### Change Kernel Registration + +The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions +to register op kernels. The minimum and maximum version are 1 by default: + +``` +void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +``` + +The built-in ops are registered in `lite/kernels/register.cc`. In this example, +we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we +need to change this line: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); +``` + +to: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2); +``` + +### Change TOCO TFLite exporter + +The last step is to make TOCO populate the minimum version that's required to +execute the op. In this example, it means: + +* Populate version=1 when dilation factors are all 1. +* Populate version=2 otherwise. + +To do this, you need to override `GetVersion` function for the operator class in +`lite/toco/tflite/operator.cc`. + +For ops with only one version, the `GetVersion` function is defined as: + +``` +int GetVersion(const Operator& op) const override { return 1; } +``` + +When supporting multiple versions, check the parameters and determine the +version for the op, as shown in the following example: + +``` +int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; +} +``` + +### Delegation Implementation + +TensorFlow Lite provides a delegation API which enables delegating ops to +hardware backends. In Delegate's `Prepare` function, check if the version +is supported for every node in Delegation code. + +``` +const int kMinVersion = 1; +TfLiteNode* node; +TfLiteRegistration; +context->GetNodeAndRegistration(context, node_index, &node, ®istration); + +if (registration->version > kMinVersion) { + // Reject the node if the version isn't supported. +} +``` + +This is required even if the delegation only supports version 1 ops, so the +delegation can detect incompatibility when getting a higher version op. + diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 244919bc87deaaab93cdbc53cb049fa5542cbb51..dcd17bbeabda08eaf86f8d5ac7f26cea0d3719a3 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph: * [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide) * [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars) -* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater) -* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal) * [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity) -* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less) -* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal) * [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum) * [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum) * [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply) @@ -132,7 +128,6 @@ TensorFlow operation not listed above are likely unsupported. Notably, the following common ops are not supported at the moment: * [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space) -* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) * [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear) * [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh) @@ -258,6 +253,19 @@ Options { } ``` +**EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + equal to the corresponding element of the second tensor. +} +``` + **EXP** ``` @@ -297,6 +305,19 @@ Options { } ``` +**GATHER** + +``` +Inputs { + 0: params tensor + 1: indices tensor + 2: axis tensor (optional) +} +Outputs { + 0: a tensor with same type as the params tensor. +} +``` + **GREATER** ``` @@ -408,6 +429,17 @@ Outputs { } ``` +**LOG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a tensor equivalent to log(input) +} +``` + **LOG_SOFTMAX** ``` @@ -491,6 +523,19 @@ Options { } ``` +**NOT_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is not + equal to the corresponding element of the second tensor. +} +``` + **RELU** ``` @@ -539,6 +584,31 @@ Options { } ``` +**RSQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise reciprocal square root of the input tensor +} +``` + +**SHAPE** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a 1D tensor representing the shape of the input tensor +} +Options { + out_type: the output type of the op (int32 or int64). Defaults to int32. +} +``` + **SLICE** ``` @@ -595,6 +665,21 @@ Outputs { } ``` +**SPARSE_TO_DENSE** + +``` +Inputs { + 0: 0D or 1D or 2D tensor + 1: 1D tensor + 2: 0D or 1D tensor + 3: 0D tensor + 4: a boolean value +} +Outputs { + 0: Dense Tensor of shape output_shape. Has the same type as sparse_values. +} +``` + **SPLIT** ``` @@ -610,6 +695,17 @@ Options { } ``` +**SQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise square root of the input tensor +} +``` + **SQUEEZE** ``` @@ -682,6 +778,18 @@ Outputs { } ``` +**POW** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: elementwise pow of the input tensors +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 313af5fb7574b42bcdd53b4baad06e4ccfb34053..77268d7aebe9ebfb33b9f35b319d34e6de8324ee 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -46,6 +46,9 @@ class GraphInfo { // Returns the indices of the output tensors. virtual const std::vector& outputs() const = 0; + + // Returns the indices of the variable tensors. + virtual const std::vector& variables() const = 0; }; // Represents a subgraph of a TensorFlow Lite graph. diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc index ea38b43993fef71c6820c7a978351d92d5420287..89a8f36b416b5dec54c1e374cdcdae3ab9ab0cde 100644 --- a/tensorflow/contrib/lite/graph_info_test.cc +++ b/tensorflow/contrib/lite/graph_info_test.cc @@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; } const std::vector& inputs() const override { return inputs_; } const std::vector& outputs() const override { return outputs_; } + const std::vector& variables() const override { return variables_; } void AddNode(const std::vector& inputs, const std::vector& outputs) { @@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo { std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // Partition a graph to generate a list of subgraphs. This wraps the API call diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index ebb0aedc2001a86b7fcff67ef8703b5e4a845818..521216a4f1e84582731a1782f74ce981106f636b 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -22,17 +22,21 @@ limitations under the License. #include "tensorflow/contrib/lite/arena_planner.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" -#include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" +#ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" +#endif #include "tensorflow/contrib/lite/profiling/profiler.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/util.h" namespace tflite { +#ifdef TFLITE_MCU +class NNAPIDelegate {}; +#endif namespace { @@ -53,6 +57,19 @@ void SetForbiddenContextFunction(FunctionType* func) { *func = reinterpret_cast(ForbiddenContextFunction); } +// Returns true if at least one tensor in the given list is kTfLiteDynamic. +template +bool HasDynamicTensorImpl(const TfLiteContext& context, + const TensorIntArray& int_array) { + for (int i : int_array) { + const TfLiteTensor& tensor = context.tensors[i]; + if (tensor.allocation_type == kTfLiteDynamic) { + return true; + } + } + return false; +} + } // namespace // A trivial implementation of GraphInfo around the Interpreter. @@ -82,6 +99,9 @@ class InterpreterInfo : public GraphInfo { const std::vector& outputs() const override { return interpreter_->outputs(); } + const std::vector& variables() const override { + return interpreter_->variables(); + } public: Interpreter* interpreter_; @@ -96,9 +116,9 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.AddTensors = AddTensors; context_.tensors = nullptr; context_.tensors_size = 0; - context_.eigen_context = nullptr; - context_.gemm_context = nullptr; context_.recommended_num_threads = -1; + context_.GetExternalContext = GetExternalContext; + context_.SetExternalContext = SetExternalContext; // Invalid to call these these except from TfLiteDelegate SetForbiddenContextFunction(&context_.GetNodeAndRegistration); @@ -109,6 +129,11 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) tensors_.reserve(kTensorsReservedCapacity); nodes_and_registration_.reserve(kTensorsReservedCapacity); next_execution_plan_index_to_prepare_ = 0; + + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + external_contexts_[i] = nullptr; + } + UseNNAPI(false); } @@ -266,6 +291,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( return kTfLiteOk; } +TfLiteExternalContext* Interpreter::GetExternalContext( + TfLiteExternalContextType type) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + return external_contexts_[type]; + } + return nullptr; +} + +TfLiteExternalContext* Interpreter::GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type) { + return static_cast(context->impl_)->GetExternalContext(type); +} + +void Interpreter::SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + external_contexts_[type] = ctx; + } +} + +void Interpreter::SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + return static_cast(context->impl_) + ->SetExternalContext(type, ctx); +} + // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns // this memory and it is only guaranteed to exist during the invocation of the // delegate prepare. @@ -302,6 +354,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { return kTfLiteOk; } +TfLiteStatus Interpreter::SetVariables(std::vector variables) { + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(), + variables.size())); + variables_ = std::move(variables); + return kTfLiteOk; +} + TfLiteStatus Interpreter::CheckTensorIndices(const char* label, const int* indices, int length) { // Making sure kOptionalTensor is not re-defined to something other than -1. @@ -334,6 +393,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteFloat32: *bytes = sizeof(float) * count; break; + case kTfLiteInt16: + *bytes = sizeof(int16_t) * count; + break; case kTfLiteInt32: *bytes = sizeof(int32_t) * count; break; @@ -346,32 +408,58 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteBool: *bytes = sizeof(bool) * count; break; + case kTfLiteComplex64: + *bytes = sizeof(std::complex) * count; + break; default: - ReportError( - &context_, - "Only float32, int32, int64, uint8, bool supported currently."); + ReportError(&context_, + "Only float32, int16, int32, int64, uint8, bool, complex64 " + "supported currently."); return kTfLiteError; } return kTfLiteOk; } TfLiteStatus Interpreter::AllocateTensors() { - next_execution_plan_index_to_prepare_ = 0; - if (memory_planner_) { - TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); - } - if (!consistent_) { ReportError(&context_, "AllocateTensors() called on inconsistent model."); return kTfLiteError; } + // Explicit (re)allocation is necessary if nodes have been changed or tensors + // have been resized. For inputs marked as dynamic, we can't short-circuit the + // allocation as the client may have done the resize manually. + if (state_ != kStateUninvokable && !HasDynamicTensorImpl(context_, inputs_)) { + return kTfLiteOk; + } + + next_execution_plan_index_to_prepare_ = 0; + if (memory_planner_) { + TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); + } + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); - if (state_ == kStateUninvokable) { - state_ = kStateInvokable; + + state_ = kStateInvokable; + return kTfLiteOk; +} + +// TODO(ycling): Consider to provide other functions to initialize variable +// tensors to non-zero values. +TfLiteStatus Interpreter::ResetVariableTensorsToZero() { + for (auto& tensor : tensors_) { + if (!tensor.is_variable) { + continue; + } + + // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be + // allocated after the initial `PrepareOpsAndTensors()` is called. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, + kTfLiteArenaRwPersistent); + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + + memset(tensor.data.raw, 0, tensor.bytes); } - TF_LITE_ENSURE(&context_, state_ == kStateInvokable || - state_ == kStateInvokableAndImmutable); return kTfLiteOk; } @@ -445,26 +533,26 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, "ResizeInputTensor is disallowed when graph is immutable."); return kTfLiteError; } - state_ = kStateUninvokable; // TODO(aselle): All bounds checks can be implemented as one-sided bounds // checks by casting to unsigned for efficiency. Profile before doing this. TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); - TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); - return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); + TfLiteTensor* tensor = &context_.tensors[tensor_index]; + + // Short-circuit the state change if the dimensions don't change, avoiding + // unnecessary (re)allocations. + if (EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) { + return kTfLiteOk; + } + + state_ = kStateUninvokable; + return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims)); } -// Returns true if at least one tensor in the given list is kTfLiteDynamic. bool HasDynamicTensor(const TfLiteContext& context, - const TfLiteIntArray* tensors) { - for (int i = 0; i < tensors->size; ++i) { - const TfLiteTensor& tensor = context.tensors[tensors->data[i]]; - if (tensor.allocation_type == kTfLiteDynamic) { - return true; - } - } - return false; + const TfLiteIntArray* int_array) { + return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array}); } TfLiteStatus Interpreter::PrepareOpsStartingAt( @@ -495,7 +583,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( TfLiteStatus Interpreter::PrepareOpsAndTensors() { if (!memory_planner_) { memory_planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr(new InterpreterInfo(this)))); + &context_, std::unique_ptr(new InterpreterInfo(this)), + /*preserve_inputs=*/true)); memory_planner_->PlanAllocations(); } @@ -521,6 +610,7 @@ TfLiteStatus Interpreter::Invoke() { } TfLiteStatus status = kTfLiteOk; +#ifndef TFLITE_MCU if (nnapi_delegate_) { if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); @@ -534,6 +624,7 @@ TfLiteStatus Interpreter::Invoke() { return kTfLiteError; } } +#endif // Invocations are always done in node order. // Note that calling Invoke repeatedly will cause the original memory plan to @@ -572,9 +663,17 @@ TfLiteStatus Interpreter::Invoke() { } EnsureTensorsVectorCapacity(); + tensor_resized_since_op_invoke_ = false; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } + + // Force execution prep for downstream ops if the latest op triggered the + // resize of a dynamic tensor. + if (tensor_resized_since_op_invoke_ && + HasDynamicTensor(context_, node.outputs)) { + next_execution_plan_index_to_prepare_ = execution_plan_index + 1; + } } if (!allow_buffer_handle_output_) { @@ -687,7 +786,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, const_cast(buffer), bytes, - kTfLiteMmapRo, allocation, &tensor); + kTfLiteMmapRo, allocation, false, &tensor); } return kTfLiteOk; } @@ -698,7 +797,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization) { + const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { if (state_ == kStateInvokableAndImmutable) { ReportError( &context_, @@ -716,11 +815,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims, rank, &required_bytes)); } + + TfLiteAllocationType allocation_type = kTfLiteArenaRw; + if (type == kTfLiteString) { + if (is_variable) { + // We don't have a real use case for string variable tensor. + ReportError(&context_, "String variable tensor isn't supported."); + return kTfLiteError; + } + allocation_type = kTfLiteDynamic; + } else if (is_variable) { + allocation_type = kTfLiteArenaRwPersistent; + } + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, - /*buffer=*/nullptr, required_bytes, - type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, - nullptr, &context_.tensors[tensor_index]); + /*buffer=*/nullptr, required_bytes, allocation_type, + nullptr, is_variable, &context_.tensors[tensor_index]); return kTfLiteOk; } @@ -736,7 +847,10 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. if (tensor->allocation_type == kTfLiteArenaRw || - tensor->allocation_type == kTfLiteDynamic) { + tensor->allocation_type == kTfLiteDynamic || + tensor->allocation_type == kTfLiteArenaRwPersistent) { + tensor_resized_since_op_invoke_ |= + TfLiteIntArrayEqual(tensor->dims, new_size) == 0; if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, @@ -767,6 +881,7 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, } void Interpreter::UseNNAPI(bool enable) { +#ifndef TFLITE_MCU // TODO(aselle): This is a workaround for finding if NNAPI exists. // We also need to make sure getLibraryHandle() is renamed to be NNAPI // prefixed. @@ -776,15 +891,18 @@ void Interpreter::UseNNAPI(bool enable) { } else if (!nnapi_delegate_) { nnapi_delegate_.reset(new NNAPIDelegate); } +#endif } void Interpreter::SetNumThreads(int num_threads) { context_.recommended_num_threads = num_threads; - // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to - // be required in order to compile the framework. - gemm_support::SetNumThreads(&context_, num_threads); - eigen_support::SetNumThreads(&context_, num_threads); + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + auto* c = external_contexts_[i]; + if (c && c->Refresh) { + c->Refresh(&context_); + } + } } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, @@ -828,9 +946,10 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, TF_LITE_ENSURE_OK(&context_, status); if (!allow_dynamic_tensors) { + // Reset the state to force tensor/op reallocation. + state_ = kStateUninvokable; TF_LITE_ENSURE_OK(&context_, AllocateTensors()); - TF_LITE_ENSURE(&context_, state_ == kStateInvokable || - state_ == kStateInvokableAndImmutable); + TF_LITE_ENSURE_EQ(&context_, state_, kStateInvokable); // After using a delegate which doesn't support dynamic tensors, make the // entire graph immutable. state_ = kStateInvokableAndImmutable; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 0450e86ae7f84e4aa6c70235eb825ca3b4f7aebc..b69c50fbfce131f6862dc6e91387035e3d3bb7d8 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -17,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ #define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#include #include #include #include @@ -39,6 +40,10 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt32; } template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt16; +} +template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt64; } @@ -54,6 +59,10 @@ template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteBool; } +template <> +constexpr TfLiteType typeToTfLiteType>() { + return kTfLiteComplex64; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -118,6 +127,11 @@ class Interpreter { // interpreter. TfLiteStatus SetOutputs(std::vector outputs); + // Provide a list of tensor indexes that are variable tensors. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetVariables(std::vector variables); + // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' @@ -160,13 +174,15 @@ class Interpreter { // to Interpreter. inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization) { + const std::vector& dims, TfLiteQuantizationParams quantization, + bool is_variable = false) { return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), - dims.data(), quantization); + dims.data(), quantization, is_variable); } TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization); + const int* dims, TfLiteQuantizationParams quantization, + bool is_variable = false); // Functions to access tensor data @@ -182,6 +198,9 @@ class Interpreter { // Read only access to list of outputs. const std::vector& outputs() const { return outputs_; } + // Read only access to list of variable tensors. + const std::vector& variables() const { return variables_; } + // Return the name of a given output. The given index must be between 0 and // outputs().size(). const char* GetOutputName(int index) const { @@ -249,13 +268,20 @@ class Interpreter { return nullptr; } - // Return a pointer into the data of a given input tensor. The given index - // must be between 0 and inputs().size(). + // Return a mutable pointer into the data of a given input tensor. The given + // index must be between 0 and inputs().size(). template T* typed_input_tensor(int index) { return typed_tensor(inputs_[index]); } + // Return an immutable pointer into the data of a given input tensor. The + // given index must be between 0 and inputs().size(). + template + const T* typed_input_tensor(int index) const { + return typed_tensor(inputs_[index]); + } + // Return a mutable pointer into the data of a given output tensor. The given // index must be between 0 and outputs().size(). template @@ -372,7 +398,20 @@ class Interpreter { allow_buffer_handle_output_ = allow_buffer_handle_output; } + // Reset all variable tensors to zero. + // WARNING: This is an experimental API and subject to change. + TfLiteStatus ResetVariableTensorsToZero(); + + // Retrieve an operator's description of its work, for profiling purposes. + const char* OpProfilingString(const TfLiteRegistration& op_reg, + const TfLiteNode* node) const { + if (op_reg.profiling_string == nullptr) return nullptr; + return op_reg.profiling_string(&context_, node); + } + private: + friend class InterpreterTest; + // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, @@ -485,6 +524,18 @@ class Interpreter { static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Retrieve an existing external context by type. + TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type); + static TfLiteExternalContext* GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type); + + // Set the value of an external context. + void SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + static void SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra // capacity. Calling this function may invalidate existing pointers to // tensors. After calling this function, adding `kTensorsCapacityHeadroom` @@ -534,6 +585,9 @@ class Interpreter { // interpreter. std::vector outputs_; + // Array of indices representing the tensors that are variable tensors. + std::vector variables_; + // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; @@ -565,8 +619,16 @@ class Interpreter { bool allow_buffer_handle_output_ = false; + // Tracking bit for whether a tensor was resized in the course of an op + // invocation. This is a useful hint to ensure that dynamic tensor outputs + // trigger downstream reallocation after op invocation. + bool tensor_resized_since_op_invoke_ = false; + // Profiler for this interpreter instance. profiling::Profiler* profiler_; + + // List of active external contexts. + TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 453c1ada1cf6263be14a3b170f209e3a30580cc3..4fa97512fca186fce8a2ec6514488b77c6d6511d 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,21 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { + +// InterpreterTest is a friend of Interpreter, so it can access context_. +class InterpreterTest : public ::testing::Test { + protected: + TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; } + + Interpreter interpreter_; +}; + +namespace ops { +namespace builtin { +TfLiteRegistration* Register_PADV2(); +TfLiteRegistration* Register_NEG(); +} // namespace builtin +} // namespace ops namespace { // Make an interpreter that has no tensors and no nodes @@ -106,10 +121,9 @@ TEST(BasicInterpreter, CheckAllocate) { TfLiteType type; size_t size; } cases[] = { - {kTfLiteFloat32, sizeof(float)}, - {kTfLiteInt32, sizeof(int32_t)}, - {kTfLiteUInt8, sizeof(uint8_t)}, - {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteInt16, sizeof(int16_t)}, }; for (auto test : cases) { @@ -134,6 +148,7 @@ TEST(BasicInterpreter, CheckResize) { const int32_t int32s[] = {-3, -4}; const uint8_t uint8s[] = {3, 4}; const int64_t int64s[] = {6, -7}; + const int16_t int16s[] = {8, -9}; struct { TfLiteType type; @@ -144,6 +159,7 @@ TEST(BasicInterpreter, CheckResize) { {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + {kTfLiteInt16, sizeof(int16_t), reinterpret_cast(int16s)}, }; for (auto test : cases) { @@ -179,10 +195,8 @@ TEST(BasicInterpreter, CheckAlignment) { struct { TfLiteType type; } cases[] = { - {kTfLiteFloat32}, - {kTfLiteInt32}, - {kTfLiteUInt8}, - {kTfLiteInt64}, + {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, + {kTfLiteInt64}, {kTfLiteInt16}, }; for (auto test : cases) { @@ -211,7 +225,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; std::vector sizes{2048, 4096, 1023, 2047, 1021, - 2047, 1023, 2046, 1021, 2048}; + 2047, 1023, 2046, 0, 2048}; for (int i = 0; i < sizes.size(); ++i) { interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, quant); @@ -226,31 +240,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); - ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); - ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); - - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw); ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw); + // #7 is the one with the largest pointer. + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); } TEST(BasicInterpreter, BufferAccess) { @@ -286,6 +285,57 @@ TEST(BasicInterpreter, NoOpInterpreter) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, RedundantAllocateTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + const auto data_raw = interpreter.tensor(0)->data.raw; + ASSERT_NE(data_raw, nullptr); + + // A redundant allocation request should have no impact. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->data.raw, data_raw); +} + +TEST(BasicInterpreter, RedundantAllocateTensorsWithDynamicInputs) { + Interpreter interpreter; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0}); + interpreter.SetOutputs({1}); + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + // Configure the input tensor as dynamic. + interpreter.tensor(0)->data.raw = nullptr; + interpreter.tensor(0)->allocation_type = kTfLiteDynamic; + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + + // Reset the output tensor's buffer. + interpreter.tensor(1)->data.raw = nullptr; + + // A redundant allocation request should be honored, as the input tensor + // was marked dynamic. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); +} + TEST(BasicInterpreter, ResizingTensors) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); @@ -314,6 +364,18 @@ TEST(BasicInterpreter, ResizingTensors) { EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 1 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // TODO(ahentz): We shouldn't have to force reallocation, but // ResizeInputTensor doesn't realloc dynamic tensors. Also note that // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. @@ -331,6 +393,37 @@ TEST(BasicInterpreter, ResizingTensors) { tensor->data.f[15] = 0.123f; } +TEST(BasicInterpreter, NoopResizingTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + tensor->data.f[5] = 0.123f; + + // Resizing to the same size should not trigger re-allocation. + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_NE(tensor->data.raw, nullptr); + ASSERT_EQ(tensor->data.f[5], 0.123f); + + // Explicitly allocating should be a no-op, as no resize was performed. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_NE(tensor->data.raw, nullptr); + ASSERT_EQ(tensor->data.f[5], 0.123f); +} + TEST(BasicInterpreter, OneOpInterpreter) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); @@ -603,6 +696,59 @@ TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError); } +TEST(BasicInterpreter, DynamicTensorsResizeDescendants) { + // Assemble a graph with a node that has dynamically sized output (via the + // pad op), followed by a node with a standard element-wise op (negate). + Interpreter interpreter; + interpreter.AddTensors(4); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({3}); + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1}, + quant); + interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant); + + TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2(); + TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG(); + interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Configure [[2,2],[4,4]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 2; + interpreter.typed_tensor(1)[1] = 2; + interpreter.typed_tensor(1)[2] = 2; + interpreter.typed_tensor(1)[3] = 2; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Both the output and intermediate tensor sizes should reflect the output + // from the dynamic pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6); + + // Now configure [[4,4],[6,6]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 4; + interpreter.typed_tensor(1)[1] = 4; + interpreter.typed_tensor(1)[2] = 6; + interpreter.typed_tensor(1)[3] = 6; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Again, the output and intermediate tensor sizes should reflect the *new* + // resize from the latest pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14); +} + TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), @@ -643,6 +789,47 @@ TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); } +struct TestExternalContext : public TfLiteExternalContext { + static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext; + + static TestExternalContext* Get(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kType)); + } + + static void Set(TfLiteContext* context, TestExternalContext* value) { + context->SetExternalContext(context, kType, value); + } + + int num_refreshes = 0; +}; + +TEST_F(InterpreterTest, GetSetResetExternalContexts) { + auto* context = GetInterpreterContext(); + + TestExternalContext external_context; + external_context.Refresh = [](TfLiteContext* context) { + auto* ptr = TestExternalContext::Get(context); + if (ptr != nullptr) { + ++ptr->num_refreshes; + } + return kTfLiteOk; + }; + + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); + + TestExternalContext::Set(context, &external_context); + EXPECT_EQ(TestExternalContext::Get(context), &external_context); + interpreter_.SetNumThreads(4); + interpreter_.SetNumThreads(5); + EXPECT_EQ(external_context.num_refreshes, 2); + + TestExternalContext::Set(context, nullptr); + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); +} + // Test fixture that allows playing with execution plans. It creates a two // node graph that can be executed in either [0,1] order or [1,0] order. // The CopyOp records when it is invoked in the class member run_order_ diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1e579226037fa360e4d5dad25077b8966e1126bc..593af81a18a1e20a41dcc8d9bb3a1d815876e294 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -1,7 +1,9 @@ # Description: # TensorFlow Lite Java API. -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//tensorflow/contrib/lite/java/ovic:__pkg__", +]) licenses(["notice"]) # Apache 2.0 @@ -46,38 +48,6 @@ android_library( ], ) -android_library( - name = "ovicbenchmarkerlib", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - manifest = "AndroidManifest.xml", - visibility = ["//visibility:public"], - deps = [ - ":tensorflowlite", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - -java_library( - name = "ovicbenchmarkerlib_java", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - javacopts = JAVACOPTS, - visibility = ["//visibility:public"], - deps = [ - ":libtensorflowlite_jni.so", - ":tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - java_library( name = "tensorflowlitelib", srcs = glob( @@ -180,24 +150,6 @@ java_test( ], ) -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", - ], - javacopts = JAVACOPTS, - test_class = "org.tensorflow.ovic.OvicClassifierTest", - visibility = ["//visibility:public"], - deps = [ - ":ovicbenchmarkerlib_java", - "@com_google_truth", - "@junit", - ], -) - filegroup( name = "libtensorflowlite_jni", srcs = select({ diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index 4450bc9085555b3416f51bac07ea94a1240e919c..db837cf29edfc0ffe9950ffedc02cca1389b0fdf 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -1,5 +1,7 @@ """Generate zipped aar file including different variants of .so in jni folder.""" +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + def aar_with_jni(name, android_library): # Generate dummy AndroidManifest.xml for dummy apk usage # (dummy apk is generated by _dummy_app_for_so target below) @@ -19,7 +21,7 @@ EOF # Generate dummy apk including .so files and later we extract out # .so files and throw away the apk. - native.android_binary( + android_binary( name = name + "_dummy_app_for_so", manifest = name + "_generated_AndroidManifest.xml", custom_package = "dummy.package.for.so", diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d..e3cea19e1683ac2680521bce66d1328e4b2caf1c 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -1,5 +1,14 @@ # TF Lite Android App +## Building in Android Studio with TensorFlow Lite AAR from JCenter. +The build.gradle is configured to use TensorFlow Lite's nightly build. + +If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is +undefined for type Interpreter), there has likely been a backwards compatible +change to the API. You will need to pull new app code that's compatible with the +nightly build and may need to first wait a few days for our external and internal +code to merge. + ## Building from Source with Bazel 1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index b76eaad8bb91224805d16b3d6f7c3274c9feb90c..49868c5a7566c8c537ac2ae9e0a4acc2c872ecbf 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -5,11 +5,12 @@ android { buildToolsVersion "26.0.1" defaultConfig { applicationId "android.example.com.tflitecamerademo" - minSdkVersion 15 + // Required by Camera2 API. + minSdkVersion 21 targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -43,7 +44,7 @@ repositories { dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'com.android.support:appcompat-v7:25.2.0' @@ -52,7 +53,43 @@ dependencies { compile 'com.android.support:support-annotations:25.3.1' compile 'com.android.support:support-v13:25.2.0' - compile 'org.tensorflow:tensorflow-lite:+' + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' testCompile 'junit:junit:4.12' } + +def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" +def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip" +def targetFolder = "src/main/assets" + +task downloadModel(type: DownloadUrlTask) { + doFirst { + println "Downloading ${modelDownloadUrl}" + } + sourceUrl = "${modelDownloadUrl}" + target = file("${localCache}") +} + +task unzipModel(type: Copy, dependsOn: 'downloadModel') { + doFirst { + println "Unzipping ${localCache}" + } + from zipTree("${localCache}") + into "${targetFolder}" +} + +// Ensure the model file is downloaded and extracted before every build +preBuild.dependsOn unzipModel + +class DownloadUrlTask extends DefaultTask { + @Input + String sourceUrl + + @OutputFile + File target + + @TaskAction + void download() { + ant.get(src: sourceUrl, dest: target) + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index d6fbef9cc938993b283103984307ab51e609dd6e..220d6c2159b56f6349e93132418fa0f6c69d1ab3 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f232b00045cf1df6a31ada80af4cc5885a4c0099 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -0,0 +1,70 @@ +# Description: +# OVIC Benchmarker Java API. + +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +java_test( + name = "OvicClassifierTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.ovic.OvicClassifierTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +java_binary( + name = "ovic_validator", + srcs = ["src/main/java/org/tensorflow/ovic/OvicValidator.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + main_class = "org.tensorflow.ovic.OvicValidator", + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + ], +) + +android_library( + name = "ovicbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 77799b35691813868fb65a2c8b068f41751717db..26349347faebac135ae555e0c5d8219046ab1c29 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -2,7 +2,7 @@ This folder contains building code for track one of the [Low Power ImageNet Recognition Challenge workshop at CVPR 2018.](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018) -## Pre-requesits +## Pre-requisite Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. @@ -37,19 +37,37 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --cxxopt=-Wno-all --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions -Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it as below. +Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it in two ways: -* Move your submission to the testdata folder: +#### Validate using randomly generated images + +You can call the validator binary below to verify that your model fits the format requirements. This often helps you to catch size mismatches (e.g. output should be [1, 1001] instead of [1,1,1,1001]). Let say the submission file is located at `/path/to/my_model.lite`, then call: + +```sh +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:ovic_validator --cxxopt=-Wno-all +bazel-bin/tensorflow/contrib/lite/java/ovic/ovic_validator /path/to/my_model.lite +``` + +Successful validation should print the following message to terminal: + +``` +Successfully validated /path/to/my_model.lite. + +``` + +#### Test that the model produces sensible outcomes -Let say the submission file is located at `/tmp/my_model.lite`, then +You can go a step further to verify that the model produces results as expected. This helps you catch bugs during TOCO conversion (e.g. using the wrong mean and std values). + +* Move your submission to the testdata folder: ```sh -cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ +cp /path/to/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ ``` * Resize the test image to the resolutions that are expected by your submission: @@ -136,3 +154,5 @@ Note: the benchmarking results can be quite different depending on the backgroun | quantized_model.lite | 85 | 74 | | low_res_model.lite | 4.2 | 4.0 | +Since Pixel 2 has excellent support for 8-bit quantized models, we strongly recommend you to check out the [quantization training tutorial](https://www.tensorflow.org/performance/quantization). + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index 47101ff574a797a81c5d993b0863c024885f03a0..a8d751ade26adc358e130138381eab9956f2d848 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + # Sample app for OVIC benchmarking. licenses(["notice"]) # Apache 2.0 @@ -21,8 +23,8 @@ android_binary( resource_files = glob(["res/**"]), tags = ["manual"], deps = [ - "//tensorflow/contrib/lite/java:ovicbenchmarkerlib", "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", "@androidsdk//com.android.support:support-v13-25.2.0", "@androidsdk//com.android.support:support-v4-25.2.0", ], diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle index c5d19bad89a93988a6830a17fe2fb4a60e2fb00f..3f32d62e5c08419c6413fffe09b64356edcac836 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -9,7 +9,7 @@ android { targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -43,7 +43,7 @@ repositories { dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'com.android.support:appcompat-v7:25.2.0' diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java new file mode 100644 index 0000000000000000000000000000000000000000..a504ec74a9d0a124f877a6377cae155f204849a5 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java @@ -0,0 +1,94 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.Random; + +/** Validate a submission model. */ +public class OvicValidator { + private static void printUsage(PrintStream s) { + s.println("Java program that validates a submission model."); + s.println(); + s.println("Usage: ovic_validator "); + s.println(); + s.println("Where:"); + s.println(" is the model in TfLite format;"); + } + + public static void main(String[] args) { + if (args.length != 1) { + printUsage(System.err); + System.exit(1); + } + final String labelPath = + "tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt"; + + final String modelFile = args[0]; + try { + File labelsfile = new File(labelPath); + InputStream labelsInputStream = new FileInputStream(labelsfile); + MappedByteBuffer model = loadModelFile(modelFile); + OvicClassifier classifier = new OvicClassifier(labelsInputStream, model); + ByteBuffer imgData = createByteBufferForClassifier(classifier); + OvicSingleImageResult testResult = classifier.classifyByteBuffer(imgData); + if (testResult.topKClasses.isEmpty()) { + throw new RuntimeException("Failed to return top K predictions."); + } + System.out.printf("Successfully validated %s.%n", modelFile); + } catch (Exception e) { + System.out.println(e.getMessage()); + System.out.printf("Failed to validate %s.%n", modelFile); + } + } + + private static ByteBuffer createByteBufferForClassifier(OvicClassifier classifier) { + if (classifier == null) { + throw new RuntimeException("Cannot create image buffer with the classifier."); + } + int[] inputDims = classifier.getInputDims(); + int imgHeight = inputDims[1]; + int imgWidth = inputDims[2]; + ByteBuffer imgData = ByteBuffer.allocateDirect(imgHeight * imgWidth * 3); + imgData.order(ByteOrder.nativeOrder()); + Random rand = new Random(); + for (int y = 0; y < imgHeight; y++) { + for (int x = 0; x < imgWidth; x++) { + int val = rand.nextInt(); + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + return imgData; + } + + private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException { + File modelfile = new File(modelFilePath); + FileInputStream inputStream = new FileInputStream(modelfile); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = 0L; + long declaredLength = fileChannel.size(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index e84ee7112983ec584308b7cbcd919f119eccbcc9..4e22a68bf2e5e9cdc7783ffd829e124023a05479 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.lite; import java.io.File; +import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; @@ -80,6 +81,29 @@ public final class Interpreter implements AutoCloseable { wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads); } + /** + * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file. + * + *

The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The + * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a + * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. + */ + public Interpreter(@NonNull ByteBuffer byteBuffer) { + wrapper = new NativeInterpreterWrapper(byteBuffer); + } + + /** + * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and specifies the + * number of threads used for inference. + * + *

The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The + * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a + * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. + */ + public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { + wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads); + } + /** * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. * @@ -111,7 +135,8 @@ public final class Interpreter implements AutoCloseable { * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large * input data. When {@link ByteBuffer} is used, its content should remain unchanged until * model inference is done. - * @param output a multidimensional array of output data. + * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive + * types including int, float, long, and byte. */ public void run(@NonNull Object input, @NonNull Object output) { Object[] inputs = {input}; @@ -131,8 +156,9 @@ public final class Interpreter implements AutoCloseable { * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred * way to pass large input data. When {@link ByteBuffer} is used, its content should remain * unchanged until model inference is done. - * @param outputs a map mapping output indices to multidimensional arrays of output data. It only - * needs to keep entries for the outputs to be used. + * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link + * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep + * entries for the outputs to be used. */ public void runForMultipleInputsOutputs( @NonNull Object[] inputs, @NonNull Map outputs) { @@ -215,11 +241,11 @@ public final class Interpreter implements AutoCloseable { } } - public void setNumThreads(int num_threads) { + public void setNumThreads(int numThreads) { if (wrapper == null) { throw new IllegalStateException("The interpreter has already been closed."); } - wrapper.setNumThreads(num_threads); + wrapper.setNumThreads(numThreads); } /** Release resources associated with the {@code Interpreter}. */ @@ -229,5 +255,14 @@ public final class Interpreter implements AutoCloseable { wrapper = null; } + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } + } + NativeInterpreterWrapper wrapper; } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index a43251cad13a4ed0b35367e796948b4b9a9faa67..80de88b6a1cd75b033e116f76f5612ee66e48f03 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -43,21 +43,31 @@ final class NativeInterpreterWrapper implements AutoCloseable { } /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The - * MappedByteBuffer should not be modified after the construction of a {@code - * NativeInterpreterWrapper}. + * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should + * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code + * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct + * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. */ - NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { - this(mappedByteBuffer, /* numThreads= */ -1); + NativeInterpreterWrapper(ByteBuffer byteBuffer) { + this(byteBuffer, /* numThreads= */ -1); } /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer} and specifies - * the number of inference threads. The MappedByteBuffer should not be modified after the - * construction of a {@code NativeInterpreterWrapper}. + * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the + * number of inference threads. The ByteBuffer should not be modified after the construction of a + * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code + * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of + * nativeOrder() that contains the bytes content of a model. */ - NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer, int numThreads) { - modelByteBuffer = mappedByteBuffer; + NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) { + if (buffer == null + || (!(buffer instanceof MappedByteBuffer) + && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) { + throw new IllegalArgumentException( + "Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct " + + "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content."); + } + modelByteBuffer = buffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); @@ -90,9 +100,10 @@ final class NativeInterpreterWrapper implements AutoCloseable { dataTypes[i] = dataType.getNumber(); if (dataType == DataType.BYTEBUFFER) { ByteBuffer buffer = (ByteBuffer) inputs[i]; - if (buffer.order() != ByteOrder.nativeOrder()) { + if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) { throw new IllegalArgumentException( - "Input error: ByteBuffer shoud use ByteOrder.nativeOrder()."); + "Input error: ByteBuffer should be a direct ByteBuffer that uses " + + "ByteOrder.nativeOrder()."); } numsOfBytes[i] = buffer.limit(); sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); @@ -153,8 +164,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { useNNAPI(interpreterHandle, useNNAPI); } - void setNumThreads(int num_threads) { - numThreads(interpreterHandle, num_threads); + void setNumThreads(int numThreads) { + numThreads(interpreterHandle, numThreads); } /** Gets index of an input given its name. */ @@ -300,8 +311,30 @@ final class NativeInterpreterWrapper implements AutoCloseable { return DataType.fromNumber(type).toStringName(); } + /** + * Gets the quantization zero point of an output. + * + * @throws IllegalArgumentExeption if the output index is invalid. + */ + int getOutputQuantizationZeroPoint(int index) { + return getOutputQuantizationZeroPoint(interpreterHandle, index); + } + + /** + * Gets the quantization scale of an output. + * + * @throws IllegalArgumentExeption if the output index is invalid. + */ + float getOutputQuantizationScale(int index) { + return getOutputQuantizationScale(interpreterHandle, index); + } + private static native int getOutputDataType(long interpreterHandle, int outputIdx); + private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx); + + private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx); + private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; @@ -314,7 +347,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { private long inferenceDurationNanoseconds = -1; - private MappedByteBuffer modelByteBuffer; + private ByteBuffer modelByteBuffer; private Map inputsIndexes; @@ -328,13 +361,13 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void useNNAPI(long interpreterHandle, boolean state); - private static native void numThreads(long interpreterHandle, int num_threads); + private static native void numThreads(long interpreterHandle, int numThreads); private static native long createErrorReporter(int size); private static native long createModel(String modelPathOrBuffer, long errorHandle); - private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); + private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle); private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 09e887aae3339e9f114c07d689c0d7b5e2fc384b..b2a3e04c55d86a33307e48571d50a72e0fa461ac 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -15,6 +15,8 @@ limitations under the License. package org.tensorflow.lite; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; /** @@ -29,8 +31,21 @@ final class Tensor { return new Tensor(nativeHandle); } - /** Reads Tensor content into an array. */ + /** + * Copies the contents of the tensor to {@code dst} and returns {@code dst}. + * + * @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}. + * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, + * mismatched data types or shapes). + * @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the + * data in this tensor. + */ T copyTo(T dst) { + if (dst instanceof ByteBuffer) { + ByteBuffer dstByteBuffer = (ByteBuffer) dst; + dstByteBuffer.put(buffer()); + return dst; + } if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { throw new IllegalArgumentException( String.format( @@ -60,6 +75,12 @@ final class Tensor { this.shapeCopy = shape(nativeHandle); } + private ByteBuffer buffer() { + return buffer(nativeHandle).order(ByteOrder.nativeOrder()); + } + + private static native ByteBuffer buffer(long handle); + private static native int dtype(long handle); private static native int[] shape(long handle); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 45f510da1d940a288e2794cb3e08f66451956b64..31f7b58fbc30cab9e6cb813094ea4b2627ba5cba 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -387,7 +387,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( jlong capacity = env->GetDirectBufferCapacity(model_buffer); if (!VerifyModel(buf, capacity)) { throwException(env, kIllegalArgumentException, - "MappedByteBuffer is not a valid flatbuffer model"); + "ByteBuffer is not a valid flatbuffer model"); return 0; } @@ -395,8 +395,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( buf, static_cast(capacity), error_reporter); if (!model) { throwException(env, kIllegalArgumentException, - "MappedByteBuffer does not encode a valid " - "TensorFlowLite model: %s", + "ByteBuffer does not encode a valid model: %s", error_reporter->CachedErrorMessage()); return 0; } @@ -426,7 +425,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( status = interpreter->AllocateTensors(); if (status != kTfLiteOk) { throwException(env, kNullPointerException, - "Internal error: Cannot allocate memory for the interpreter", + "Internal error: Cannot allocate memory for the interpreter:" + " %s", error_reporter->CachedErrorMessage()); return 0; } @@ -561,6 +561,38 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( return static_cast(type); } +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 0; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.zero_point); +} + +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 1.0f; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 1.0f; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.scale); +} + JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index eaa765cb343e9764bd0ef018d636a76f4b8a13e4..128ece49811a112684dac7b36810e920eeeb7351 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -152,6 +152,28 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( JNIEnv* env, jclass clazz, jlong handle, jint output_idx); +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)I + * + * Gets output quantization zero point. + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)F + * + * Gets output quantization scale. + */ +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 005dca0253d2c30d56a15adf6e2b371d43f50945..08b4d042803708830221d5e25fe4463366a4c99a 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -43,31 +43,27 @@ size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, } switch (type) { case kTfLiteFloat32: { - jfloatArray a = static_cast(array); - jfloat* values = env->GetFloatArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + jfloatArray float_array = static_cast(array); + jfloat* float_dst = static_cast(dst); + env->GetFloatArrayRegion(float_array, 0, num_elements, float_dst); return to_copy; } case kTfLiteInt32: { - jintArray a = static_cast(array); - jint* values = env->GetIntArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseIntArrayElements(a, values, JNI_ABORT); + jintArray int_array = static_cast(array); + jint* int_dst = static_cast(dst); + env->GetIntArrayRegion(int_array, 0, num_elements, int_dst); return to_copy; } case kTfLiteInt64: { - jlongArray a = static_cast(array); - jlong* values = env->GetLongArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseLongArrayElements(a, values, JNI_ABORT); + jlongArray long_array = static_cast(array); + jlong* long_dst = static_cast(dst); + env->GetLongArrayRegion(long_array, 0, num_elements, long_dst); return to_copy; } case kTfLiteUInt8: { - jbyteArray a = static_cast(array); - jbyte* values = env->GetByteArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseByteArrayElements(a, values, JNI_ABORT); + jbyteArray byte_array = static_cast(array); + jbyte* byte_dst = static_cast(dst); + env->GetByteArrayRegion(byte_array, 0, num_elements, byte_dst); return to_copy; } default: { @@ -207,6 +203,16 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } } +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + + return env->NewDirectByteBuffer(static_cast(tensor->data.raw), + static_cast(tensor->bytes)); +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index 3a4910dcc3a719fbb9f365dae693423de768349c..9ba95d9ac402662e6de69e3da8a60a6e841f97d6 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -24,8 +24,17 @@ extern "C" { #endif // __cplusplus /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: buffer + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: dtype * Signature: (J)I */ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, @@ -33,8 +42,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jlong handle); /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: shape * Signature: (J)[I */ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, @@ -42,8 +51,8 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jlong handle); /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: readMultiDimensionalArray * Signature: (JLjava/lang/Object;) */ JNIEXPORT void JNICALL diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 210d9437241f117ab281b627a4352fce7d340bcb..e6deadffe2d7a110ff742b05a5bf06fa1bc67de9 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -19,6 +19,8 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; import java.io.File; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Files; @@ -69,6 +71,49 @@ public final class InterpreterTest { fileChannel.close(); } + @Test + public void testRunWithDirectByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) fileChannel.size()); + byteBuffer.order(ByteOrder.nativeOrder()); + fileChannel.read(byteBuffer); + Interpreter interpreter = new Interpreter(byteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRunWithInvalidByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + ByteBuffer byteBuffer = ByteBuffer.allocate((int) fileChannel.size()); + byteBuffer.order(ByteOrder.nativeOrder()); + fileChannel.read(byteBuffer); + try { + Interpreter interpreter = new Interpreter(byteBuffer); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Model ByteBuffer should be either a MappedByteBuffer" + + " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()"); + } + fileChannel.close(); + } + @Test public void testRun() { Interpreter interpreter = new Interpreter(MODEL_FILE); @@ -119,6 +164,24 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testRunWithByteBufferOutput() { + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { + interpreter.run(fourD, parsedOutput); + } + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + @Test public void testMobilenetRun() { // Create a gray image. diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 7c00d3196fd001a288d77d4e01f0b30978d72afe..029e5853e2f843fc38eeca0ffa9bb3a82390093b 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -41,6 +41,9 @@ public final class NativeInterpreterWrapperTest { private static final String BYTE_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + private static final String QUANTIZED_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/quantized.bin"; + private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; @@ -108,6 +111,27 @@ public final class NativeInterpreterWrapperTest { wrapper.close(); } + @Test + public void testRunWithBufferOutput() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs).hasLength(1); + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + outputs[0].copyTo(parsedOutput); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + } + @Test public void testRunWithInputsOfSameDims() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -536,4 +560,16 @@ public final class NativeInterpreterWrapperTest { assertThat(wrapper.getOutputDataType(0)).contains("byte"); wrapper.close(); } + + @Test + public void testGetOutputQuantizationParams() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f); + } + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f); + } + } } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 94b6632bb8dd7117bf4074da1939bd23ce732efd..dd9d37eedafaa8250f5f926375edcf7cb3b730a0 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -18,6 +18,9 @@ package org.tensorflow.lite; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -70,6 +73,32 @@ public final class TensorTest { assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); } + @Test + public void testCopyToByteBuffer() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + tensor.copyTo(parsedOutput); + assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToInvalidByteBuffer() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.copyTo(parsedOutput); + fail(); + } catch (BufferOverflowException e) { + // Expected. + } + } + @Test public void testCopyToWrongType() { Tensor tensor = Tensor.fromHandle(nativeHandle); diff --git a/tensorflow/contrib/lite/java/src/testdata/quantized.bin b/tensorflow/contrib/lite/java/src/testdata/quantized.bin new file mode 100644 index 0000000000000000000000000000000000000000..4062088cdf717e8752490de5c9acff35fd6af54f Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/quantized.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index b524246d436858bbf506809a38cead2897f78d93..af1d99ef41e6413d8ef2c6f478aaa8f9e3931ff8 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -1,6 +1,8 @@ # Description: # Internal helper function to test TF Lite API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b7291dd379a6c09a70a78de7bc6c2f217b293b26..27b8a16e1522de4d31b2870e6130fb3281941a05 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -142,11 +142,13 @@ cc_library( "conv.cc", "depthwise_conv.cc", "dequantize.cc", + "detection_postprocess.cc", "div.cc", "elementwise.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", + "expand_dims.cc", "floor.cc", "fully_connected.cc", "gather.cc", @@ -156,25 +158,29 @@ cc_library( "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", - "mean.cc", "mfcc.cc", "mul.cc", "neg.cc", "pad.cc", "pooling.cc", + "pow.cc", + "reduce.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", + "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_to_dense.cc", "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "tile.cc", "topk_v2.cc", "transpose.cc", "transpose_conv.cc", @@ -243,6 +249,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "detection_postprocess_test", + size = "small", + srcs = ["detection_postprocess_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + tf_cc_test( name = "activations_test", size = "small", @@ -551,9 +571,9 @@ tf_cc_test( ) tf_cc_test( - name = "mean_test", + name = "reduce_test", size = "small", - srcs = ["mean_test.cc"], + srcs = ["reduce_test.cc"], tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", @@ -857,6 +877,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "tile_test", + size = "small", + srcs = ["tile_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "comparisons_test", size = "small", @@ -930,6 +964,63 @@ tf_cc_test( ":builtin_ops", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "expand_dims_test", + size = "small", + srcs = ["expand_dims_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "sparse_to_dense_test", + size = "small", + srcs = ["sparse_to_dense_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "shape_test", + size = "small", + srcs = ["shape_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pow_test", + size = "small", + srcs = ["pow_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 4972159a05eb9a6951b2a5fd2c2832966f4b76df..99f81c4a8a78ab0b2a24955d77f25ed09da13b84 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -84,6 +84,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // These operators are implemented in fixed-point arithmetic, + // which intrinsically wants symmetric ranges (zero_point==0) + // and power-of-two scales (power-of-two is abbreviated below as POT). + // While more general support would be possible by means of rescaling, + // that would add some overhead and some loss of accuracy and wouldn't + // be used at the moment as current quantized LSTM applications are + // happy with symmetric, power-of-two-scales quantization. So we just + // implement that narrow case only for now. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TF_LITE_ENSURE(context, data->input_left_shift >= 0); + TF_LITE_ENSURE(context, data->input_left_shift <= 1); } return context->ResizeTensor(context, output, @@ -114,6 +146,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // See comments in TanhPrepare about requiring zero_point==0 + // and a power-of-two ("POT") scale. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // The int16 logistic implementation does not support shifting of the input. + TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0); } return context->ResizeTensor(context, output, @@ -191,7 +247,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } @@ -211,7 +268,8 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } @@ -229,7 +287,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } @@ -247,16 +306,24 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; } break; + case kTfLiteInt16: { + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), + data->input_left_shift, + GetTensorData(output), + GetTensorShape(output)); + return kTfLiteOk; + } break; case kTfLiteUInt8: { - optimized_ops::Tanh(GetTensorData(input), GetTensorDims(input), + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } @@ -276,16 +343,23 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); break; } + case kTfLiteInt16: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); + break; + } case kTfLiteUInt8: { optimized_ops::Logistic( - GetTensorData(input), GetTensorDims(input), + GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, - GetTensorData(output), GetTensorDims(output)); + GetTensorData(output), GetTensorShape(output)); break; } default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } return kTfLiteOk; @@ -336,26 +410,26 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; optimized_ops::Softmax(GetTensorData(input), - GetTensorDims({batch_size, 1, 1, input_size}), + GetTensorShape({batch_size, 1, 1, input_size}), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims({batch_size, 1, 1, input_size})); + GetTensorShape({batch_size, 1, 1, input_size})); } // Takes a 4D tensor and perform softmax along the forth dimension. void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), params->beta, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { @@ -377,8 +451,9 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax4DFloat(input, output, params); return kTfLiteOk; } - context->ReportError(context, - "Only 2D and 4D tensors supported currently."); + context->ReportError( + context, "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); return kTfLiteError; } case kTfLiteUInt8: { @@ -390,13 +465,15 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax4DQuantized(input, output, params, data); return kTfLiteOk; } - context->ReportError(context, - "Only 2D and 4D tensors supported currently."); + context->ReportError( + context, "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); return kTfLiteError; } default: - context->ReportError(context, - "Only float32 and uint8_t supported currently."); + context->ReportError( + context, "Only float32 and uint8_t supported currently, got %d.", + input->type); return kTfLiteError; } } @@ -407,11 +484,12 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: optimized_ops::LogSoftmax( - GetTensorData(input), GetTensorDims(input), - GetTensorData(output), GetTensorDims(output)); + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); return kTfLiteOk; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently., got %d", + input->type); return kTfLiteError; } } @@ -422,7 +500,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* output = GetOutput(context, node, 0); if (input->type != kTfLiteFloat32) { - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } TF_LITE_ENSURE_EQ(context, input->dims->size, 4); diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 50a84edd475c8051a563cf8ed9fc03099829b786..587e1303da6afed1fc711100f457f1bf62b0b7e1 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; -// TODO(ahentz): I don't quite understand the tradeoffs in the quantized -// implementation of sigmoid and software, but a tolerance of twice the output -// scale seems reasonable. We might want to change this if we have a better -// theoretical bound. +// Our fixed-point math function implementations have roughly 12 bits of +// accuracy, when specialized to 16-bit fixed-point arithmetic. +// That is purely an implementation compromise, it would have been possible +// to get closer to 16 bits of accuracy but that would be more expensive, +// and not needed for our purposes as ultimately the output is either +// immediately down-quantized to 8 bits, or will typically be at the output +// of the surrounding LSTM cell. +// So we can require roughly 2^-12 accuracy when the output is 16-bit, and +// we can more or less expect the full 2^-8 accuracy when the output is 8-bit. +// +// However, the representable output interval is often [-1, 1] (it has to be +// for tanh, and even for logistic, when we implement it in fixed-point, we +// typically have to do so on such a symmetric interval, e.g. ARM NEON only +// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1] +// is 2, our representable values are often diluted by a factor of 2, whence +// the factor of 2 below. const float kQuantizedTolerance = 2 * (1. / 256); +const float kQuantizedToleranceInt16 = 2 * (1. / 4096); class QuantizedActivationsOpModel : public BaseActivationsOpModel { public: using BaseActivationsOpModel::BaseActivationsOpModel; + template void SetInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) { } TEST(QuantizedActivationsOpTest, Tanh) { + const float kMin = -1; + const float kMax = 127.f / 128.f; QuantizedActivationsOpModel m( BuiltinOperator_TANH, - /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8}, - /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1}); - m.SetInput({ + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ 0, -6, 2, 4, // -4, -2, 8, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.0, -0.999987, 0.964027, 0.999329, // - -0.996078, -0.96402, 0.99999, 0.76159, // + -0.999329, -0.96402, 0.99999, 0.76159, // }, - 4 * (1. / 256)))); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226})); + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225})); +} + +TEST(QuantizedActivationsOpTest, TanhInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_TANH, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + -4, -2, 8, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.0, -0.999987, 0.964027, 0.999329, // + -0.999329, -0.96402, 0.99999, 0.76159, // + }, + kQuantizedToleranceInt16))); } TEST(FloatActivationsOpTest, Sigmoid) { @@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) { QuantizedActivationsOpModel m( BuiltinOperator_LOGISTIC, /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.5, 0.002473, 0.880797, 0.982014, // 0.952574, 0.119203, 0.999955, 0.731059, // }, kQuantizedTolerance))); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); } +TEST(QuantizedActivationsOpTest, SigmoidInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedToleranceInt16))); +} + TEST(FloatActivationsOpTest, Softmax4D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); @@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m( 0.1, /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -258,21 +321,22 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m2( 0.1, /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } TEST(FloatActivationsOpTest, Softmax2D) { @@ -309,12 +373,12 @@ TEST(FloatActivationsOpTest, Softmax2D) { TEST(QuantizedActivationsOpTest, Softmax2D) { QuantizedActivationsOpModel m(0.1, /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -325,21 +389,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { // Same input, but a different shape. QuantizedActivationsOpModel m2(0.1, /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } // This contains the same test values as the Softmax test, but reference answer diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 7ca1e35489cba3b5d2567bc04e532fedf8a527a7..f44d531cbfa9ed41f881380752558555aab97b4d 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -39,6 +39,23 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // These fields are used in both the general 8-bit -> 8bit quantized path, + // and the special 16-bit -> 16bit quantized path + int input1_shift; + int input2_shift; + int32 output_activation_min; + int32 output_activation_max; + + // These fields are used only in the general 8-bit -> 8bit quantized path + int32 input1_multiplier; + int32 input2_multiplier; + int32 output_multiplier; + int output_shift; + int left_shift; + int32 input1_offset; + int32 input2_offset; + int32 output_offset; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +69,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -74,89 +92,169 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + // 8bit -> 8bit general quantized path, with general rescalings + data->input1_offset = -input1->params.zero_point; + data->input2_offset = -input2->params.zero_point; + data->output_offset = output->params.zero_point; + data->left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / + ((1 << data->left_shift) * output->params.scale); + + QuantizeMultiplierSmallerThanOneExp( + real_input1_multiplier, &data->input1_multiplier, &data->input1_shift); + data->input1_shift *= -1; + + QuantizeMultiplierSmallerThanOneExp( + real_input2_multiplier, &data->input2_multiplier, &data->input2_shift); + data->input2_shift *= -1; + + QuantizeMultiplierSmallerThanOneExp( + real_output_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + + } else if (output->type == kTfLiteInt16) { + // 16bit -> 16bit special quantized path, supporting only a rather + // narrow case of quantization parameters: zero_points must all be 0 + // ("symmetric quantization") and scales must be power-of-two (which + // we abbreviate as "POT" below). The intended use case for this path + // is in LSTM cells, where, due to the constraints of implementing + // some of the math in these LSTM cells in fixed-point arithmetic, + // we need to have such symmetric, power-of-two quantization + // (Fixed-point formats are inherently symmetric, power-of-two). + TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input1_scale_log2_rounded; + bool input1_scale_is_pot = + CheckedLog2(input1->params.scale, &input1_scale_log2_rounded); + TF_LITE_ENSURE(context, input1_scale_is_pot); + + int input2_scale_log2_rounded; + bool input2_scale_is_pot = + CheckedLog2(input2->params.scale, &input2_scale_log2_rounded); + TF_LITE_ENSURE(context, input2_scale_is_pot); + + int output_scale_log2_rounded; + bool output_scale_is_pot = + CheckedLog2(output->params.scale, &output_scale_log2_rounded); + TF_LITE_ENSURE(context, output_scale_is_pot); + + data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded; + data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded; + + // Shifting of one input is supported. The graph quantization should ensure + // that the other input matches the output. + TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0); + TF_LITE_ENSURE(context, data->input1_shift >= 0); + TF_LITE_ENSURE(context, data->input2_shift >= 0); + + CalculateActivationRangeQuantized(context, params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return context->ResizeTensor(context, output, output_size); } template -void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_ADD(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd); +void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_ADD(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t); + } else { + TF_LITE_ADD(reference_ops, Add, int32_t); + } } else { - TF_LITE_ADD(reference_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t); + } else { + TF_LITE_ADD(optimized_ops, Add, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd, float); + } else { + TF_LITE_ADD(reference_ops, Add, float); + } } else { - TF_LITE_ADD(optimized_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd, float); + } else { + TF_LITE_ADD(optimized_ops, Add, float); + } } } #undef TF_LITE_ADD } template -void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - const int left_shift = 20; - const double twice_max_input_scale = - 2 * std::max(input1->params.scale, input2->params.scale); - const double real_input1_multiplier = - input1->params.scale / twice_max_input_scale; - const double real_input2_multiplier = - input2->params.scale / twice_max_input_scale; - const double real_output_multiplier = - twice_max_input_scale / ((1 << left_shift) * output->params.scale); - - int32 input1_multiplier; - int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); - int32 input2_multiplier; - int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); - int32 output_multiplier; - int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_ADD(type, opname) \ - type::opname(left_shift, GetTensorData(input1), \ - GetTensorDims(input1), input1_offset, input1_multiplier, \ - input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, input2_multiplier, \ - input2_shift, output_offset, output_multiplier, output_shift, \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - // The quantized version of Add doesn't support activations, so we - // always use BroadcastAdd. - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops, BroadcastAdd); - } else { - TF_LITE_ADD(optimized_ops, BroadcastAdd); - } +TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, + TfLiteTensor* output) { + if (output->type == kTfLiteUInt8) { +#define TF_LITE_ADD(type, opname) \ + type::opname( \ + data->left_shift, GetTensorData(input1), GetTensorDims(input1), \ + data->input1_offset, data->input1_multiplier, data->input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + data->input2_offset, data->input2_multiplier, data->input2_shift, \ + data->output_offset, data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, BroadcastAdd); + } else { + TF_LITE_ADD(optimized_ops, BroadcastAdd); + } +#undef TF_LITE_ADD + } else if (output->type == kTfLiteInt16) { +#define TF_LITE_ADD(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + data->input1_shift, GetTensorData(input2), \ + GetTensorDims(input2), data->input2_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, Add); + } else { + TF_LITE_ADD(optimized_ops, Add); + } #undef TF_LITE_ADD + } + + return kTfLiteOk; } template @@ -168,15 +266,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalAddFloat(context, node, params, data, input1, input2, - output); - } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, data, input1, input2, - output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalAdd(context, node, params, data, input1, input2, output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK(context, + EvalAddQuantized(context, node, params, data, + input1, input2, output)); } else { context->ReportError(context, - "Inputs and outputs not all float|uint8 types."); + "Inputs and outputs not all float|uint8|int16 types."); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 956d05bed5162f6ce59705d59aad77ff056dda77..0b5844321133de103919de76d367574f018a6698 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -52,6 +52,13 @@ class FloatAddOpModel : public BaseAddOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + class QuantizedAddOpModel : public BaseAddOpModel { public: using BaseAddOpModel::BaseAddOpModel; @@ -60,15 +67,26 @@ class QuantizedAddOpModel : public BaseAddOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; // for quantized Add, the error shouldn't exceed 2*step -float GetTolerance(int min, int max) { +float GetTolerance(float min, float max) { float kQuantizedStep = (max - min) / 255.0; float kQuantizedTolerance = 2.0 * kQuantizedStep; return kQuantizedTolerance; } +float GetToleranceInt16(float min, float max) { + float kQuantizedStep = (max - min) / 32767.f; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -122,6 +140,57 @@ TEST(FloatAddOpModel, WithBroadcast) { } } +TEST(IntegerAddOpModel, NoActivation) { + IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 4, 10, 13})); +} + +TEST(IntegerAddOpModel, ActivationRELU_N1_TO_1) { + IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1})); +} + +TEST(IntegerAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerAddOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 04, 10, 13, 22, 21})) + << "With shape number " << i; + } +} + +TEST(IntegerAddOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerAddOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-19, 3, 8, 9, 12, 21}))) + << "With shape number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { @@ -144,6 +213,31 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { } } +TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_max.cc index 738d475f60a66c8100fa5f8539660c25bd82128a..26f57e88962116f446e72fbc164d2747e8b633b4 100644 --- a/tensorflow/contrib/lite/kernels/arg_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_max.cc @@ -52,7 +52,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output->type = kTfLiteInt64; break; default: - context->ReportError(context, "Unknown index output data type"); + context->ReportError(context, "Unknown index output data type: %d", + params->output_type); return kTfLiteError; } @@ -64,7 +65,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { break; default: - context->ReportError(context, "Only float32 and int types are supported"); + context->ReportError( + context, + "Unkonwn input type: %d, only float32 and int types are supported", + input->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 0907547f9f3f5f2b059748940c1fd6d027cb8d8a..c09b15b3d263d6cd639234590c99a50a9a48f4a7 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -36,7 +36,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -91,7 +91,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -114,6 +114,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; @@ -145,14 +155,14 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, const TfLiteRNNParams* params, + TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, + TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; @@ -176,12 +186,14 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, hidden_state_ptr_batch, + output_ptr_batch); return kTfLiteOk; } @@ -205,12 +217,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input_weights->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 262e1aeab159d2518e243069c79a2d7200a33c5a..c8cee88edfdbf42f422f66e4d0ca6eeb5eccbf8d 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -163,8 +163,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; default: - context->ReportError(context, - "Type is currently not supported by BatchToSpace."); + context->ReportError( + context, "Type %d is currently not supported by BatchToSpace.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_BATCH_TO_SPACE_ND diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 673eedc2e948ba91908d943d7911162638301d8b..8dd48af57fd1bd9ef21256410d6bede6b7baa566 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" @@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) { [](FromT a) { return static_cast(a); }); } +template +void copyCast(const std::complex* in, ToT* out, int num_elements) { + std::transform(in, in + num_elements, out, [](std::complex a) { + return static_cast(std::real(a)); + }); +} + +template <> +void copyCast(const std::complex* in, std::complex* out, + int num_elements) { + std::transform(in, in + num_elements, out, + [](std::complex a) { return a; }); +} + template TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, int num_elements) { @@ -69,6 +84,13 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, case kTfLiteFloat32: copyCast(in, out->data.f, num_elements); break; + case kTfLiteBool: + copyCast(in, out->data.b, num_elements); + break; + case kTfLiteComplex64: + copyCast(in, reinterpret_cast*>(out->data.c64), + num_elements); + break; default: // Unsupported type. return kTfLiteError; @@ -90,6 +112,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return copyToTensor(input->data.uint8, output, num_elements); case kTfLiteFloat32: return copyToTensor(input->data.f, output, num_elements); + case kTfLiteBool: + return copyToTensor(input->data.b, output, num_elements); + case kTfLiteComplex64: + return copyToTensor( + reinterpret_cast*>(input->data.c64), output, + num_elements); default: // Unsupported type. return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 4e56482a371550b6275a6380e2beebe3cef958ff..954f998206563a38c74a1382092851cfbee1013b 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" @@ -57,6 +59,87 @@ TEST(CastOpModel, CastFloatToInt) { ElementsAreArray({100, 20, 3, 0, 0, 1})); } +TEST(CastOpModel, CastFloatToBool) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}}); + m.PopulateTensor(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({true, true, false, true, true, true})); +} + +TEST(CastOpModel, CastBoolToFloat) { + CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}}); + m.PopulateTensor(m.input(), {true, true, false, true, false, true}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); +} + +TEST(CastOpModel, CastComplex64ToFloat) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); +} + +TEST(CastOpModel, CastFloatToComplex64) { + CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 0.0f), std::complex(2.0f, 0.0f), + std::complex(3.0f, 0.0f), std::complex(4.0f, 0.0f), + std::complex(5.0f, 0.0f), std::complex(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToInt) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(CastOpModel, CastIntToComplex64) { + CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor(m.input(), {1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 0.0f), std::complex(2.0f, 0.0f), + std::complex(3.0f, 0.0f), std::complex(4.0f, 0.0f), + std::complex(5.0f, 0.0f), std::complex(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToComplex64) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), + std::complex(6.0f, 16.0f)})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index b948334b6d82aecf047a423c85c4ab5060a5c864..f678f48fa5bbbcece6c5b87030d951783378d78f 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -23,6 +23,7 @@ namespace tflite { namespace ops { namespace builtin { namespace comparisons { +namespace { constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; @@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { GetTensorData(input2), GetTensorDims(input2), \ GetTensorData(output), GetTensorDims(output)); +TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Equal, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// TODO(renjieliu): Refactor the logic to avoid duplications. +TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); @@ -85,7 +137,8 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; @@ -109,7 +162,8 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; @@ -133,7 +187,8 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; @@ -157,14 +212,29 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; } +} // namespace } // namespace comparisons +TfLiteRegistration* Register_EQUAL() { + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval}; + return &r; +} + +TfLiteRegistration* Register_NOT_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::NotEqualEval}; + return &r; +} + TfLiteRegistration* Register_GREATER() { static TfLiteRegistration r = {nullptr, nullptr, comparisons::ComparisonPrepare, diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 835d238d36d1757a27119ae24b3c07232e9d3dc0..bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -21,18 +21,17 @@ limitations under the License. namespace tflite { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; -class GreaterOpModel : public SingleOpModel { +class ComparisonOpModel : public SingleOpModel { public: - GreaterOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { + ComparisonOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, + TensorType input_type, BuiltinOperator op) { input1_ = AddInput(input_type); input2_ = AddInput(input_type); output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, - CreateGreaterOptions(builder_).Union()); + ConfigureBuiltinOp(op); BuildInterpreter({input1_shape, input2_shape}); } @@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel { int input1_; int input2_; int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_EqualOptions, + CreateEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_NOT_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_NotEqualOptions, + CreateNotEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER: { + SetBuiltinOp(op, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS: { + SetBuiltinOp(op, BuiltinOptions_LessOptions, + CreateLessOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } }; -TEST(ComparisonsTest, GreaterFloat) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); +TEST(ComparisonsTest, EqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterInt) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcast) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcastTwoD) { - GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false, + false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class GreaterEqualOpModel : public SingleOpModel { - public: - GreaterEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, - BuiltinOptions_GreaterEqualOptions, - CreateGreaterEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } +TEST(ComparisonsTest, NotEqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); - int input1() { return input1_; } - int input2() { return input2_; } + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } +TEST(ComparisonsTest, NotEqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); - private: - int input1_; - int input2_; - int output_; -}; + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, true, true, true, true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} + +TEST(ComparisonsTest, GreaterFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} TEST(ComparisonsTest, GreaterEqualFloat) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualInt) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcast) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { - GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessOpModel : public SingleOpModel { - public: - LessOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions, - CreateLessOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; TEST(ComparisonsTest, LessFloat) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessInt) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 6, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcast) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcastTwoD) { - LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessEqualOpModel : public SingleOpModel { - public: - LessEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, - CreateLessEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; - TEST(ComparisonsTest, LessEqualFloat) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualInt) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcast) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcastTwoD) { - LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 3b467b3aa284586ab8e67ede55583adffbe06cc7..0321b2e2a0088bdb09b2c3c61827be8064fe939b 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). data->need_im2col = (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); + params->dilation_width_factor != 1 || + params->dilation_height_factor != 1 || filter_width != 1 || + filter_height != 1); // If we're using the optimized multithreaded EigenTensor implementation of // convolution, it expects the filter weights to be transposed compared to // the normal TF Lite buffer format. Typical TF Lite weights are @@ -177,9 +179,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); - bool hasBias = node->inputs->size == 3; + bool has_bias = node->inputs->size == 3; // Check number of inputs/outputs - TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; @@ -202,9 +204,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(ahentz): At this point the optimized versions require 'bias'. We can // either change that or document that convolution requires it. - TF_LITE_ENSURE(context, hasBias); + TF_LITE_ENSURE(context, has_bias); - if (hasBias) { + if (has_bias) { bias = &context->tensors[node->inputs->data[2]]; if (data_type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); @@ -212,8 +214,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { TF_LITE_ENSURE_EQ(context, bias->type, data_type); } - TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]); + TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } int channels_out = filter->dims->data[0]; @@ -225,29 +226,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto computeOutSize = [padding](int imageSize, int filterSize, int stride, - int dilationRate) -> int { - int effectiveFilterSize = (filterSize - 1) * dilationRate + 1; + auto compute_out_size = [padding](int image_size, int filter_size, int stride, + int dilation_rate) -> int { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - effectiveFilterSize + stride) / stride + ? (image_size - effective_filter_size + stride) / stride : 0; }; - int outWidth = computeOutSize(width, filter_width, params->stride_width, - params->dilation_width_factor); - int outHeight = computeOutSize(height, filter_height, params->stride_height, - params->dilation_height_factor); + int out_width = compute_out_size(width, filter_width, params->stride_width, + params->dilation_width_factor); + int out_height = + compute_out_size(height, filter_height, params->stride_height, + params->dilation_height_factor); data->padding.height = ComputePadding(params->stride_height, params->dilation_height_factor, - height, filter_height, outHeight); + height, filter_height, out_height); data->padding.width = ComputePadding(params->stride_width, params->dilation_width_factor, width, - filter_width, outWidth); + filter_width, out_width); - TF_LITE_ENSURE(context, hasBias); + TF_LITE_ENSURE(context, has_bias); // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. @@ -255,8 +257,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + TF_LITE_ENSURE(context, real_multiplier < 1.0); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); @@ -264,8 +268,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); output_size->data[0] = batches; - output_size->data[1] = outHeight; - output_size->data[2] = outWidth; + output_size->data[1] = out_height; + output_size->data[2] = out_width; output_size->data[3] = channels_out; auto output_status = context->ResizeTensor(context, output, output_size); @@ -305,18 +309,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* hwcn_weights = &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; hwcn_weights->type = data_type; - hwcn_weights->allocation_type = kTfLiteDynamic; - // Make sure we release any previous allocations before we reallocate. - // TODO(petewarden): Persistent arenas would be a better fit for this, but - // they aren't fully implemented yet. - if (hwcn_weights->data.raw) { - free(hwcn_weights->data.raw); - hwcn_weights->data.raw = nullptr; - } + hwcn_weights->allocation_type = kTfLiteArenaRwPersistent; - // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and - // ResizeTensor will actually allocate space for it. The would be more - // efficient if we placed hwcn_weights_status in the persistent arena. auto hwcn_weights_status = context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; @@ -378,8 +372,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); KernelType effective_kernel_type; if (((kernel_type == kMultithreadOptimized) || (kernel_type == kCblasOptimized)) && @@ -455,9 +449,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; - bool hasBias = node->inputs->size == 3; + bool has_bias = node->inputs->size == 3; TfLiteTensor* bias = - hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + has_bias ? &context->tensors[node->inputs->data[2]] : nullptr; TfLiteTensor* im2col = data->need_im2col ? &context->tensors[node->temporaries->data[data->im2col_index]] @@ -489,7 +483,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bias, im2col, hwcn_weights, output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index 3ad8d7d4e10cb814db5f74fa70eef798a4e863ca..16e5f1d065d8ea6d187c5e368d6c9385fe62514b 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -151,8 +151,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); @@ -172,8 +173,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); void (*depthwise_conv)(const float*, const Dims<4>&, const float*, const Dims<4>&, const float*, const Dims<4>&, int, int, @@ -247,7 +248,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bias, output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index 1439c8bce14ad127ed68dc54991aed8b8bb39383..c00cafb9fbfaf53d4dc301ccd3f21a6c6fd892e6 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -47,12 +47,6 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } int input_depth = GetShape(input_)[3]; int output_depth = GetShape(filter_)[3]; @@ -176,6 +170,43 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { })); } +TEST(QuantizedDepthwiseConvolutionOpTest, + SimpleTestQuantizedFilterMultiplierGreaterThan1) { + QuantizedDepthwiseConvolutionOpModel quant_op( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + std::initializer_list input = { + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }; + std::initializer_list bias = {1, 2, 3, 4}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c532cac5a9f59c8b09ff9aefc294e243561f027 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -0,0 +1,591 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace detection_postprocess { + +// Input tensors +constexpr int kInputTensorBoxEncodings = 0; +constexpr int kInputTensorClassPredictions = 1; +constexpr int kInputTensorAnchors = 2; + +// Output tensors +constexpr int kOutputTensorDetectionBoxes = 0; +constexpr int kOutputTensorDetectionClasses = 1; +constexpr int kOutputTensorDetectionScores = 2; +constexpr int kOutputTensorNumDetections = 3; + +constexpr size_t kNumCoordBox = 4; +constexpr size_t kBatchSize = 1; + +// Object Detection model produces axis-aligned boxes in two formats: +// BoxCorner represents the upper right (xmin, ymin) and +// lower left corner (xmax, ymax). +// CenterSize represents the center (xcenter, ycenter), height and width. +// BoxCornerEncoding and CenterSizeEncoding are related as follows: +// ycenter = y / y_scale * anchor.h + anchor.y; +// xcenter = x / x_scale * anchor.w + anchor.x; +// half_h = 0.5*exp(h/ h_scale)) * anchor.h; +// half_w = 0.5*exp(w / w_scale)) * anchor.w; +// ymin = ycenter - half_h +// ymax = ycenter + half_h +// xmin = xcenter - half_w +// xmax = xcenter + half_w +struct BoxCornerEncoding { + float ymin; + float xmin; + float ymax; + float xmax; +}; + +struct CenterSizeEncoding { + float y; + float x; + float h; + float w; +}; +// We make sure that the memory allocations are contiguous with static assert. +static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, + "Size of BoxCornerEncoding is 4 float values"); +static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, + "Size of CenterSizeEncoding is 4 float values"); + +struct OpData { + int max_detections; + int max_classes_per_detection; + float non_max_suppression_score_threshold; + float intersection_over_union_threshold; + int num_classes; + CenterSizeEncoding scale_values; + // Indices of Temporary tensors + int decoded_boxes_index; + int scores_index; + int active_candidate_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + op_data->max_detections = m["max_detections"].AsInt32(); + op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + op_data->non_max_suppression_score_threshold = + m["nms_score_threshold"].AsFloat(); + op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + op_data->num_classes = m["num_classes"].AsInt32(); + op_data->scale_values.y = m["y_scale"].AsFloat(); + op_data->scale_values.x = m["x_scale"].AsFloat(); + op_data->scale_values.h = m["h_scale"].AsFloat(); + op_data->scale_values.w = m["w_scale"].AsFloat(); + context->AddTensors(context, 1, &op_data->decoded_boxes_index); + context->AddTensors(context, 1, &op_data->scores_index); + context->AddTensors(context, 1, &op_data->active_candidate_index); + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +// TODO(chowdhery): Add to kernel_util.h +TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor, + std::initializer_list values) { + TfLiteIntArray* size = TfLiteIntArrayCreate(values.size()); + int index = 0; + for (int v : values) { + size->data[index] = v; + ++index; + } + return context->ResizeTensor(context, tensor, size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = reinterpret_cast(node->user_data); + // Inputs: box_encodings, scores, anchors + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); + // number of detected boxes + const int num_detected_boxes = + op_data->max_detections * op_data->max_classes_per_detection; + + // Outputs: detection_boxes, detection_scores, detection_classes, + // num_detections + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + detection_boxes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_boxes, + {kBatchSize, num_detected_boxes, kNumCoordBox}); + + // Output Tensor detection_classes: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + detection_classes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes}); + + // Output Tensor detection_scores: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + detection_scores->type = kTfLiteFloat32; + SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes}); + + // Output Tensor num_detections: size is set to 1 + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + num_detections->type = kTfLiteFloat32; + // TODO (chowdhery): Make it a scalar when available + SetTensorSizes(context, num_detections, {1}); + + // Temporary tensors + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries->data[0] = op_data->decoded_boxes_index; + node->temporaries->data[1] = op_data->scores_index; + node->temporaries->data[2] = op_data->active_candidate_index; + + // decoded_boxes + TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; + decoded_boxes->type = kTfLiteFloat32; + decoded_boxes->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, decoded_boxes, + {input_box_encodings->dims->data[1], kNumCoordBox}); + + // scores + TfLiteTensor* scores = &context->tensors[op_data->scores_index]; + scores->type = kTfLiteFloat32; + scores->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, scores, + {input_class_predictions->dims->data[1], + input_class_predictions->dims->data[2]}); + + // active_candidate + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + active_candidate->type = kTfLiteUInt8; + active_candidate->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, active_candidate, + {input_box_encodings->dims->data[1]}); + + return kTfLiteOk; +} + +class Dequantizer { + public: + Dequantizer(int zero_point, float scale) + : zero_point_(zero_point), scale_(scale) {} + float operator()(uint8 x) { + return (static_cast(x) - zero_point_) * scale_; + } + + private: + int zero_point_; + float scale_; +}; + +void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, + float quant_zero_point, float quant_scale, + CenterSizeEncoding* box_centersize) { + const uint8* boxes = + GetTensorData(input_box_encodings) + kNumCoordBox * idx; + Dequantizer dequantize(quant_zero_point, quant_scale); + box_centersize->y = dequantize(boxes[0]); + box_centersize->x = dequantize(boxes[1]); + box_centersize->h = dequantize(boxes[2]); + box_centersize->w = dequantize(boxes[3]); +} + +template +T ReInterpretTensor(const TfLiteTensor* tensor) { + // TODO (chowdhery): check float + const float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +template +T ReInterpretTensor(TfLiteTensor* tensor) { + // TODO (chowdhery): check float + float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, + OpData* op_data) { + // Parse input tensor boxencodings + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); + const int num_boxes = input_box_encodings->dims->data[1]; + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + + // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors + CenterSizeEncoding box_centersize; + CenterSizeEncoding scale_values = op_data->scale_values; + CenterSizeEncoding anchor; + for (int idx = 0; idx < num_boxes; ++idx) { + switch (input_box_encodings->type) { + // Quantized + case kTfLiteUInt8: + DequantizeBoxEncodings( + input_box_encodings, idx, + static_cast(input_box_encodings->params.zero_point), + static_cast(input_box_encodings->params.scale), + &box_centersize); + DequantizeBoxEncodings( + input_anchors, idx, + static_cast(input_anchors->params.zero_point), + static_cast(input_anchors->params.scale), &anchor); + break; + // Float + case kTfLiteFloat32: + box_centersize = ReInterpretTensor( + input_box_encodings)[idx]; + anchor = + ReInterpretTensor(input_anchors)[idx]; + break; + default: + // Unsupported type. + return kTfLiteError; + } + + float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y; + float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x; + float half_h = + 0.5f * static_cast(std::exp(box_centersize.h / scale_values.h)) * + anchor.h; + float half_w = + 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * + anchor.w; + TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + auto& box = ReInterpretTensor(decoded_boxes)[idx]; + box.ymin = ycenter - half_h; + box.xmin = xcenter - half_w; + box.ymax = ycenter + half_h; + box.xmax = xcenter + half_w; + } + return kTfLiteOk; +} + +void DecreasingPartialArgSort(const float* values, int num_values, + int num_to_sort, int* indices) { + std::iota(indices, indices + num_values, 0); + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +void SelectDetectionsAboveScoreThreshold(const std::vector& values, + const float threshold, + std::vector* keep_values, + std::vector* keep_indices) { + for (int i = 0; i < values.size(); i++) { + if (values[i] >= threshold) { + keep_values->emplace_back(values[i]); + keep_indices->emplace_back(i); + } + } +} + +bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) { + for (int i = 0; i < num_boxes; ++i) { + // ymax>=ymin, xmax>=xmin + auto& box = ReInterpretTensor(decoded_boxes)[i]; + if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + return false; + } + } + return true; +} + +float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes, + const int i, const int j) { + auto& box_i = ReInterpretTensor(decoded_boxes)[i]; + auto& box_j = ReInterpretTensor(decoded_boxes)[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) return 0.0; + const float intersection_ymin = std::max(box_i.ymin, box_j.ymin); + const float intersection_xmin = std::max(box_i.xmin, box_j.xmin); + const float intersection_ymax = std::min(box_i.ymax, box_j.ymax); + const float intersection_xmax = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = + std::max(intersection_ymax - intersection_ymin, 0.0) * + std::max(intersection_xmax - intersection_xmin, 0.0); + return intersection_area / (area_i + area_j - intersection_area); +} + +// NonMaxSuppressionSingleClass() is O(n^2) pairwise comparison between boxes +// It assumes all boxes are good in beginning and sorts based on the scores. +// If lower-scoring box has too much overlap with a higher-scoring box, +// we get rid of the lower-scoring box. +TfLiteStatus NonMaxSuppressionSingleClassHelper( + TfLiteContext* context, TfLiteNode* node, OpData* op_data, + const std::vector& scores, std::vector* selected) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + const int num_boxes = input_box_encodings->dims->data[1]; + const int max_detections = op_data->max_detections; + const float non_max_suppression_score_threshold = + op_data->non_max_suppression_score_threshold; + const float intersection_over_union_threshold = + op_data->intersection_over_union_threshold; + // Maximum detections should be positive. + TF_LITE_ENSURE(context, (max_detections >= 0)); + // intersection_over_union_threshold should be positive + // and should be less than 1. + TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && + (intersection_over_union_threshold <= 1.0f)); + // Validate boxes + TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); + + // threshold scores + std::vector keep_indices; + // TODO (chowdhery): Remove the dynamic allocation and replace it + // with temporaries, esp for std::vector + std::vector keep_scores; + SelectDetectionsAboveScoreThreshold( + scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); + + int num_scores_kept = keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept, + sorted_indices.data()); + + const int num_boxes_kept = num_scores_kept; + const int output_size = std::min(num_boxes_kept, max_detections); + selected->clear(); + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes); + int num_active_candidate = num_boxes_kept; + uint8_t* active_box_candidate = (active_candidate->data.uint8); + for (int row = 0; row < num_boxes_kept; row++) { + active_box_candidate[row] = 1; + } + + for (int i = 0; i < num_boxes_kept; ++i) { + if (num_active_candidate == 0 || selected->size() >= output_size) break; + if (active_box_candidate[i] == 1) { + selected->push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } else { + continue; + } + for (int j = i + 1; j < num_boxes_kept; ++j) { + if (active_box_candidate[j] == 1) { + float intersection_over_union = ComputeIntersectionOverUnion( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (intersection_over_union > intersection_over_union_threshold) { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + return kTfLiteOk; +} + +// This function implements a fast version of Non Maximal Suppression for +// multiple classes where +// 1) we keep the top-k scores for each anchor and +// 2) during NMS, each anchor only uses the highest class score for sorting. +// 3) Compared to standard NMS, the worst runtime of this version is O(N^2) +// instead of O(KN^2) where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int max_categories_per_anchor = op_data->max_classes_per_detection; + // The row index offset is 1 if background class is included and 0 otherwise. + const int label_offset = 1; + TF_LITE_ENSURE(context, (label_offset != -1)); + TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); + const int num_classes_with_background = num_classes + label_offset; + const int num_categories_per_anchor = + std::min(max_categories_per_anchor, num_classes); + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_classes); + for (int row = 0; row < num_boxes; row++) { + const float* box_scores = + scores + row * num_classes_with_background + label_offset; + int* class_indices = sorted_class_indices.data() + row * num_classes; + DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, + class_indices); + max_scores[row] = box_scores[class_indices[0]]; + } + // Perform non-maximal suppression on max scores + std::vector selected; + NonMaxSuppressionSingleClassHelper(context, node, op_data, max_scores, + &selected); + // Allocate output tensors + int output_box_index = 0; + for (const auto& selected_index : selected) { + const float* box_scores = + scores + selected_index * num_classes_with_background + label_offset; + const int* class_indices = + sorted_class_indices.data() + selected_index * num_classes; + + for (int col = 0; col < num_categories_per_anchor; ++col) { + int box_offset = num_categories_per_anchor * output_box_index + col; + // detection_boxes + ReInterpretTensor(detection_boxes)[box_offset] = + ReInterpretTensor( + decoded_boxes)[selected_index]; + // detection_classes + detection_classes->data.f[box_offset] = class_indices[col]; + // detection_scores + detection_scores->data.f[box_offset] = box_scores[class_indices[col]]; + output_box_index++; + } + } + num_detections->data.f[0] = output_box_index; + return kTfLiteOk; +} + +void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, + const int num_boxes, + const int num_classes_with_background, + const TfLiteTensor* scores) { + float quant_zero_point = + static_cast(input_class_predictions->params.zero_point); + float quant_scale = static_cast(input_class_predictions->params.scale); + Dequantizer dequantize(quant_zero_point, quant_scale); + const uint8* scores_quant = GetTensorData(input_class_predictions); + for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { + scores->data.f[idx] = dequantize(scores_quant[idx]); + } +} + +TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, + TfLiteNode* node, OpData* op_data) { + // Get the input tensors + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], + kBatchSize); + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1)); + + const TfLiteTensor* scores; + switch (input_class_predictions->type) { + case kTfLiteUInt8: { + TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index]; + DequantizeClassPredictions(input_class_predictions, num_boxes, + num_classes_with_background, temporary_scores); + scores = temporary_scores; + } break; + case kTfLiteFloat32: + scores = input_class_predictions; + break; + default: + // Unsupported type. + return kTfLiteError; + } + NonMaxSuppressionMultiClassFastHelper(context, node, op_data, + GetTensorData(scores)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(chowdhery): Generalize for any batch size + TF_LITE_ENSURE(context, (kBatchSize == 1)); + auto* op_data = reinterpret_cast(node->user_data); + // These two functions correspond to two blocks in the Object Detection model. + // In future, we would like to break the custom op in two blocks, which is + // currently not feasible because we would like to input quantized inputs + // and do all calculations in float. Mixed quantized/float calculations are + // currently not supported in TFLite. + + // This fills in temporary decoded_boxes + // by transforming input_box_encodings and input_anchors from + // CenterSizeEncodings to BoxCornerEncoding + DecodeCenterSizeBoxes(context, node, op_data); + // This fills in the output tensors + // by choosing effective set of decoded boxes + // based on Non Maximal Suppression, i.e. selecting + // highest scoring non-overlapping boxes. + NonMaxSuppressionMultiClass(context, node, op_data); + + return kTfLiteOk; +} +} // namespace detection_postprocess + +TfLiteRegistration* Register_DETECTION_POSTPROCESS() { + static TfLiteRegistration r = {detection_postprocess::Init, + detection_postprocess::Free, + detection_postprocess::Prepare, + detection_postprocess::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e0f8484a328d7d1668afd096ad3d08204fbb4a1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class BaseDetectionPostprocessOpModel : public SingleOpModel { + public: + BaseDetectionPostprocessOpModel(const TensorData& input1, + const TensorData& input2, + const TensorData& input3, + const TensorData& output1, + const TensorData& output2, + const TensorData& output3, + const TensorData& output4) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + input3_ = AddInput(input3); + output1_ = AddOutput(output1); + output2_ = AddOutput(output2); + output3_ = AddOutput(output3); + output4_ = AddOutput(output4); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("max_detections", 3); + fbb.Int("max_classes_per_detection", 1); + fbb.Float("nms_score_threshold", 0.0); + fbb.Float("nms_iou_threshold", 0.5); + fbb.Int("num_classes", 2); + fbb.Float("y_scale", 10.0); + fbb.Float("x_scale", 10.0); + fbb.Float("h_scale", 5.0); + fbb.Float("w_scale", 5.0); + }); + fbb.Finish(); + SetCustomOp("TFLite_Detection_PostProcess", fbb.GetBuffer(), + Register_DETECTION_POSTPROCESS); + BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + int input3() { return input3_; } + + template + void SetInput1(std::initializer_list data) { + PopulateTensor(input1_, data); + } + + template + void SetInput2(std::initializer_list data) { + PopulateTensor(input2_, data); + } + + template + void SetInput3(std::initializer_list data) { + PopulateTensor(input3_, data); + } + + template + std::vector GetOutput1() { + return ExtractVector(output1_); + } + + template + std::vector GetOutput2() { + return ExtractVector(output2_); + } + + template + std::vector GetOutput3() { + return ExtractVector(output3_); + } + + template + std::vector GetOutput4() { + return ExtractVector(output4_); + } + + std::vector GetOutputShape1() { return GetTensorShape(output1_); } + std::vector GetOutputShape2() { return GetTensorShape(output2_); } + std::vector GetOutputShape3() { return GetTensorShape(output3_); } + std::vector GetOutputShape4() { return GetTensorShape(output4_); } + + protected: + int input1_; + int input2_; + int input3_; + int output1_; + int output2_; + int output3_; + int output4_; +}; + +TEST(DetectionPostprocessOpTest, FloatTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 3}}, + {TensorType_FLOAT32, {6, 4}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + + // six boxes in center-size encoding + m.SetInput1({0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); + // class scores - two classes with background + m.SetInput2({0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., + .5, .4, 0., .3, .2}); + // six anchors in center-size encoding + m.SetInput3({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}); + // Same boxes in box-corner encoding: + // { 0.0, 0.0, 1.0, 1.0, + // 0.0, 0.1, 1.0, 1.1, + // 0.0, -0.1, 1.0, 0.9, + // 0.0, 10.0, 1.0, 11.0, + // 0.0, 10.1, 1.0, 11.1, + // 0.0, 100.0, 1.0, 101.0} + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 1e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} + +TEST(DetectionPostprocessOpTest, QuantizedTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0}, + {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, + {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + // six boxes in center-size encoding + std::vector> inputs1 = { + {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; + m.QuantizeAndPopulate(m.input1(), inputs1[0]); + // class scores - two classes with background + std::vector> inputs2 = { + {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., .5, .4, 0., .3, + .2}}; + m.QuantizeAndPopulate(m.input2(), inputs2[0]); + // six anchors in center-size encoding + std::vector> inputs3 = { + {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}}; + m.QuantizeAndPopulate(m.input3(), inputs3[0]); + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 3e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index e52e4fe535c4e343ed29e22cfb6de0d9dc4b0cd7..bc5c3783fd63451fd6d600df2d8e93f740c68e95 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -83,8 +83,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_DIV(type, opname) \ type::opname(GetTensorData(input1), GetTensorDims(input1), \ GetTensorData(input2), GetTensorDims(input2), \ @@ -118,8 +118,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalFloat(context, node, params, data, input1, input2, output); } else { - context->ReportError(context, - "Div only supports FLOAT32 and quantized UINT8 now."); + context->ReportError( + context, "Div only supports FLOAT32 and quantized UINT8 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index f1fdb42624073717fb70423ff70dfad08e578ca6..94927cb53df8033e55e647e19fb19afd7def788f 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -19,26 +19,41 @@ limitations under the License. namespace tflite { namespace eigen_support { +namespace { -struct RefCountedEigenContext { +struct RefCountedEigenContext : public TfLiteExternalContext { int num_references = 0; }; +RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kTfLiteEigenContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + Eigen::setNbThreads(context->recommended_num_threads); + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { if (context->recommended_num_threads != -1) { Eigen::setNbThreads(context->recommended_num_threads); } ptr = new RefCountedEigenContext; + ptr->type = kTfLiteEigenContext; + ptr->Refresh = Refresh; ptr->num_references = 0; - context->eigen_context = ptr; + context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " @@ -46,15 +61,9 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; - context->eigen_context = nullptr; + context->SetExternalContext(context, kTfLiteEigenContext, nullptr); } } -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - Eigen::setNbThreads(num_threads); - DecrementUsageCounter(context); -} - } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index aa8c351fd8e8dae45f7d4807ce24d80bb393c41c..d47e691123282a8a8cc53c29be1d95af037e3939 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -28,9 +28,6 @@ void IncrementUsageCounter(TfLiteContext* context); // usages all temporary Eigen objects will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by Eigen. -void SetNumThreads(TfLiteContext* context, int num_threads); - } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index b719a0839435e336890ed34a4461f5204b169bd5..59bab3c4ecd20bf938919ca606a5933f3112f233 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -23,7 +23,7 @@ namespace ops { namespace builtin { namespace elementwise { -TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); @@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } -TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { +inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -44,24 +45,59 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { const float* in = GetTensorData(input); const float* in_end = in + elements; float* out = output->data.f; - for (; in < in_end; in++, out++) *out = std::sin(*in); + for (; in < in_end; in++, out++) *out = float_func(*in); return kTfLiteOk; } default: { - context->ReportError(context, "Only float32 is supported currently"); + context->ReportError(context, "Input type is %d, requires float32", + input->type); return kTfLiteError; } } } +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sin); +} + +TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::log); +} + +TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sqrt); +} + +TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare, + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, elementwise::SinEval}; return &r; } +TfLiteRegistration* Register_LOG() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + +TfLiteRegistration* Register_SQRT() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::SqrtEval}; + return &r; +} + +TfLiteRegistration* Register_RSQRT() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index 412ffb04b90fbc24d232d25d2a86ce639752c3e8..ce4c602ee5c788d67701af3ecd3e023f2b25aae7 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,12 +24,13 @@ namespace { using ::testing::ElementsAreArray; -class SinOpModel : public SingleOpModel { +class ElementWiseOpModel : public SingleOpModel { public: - SinOpModel(std::initializer_list input_shape) { + ElementWiseOpModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } @@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel { }; TEST(ElementWise, Sin) { - SinOpModel m({1, 1, 4, 1}); + ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -50,6 +51,33 @@ TEST(ElementWise, Sin) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Log) { + ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Sqrt) { + ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Rsqrt) { + ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 2, 4, 9}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 7539c0b30ded921df957217bebdc7b20ea4b40b4..0ba170a4da7b7f0d7afa8b425027b03185d3a559 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -24,7 +24,8 @@ limitations under the License. // Output: // Output.dim[0] == Tensor[0].dim[0], num of lookups // Output.dim[1] == Tensor[1].dim[1], num of items per row -// Each item in output is a raw bytes copy of corresponding item in input. +// Each item in output is a raw bytes copy of the corresponding item in input, +// or a dequantized value in the case of a uint8 input. // When indices are out of bound, the ops will not succeed. // @@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, outputSize); } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, 0); - const TfLiteTensor* lookup = GetInput(context, node, 0); - const TfLiteTensor* value = GetInput(context, node, 1); - +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); const int row_bytes = value->bytes / row_size; @@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { + const int row_size = SizeOfDimension(value, 0); + const double scaling_factor = value->params.scale; + + // col_size after we flatten tensor into 2D. + int col_size = 1; + for (int i = 1; i < NumDimensions(value); i++) { + col_size *= SizeOfDimension(value, i); + } + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + // Dequantize embedding values. + // TODO(alanchiao): refactor scalar multiply into separate function + // for ease of adding a neon equivalent if ever necessary. + for (int j = 0; j < col_size; j++) { + output->data.f[j + i * col_size] = + value->data.uint8[j + idx * col_size] * scaling_factor; + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* value = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (value->type) { + case kTfLiteFloat32: + return EvalFloat(context, node, lookup, value, output); + case kTfLiteUInt8: + return EvalHybrid(context, node, lookup, value, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +} + } // namespace embedding_lookup TfLiteRegistration* Register_EMBEDDING_LOOKUP() { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 9b501878f196216a61568bfa36e6615f4dd07478..04657fd86323ef1c58d069c06097c7665f55cc87 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -7,13 +7,14 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License +for the specific language governing permissions and limitations under the +License. ==============================================================================*/ // Unit test for TFLite Lookup op. +#include #include #include @@ -29,12 +30,13 @@ namespace { using ::testing::ElementsAreArray; -class EmbeddingLookupOpModel : public SingleOpModel { +class BaseEmbeddingLookupOpModel : public SingleOpModel { public: - EmbeddingLookupOpModel(std::initializer_list index_shape, - std::initializer_list weight_shape) { + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { input_ = AddInput(TensorType_INT32); - weight_ = AddInput(TensorType_FLOAT32); + weight_ = AddInput(weight_type); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); BuildInterpreter({index_shape, weight_shape}); @@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel { PopulateTensor(input_, data); } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + void Set3DWeightMatrix(const std::function& function) { TfLiteTensor* tensor = interpreter_->tensor(weight_); int rows = tensor->dims->data[0]; @@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel { } } } +}; - std::vector GetOutput() { return ExtractVector(output_); } +class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + HybridEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) + : BaseEmbeddingLookupOpModel(index_shape, weight_shape, + TensorType_UINT8) {} - private: - int input_; - int weight_; - int output_; + void SetWeight(std::initializer_list data) { + SymmetricQuantizeAndPopulate(weight_, data); + } }; // TODO(ahentz): write more tests that exercise the details of the op, such as // lookup errors and variable input shapes. TEST(EmbeddingLookupOpTest, SimpleTest) { EmbeddingLookupOpModel m({3}, {3, 2, 4}); - m.PopulateTensor(0, {1, 0, 2}); + m.SetInput({1, 0, 2}); m.Set3DWeightMatrix( [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); @@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 8}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed33012864354cd93eac2344f75d7eca302c8952 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims.cc @@ -0,0 +1,113 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace expand_dims { +constexpr int kInput = 0; +constexpr int kAxis = 1; +constexpr int kOutput = 0; + +namespace { +TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input, + int axis, TfLiteTensor* output) { + const TfLiteIntArray& input_dims = *input.dims; + if (axis < 0) { + axis = input_dims.size + 1 + axis; + } + TF_LITE_ENSURE(context, axis <= input_dims.size); + + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1); + for (int i = 0; i < output_dims->size; ++i) { + if (i < axis) { + output_dims->data[i] = input_dims.data[i]; + } else if (i == axis) { + output_dims->data[i] = 1; + } else { + output_dims->data[i] = input_dims.data[i - 1]; + } + } + + return context->ResizeTensor(context, output, output_dims); +} + +TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context, + const TfLiteTensor& axis, int* axis_value) { + TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1); + switch (axis.type) { + case kTfLiteInt32: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + case kTfLiteInt64: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + default: + return kTfLiteError; + } +} + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInput); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + TfLiteTensor* output = GetOutput(context, node, 0); + output->type = input->type; + if (IsConstantTensor(axis)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + return ExpandTensorDim(context, *input, axis_value, output); + } + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // Just copy input to output. + const TfLiteTensor* input = GetInput(context, node, kInput); + TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + if (IsDynamicTensor(output)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + TF_LITE_ENSURE_OK(context, + ExpandTensorDim(context, *input, axis_value, output)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); + return kTfLiteOk; +} + +} // namespace expand_dims +TfLiteRegistration* Register_EXPAND_DIMS() { + static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare, + expand_dims::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..50dc860e5a83f185abc70a844abdbc974f7bc4e7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -0,0 +1,83 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpandDimsOpModel : public SingleOpModel { + public: + ExpandDimsOpModel(std::initializer_list input_shape, + TensorType input_type) { + input_ = AddInput(input_type); + axis_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions, + 0); + BuildInterpreter({input_shape, {1}}); + } + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + std::vector GetValuesFloat() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int axis_; + int output_; +}; + +TEST(ExpandDimsOpTest, DifferentAxis) { + ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32); + std::initializer_list values = {-1.f, 1.f, -2.f, 2.f}; + m.SetInputFloat(values); + m.SetAxis(0); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); + + m.SetAxis(1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2})); + + m.SetAxis(2); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); + + m.SetAxis(-1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 1ba30649ec441fd7c7cb47357df90ff3ef6ed610..3b203dd480f95c5dc70a69aafce0bac6ab2cbc06 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -63,6 +63,7 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +constexpr int kShuffledInputWorkspaceTensor = 1; constexpr int kScratchBufferTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -87,7 +88,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + // Shuffled formats need a workspace to store the shuffled input activations. + const int expected_outputs_count = + params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1 + : 2; + TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); @@ -101,17 +106,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_size *= input->dims->data[i]; } + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); const int batch_size = input_size / filter->dims->data[1]; const int num_units = filter->dims->data[0]; - TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]); if (bias) { - TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } - TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); - TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); - // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. TfLiteType data_type = input->type; @@ -119,11 +122,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); + TF_LITE_ENSURE(context, real_multiplier < 1.0); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); } // If we have to perform on-the-fly quantization (with quantized weights and @@ -219,11 +224,8 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node, tensor_utils::ZeroVector(output->data.f, batch_size * num_units); } - // TODO(mirkov): change std::minmax_element with a vectorized call. - auto minmax_element = - std::minmax_element(input->data.f, input->data.f + total_input_size); // Save matrix multiplication computation for all zero input. - if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) { + if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) { tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, params->activation, output->data.f); @@ -281,44 +283,101 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -filter->params.zero_point; int32_t output_offset = output->params.zero_point; -#define TF_LITE_FULLY_CONNECTED(type) \ +#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ type::FullyConnected( \ GetTensorData(input), GetTensorDims(input), input_offset, \ GetTensorData(filter), GetTensorDims(filter), filter_offset, \ GetTensorData(bias), GetTensorDims(bias), output_offset, \ data->output_multiplier, data->output_shift, \ data->output_activation_min, data->output_activation_max, \ - GetTensorData(output), GetTensorDims(output), gemm_context) + GetTensorData(output), GetTensorDims(output), \ + gemm_context) if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops); - } else if (kernel_type == kPie) { - if (input->type == kTfLiteFloat32) { - // Pie currently only supports quantized models and float inputs/outputs. - TfLiteTensor* input_quantized = - &context->tensors[node->temporaries->data[0]]; - return EvalPieQuantized(context, node, params, data, input, filter, bias, - input_quantized, output); - } else { - // TODO(ahentz): we don't have a quantized version of the PIE kernels, so - // we just defer to the MINI ones. - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; } + } else if (kernel_type == kPie && input->type == kTfLiteFloat32) { + // Pie currently only supports quantized models and float inputs/outputs. + TfLiteTensor* input_quantized = + &context->tensors[node->temporaries->data[0]]; + return EvalPieQuantized(context, node, params, data, input, filter, bias, + input_quantized, output); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } } #undef TF_LITE_FULLY_CONNECTED return kTfLiteOk; } +template +TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output, + TfLiteTensor* shuffled_input_workspace) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + // TODO(b/110697972) decide more consistently if / how / where we want + // to perform this kind of runtime data type checks. + if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 || + bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 || + shuffled_input_workspace->type != kTfLiteUInt8) { + context->ReportError(context, "Unexpected data type"); + return kTfLiteError; + } + +#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ + type::ShuffledFullyConnected( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), \ + GetTensorData(shuffled_input_workspace), gemm_context) + if (kernel_type == kReference) { + TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); + } else { + TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_SHUFFLED_FULLY_CONNECTED + + return kTfLiteOk; +} + template TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_FULLY_CONNECTED(type) \ type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ GetTensorData(filter), GetTensorDims(filter), \ @@ -355,10 +414,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalFloat(context, node, params, data, input, filter, bias, output); case kTfLiteUInt8: - return EvalQuantized(context, node, params, data, input, - filter, bias, output); + if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { + TfLiteTensor* shuffled_input_workspace = + GetOutput(context, node, kShuffledInputWorkspaceTensor); + return EvalShuffledQuantized(context, node, params, data, + input, filter, bias, output, + shuffled_input_workspace); + } else if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatDefault) { + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + } else { + context->ReportError(context, + "Unhandled fully-connected weights format"); + return kTfLiteError; + } default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + filter->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 05dd028b484c09bdf90a09fab1238f48e8a9ddab..ec949056971ccb5f7a6f93fa9f236a93625ca6ad 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -15,6 +15,7 @@ limitations under the License. // Unit test for TFLite FULLY_CONNECTED op. #include +#include #include #include @@ -133,9 +134,12 @@ static float fully_connected_golden_output[] = { class BaseFullyConnectedOpModel : public SingleOpModel { public: // TODO(ahentz): test different activation types too. - BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units, - int batches, const TensorData& input, - const TensorData& output = {TensorType_FLOAT32}) + BaseFullyConnectedOpModel( + TfLiteRegistration* registration, int units, int batches, + const TensorData& input, const TensorData& output = {TensorType_FLOAT32}, + ActivationFunctionType activation_func = ActivationFunctionType_RELU, + FullyConnectedOptionsWeightsFormat weights_format = + FullyConnectedOptionsWeightsFormat_DEFAULT) : batches_(batches), units_(units) { int total_input_size = 1; for (int i = 0; i < input.shape.size(); ++i) { @@ -159,10 +163,13 @@ class BaseFullyConnectedOpModel : public SingleOpModel { } output_ = AddOutput(output); + if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) { + AddOutput({TensorType_UINT8, input.shape}); + } SetBuiltinOp( BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + CreateFullyConnectedOptions(builder_, activation_func, weights_format) .Union()); resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, registration); @@ -188,13 +195,11 @@ class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { public: using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } - void SetWeights(std::initializer_list f) { - PopulateTensor(weights_, f); - } + void SetWeights(const std::vector& f) { PopulateTensor(weights_, f); } - void SetInput(std::initializer_list data) { + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } void SetInput(int offset, float* begin, float* end) { @@ -208,20 +213,50 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { public: using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; - void SetBias(std::initializer_list data) { + void SetBias(const std::vector& data) { QuantizeAndPopulate(bias_, data); } - void SetWeights(std::initializer_list data) { + void SetWeights(const std::vector& data) { QuantizeAndPopulate(weights_, data); } - void SetInput(std::initializer_list data) { + void ShuffleAndSetWeights(const std::vector& data, int input_depth, + int output_depth) { + std::vector shuffled_data(data.size()); + CHECK_EQ(input_depth % 16, 0); + CHECK_EQ(output_depth % 4, 0); + float* shuffled_data_ptr = shuffled_data.data(); + for (int block_o = 0; block_o < output_depth; block_o += 4) { + for (int block_i = 0; block_i < input_depth; block_i += 16) { + for (int o = 0; o < 4; o++) { + for (int i = 0; i < 16; i++) { + *shuffled_data_ptr++ = + data[(block_o + o) * input_depth + block_i + i]; + } + } + } + } + TfLiteTensor* t = interpreter_->tensor(weights_); + auto quantized_data = + Quantize(shuffled_data, t->params.scale, t->params.zero_point); + for (uint8_t& q : quantized_data) { + q ^= 0x80; + } + PopulateTensor(weights_, 0, quantized_data.data(), + quantized_data.data() + quantized_data.size()); + } + void SetInput(const std::vector& data) { QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -256,12 +291,12 @@ class HybridFullyConnectedOpModel : public SingleOpModel { ops::builtin::Register_FULLY_CONNECTED_PIE()); BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); } - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } - void SetWeights(std::initializer_list data) { + void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } + void SetWeights(const std::vector& data) { SymmetricQuantizeAndPopulate(weights_, data); } - void SetInput(std::initializer_list f) { PopulateTensor(input_, f); } + void SetInput(const std::vector& f) { PopulateTensor(input_, f); } std::vector GetOutput() { return ExtractVector(output_); } int input_size() { return input_size_; } @@ -340,6 +375,24 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } +TEST_P(FloatFullyConnectedOpTest, SimpleTest2) { + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}); + m.SetWeights({ + 2, 4, // u = 0 + }); + m.SetBias({1}); + + m.SetInput({ + 1, 2, // b = 0 + 2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9)); +} + TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, @@ -350,7 +403,7 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }); m.SetBias({1, 2, 3}); @@ -361,11 +414,136 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ - 24, 25, 26, // - 58, 59, 60, // - }))); - EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(151, 152, 153, 185, 186, 187)); +} + +void SimpleTestQuantizedInt16OutputCase( + TfLiteRegistration* registration, int input_depth, int output_depth, + int batches, FullyConnectedOptionsWeightsFormat weights_format) { + const uint8_t kWeightsZeroPoint = 128; + const float kWeightsScale = 1.f / 128.f; + const uint8_t kInputZeroPoint = 128; + const float kInputScale = 1.f / 128.f; + const float kInputMin = (0 - kInputZeroPoint) * kInputScale; + const float kInputMax = (255 - kInputZeroPoint) * kInputScale; + // Output ranges in [-8..8] encoded as int16 + const float kOutputScale = 8.f / 32768.f; + const float kOutputMin = -32768 * kOutputScale; + const float kOutputMax = 32767 * kOutputScale; + + QuantizedFullyConnectedOpModel m( + registration, output_depth, batches, + /*input=*/ + {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax}, + /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax}, + /*activation_func=*/ActivationFunctionType_NONE, weights_format); + + std::mt19937 random_engine; + std::uniform_int_distribution weights_dist; + + std::vector weights_data(input_depth * output_depth); + for (auto& w : weights_data) { + uint8_t q = weights_dist(random_engine); + w = (q - kWeightsZeroPoint) * kWeightsScale; + } + + // Based on weights_format, enforce any shape requirement for that format/path + // and set the (possibly shuffled) weights. + switch (weights_format) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + m.SetWeights(weights_data); + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + // The shuffled path currently supports only a restrictive subset of + // shapes, described by the following assertions: + CHECK_EQ(input_depth % 16, 0); + CHECK_EQ(output_depth % 4, 0); + CHECK(batches == 1 || batches == 4); + m.ShuffleAndSetWeights(weights_data, input_depth, output_depth); + break; + default: + LOG(FATAL) << "Unhandled weights format"; + } + + std::uniform_int_distribution input_dist; + std::vector input_data(input_depth * batches); + for (auto& i : input_data) { + uint8_t q = input_dist(random_engine); + i = (q - kInputZeroPoint) * kInputScale; + } + + std::vector bias_data(output_depth); + // As the output ranges in [-8, 8], it's reasonable to have bias values + // in [-1, 1], this won't result in too much saturation. + std::uniform_real_distribution bias_dist(-1.f, 1.f); + for (auto& b : bias_data) { + b = bias_dist(random_engine); + } + + m.SetBias(bias_data); + m.SetInput(input_data); + + m.Invoke(); + + std::vector expected_output_data(output_depth * batches); + for (int b = 0; b < batches; b++) { + for (int o = 0; o < output_depth; o++) { + float accum = bias_data[o]; + for (int i = 0; i < input_depth; i++) { + accum += + input_data[b * input_depth + i] * weights_data[o * input_depth + i]; + } + accum = std::min(accum, kOutputMax); + accum = std::max(accum, kOutputMin); + expected_output_data[b * output_depth + o] = accum; + } + } + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f))); +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedInt16OutputDefaultWeights) { + for (int input_depth : {1, 3, 10, 100}) { + for (int output_depth : {1, 3, 10, 100}) { + for (int batch : {1, 3, 10, 100}) { + SimpleTestQuantizedInt16OutputCase( + GetRegistration(), input_depth, output_depth, batch, + FullyConnectedOptionsWeightsFormat_DEFAULT); + } + } + } +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) { + // The shuffled weights block shape is 4x16. The shape of the weights matrix + // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16. + // This means that output_depth must be a multiple of 4, and input_deth must + // be a multiple of 16. + for (int input_depth_numblocks : {1, 3}) { + for (int output_depth_numblocks : {1, 3}) { + int input_depth = 16 * input_depth_numblocks; + int output_depth = 4 * output_depth_numblocks; + // The fast shuffled path is currently supporting only batch sizes of 1 + // and 4. The idea is that the whole point of that path is to go as fast + // as possible for small batch size, which requires fully specializing + // it for each batch size, and for larger batch sizes the generic + // gemmlowp-based implementation is fast enough. + for (int batch : {1, 4}) { + SimpleTestQuantizedInt16OutputCase( + GetRegistration(), input_depth, output_depth, batch, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8); + } + } + } } TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { @@ -396,11 +574,11 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { /*max_abs_error=*/1.3f))); } -TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) { +TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in // batches. In this case, we need the input to have multiples of '2'. - FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(), + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); m.SetWeights({ @@ -444,11 +622,13 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ - 24, 25, 26, // - 58, 59, 60, // - }))); - EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(151, 152, 153, 185, 186, 187)); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index c452d3ebac7d26880d81e80c3e0fa391fcfc477e..2b2a9e662051287fd1e3dbe8978f4689dc731064 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -40,10 +40,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Only INT32 positions are supported. TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); - // Check that input and output types match. - TF_LITE_ENSURE_EQ(context, input->type, output->type); - // TODO(mgubin): only 0D or 1D positions are currently supported. - TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); + // Assign to output the input type. + output->type = input->type; // TODO(mgubin): Only default axis == 0 is supported. TF_LITE_ENSURE_EQ(context, params->axis, 0); // Check conditions for different types. @@ -59,8 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); } break; default: - context->ReportError(context, - "Only float32 and string types are supported"); + context->ReportError( + context, "Only float32 and string types are supported, got %d", + input->type); return kTfLiteError; } const int num_dimensions = @@ -101,6 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_GATHER(int32_t, int32_t); break; case kTfLiteString: { + // TODO(mgubin): Currently support only for 1D output tensors. DynamicBuffer buffer; const int32* indexes = positions->data.i32; const int num_strings = GetStringCount(input); diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index cdadbeda1884ba0186846826dd16be6ff69878d9..1d4292955cced59a47e0500833a86113cb9d3eb8 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) { EXPECT_TRUE(m.GetOutputShape().empty()); } +TEST(GatherOpTest, Test2DIndexWith2DResult) { + GatherOpModel m({3}, TensorType_FLOAT32, {1, 2}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0, 1.0}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); +} + TEST(FloatGatherOpTest, Duplicate) { GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc index 95f45ea768be7f9bae9570563f161792afbff436..ed334af2da877edf9f591612478e22f04cf15931 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -14,57 +14,70 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include + #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace gemm_support { +namespace { -struct RefCountedGemmContext { - gemmlowp::GemmContext* gemm_context_ = nullptr; - int num_references_ = 0; +struct RefCountedGemmContext : public TfLiteExternalContext { + std::unique_ptr gemm_context; + int num_references = 0; }; +RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kTfLiteGemmLowpContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + auto* ptr = GetGemmLowpContext(context); + if (ptr != nullptr) { + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); + } + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { ptr = new RefCountedGemmContext; - ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->type = kTfLiteGemmLowpContext; + ptr->Refresh = Refresh; + ptr->gemm_context.reset(new gemmlowp::GemmContext()); if (context->recommended_num_threads != -1) { - ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads); + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); } - ptr->num_references_ = 0; - context->gemm_context = ptr; + ptr->num_references = 0; + context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr); } - ptr->num_references_++; + ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " "IncrementUsageCounter()"); } - if (--ptr->num_references_ == 0) { - delete ptr->gemm_context_; + if (--ptr->num_references == 0) { delete ptr; - context->gemm_context = nullptr; + context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr); } } gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to GetFromContext() not preceded by IncrementUsageCounter()"); } - return ptr->gemm_context_; -} - -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - GetFromContext(context)->set_max_num_threads(num_threads); - DecrementUsageCounter(context); + return ptr->gemm_context.get(); } } // namespace gemm_support diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index f033501cb6e341aa014fa4d956b531bd79aa555b..37af772c6846f2f8124faabf1a0f0987e2e9393d 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context); // 'context'. If there are no more usages the GemmContext will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by gemmlowp. -void SetNumThreads(TfLiteContext* context, int num_threads); - } // namespace gemm_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index d8340d426ae0bda1dbecc9322650f7c75985126b..7962fcbc9d6c839ea11d7355e955239194442e03 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -176,6 +176,40 @@ cc_library( }), ) +cc_library( + name = "legacy_optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/depthwiseconv_uint8_3x3_filter.h", + "optimized/legacy_optimized_ops.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":quantization_util", + ":strided_slice_logic", + ":types", + ":legacy_reference_base", + ":round", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "optimized", hdrs = [ @@ -273,6 +307,37 @@ cc_library( }), ) +cc_library( + name = "legacy_reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/legacy_reference_ops.h", + "reference/reference_ops.h", + ], + deps = [ + ":quantization_util", + ":round", + ":strided_slice_logic", + ":types", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "reference", hdrs = ["tensor.h"], @@ -302,6 +367,8 @@ cc_library( name = "neon_tensor_utils", srcs = [ "optimized/neon_tensor_utils.cc", + "reference/portable_tensor_utils.cc", + "reference/portable_tensor_utils.h", ], hdrs = [ "common.h", @@ -313,11 +380,11 @@ cc_library( copts = NEON_FLAGS_IF_APPLICABLE + HARD_FP_FLAGS_IF_APPLICABLE, deps = [ ":cpu_check", - ":portable_tensor_utils", ":round", ":types", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite/kernels:op_macros", "@arm_neon_2_x86_sse", "@gemmlowp", ], @@ -418,6 +485,15 @@ cc_library( }), ) +cc_library( + name = "test_util", + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":types", + ], +) + cc_test( name = "tensor_utils_test", srcs = ["tensor_utils_test.cc"], @@ -438,6 +514,84 @@ cc_test( ], ) +cc_test( + name = "depthwiseconv_float_test", + srcs = ["depthwiseconv_float_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "depthwiseconv_quantized_test", + srcs = ["depthwiseconv_quantized_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "resize_bilinear_test", + srcs = ["resize_bilinear_test.cc"], + tags = ["tflite_not_portable"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "softmax_quantized_test", + timeout = "long", + srcs = [ + "softmax_quantized_test.cc", + ], + deps = [ + ":optimized_base", + ":quantization_util", + ":reference_base", + ":test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "logsoftmax_quantized_test", + timeout = "long", + srcs = [ + "logsoftmax_quantized_test.cc", + ], + tags = ["tflite_not_portable"], + deps = [ + ":optimized_base", + ":quantization_util", + ":reference_base", + ":test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "log_quantized_test", + srcs = ["log_quantized_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "cpu_check", hdrs = [ diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index ede95dfee069fa078b89d23b68ce1bb264761351..b86ca49c116875672c4516a2a47f7dae511a7116 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -87,12 +87,12 @@ float ActivationFunction(float x) { output_activation_max); } -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { +inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32 x, int32 quantized_multiplier, int left_shift) { using gemmlowp::RoundingDivideByPOT; using gemmlowp::SaturatingRoundingDoublingHighMul; return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift); } inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..844ee6a53dd65b81f21ae1ef5b6d04192744a304 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" + +namespace tflite { +namespace { + +// Runs the DepthwiseConv and compares against the reference implementation. +template +void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, const Dims<4>& output_dims) { + const int output_buffer_size = RequiredBufferSizeForDims(output_dims); + std::vector output_data(output_buffer_size); + std::vector reference_output_data(output_buffer_size); + reference_ops::DepthwiseConv(input_data, input_dims, filter_data, + filter_dims, bias_data, bias_dims, stride, + pad_width, pad_height, depth_multiplier, + reference_output_data.data(), output_dims); + optimized_ops::DepthwiseConv(input_data, input_dims, filter_data, + filter_dims, bias_data, bias_dims, stride, + pad_width, pad_height, depth_multiplier, + output_data.data(), output_dims); + double sum_abs_diff = 0; + float max_abs_val = 0; + for (int i = 0; i < output_buffer_size; i++) { + sum_abs_diff += std::abs(output_data[i] - reference_output_data[i]); + max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i])); + } + if (sum_abs_diff != 0.f) { + const float mean_diff = + static_cast(sum_abs_diff / output_buffer_size); + const float relative_error = std::abs(mean_diff) / max_abs_val; + ASSERT_LT(relative_error, 1e-5f); + } +} + +void TestOneDepthwiseConv(FusedActivationFunctionType Ac, + const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, const Dims<4>& output_dims) { +#define TOCO_HANDLE_CASE(AC_TYPE) \ + if (AC_TYPE == Ac) { \ + TestOneDepthwiseConv(input_data, input_dims, filter_data, \ + filter_dims, bias_data, bias_dims, stride, \ + pad_width, pad_height, depth_multiplier, \ + output_dims); \ + return; \ + } + TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6) +#undef TOCO_HANDLE_CASE +} + +// This function picks some random DepthwiseConv params, which may or may not +// be legal. If they're not legal, it returns false. If they're legal, +// it runs the DepthwiseConv test and returns true. This allows the caller +// to loop until a test has been run. +bool TryTestOneDepthwiseConv() { + // We have to pick a lot of positive values, where we are particularly + // interested in small values because they are most likely to be special + // cases in optimized implementations, and secondarily because they allow + // tests to run fast, which means we can run more tests and get more + // coverage. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const int output_depth = input_depth * depth_multiplier; + // The optimized DepthwiseConv implementation currently uses a fixed-size + // accumulator buffer on the stack, with that size. This currently means + // that it does not support larger output depths. It CHECK's for it, + // so it's safe in the sense that if a larger output depth was encountered, + // it would explicitly fail. We just need to adjust our testing to that + // constraint. + const int kMaxSupportedOutputDepth = 1024; + if (output_depth > kMaxSupportedOutputDepth) { + return false; + } + const auto ac = RandomElement(std::vector( + {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu, + FusedActivationFunctionType::kRelu6, + FusedActivationFunctionType::kRelu1})); + Dims<4> input_dims_inference = + MakeDimsForInference(input_depth, input_width, input_height, batch); + Dims<4> output_dims_inference; + int pad_width, pad_height; + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + filter_height, stride, padding_type, + &output_dims_inference, &pad_width, &pad_height)) { + return false; + } + Dims<4> filter_dims_inference = + MakeDimsForInference(output_depth, filter_width, filter_height, 1); + Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int filter_buffer_size = + RequiredBufferSizeForDims(filter_dims_inference); + std::vector input_data(input_buffer_size); + std::vector filter_data(filter_buffer_size); + std::vector bias_data(output_depth); + const float input_amplitude = 1.f; + const float filter_amplitude = 1.f; + const float bias_amplitude = + filter_width * filter_height * input_amplitude * filter_amplitude; + FillRandom(&input_data, -input_amplitude, input_amplitude); + FillRandom(&filter_data, -filter_amplitude, filter_amplitude); + FillRandom(&bias_data, -bias_amplitude, bias_amplitude); + TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference, + filter_data.data(), filter_dims_inference, + bias_data.data(), bias_dims_inference, stride, pad_width, + pad_height, depth_multiplier, output_dims_inference); + return true; +} + +void TestOneDepthwiseConv() { + while (!TryTestOneDepthwiseConv()) { + } +} + +TEST(TestDepthwiseConv, TestDepthwiseConv) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv(); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c0fc8433e18fb7f7f89c17380210d94b39ffc94 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" + +namespace tflite { +namespace { + +// Runs the DepthwiseConv and compares against the reference implementation. +template +int TestOneDepthwiseConvWithGivenOutputShift( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + int output_shift, std::int32_t output_activation_min, + std::int32_t output_activation_max, const Dims<4>& output_dims) { + const int output_buffer_size = RequiredBufferSizeForDims(output_dims); + std::vector output_data(output_buffer_size); + std::vector reference_output_data(output_buffer_size); + reference_ops::DepthwiseConv( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, + reference_output_data.data(), output_dims); + optimized_ops::DepthwiseConv( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data.data(), + output_dims); + int saturated_min = 0; + int saturated_max = 0; + std::vector diff(output_buffer_size); + std::int64_t sum_diff = 0; + std::int64_t sum_abs_diff = 0; + for (int i = 0; i < output_buffer_size; i++) { + diff[i] = static_cast(output_data[i]) - + static_cast(reference_output_data[i]); + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + saturated_min += output_data[i] == output_activation_min; + saturated_max += output_data[i] == output_activation_max; + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / output_buffer_size; + const float mean_abs_diff = + static_cast(sum_abs_diff) / output_buffer_size; + // Normally we should require bit-for-bit exact results. Unfortunately a bug + // in the Intel arm_neon_sse.h translation header that we use for x86 tests + // causes 1-bit inaccuracy in + // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized + // DepthwiseConv ops. So we have to live with a few off-by-one errors for now, + // yet still ensure that no more than a small minority of values are wrong. + EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1); + if (saturated_min > 2 * saturated_max) { + return -1; + } + if (saturated_max > 2 * saturated_min) { + return 1; + } + return 0; +} + +// The point of this function is that we can't practically know which +// output_shift value to pass to test DepthwiseConv. It's not easy to guess (we +// could do some +// statistics for large size, but they would be fragile at smaller sizes), and +// guessing wrong would mean that all the values get saturated so the test +// becomes +// vacuous. So we just bisect our way to reasonable output_shift values. +template +void TestOneDepthwiseConvBisectOutputShift( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + int output_activation_bisect_start, int output_activation_bisect_end, + std::int32_t output_activation_min, std::int32_t output_activation_max, + const Dims<4>& output_dims) { + ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end) + << "Bisection failed ?!?!"; + int output_shift_bisect_midpoint = + (output_activation_bisect_start + output_activation_bisect_end) / 2; + int bisect_result = TestOneDepthwiseConvWithGivenOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, + output_shift_bisect_midpoint, output_activation_min, + output_activation_max, output_dims); + // At this point we know that the test succeeded (otherwise it would have + // aborted). + if (bisect_result == 0) { + // The result isn't particularly saturated on one or the other side. + // All good, we're done. + return; + } + if (output_activation_bisect_start == output_activation_bisect_end - 1) { + // There is still some saturation on one side, but the bisection is + // finished anyways. We're done; nothing more we can do about it. This + // happens + // in particular when using an activation with a narrow range. + return; + } + // Continue the bisection based on the present result. + int new_output_activation_bisect_start = bisect_result == 1 + ? output_shift_bisect_midpoint + : output_activation_bisect_start; + int new_output_activation_bisect_end = bisect_result == 1 + ? output_activation_bisect_end + : output_shift_bisect_midpoint; + TestOneDepthwiseConvBisectOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, + new_output_activation_bisect_start, new_output_activation_bisect_end, + output_activation_min, output_activation_max, output_dims); +} + +template +void TestOneDepthwiseConv( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + std::int32_t output_activation_min, std::int32_t output_activation_max, + const Dims<4>& output_dims) { + TestOneDepthwiseConvBisectOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, 0, 32, + output_activation_min, output_activation_max, output_dims); +} + +void TestOneDepthwiseConv( + FusedActivationFunctionType Ac, const std::uint8_t* input_data, + const Dims<4>& input_dims, std::int32_t input_offset, + const std::uint8_t* filter_data, const Dims<4>& filter_dims, + std::int32_t filter_offset, const std::int32_t* bias_data, + const Dims<4>& bias_dims, int stride, int pad_width, int pad_height, + int depth_multiplier, std::int32_t output_offset, + std::int32_t output_multiplier, std::int32_t output_activation_min, + std::int32_t output_activation_max, const Dims<4>& output_dims) { +#define TOCO_HANDLE_CASE(AC_TYPE) \ + if (AC_TYPE == Ac) { \ + TestOneDepthwiseConv( \ + input_data, input_dims, input_offset, filter_data, filter_dims, \ + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, \ + depth_multiplier, output_offset, output_multiplier, \ + output_activation_min, output_activation_max, output_dims); \ + return; \ + } + TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6) +#undef TOCO_HANDLE_CASE +} + +bool TryTestDepthwiseConv(int batch, int input_depth, int input_width, + int input_height, int filter_width, int filter_height, + int depth_multiplier, int stride, + PaddingType padding_type) { + const int output_depth = input_depth * depth_multiplier; + // The optimized DepthwiseConv implementation currently uses a fixed-size + // accumulator buffer on the stack, with that size. This currently means + // that it does not support larger output depths. It CHECK's for it, + // so it's safe in the sense that if a larger output depth was encountered, + // it would explicitly fail. We just need to adjust our testing to that + // constraint. + const int kMaxSupportedOutputDepth = 1024; + if (output_depth > kMaxSupportedOutputDepth) { + return false; + } + const auto ac = RandomElement(std::vector( + {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu, + FusedActivationFunctionType::kRelu6, + FusedActivationFunctionType::kRelu1})); + int output_activation_min = 0; + int output_activation_max = 255; + if (ac != FusedActivationFunctionType::kNone && UniformRandomInt(0, 1)) { + output_activation_min = UniformRandomInt(0, 50); + output_activation_max = UniformRandomInt(200, 255); + } + const std::int32_t output_multiplier = + UniformRandomInt(1 << 29, std::numeric_limits::max()); + const std::int32_t input_offset = UniformRandomInt(-256, 0); + const std::int32_t filter_offset = UniformRandomInt(-256, 0); + const std::int32_t output_offset = UniformRandomInt(-256, 0); + Dims<4> input_dims_inference = + MakeDimsForInference(input_depth, input_width, input_height, batch); + Dims<4> output_dims_inference; + int pad_width, pad_height; + if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + filter_height, stride, padding_type, + &output_dims_inference, &pad_width, &pad_height)) { + return false; + } + Dims<4> filter_dims_inference = + MakeDimsForInference(output_depth, filter_width, filter_height, 1); + Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int filter_buffer_size = + RequiredBufferSizeForDims(filter_dims_inference); + std::vector input_data(input_buffer_size); + std::vector filter_data(filter_buffer_size); + std::vector bias_data(output_depth); + FillRandom(&input_data); + FillRandom(&filter_data); + FillRandom(&bias_data, -10000, 10000); + TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference, + input_offset, filter_data.data(), filter_dims_inference, + filter_offset, bias_data.data(), bias_dims_inference, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_activation_min, + output_activation_max, output_dims_inference); + return true; +} + +// This function picks some random DepthwiseConv params, which may or may not +// be legal. If they're not legal, it returns false. If they're legal, +// it runs the DepthwiseConv test and returns true. This allows the caller +// to loop until a test has been run. +bool TryTestOneDepthwiseConv() { + // We have to pick a lot of positive values, where we are particularly + // interested in small values because they are most likely to be special + // cases in optimized implementations, and secondarily because they allow + // tests to run fast, which means we can run more tests and get more + // coverage. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + + return TryTestDepthwiseConv(batch, input_depth, input_width, input_height, + filter_width, filter_height, depth_multiplier, + stride, padding_type); +} + +// Tests parameters for the 3x3 filter kernel. +bool TryTestOneDepthwiseConv3x3Filter() { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = 3; + const int filter_height = 3; + const int depth_multiplier = 1; + const int stride = UniformRandomInt(1, 2); + // Although the kernel supports only kValid padding, we test that kSame + // is using the correct code path. + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + + return TryTestDepthwiseConv(batch, input_depth, input_width, input_height, + filter_width, filter_height, depth_multiplier, + stride, padding_type); +} + +void TestOneDepthwiseConv() { + while (!TryTestOneDepthwiseConv()) { + } +} + +void TestOneDepthwiseConv3x3Filter() { + while (!TryTestOneDepthwiseConv3x3Filter()) { + } +} + +TEST(TestDepthwiseConv, TestDepthwiseConv) { + const int kTestsToRun = 10 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv(); + } +} + +TEST(TestDepthwiseConv3x3Filter, TestDepthwiseConv) { + const int kTestsToRun = 3 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv3x3Filter(); + } +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 5f9cfc450db1c25ff604f99f93481e5ca590a5a2..a0e382edb6efe467c7b16624cf1760b0d1c6d760 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -52,21 +52,19 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch) { // Output = bias tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, output_ptr_batch); - // TODO(mirkov): change std::minmax_element with a vectorized call. - auto minmax_element = std::minmax_element( - input_ptr_batch, input_ptr_batch + batch_size * input_size); - // Save quantization and matmul computation for all zero input. - if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) { + if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) { // Quantize input from float to uint8 + quantization params (scaling // factor). float unused_min, unused_max; - float* scaling_factors = new float[batch_size]; + // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function) + // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; tensor_utils::SymmetricQuantizeFloats( @@ -80,16 +78,13 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); - delete[] scaling_factors; } - minmax_element = std::minmax_element( - hidden_state_ptr_batch, hidden_state_ptr_batch + batch_size * num_units); // Save quantization and matmul computation for all zero input. - if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) { + if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch, + batch_size * num_units)) { // Quantize hidden_state float unused_min, unused_max; - float* scaling_factors = new float[batch_size]; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; tensor_utils::SymmetricQuantizeFloats( @@ -104,7 +99,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch, scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); - delete[] scaling_factors; } // Output = activation(Output) and update hidden_state @@ -155,6 +149,7 @@ void LstmStep( input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, input_gate_scratch, /*result_stride=*/1); } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); @@ -169,8 +164,7 @@ void LstmStep( if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); + n_batch, input_gate_scratch, /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, @@ -261,5 +255,263 @@ void LstmStep( output_state_ptr); } +// TODO(alanchiao): move this to tensor_utils. +void VectorMultiply(const int8_t* vector, const int v_size, const float scale, + float* result) { + for (int i = 0; i < v_size; ++i) { + *result++ = scale * *vector++; + } +} + +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index cbfbcbeefcd34fa732799d89f52791b18855857d..2a11b37a6069367e8232350c2fc68d4c385e14ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -41,6 +41,9 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, // values of hidden_state_ptr_batch and input_ptr_batch, respectively. // These temporary storages are expected to be preallocated to the same size as // the respective pointers. +// An additional preallocated temporary storage 'scaling_factors' (of size +// batch_size) is used to store the scaling factors of the quantization (used +// for recovery). // {input,recurrent}_weights_scale params are used for dequantization/recovery. void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, float input_weights_scale, @@ -50,7 +53,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch); // Performs an LSTM batch inference step for input specified by input_ptr_batch. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and @@ -88,6 +92,89 @@ void LstmStep( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr_batch +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional (can be nullptr) +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr_batch (same size as input_ptr_batch) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_cell_state_ptr (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr_batch - size 'n_batch * n_output' +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e9ff5242a43a8b54e0e6ae167cdcf7a341c918e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc @@ -0,0 +1,333 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS + +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" + +namespace { + +class NumberGenerator { + public: + std::vector RandomIntVector(int n, int min_val, int max_val) { + std::vector vec(n); + double scale = static_cast(max_val + 1 - min_val) / engine_.max(); + for (auto& it : vec) { + it = min_val + std::floor(engine_() * scale); + } + return vec; + } + + std::mt19937 engine_; +}; + +class LogQuantizedTest : public ::testing::Test { + public: + NumberGenerator generator_; +}; + +// input_integer_bits <= 30. output_integer_bits > 0. +inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits, + int output_integer_bits) { + const double float_log_sum_of_exps = std::log( + static_cast(input_val) * 0.5 / (1 << (30 - input_integer_bits))); + static constexpr double min_int = + static_cast(std::numeric_limits::min()); + static constexpr double max_int = + static_cast(std::numeric_limits::max()); + double double_result = tflite::TfLiteRound(float_log_sum_of_exps * + (1 << (31 - output_integer_bits))); + return static_cast( + std::min(max_int, std::max(min_int, double_result))); +} + +void CheckOutputData(const std::vector& test_output, + const std::vector& reference_output, + const std::vector& test_input, + const string& check_label, int input_integer_bits, + int output_integer_bits, int tolerance) { + // In the special case of small input, specifically raw value of 5, a rounding + // up leads to difference in the output. We do not aim to be accurate for + // very small input values, and there should be sufficient input fractional + // bits that this is a small input. + static constexpr double error_from_rounding_up = 0.0224585; + const int n = test_output.size(); + ASSERT_EQ(n, reference_output.size()); + for (int i = 0; i < n; ++i) { + // Adjust tolerance when input <= 5*2^-(31-input_integer_bits). + const int adjusted_tolerance = + test_input[i] > 5 + ? tolerance + : std::max(tolerance, static_cast(std::ceil( + error_from_rounding_up * + (1 << (31 - output_integer_bits))))); + ASSERT_LE(std::abs(test_output[i] - reference_output[i]), + adjusted_tolerance) + << "Failure in \"" << check_label << "\" at i=" << i + << ", test_input[i]=" << test_input[i] << "=" + << static_cast(test_input[i]) / (1 << (31 - input_integer_bits)) + << ", test_output[i]=" << test_output[i] << "=" + << static_cast(test_output[i]) / + (1 << (31 - output_integer_bits)) + << ", reference_output[i]=" << reference_output[i] << "=" + << static_cast(reference_output[i]) / + (1 << (31 - output_integer_bits)) + << ", difference[i]=" << std::abs(reference_output[i] - test_output[i]) + << "=" + << static_cast(std::abs(reference_output[i] - test_output[i])) / + (1 << (31 - output_integer_bits)) + << "; tolerance=" << tolerance + << ", adj tolerance=" << adjusted_tolerance; + } +} + +void RightShiftVector(const std::vector& shifts, + std::vector* vec) { + const int n = vec->size(); + ASSERT_EQ(n, shifts.size()); + for (int i = 0; i < n; ++i) { + vec->at(i) = std::max(1, vec->at(i) >> shifts[i]); + } +} + +template +void RunSingleTest(const std::vector& test_input, + const string& check_label, int tolerance) { + const int n = test_input.size(); + std::vector float_gen_output(n, 0); + std::vector reference_output(n, 0); + std::vector optimized_output(n, 0); + + // Workaround the stupid things that intelligent humans do. + // Consequence of __builtin_clz(0u) may equal 31 instead of 32. + std::vector fudged_input(n, 0); + for (int i = 0; i < n; ++i) { + fudged_input[i] = std::max(test_input[i], 2); + } + + for (int i = 0; i < n; ++i) { + reference_output[i] = + tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl< + OutputIntegerBits, InputIntegerBits>( + gemmlowp::FixedPoint::FromRaw( + fudged_input[i])) + .raw(); + optimized_output[i] = + tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl< + OutputIntegerBits, InputIntegerBits>( + gemmlowp::FixedPoint::FromRaw( + fudged_input[i])) + .raw(); + float_gen_output[i] = LogPositiveValuesViaFloat( + fudged_input[i], InputIntegerBits, OutputIntegerBits); + } + // Note that first check is intolerant. + { + std::ostringstream label; + label << check_label << " / optimized vs reference / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + optimized_output, reference_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, 0); + } + { + std::ostringstream label; + label << check_label << " / reference vs float-gen / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + reference_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); + } + { + std::ostringstream label; + label << check_label << " optimized vs float-gen / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + optimized_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); + } +} + +template +void RunSingleTest(const std::vector& test_input, int input_integer_bits, + const string& check_label, int tolerance) { +#define INPUT_CASE(K) \ + case K: \ + return RunSingleTest(test_input, check_label, \ + tolerance) + switch (input_integer_bits) { + INPUT_CASE(0); + INPUT_CASE(1); + INPUT_CASE(2); + INPUT_CASE(3); + INPUT_CASE(4); + INPUT_CASE(5); + INPUT_CASE(6); + INPUT_CASE(7); + INPUT_CASE(8); + INPUT_CASE(9); + INPUT_CASE(10); + INPUT_CASE(11); + INPUT_CASE(12); + INPUT_CASE(13); + INPUT_CASE(14); + INPUT_CASE(15); + INPUT_CASE(16); + INPUT_CASE(17); + INPUT_CASE(18); + INPUT_CASE(19); + INPUT_CASE(20); + INPUT_CASE(21); + INPUT_CASE(22); + INPUT_CASE(23); + INPUT_CASE(24); + INPUT_CASE(25); + INPUT_CASE(26); + INPUT_CASE(27); + INPUT_CASE(28); + INPUT_CASE(29); + default: + ASSERT_LE(input_integer_bits, 30) + << "Input integer bits not handled: " << input_integer_bits; + } +#undef INPUT_CASE +} + +void RunSingleTest(const std::vector& test_input, int input_integer_bits, + int output_integer_bits, const string& check_label, + int tolerance) { +#define OUTPUT_CASE(K) \ + case K: \ + return RunSingleTest(test_input, input_integer_bits, check_label, \ + tolerance) + switch (output_integer_bits) { + OUTPUT_CASE(0); + OUTPUT_CASE(1); + OUTPUT_CASE(2); + OUTPUT_CASE(3); + OUTPUT_CASE(4); + OUTPUT_CASE(5); + OUTPUT_CASE(6); + OUTPUT_CASE(7); + OUTPUT_CASE(8); + OUTPUT_CASE(9); + OUTPUT_CASE(10); + OUTPUT_CASE(11); + OUTPUT_CASE(12); + OUTPUT_CASE(13); + OUTPUT_CASE(14); + OUTPUT_CASE(15); + OUTPUT_CASE(16); + OUTPUT_CASE(17); + OUTPUT_CASE(18); + OUTPUT_CASE(19); + OUTPUT_CASE(20); + OUTPUT_CASE(21); + OUTPUT_CASE(22); + OUTPUT_CASE(23); + OUTPUT_CASE(24); + OUTPUT_CASE(25); + OUTPUT_CASE(26); + OUTPUT_CASE(27); + OUTPUT_CASE(28); + OUTPUT_CASE(29); + default: + ASSERT_LE(input_integer_bits, 30) + << "Input integer bits not handled: " << input_integer_bits; + } +#undef OUTPUT_CASE +} + +void RunUniformTest(int test_size, int input_integer_bits, + int output_integer_bits, const string& check_label, + int tolerance, NumberGenerator* generator) { + std::vector test_data = generator->RandomIntVector( + test_size, 2, std::numeric_limits::max() - 1); + test_data[0] = 2; + test_data[1] = 3; + test_data[2] = 4; + test_data[3] = std::numeric_limits::max() - 2; + test_data[4] = std::numeric_limits::max() - 1; + test_data[5] = std::numeric_limits::max(); + + RunSingleTest(test_data, input_integer_bits, output_integer_bits, + check_label + " / uniform test", tolerance); +} + +void RunUniformShiftUniformTest(int test_size, int input_integer_bits, + int output_integer_bits, + const string& check_label, int tolerance, + NumberGenerator* generator) { + std::vector test_data = generator->RandomIntVector( + test_size, 2, std::numeric_limits::max() - 1); + std::vector shifts = generator->RandomIntVector(test_size, 0, 29); + RightShiftVector(shifts, &test_data); + + RunSingleTest(test_data, input_integer_bits, output_integer_bits, + check_label + " / shifted test", tolerance); +} + +TEST_F(LogQuantizedTest, VariedIntegerBits) { + static constexpr int kVariations = 250; + static constexpr int kRunSize = 250; + static constexpr int kIntegerTolerance = 8; + static constexpr double kOutputFloatTolerance = 7.0e-7; + + std::vector input_integer_bits = + generator_.RandomIntVector(kVariations, 0, 24); + std::vector output_integer_bits = + generator_.RandomIntVector(kVariations, 1, 10); + + for (int i = 0; i < kVariations; ++i) { + int var_output_integer_bits = output_integer_bits[i]; + int tolerance = + std::max(1.0 * kIntegerTolerance, + (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance); + + RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits, + "VariedIntegerBits", tolerance, &generator_); + RunUniformShiftUniformTest(kRunSize, input_integer_bits[i], + var_output_integer_bits, "VariedIntegerBits", + tolerance, &generator_); + } +} + +TEST_F(LogQuantizedTest, SelectedIntegerBits) { + static constexpr int kInputBits = 12; + static constexpr int kOutputBits = 5; + static constexpr int kRunSize = 100000; + static constexpr int kIntegerTolerance = 4; + + RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits", + kIntegerTolerance, &generator_); + RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits, + "SelectedIntegerBits", kIntegerTolerance, + &generator_); +} + +} // namespace diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d2f1103e14b40b81c59c8053bcdbee30c85e5c78 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -0,0 +1,244 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +namespace tflite { +namespace { + +void RunLogSoftmaxFloatReference(const uint8* input_data, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float LogSoftmax. + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common, + reference_output_float_data.data(), shape_common); + // Work with quantized scaling for LogSoftmax, under which 255 represents 0, + // and -16 gets nudged up to 0. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::max( + 0, static_cast( + 255 + std::round(16.0f * reference_output_float_data[i]))); + } +} + +void CheckOutputData(const uint8* test_output, const uint8* reference_output, + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); + // While calculating some metrics in floating point, we work with quantized + // scaling. + std::vector diff(buffer_size); + int64_t sum_diff = 0; + int64_t sum_abs_diff = 0; + for (int i = 0; i < buffer_size; i++) { + diff[i] = static_cast(test_output[i]) - reference_output[i]; + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / buffer_size; + const float mean_abs_diff = static_cast(sum_abs_diff) / buffer_size; + // We either check for bit exactness (against the reference quantized version) + // or for general accuracy, allowing off-by-one (against the float reference). + if (be_exacting) { + ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0) + << check_label << ": " + << "std::abs(min_diff)=" << std::abs(min_diff) + << ", std::abs(max_diff)=" << std::abs(max_diff); + } else { + // For small numbers of samples, the estimates of the means vary more. + // Rather than widen the tolerances, we skip the smaller tests. + ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) || + buffer_size < 10000) && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1) + << check_label << ": " + << "buffer_size=" << buffer_size << ", mean_diff=" << mean_diff + << ", mean_abs_diff=" << mean_abs_diff + << ", median_diff=" << median_diff << ", min_diff=" << min_diff + << ", max_diff=" << max_diff; + } +} + +// Runs the LogSoftmax and compares against the float reference implementation +// and the quantized reference implementation. +void RunOneLogSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); + std::vector optimized_logsoftmax_output(buffer_size); + std::vector reference_float_logsoftmax_output(buffer_size); + std::vector reference_quant_logsoftmax_output(buffer_size); + + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, + input_scale, stride, beta, + reference_float_logsoftmax_output.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + int32 reverse_scaling_divisor; + int reverse_scaling_right_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, + &input_beta_left_shift, &reverse_scaling_divisor, + &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier, + input_beta_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, + optimized_logsoftmax_output.data(), shape_common); + reference_ops::LogSoftmax( + input_data, shape_common, input_beta_multiplier, input_beta_left_shift, + reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, + reference_quant_logsoftmax_output.data(), shape_common); + + CheckOutputData(optimized_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), shape_common, + "Optimized vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_quant_logsoftmax_output.data(), shape_common, + "Optimized vs quant reference", true); + CheckOutputData(reference_quant_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), shape_common, + "Quant reference vs float reference", false); +} + +// This function picks some random LogSoftmax params, which are checked for +// desirability. If not acceptable, it returns false. If they're OK, +// it runs the LogSoftmax test and returns true. This allows the caller +// to loop until a test has been run. +// +// Currently we do not reject for any reason. +bool TryOneUniformLogSoftmax() { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // LogSoftmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + static constexpr float beta = 1.0f; + + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); + + std::vector input_data(buffer_size); + FillRandom(&input_data); + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, + input_scale, stride, beta); + return true; +} + +// See TryOneUniformLogSoftmax() for a general description. +// +// Tests with "skyscraper" input patterns are included for two reasons. (a) +// Bimodal distributions are potentially challenging and perhaps more +// realistic than simple uniform random inputs. (b) Some implementations of +// LogSoftmax may adapt as they traverse the depth, and so we test handling of +// cases where relatively small values are encountered at the beginning and end. +bool TryOneSkyscraperLogSoftmax(bool small_depth) { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // LogSoftmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = small_depth + ? ExponentialRandomPositiveInt(0.75f, 40, 500) + : ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + static constexpr float beta = 1.0f; + // Extra parameters for skyscraper input patterns. + const double middle_proportion = + ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0); + const int middle_min = UniformRandomInt(0, 255); + const int sides_max = UniformRandomInt(0, middle_min); + + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); + + std::vector input_data(buffer_size); + FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, + sides_max); + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, + input_scale, stride, beta); + return true; +} + +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformLogSoftmax()) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperLogSoftmax(false)) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperLogSoftmax(true)) { + } + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index dd6932ffe7b7a6f1101f146ce6472b0df4cbda3b..3fd00c89308d3b163111fda004287c715259352f 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1691,14 +1691,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); +#ifdef USE_NEON + const bool shift_left = (output_shift <= 0); + const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1; +#endif TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); -#ifdef __aarch64__ +// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on +// Jetson TX-2. This compiler does not support the offsetof() macro. +#if defined(__aarch64__) && !defined(GOOGLE_L4T) // Call kernel optimized for depthwise convolutions using 3x3 filters if // parameters are supported. - if (Fast3x3FilterKernelSupported(input_dims, filter_dims, stride_width, - stride_height, pad_width, pad_height, - depth_multiplier, output_dims)) { + if (Fast3x3FilterKernelSupported( + input_dims, filter_dims, stride_width, stride_height, pad_width, + pad_height, depth_multiplier, output_dims, output_shift)) { DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, stride_height, pad_width, pad_height, @@ -1833,12 +1839,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, acc[j] = vld1q_s32(acc_buffer + i + 4 * j); } - // Fixed-point multiplication. - for (int j = 0; j < 4; j++) { - acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); - } - for (int j = 0; j < 4; j++) { - acc[j] = RoundingDivideByPOT(acc[j], output_shift); + if (!shift_left) { + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + } else { + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two); + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } } // Add the output offset. for (int j = 0; j < 4; j++) { @@ -1870,12 +1884,21 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, for (; i <= num_output_values - 8; i += 8) { int32x4_t acc0 = vld1q_s32(acc_buffer + i); int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); - // Fixed-point multiplication. - acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); - acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); - // Rounding right shift. - acc0 = RoundingDivideByPOT(acc0, output_shift); - acc1 = RoundingDivideByPOT(acc1, output_shift); + if (!shift_left) { + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + } else { + // Fixed-point multiplication. + acc0 = vmulq_n_s32(acc0, multiplier_power_of_two); + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + + acc1 = vmulq_n_s32(acc1, multiplier_power_of_two); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + } // Add the output offset. acc0 = vaddq_s32(acc0, output_offset_vec); acc1 = vaddq_s32(acc1, output_offset_vec); @@ -1899,10 +1922,16 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // that will have to go through the very slow scalar code. for (; i <= num_output_values - 4; i += 4) { int32x4_t acc = vld1q_s32(acc_buffer + i); - // Fixed-point multiplication. - acc = vqrdmulhq_n_s32(acc, output_multiplier); - // Rounding right shift. - acc = RoundingDivideByPOT(acc, output_shift); + if (!shift_left) { + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + } else { + // Fixed-point multiplication. + acc = vmulq_n_s32(acc, multiplier_power_of_two); + acc = vqrdmulhq_n_s32(acc, output_multiplier); + } // Add the output offset. acc = vaddq_s32(acc, output_offset_vec); // Apply the activation function. @@ -1923,8 +1952,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Handle leftover values, one by one. This is very slow. for (; i < num_output_values; i++) { int32 acc = acc_buffer[i]; - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + -output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 55e0d5c3aa9ebb8b46403550e190b00a54cb53e5..0ce64f8c70d76f970df610f47947580a1efde720 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -23,3848 +23,2912 @@ limitations under the License. namespace tflite { namespace optimized_ops { -#ifdef __aarch64__ - -inline void preload_l1_keep(const uint8* ptr) { -#ifdef GEMMLOWP_ARM_64 - asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); -#else - gemmlowp::Prefetch(ptr); -#endif -} - -// Implementation of quantized DepthwiseConv for 3x3 filters. - -// Below are helper structs to remove the use of arrays. -// There is an llvm bug that causes significant slowdown when using arrays for -// NEON intrinsics vector data types. -// See: https://bugs.llvm.org/show_bug.cgi?id=34945 - -struct Int32x8 { - int32x4_t low, high; -}; - -struct Filter3x3x8 { - int16x8_t f0, f1, f2, f3, f4, f5, f6, f7, f8; -}; - -// Loads 3x3 filter of depth 8 and adds filter offsets. -inline Filter3x3x8 Load3x3Filter(const uint8* filter_ptr, int32 filter_offset, - int output_depth) { - Filter3x3x8 filter; - - uint8x8_t temp_u8_0, temp_u8_1, temp_u8_2, temp_u8_3, temp_u8_4, temp_u8_5, - temp_u8_6, temp_u8_7, temp_u8_8; - int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); - - temp_u8_0 = vld1_u8(filter_ptr + 0 * output_depth); - temp_u8_1 = vld1_u8(filter_ptr + 1 * output_depth); - temp_u8_2 = vld1_u8(filter_ptr + 2 * output_depth); - temp_u8_3 = vld1_u8(filter_ptr + 3 * output_depth); - temp_u8_4 = vld1_u8(filter_ptr + 4 * output_depth); - temp_u8_5 = vld1_u8(filter_ptr + 5 * output_depth); - temp_u8_6 = vld1_u8(filter_ptr + 6 * output_depth); - temp_u8_7 = vld1_u8(filter_ptr + 7 * output_depth); - temp_u8_8 = vld1_u8(filter_ptr + 8 * output_depth); - - filter.f0 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_0)); - filter.f1 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_1)); - filter.f2 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_2)); - filter.f3 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_3)); - filter.f4 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_4)); - filter.f5 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_5)); - filter.f6 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_6)); - filter.f7 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_7)); - filter.f8 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_8)); - - filter.f0 = vaddq_s16(filter.f0, filter_offset_vec); - filter.f1 = vaddq_s16(filter.f1, filter_offset_vec); - filter.f2 = vaddq_s16(filter.f2, filter_offset_vec); - filter.f3 = vaddq_s16(filter.f3, filter_offset_vec); - filter.f4 = vaddq_s16(filter.f4, filter_offset_vec); - filter.f5 = vaddq_s16(filter.f5, filter_offset_vec); - filter.f6 = vaddq_s16(filter.f6, filter_offset_vec); - filter.f7 = vaddq_s16(filter.f7, filter_offset_vec); - filter.f8 = vaddq_s16(filter.f8, filter_offset_vec); - - return filter; -} - -// Applies activation, offset and downquantize on a set of accumulator -// registers that correspond to a 2x2 output of depth 8. -// Stores results to output. -inline void DownquantizeAndStore2x2Output( - Int32x8 acc_0, Int32x8 acc_1, Int32x8 acc_2, Int32x8 acc_3, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - // Fixed-point multiplication. - acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier); - acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier); - acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier); - acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier); - acc_2.low = vqrdmulhq_n_s32(acc_2.low, output_multiplier); - acc_2.high = vqrdmulhq_n_s32(acc_2.high, output_multiplier); - acc_3.low = vqrdmulhq_n_s32(acc_3.low, output_multiplier); - acc_3.high = vqrdmulhq_n_s32(acc_3.high, output_multiplier); - - acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift); - acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift); - acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift); - acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift); - acc_2.low = RoundingDivideByPOT(acc_2.low, output_shift); - acc_2.high = RoundingDivideByPOT(acc_2.high, output_shift); - acc_3.low = RoundingDivideByPOT(acc_3.low, output_shift); - acc_3.high = RoundingDivideByPOT(acc_3.high, output_shift); - - // Add the output offset. - acc_0.low = vaddq_s32(acc_0.low, output_offset_vec); - acc_0.high = vaddq_s32(acc_0.high, output_offset_vec); - acc_1.low = vaddq_s32(acc_1.low, output_offset_vec); - acc_1.high = vaddq_s32(acc_1.high, output_offset_vec); - acc_2.low = vaddq_s32(acc_2.low, output_offset_vec); - acc_2.high = vaddq_s32(acc_2.high, output_offset_vec); - acc_3.low = vaddq_s32(acc_3.low, output_offset_vec); - acc_3.high = vaddq_s32(acc_3.high, output_offset_vec); - - // Apply the activation function. - acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec); - acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec); - acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec); - acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec); - acc_2.low = vmaxq_s32(acc_2.low, output_activation_min_vec); - acc_2.high = vmaxq_s32(acc_2.high, output_activation_min_vec); - acc_3.low = vmaxq_s32(acc_3.low, output_activation_min_vec); - acc_3.high = vmaxq_s32(acc_3.high, output_activation_min_vec); - - acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec); - acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec); - acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec); - acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec); - acc_2.low = vminq_s32(acc_2.low, output_activation_max_vec); - acc_2.high = vminq_s32(acc_2.high, output_activation_max_vec); - acc_3.low = vminq_s32(acc_3.low, output_activation_max_vec); - acc_3.high = vminq_s32(acc_3.high, output_activation_max_vec); - - // Saturating cast to uint8 and store to destination. - int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low); - int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high); - int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low); - int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high); - int16x4_t acc_2_low_s16 = vqmovn_s32(acc_2.low); - int16x4_t acc_2_high_s16 = vqmovn_s32(acc_2.high); - int16x4_t acc_3_low_s16 = vqmovn_s32(acc_3.low); - int16x4_t acc_3_high_s16 = vqmovn_s32(acc_3.high); - - int16x8_t res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16); - int16x8_t res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16); - int16x8_t res_2_s16 = vcombine_s16(acc_2_low_s16, acc_2_high_s16); - int16x8_t res_3_s16 = vcombine_s16(acc_3_low_s16, acc_3_high_s16); - - uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16); - uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16); - uint8x8_t res_2_u8 = vqmovun_s16(res_2_s16); - uint8x8_t res_3_u8 = vqmovun_s16(res_3_s16); - - vst1_u8(output_ptr, res_0_u8); - vst1_u8(output_ptr + output_depth, res_1_u8); - vst1_u8(output_ptr + output_depth * output_width, res_2_u8); - vst1_u8(output_ptr + output_depth * output_width + output_depth, res_3_u8); -} - -inline void DownquantizeAndStore(Int32x8 acc, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, - uint8* output_ptr) { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - acc.low = vqrdmulhq_n_s32(acc.low, output_multiplier); - acc.high = vqrdmulhq_n_s32(acc.high, output_multiplier); - - acc.low = RoundingDivideByPOT(acc.low, output_shift); - acc.high = RoundingDivideByPOT(acc.high, output_shift); - - acc.low = vaddq_s32(acc.low, output_offset_vec); - acc.high = vaddq_s32(acc.high, output_offset_vec); - - acc.low = vmaxq_s32(acc.low, output_activation_min_vec); - acc.high = vmaxq_s32(acc.high, output_activation_min_vec); - - acc.low = vminq_s32(acc.low, output_activation_max_vec); - acc.high = vminq_s32(acc.high, output_activation_max_vec); - - int16x4_t acc_low_s16 = vqmovn_s32(acc.low); - int16x4_t acc_high_s16 = vqmovn_s32(acc.high); - - int16x8_t res_s16 = vcombine_s16(acc_low_s16, acc_high_s16); - uint8x8_t res_u8 = vqmovun_s16(res_s16); - vst1_u8(output_ptr, res_u8); -} +// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on +// Jetson TX-2. This compiler does not support the offsetof() macro. +#if defined(__aarch64__) && !defined(GOOGLE_L4T) +#include +// clang-format gets confused with this file and ends up formatting lines to +// be larger than 80 characters. Turn off here and back on at the end of the +// file. -inline void DownquantizeAndStore2Output( - Int32x8 acc_0, Int32x8 acc_1, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - // Fixed-point multiplication. - acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier); - acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier); - acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier); - acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier); - - acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift); - acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift); - acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift); - acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift); - - // Add the output offset. - acc_0.low = vaddq_s32(acc_0.low, output_offset_vec); - acc_0.high = vaddq_s32(acc_0.high, output_offset_vec); - acc_1.low = vaddq_s32(acc_1.low, output_offset_vec); - acc_1.high = vaddq_s32(acc_1.high, output_offset_vec); - - // Apply the activation function. - acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec); - acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec); - acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec); - acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec); - - acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec); - acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec); - acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec); - acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec); - } - - // Saturating cast to uint8 and store to destination. - int16x8_t res_0_s16; - { - int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low); - int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high); - res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16); - } - - int16x8_t res_1_s16; - { - int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low); - int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high); - res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16); - } - - uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16); - uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16); - vst1_u8(output_ptr, res_0_u8); - vst1_u8(output_ptr + output_ptr_offset, res_1_u8); -} - -// Performs multiply accumulate on 3 inputs of depth 8. -inline Int32x8 MultiplyAccumulateRow(Int32x8 accum, int16x8_t f0, int16x8_t f1, - int16x8_t f2, int16x8_t i0, int16x8_t i1, - int16x8_t i2) { - accum.low = vmlal_s16(accum.low, vget_low_s16(f0), vget_low_s16(i0)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f0), vget_high_s16(i0)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f1), vget_low_s16(i1)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f1), vget_high_s16(i1)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f2), vget_low_s16(i2)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f2), vget_high_s16(i2)); - return accum; -} - -// Performs multiply accumulate on 3 inputs of depth 8. -inline Int32x8 MultiplyAccumulate3x3Filter(const Filter3x3x8& f, int16x8_t i0, - int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, - int16x8_t i5, int16x8_t i6, - int16x8_t i7, int16x8_t i8, - Int32x8 accum) { - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f0), vget_low_s16(i0)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f0), vget_high_s16(i0)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f1), vget_low_s16(i1)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f1), vget_high_s16(i1)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f2), vget_low_s16(i2)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f2), vget_high_s16(i2)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f3), vget_low_s16(i3)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f3), vget_high_s16(i3)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f4), vget_low_s16(i4)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f4), vget_high_s16(i4)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f5), vget_low_s16(i5)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f5), vget_high_s16(i5)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f6), vget_low_s16(i6)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f6), vget_high_s16(i6)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f7), vget_low_s16(i7)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f7), vget_high_s16(i7)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f8), vget_low_s16(i8)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f8), vget_high_s16(i8)); - return accum; -} - -inline void DotProductAndStore(const Filter3x3x8& filter, int16x8_t i0, - int16x8_t i1, int16x8_t i2, int16x8_t i3, - int16x8_t i4, int16x8_t i5, int16x8_t i6, - int16x8_t i7, int16x8_t i8, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr) { - Int32x8 acc; - acc.low = vld1q_s32(bias_ptr); - acc.high = vld1q_s32(bias_ptr + 4); - - acc = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, i8, - acc); - - DownquantizeAndStore(acc, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr); -} - -// Performs multiply-accumulate on a 3x4 input for 2 horizontal outputs. -inline void DotProductAndStore2xStride1( - const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7, - int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11, - const int32* bias_ptr, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i4, i5, i6, i8, i9, - i10, acc_0); - acc_1 = MultiplyAccumulate3x3Filter(filter, i1, i2, i3, i5, i6, i7, i9, i10, - i11, acc_1); - DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier, - output_shift, output_activation_min, - output_activation_max, output_ptr, - output_ptr_offset); -} - -// Performs multiply-accumulate on a 4x3 input for 2 vertical outputs. -inline void DotProductAndStore2yStride1( - const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7, - int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11, - const int32* bias_ptr, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, - i8, acc_0); - acc_1 = MultiplyAccumulate3x3Filter(filter, i3, i4, i5, i6, i7, i8, i9, i10, - i11, acc_1); - DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier, - output_shift, output_activation_min, - output_activation_max, output_ptr, - output_ptr_offset); -} - -// A kernel that is optimized on the number of output cells in the x and y -// direction, and the stride. Assumes 3x3 filters of 8 depth. -template -struct ConvKernel3x3FilterDepth8 {}; - -template <> -struct ConvKernel3x3FilterDepth8<8, 8, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - // To process 8x8 outputs using a 3x3 filter, we require 10x10 inputs. - // Load inputs for the first 2 filters on the top left, then slide to - // the right, down, left, down, right, etc. in a snake-like path. This - // minimizes the total number of loads. - // - // INPUT OUTPUT - // |\----------------\ |\------------\ - // | \ \ | \ \ - // | \----------------\ | \------------\ - // | | 0 ... 9 | | | 0 ... 7 | - // | | 10 ... 19 | ---> | | 8 ... 15 | - // | | 20 ... 29 | \ | .. ... .. | - // \ | .. ... .. | \| 56 ... 63 | - // \| 90 ... 109 | |------------| - // |----------------| - // - // The first set of loads corresponds to: - // - // INPUT OUTPUT - // |\----------------- |\----------- - // | \ | \ - // | \----------------- | \---------- - // | | 0 1 2 3 ... | | 0 1 ... - // | | 10 11 12 13 ... ---> | | .. ... - // | | 20 21 22 23 ... | .. ... - // | | .. ... ... - // - // The next set of loads correspond to a sliding window to the right. - // It loads inputs 4, 5, 14, 15, 23, 24 and keeps 2, 3, 12, 13, and 22: - // - // INPUT OUTPUT - // |\------------------- |\------------- - // | \ | \ - // | \------------------- | \------------ - // | | .. 2 3 4 5 ... | | .. 2 3 ... - // | | .. 12 13 14 15 ... ---> | | .. ... - // | | .. 21 22 23 24 ... | .. ... - // | | .. ... ... - // - // And so on... - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. Referring to the - // indexes in the diagram above, this corresponds to outputs (0) and (1). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Slide to the right for outputs x = [2, 3], y = 0. Referring to the - // indexes in the diagram above, this corresponds to outputs (2) and (3). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Slide to the right again for outputs x = [4, 5], y = 0. Referring to the - // indexes in the diagram above, this corresponds to outputs (4) and (5). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_depth, output_depth); - - // Slide to the right one last time for outputs x = [6, 7], y = 0. - // Referring to the indexes in the diagram above, this corresponds to - // outputs (6) and (7). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_depth, output_depth); - - // Slide to down for outputs x = [6, 7], y = 1. Referring to the indexes in - // the diagram above, this corresponds to outputs (14) and (15). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_depth + output_row_size, - output_depth); - - // Slide left for outputs x = [4, 5], y = 1. Referring to the indexes in - // the diagram above, this corresponds to outputs (12) and (13). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_depth + output_row_size, - output_depth); - - // Slide left again for outputs x = [2, 3], y = 1. Referring to the indexes - // in the diagram above, this corresponds to outputs (10) and (11). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 1. Referring to the - // indexes in the diagram above, this corresponds to outputs (8) and (9). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (16) and (17). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (18) and (19). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 2 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (20) and (21). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 2 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (22) and (23). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 2 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (30) and (31). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 3 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (28) and (29). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 3 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (26) and (27). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 3 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 3. Referring to the - // indexes in the diagram above, this corresponds to outputs (24) and (25). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 3 * output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (32) and (33). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (34) and (35). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 4 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (36) and (37). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 4 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 4. Referring to the - // indexes in the diagram above, this corresponds to outputs (38) and (39). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 4 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (46) and (47). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 5 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (44) and (45). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 5 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (42) and (43). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 5 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 5. Referring to the - // indexes in the diagram above, this corresponds to outputs (40) and (41). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 5 * output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (48) and (49). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 8 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (50) and (51). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 6 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (52) and (53). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 6 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 6. Referring to the - // indexes in the diagram above, this corresponds to outputs (54) and (55). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 6 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (62) and (63). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 9 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 7 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (60) and (61). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 7 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (58) and (59). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 7 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 7. Referring to the - // indexes in the diagram above, this corresponds to outputs (56) and (57). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 7 * output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - // To process 4x4 outputs using a 3x3 filter, we require 6x6 inputs. - // Load inputs for the first 2 filters on the top left, then slide to - // the right, down, left, down, right, etc. in a snake-like path. This - // minimizes the total number of loads. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the top right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, output_depth); - - // Now load next inputs when sliding window right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 2 * output_row_size, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 3 * output_row_size, output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 3 * output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 2, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load next inputs one row down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load next row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load last row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 1, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 2x1 outputs starting from the top. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_row_size); - - // Load inputs for bottom 2 rows. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_6, input_7, input_8, input_9, input_10, input_11, input_0, - input_1, input_2, input_3, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, - output_row_size); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 2, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1, acc_2, acc_3; - - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_2.low = vld1q_s32(bias_ptr); - acc_3.low = vld1q_s32(bias_ptr); - - bias_ptr += 4; - acc_0.high = vld1q_s32(bias_ptr); - acc_1.high = vld1q_s32(bias_ptr); - acc_2.high = vld1q_s32(bias_ptr); - acc_3.high = vld1q_s32(bias_ptr); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - // Add scope for input registers to help the compiler know that it is - // not needed. - { - // To process 2x2 outputs using a 3x3 filter, we require 4x4 inputs. - // Load inputs for the top two filters first. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - const uint8* ptr = input_ptr; - - // Load top 3 rows. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - // Multiply-accum for top-left output. - acc_0 = MultiplyAccumulate3x3Filter(filter, input_0, input_1, input_2, - input_4, input_5, input_6, input_8, - input_9, input_10, acc_0); - - // Multiply-accum for top-right output. - acc_1 = MultiplyAccumulate3x3Filter(filter, input_1, input_2, input_3, - input_5, input_6, input_7, input_9, - input_10, input_11, acc_1); - - // Now load the bottom row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } +// clang-format off - // Multiply-accum for bottom-left output. - acc_2 = MultiplyAccumulate3x3Filter(filter, input_4, input_5, input_6, - input_8, input_9, input_10, input_0, - input_1, input_2, acc_2); - - // Multiply-accum for bottom-right output. - acc_3 = MultiplyAccumulate3x3Filter(filter, input_5, input_6, input_7, - input_9, input_10, input_11, input_1, - input_2, input_3, acc_3); - } - - DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset, - output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the top right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<1, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_depth * 4; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } +#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - } +// Encapsulates constant parameters used in DepthwiseConv. +// 64-bit is used for types that will be added to 64-bit addresses in asm. +struct DepthwiseConvParams { + int64_t input_depth; + int64_t input_row_size; + int64_t output_depth; + int64_t output_row_size; + int64_t filter_row_size; + int32 input_offset; + int32 output_offset; + int32 filter_offset; + int32 output_multiplier; + int32 output_activation_min; + int32 output_activation_max; + int32 output_shift; + int32 input_width; + int32 input_height; + int32 stride_width; + int32 stride_height; + int32 output_width; + int32 output_height; }; -template <> -struct ConvKernel3x3FilterDepth8<2, 1, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - // To process 2x1 outputs using a 3x3 filter, we require 4x3 inputs. - // Load all inputs at the beginning. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth * output_width); - } -}; +#define STR(s) STR_UNEXPANDED(s) +#define STR_UNEXPANDED(s) #s + +// Represents the number of bytes offset from the start of the +// DepthwiseConvParams struct. This is used in the asm to load parameters. +// Keep these values in sync with the static_asserts below. +#define OFFSET_INPUT_DEPTH 0 +#define OFFSET_INPUT_ROW_SIZE 8 +#define OFFSET_OUTPUT_DEPTH 16 +#define OFFSET_OUTPUT_ROW_SIZE 24 +#define OFFSET_FILTER_ROW_SIZE 32 +#define OFFSET_INPUT_OFFSET 40 +#define OFFSET_OUTPUT_OFFSET 44 +#define OFFSET_FILTER_OFFSET 48 +#define OFFSET_OUTPUT_MULTIPLIER 52 +#define OFFSET_OUTPUT_ACTIVATION_MIN 56 +#define OFFSET_OUTPUT_ACTIVATION_MAX 60 +#define OFFSET_OUTPUT_SHIFT 64 +#define OFFSET_INPUT_WIDTH 68 +#define OFFSET_INPUT_HEIGHT 72 +#define OFFSET_STRIDE_WIDTH 76 +#define OFFSET_STRIDE_HEIGHT 80 +#define OFFSET_OUTPUT_WIDTH 84 +#define OFFSET_OUTPUT_HEIGHT 88 + +static_assert(offsetof(DepthwiseConvParams, input_depth) == + OFFSET_INPUT_DEPTH, ""); +static_assert(offsetof(DepthwiseConvParams, input_row_size) == + OFFSET_INPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, output_depth) == + OFFSET_OUTPUT_DEPTH, ""); +static_assert(offsetof(DepthwiseConvParams, output_row_size) == + OFFSET_OUTPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, filter_row_size) == + OFFSET_FILTER_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, input_offset) == + OFFSET_INPUT_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, output_offset) == + OFFSET_OUTPUT_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, filter_offset) == + OFFSET_FILTER_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, output_multiplier) == + OFFSET_OUTPUT_MULTIPLIER, ""); +static_assert(offsetof(DepthwiseConvParams, output_activation_min) == + OFFSET_OUTPUT_ACTIVATION_MIN, ""); +static_assert(offsetof(DepthwiseConvParams, output_activation_max) == + OFFSET_OUTPUT_ACTIVATION_MAX, ""); +static_assert(offsetof(DepthwiseConvParams, output_shift) == + OFFSET_OUTPUT_SHIFT, ""); +static_assert(offsetof(DepthwiseConvParams, input_width) == + OFFSET_INPUT_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, input_height) == + OFFSET_INPUT_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, stride_width) == + OFFSET_STRIDE_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, stride_height) == + OFFSET_STRIDE_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, output_width) == + OFFSET_OUTPUT_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, output_height) == + OFFSET_OUTPUT_HEIGHT, ""); + +template +struct DepthwiseConvWindow {}; template <> -struct ConvKernel3x3FilterDepth8<4, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9; - - const uint8* ptr = input_ptr; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - // Load first 2 rows. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load last row. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); +struct DepthwiseConvWindow<8, 1, 1> { + public: + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, + const DepthwiseConvParams* params_ptr) { + const int64_t input_width_increment = 2 * input_depth; + const int64_t input_height_increment = 2 * input_row_size; + const int64_t output_height_increment = 2 * params_ptr->output_row_size; + +#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5" +#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6" +#define DEPTHWISECONV_LABEL_HEIGHT_1 "7" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10" +#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11" + + asm volatile( + // Performs depthwise convolutions for a window specified by + // |output_window_height| and |output_window_width|. The inner-most loop + // processes 2x2 outputs, and any leftovers at the end. + // + // Algorithm works as follows: + // + // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter + // values. + // 2. For 2 output heights at a time: + // i. For 2 output widths at a time, load inputs for a 2x1 (2 + // height, 1 width) output window (4x3 input window). + // Registers v9--v20 hold input values. Mul-add with + // accumulators v21--v24. Then run activation, downquantize + // and store. Repeat for the next 2x1 output window, + // leveraging overlapping inputs. + // ii. Handle single leftover width if exists. + // 3. Handle single leftover height if exists. + // i. For 2 output widths at a time, load inputs for a 1x2 (1 + // height, 2 width) output window (3x4 input window). + // Registers v9--v20 hold input values. Mul-add with + // accumulators v21--v24. Then run activation, downquantize + // and store. Repeat for the next 1x2 output window, + // leveraging overlapping inputs. + // ii. Handle single leftover width if exists. + // + // Loads are placed as soon as the register is no longer needed and + // interleaved with arithmetic operations to take advantage of + // dual-issue pipelines. We also add input offsets as far from the loads + // as possible to give loads enough cycles to fetch data from memory. + + // Set "constant" registers. These registers may be replaced with temp + // values from time to time when there are not enough NEON registers. + // We use x9--x15 general purpose registers as they are caller-saved + // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "cmp %w[output_window_height], #2\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v29.4s, w2\n" + "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "dup v30.4s, w4\n" + "ldr w0, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v31.4s, w0\n" + "neg w9, w9\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "add x10, %[bias_ptr], #16\n" + "ldr x1, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n" + "dup v9.8h, w9\n" + + // Load filters and add offsets. + "ld1 {v0.8b}, [%[filter_ptr]], x3\n" + "ld1 {v1.8b}, [%[filter_ptr]], x3\n" + "uaddw v0.8h, v9.8h, v0.8b\n" + "ld1 {v2.8b}, [%[filter_ptr]], x3\n" + "uaddw v1.8h, v9.8h, v1.8b\n" + "ld1 {v3.8b}, [%[filter_ptr]], x3\n" + "uaddw v2.8h, v9.8h, v2.8b\n" + "ld1 {v4.8b}, [%[filter_ptr]], x3\n" + "uaddw v3.8h, v9.8h, v3.8b\n" + "ld1 {v5.8b}, [%[filter_ptr]], x3\n" + "uaddw v4.8h, v9.8h, v4.8b\n" + "ld1 {v6.8b}, [%[filter_ptr]], x3\n" + "uaddw v5.8h, v9.8h, v5.8b\n" + "ld1 {v7.8b}, [%[filter_ptr]], x3\n" + "uaddw v6.8h, v9.8h, v6.8b\n" + "ld1 {v8.8b}, [%[filter_ptr]], x3\n" + "uaddw v7.8h, v9.8h, v7.8b\n" + "uaddw v8.8h, v9.8h, v8.8b\n" + + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n" + // This loop processes 2x2 outputs. To avoid register exhaustion, + // inputs for the left 2 outputs are loaded first, then the right + // two outputs. + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x13, x11, %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "add x14, x13, %[input_row_size]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x14, %[input_row_size]\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "mov w5, %w[output_window_width]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x1\n" + "ld1 {v15.8b}, [x14], %[input_depth]\n" + // The height 2 / width 2 loop loads an extra 2x1 outputs (2 height, + // 1 width) in anticipation for the next iteration. Make sure + // |output_window_width| is large enough to handle the additional + // loads, otherwise jump to specific the appropriate label to handle + // smaller widths. + "cmp w5, #2\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v16.8b}, [x14], %[input_depth]\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "ld1 {v18.8b}, [x15], %[input_depth]\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "ld1 {v19.8b}, [x15], %[input_depth]\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "ld1 {v20.8b}, [x15], %[input_depth]\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v22.4s}, [x10]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "ld1 {v24.4s}, [x10]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n" + "cmp w5, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n" + // Mul-add left outputs. + "smlal v21.4s, v0.4h, v9.4h\n" + "subs w5, w5, #2\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "cmp w5, #3\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "ld1 {v9.8b}, [x12]\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "ld1 {v12.8b}, [x13]\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "ld1 {v15.8b}, [x14]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "ld1 {v18.8b}, [x15]\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + + // Mul-add right outputs. + "smlal v21.4s, v0.4h, v10.4h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal2 v22.4s, v0.8h, v10.8h\n" + "mov x12, x11\n" + "smlal v23.4s, v0.4h, v13.4h\n" + "add x13, x11, %[input_row_size]\n" + "smlal2 v24.4s, v0.8h, v13.8h\n" + "add x14, x13, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v11.4h\n" + "add x15, x14, %[input_row_size]\n" + "smlal2 v22.4s, v1.8h, v11.8h\n" + "smlal v23.4s, v1.4h, v14.4h\n" + "smlal2 v24.4s, v1.8h, v14.8h\n" + "smlal v21.4s, v2.4h, v9.4h\n" + "smlal2 v22.4s, v2.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "smlal v21.4s, v5.4h, v12.4h\n" + "smlal2 v22.4s, v5.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v5.4h, v15.4h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v5.8h, v15.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v6.4h, v16.4h\n" + "smlal2 v22.4s, v6.8h, v16.8h\n" + "smlal v23.4s, v6.4h, v19.4h\n" + "smlal2 v24.4s, v6.8h, v19.8h\n" + "smlal v21.4s, v7.4h, v17.4h\n" + "smlal2 v22.4s, v7.8h, v17.8h\n" + "smlal v23.4s, v7.4h, v20.4h\n" + "smlal2 v24.4s, v7.8h, v20.8h\n" + "smlal v21.4s, v8.4h, v15.4h\n" + "smlal2 v22.4s, v8.8h, v15.8h\n" + "ld1 {v15.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v8.4h, v18.4h\n" + "ld1 {v16.8b}, [x14], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v18.8h\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "ld1 {v18.8b}, [x15], %[input_depth]\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "ld1 {v19.8b}, [x15], %[input_depth]\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "ld1 {v20.8b}, [x15], %[input_depth]\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w5, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + // Handle last 2 columns if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n" + // Mul-add left outputs. + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "ld1 {v9.8b}, [x12]\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "ld1 {v12.8b}, [x13]\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "ld1 {v15.8b}, [x14]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "ld1 {v18.8b}, [x15]\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + + // Mul-add right outputs. + "smlal v21.4s, v0.4h, v10.4h\n" + "smlal2 v22.4s, v0.8h, v10.8h\n" + "smlal v23.4s, v0.4h, v13.4h\n" + "smlal2 v24.4s, v0.8h, v13.8h\n" + "smlal v21.4s, v1.4h, v11.4h\n" + "smlal2 v22.4s, v1.8h, v11.8h\n" + "smlal v23.4s, v1.4h, v14.4h\n" + "smlal2 v24.4s, v1.8h, v14.8h\n" + "smlal v21.4s, v2.4h, v9.4h\n" + "smlal2 v22.4s, v2.8h, v9.8h\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "smlal v21.4s, v5.4h, v12.4h\n" + "smlal2 v22.4s, v5.8h, v12.8h\n" + "smlal v23.4s, v5.4h, v15.4h\n" + "smlal2 v24.4s, v5.8h, v15.8h\n" + "smlal v21.4s, v6.4h, v16.4h\n" + "smlal2 v22.4s, v6.8h, v16.8h\n" + "smlal v23.4s, v6.4h, v19.4h\n" + "smlal2 v24.4s, v6.8h, v19.8h\n" + "smlal v21.4s, v7.4h, v17.4h\n" + "smlal2 v22.4s, v7.8h, v17.8h\n" + "smlal v23.4s, v7.4h, v20.4h\n" + "smlal2 v24.4s, v7.8h, v20.8h\n" + "smlal v21.4s, v8.4h, v15.4h\n" + "smlal2 v22.4s, v8.8h, v15.8h\n" + "smlal v23.4s, v8.4h, v18.4h\n" + "smlal2 v24.4s, v8.8h, v18.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [x6], x3\n" + "st1 {v23.8b}, [x7], x3\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v9.16b, v21.16b, v28.16b\n" + "and v12.16b, v22.16b, v28.16b\n" + "and v15.16b, v23.16b, v28.16b\n" + "and v18.16b, v24.16b, v28.16b\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v12.4s, v12.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sshr v18.4s, v18.4s, #31\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v12.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "sqadd v24.4s, v24.4s, v18.4s\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [x6], x3\n" + "st1 {v23.8b}, [x7], x3\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n" + "subs %w[output_window_height], %w[output_window_height], #2\n" + "add %[input_ptr], %[input_ptr], %[input_height_increment]\n" + "cmp %w[output_window_height], #2\n" + "add %[output_ptr], %[output_ptr], %[output_height_increment]\n" + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n" + "cmp %w[output_window_height], #1\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_1 ":\n" + "mov x12, %[input_ptr]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x13, %[input_ptr], %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "add x14, x13, %[input_row_size]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x14, %[input_row_size]\n" + "mov w5, %w[output_window_width]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x1\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + // The height 1 / width 2 loop loads an extra 1x1 output in anticipation + // for the next iteration. Make sure |output_window_width| is large + // enough to handle the additional load, otherwise jump to the + // appropriate label to handle smaller widths. + "cmp w5, #2\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "ld1 {v18.8b}, [x14], %[input_depth]\n" + "ld1 {v19.8b}, [x14], %[input_depth]\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "ld1 {v22.4s}, [x10]\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "ld1 {v24.4s}, [x10]\n" + + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n" + "cmp w5, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n" + // Load inputs for 3x4 input window which corresponds to a 1x2 output + // window. + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v16.8b}, [x13]\n" + "smlal v23.4s, v0.4h, v10.4h\n" + "ld1 {v20.8b}, [x14]\n" + "smlal2 v24.4s, v0.8h, v10.8h\n" + "subs w5, w5, #2\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "cmp w5, #3\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "add %[input_ptr], %[input_ptr], %[input_width_increment]\n" + "smlal v23.4s, v1.4h, v11.4h\n" + "mov x12, %[input_ptr]\n" + "smlal2 v24.4s, v1.8h, v11.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x13, %[input_ptr], %[input_row_size]\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "add x14, x13, %[input_row_size]\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "add x15, x14, %[input_row_size]\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v3.4h, v14.4h\n" + "smlal2 v24.4s, v3.8h, v14.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v4.4h, v15.4h\n" + "smlal2 v24.4s, v4.8h, v15.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v5.4h, v16.4h\n" + "smlal2 v24.4s, v5.8h, v16.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "ld1 {v18.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + "ld1 {v19.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [%[output_ptr]], x3\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "st1 {v23.8b}, [%[output_ptr]], x3\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w5, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + // Handle last two horizontal outputs if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v0.4h, v10.4h\n" + "ld1 {v20.8b}, [x14], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v10.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v11.4h\n" + "smlal2 v24.4s, v1.8h, v11.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v14.4h\n" + "smlal2 v24.4s, v3.8h, v14.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v15.4h\n" + "smlal2 v24.4s, v4.8h, v15.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "smlal v23.4s, v5.4h, v16.4h\n" + "smlal2 v24.4s, v5.8h, v16.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [%[output_ptr]], x3\n" + "st1 {v23.8b}, [%[output_ptr]], x3\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + // Handle bottom right output if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "and v9.16b, v21.16b, v28.16b\n" + "and v12.16b, v22.16b, v28.16b\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v12.4s, v12.4s, #31\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v12.4s\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtun v21.8b, v21.8h\n" + "st1 {v21.8b}, [%[output_ptr]]\n" + DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), + [output_window_height] "+r"(output_window_height) + : + // Inputs. + [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size), + [input_depth] "r"(input_depth), + [output_window_width] "r"(output_window_width), + [input_width_increment] "r"(input_width_increment), + [input_height_increment] "r"(input_height_increment), + [output_height_increment] "r"(output_height_increment), + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1 +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_END } }; template <> -struct ConvKernel3x3FilterDepth8<4, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - // Reuse 4x2 kernel twice. - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth, - output_width); - - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr + 2 * output_depth, output_depth, output_width); +struct DepthwiseConvWindow<8, 2, 2> { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, + const DepthwiseConvParams* params_ptr) { + const int64_t input_width_increment = 4 * input_depth; + const int64_t input_height_increment = 4 * input_row_size; + const int64_t output_height_increment = 2 * params_ptr->output_row_size; + +#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5" +#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6" +#define DEPTHWISECONV_LABEL_HEIGHT_1 "7" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10" +#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11" + + asm volatile( + // Performs depthwise convolutions for a window specified by + // |output_window_height| and |output_window_width|. The inner-most loop + // processes 2x2 outputs, and any leftovers at the end. + // + // Algorithm works as follows: + // + // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter + // values. + // 2. For 2 output heights at a time: + // i. For 2 output widths at a time at stride 2, a 5x5 input + // window is required. To avoid register exhaustion, we load + // the first 2 rows of the 5x5 input window into registers + // v9--v18, and use the same registers to load the next 2 + // rows, and finally v9--v13 to load the last row. + // Accumulators for all 2x2 outputs are reserved by registers + // v21-v22 (top left output), v23-v24 (top right output), + // v19-v20 (bottom left output), v25-v26 (bottom right + // output). + // ii. Handle single leftover width if exists. + // 3. Handle single leftover height if exists. + // i. For 2 output widths at a time at stride 2, load inputs for + // a 1x2 (1 height, 2 width) output window (3x5 input + // window). Registers v9--v24 hold input values. Mul-add with + // accumulators v24--v27. + // ii. Handle single leftover width if exists. + // + // Loads are placed as soon as the register is no longer needed and + // interleaved with arithmetic operations to take advantage of + // dual-issue pipelines. We also add input offsets as far from the loads + // as possible to give loads enough cycles to fetch data from memory. + + // Set "constant" registers. These registers may be replaced with temp + // values from time to time when there are not enough NEON registers. + // We use x9--x15 general purpose registers as they are caller-saved + // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "cmp %w[output_window_height], #2\n" + "dup v28.8h, w0\n" + "neg w9, w9\n" + "ldr w1, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.4s, w9\n" + "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w1\n" + "ldr w3, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "dup v29.4s, w2\n" + "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w3\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "dup v31.4s, w4\n" + "ldr x19, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n" + "ldr w20, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + + // Load filters and add offsets. + "add x10, %[bias_ptr], #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], x5\n" + "dup v9.8h, w20\n" + "ld1 {v1.8b}, [%[filter_ptr]], x5\n" + "uaddw v0.8h, v9.8h, v0.8b\n" + "ld1 {v2.8b}, [%[filter_ptr]], x5\n" + "uaddw v1.8h, v9.8h, v1.8b\n" + "ld1 {v3.8b}, [%[filter_ptr]], x5\n" + "uaddw v2.8h, v9.8h, v2.8b\n" + "ld1 {v4.8b}, [%[filter_ptr]], x5\n" + "uaddw v3.8h, v9.8h, v3.8b\n" + "ld1 {v5.8b}, [%[filter_ptr]], x5\n" + "uaddw v4.8h, v9.8h, v4.8b\n" + "ld1 {v6.8b}, [%[filter_ptr]], x5\n" + "uaddw v5.8h, v9.8h, v5.8b\n" + "ld1 {v7.8b}, [%[filter_ptr]], x5\n" + "uaddw v6.8h, v9.8h, v6.8b\n" + "ld1 {v8.8b}, [%[filter_ptr]]\n" + "uaddw v7.8h, v9.8h, v7.8b\n" + "uaddw v8.8h, v9.8h, v8.8b\n" + + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n" + // Load the first two rows of the 5x5 input window, then reuse the + // same registers to load subsequent rows as they become available. + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "add x13, x12, %[input_row_size]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "mov w14, %w[output_window_width]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + // The height 2 / width 2 loop loads an extra 1 output horizontally in + // anticipation for the next iteration. Make sure + // |output_window_width| is large enough to handle the additional + // load, otherwise jump to the appropriate label to handle smaller + // widths. + "cmp w14, #2\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x13, %[input_row_size]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x19\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "ld1 {v22.4s}, [x10]\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "ld1 {v24.4s}, [x10]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "ld1 {v19.4s}, [%[bias_ptr]]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v20.4s}, [x10]\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v25.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "ld1 {v26.4s}, [x10]\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n" + "cmp w14, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v13.8b}, [x12]\n" + "add x12, x15, %[input_row_size]\n" + "smlal v23.4s, v0.4h, v11.4h\n" + "ld1 {v17.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v11.8h\n" + "ld1 {v18.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "subs w14, w14, #2\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "cmp w14, #3\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v1.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v1.8h, v12.8h\n" + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v23.4s, v2.4h, v13.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v24.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x15]\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "ld1 {v17.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v5.4h, v18.4h\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "smlal2 v24.4s, v5.8h, v18.8h\n" + "ld1 {v18.8b}, [x12]\n" + + "smlal v21.4s, v6.4h, v9.4h\n" + "smlal2 v22.4s, v6.8h, v9.8h\n" + "smlal v19.4s, v0.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v20.4s, v0.8h, v9.8h\n" + "ld1 {v9.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v6.4h, v11.4h\n" + "smlal2 v24.4s, v6.8h, v11.8h\n" + "smlal v21.4s, v7.4h, v10.4h\n" + "smlal2 v22.4s, v7.8h, v10.8h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal v19.4s, v1.4h, v10.4h\n" + "smlal2 v20.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v7.4h, v12.4h\n" + "smlal2 v24.4s, v7.8h, v12.8h\n" + "smlal v25.4s, v1.4h, v12.4h\n" + "smlal2 v26.4s, v1.8h, v12.8h\n" + "smlal v21.4s, v8.4h, v11.4h\n" + "smlal2 v22.4s, v8.8h, v11.8h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal v19.4s, v2.4h, v11.4h\n" + "mov x12, x11\n" + "smlal2 v20.4s, v2.8h, v11.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal v25.4s, v0.4h, v11.4h\n" + "smlal2 v26.4s, v0.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v8.4h, v13.4h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v13.8h\n" + "smlal v25.4s, v2.4h, v13.4h\n" + "smlal2 v26.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "add x15, x13, %[input_row_size]\n" + + "dup v28.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v27.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v23.8b}, [x6], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + + "smlal v19.4s, v6.4h, v9.4h\n" + "smlal2 v20.4s, v6.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v6.4h, v11.4h\n" + "smlal2 v26.4s, v6.8h, v11.8h\n" + "smlal v19.4s, v7.4h, v10.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v20.4s, v7.8h, v10.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v7.4h, v12.4h\n" + "smlal2 v26.4s, v7.8h, v12.8h\n" + "smlal v19.4s, v8.4h, v11.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v20.4s, v8.8h, v11.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v8.4h, v13.4h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal2 v26.4s, v8.8h, v13.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v19.4s, v3.4h, v14.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v20.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v3.4h, v16.4h\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "smlal2 v26.4s, v3.8h, v16.8h\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "smlal v19.4s, v4.4h, v15.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v20.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v4.4h, v17.4h\n" + "smlal2 v26.4s, v4.8h, v17.8h\n" + "smlal v19.4s, v5.4h, v16.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v20.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v5.4h, v18.4h\n" + "smlal2 v26.4s, v5.8h, v18.8h\n" + + "dup v28.4s, w9\n" + "sqrdmulh v19.4s, v19.4s, v27.4s\n" + "sqrdmulh v20.4s, v20.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "sqrdmulh v26.4s, v26.4s, v27.4s\n" + "and v27.16b, v19.16b, v28.16b\n" + "and v29.16b, v20.16b, v28.16b\n" + "and v30.16b, v25.16b, v28.16b\n" + "and v31.16b, v26.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v19.4s, v19.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v20.4s, v20.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v25.4s, v25.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v26.4s, v26.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v19.4s, v19.4s, v28.4s\n" + "srshl v20.4s, v20.4s, v28.4s\n" + "srshl v25.4s, v25.4s, v28.4s\n" + "srshl v26.4s, v26.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v19.4s, v19.4s, v29.4s\n" + "add v20.4s, v20.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "add v26.4s, v26.4s, v29.4s\n" + "smax v19.4s, v19.4s, v30.4s\n" + "smax v20.4s, v20.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smin v19.4s, v19.4s, v31.4s\n" + "smin v20.4s, v20.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "sqxtn v19.4h, v19.4s\n" + "sqxtn v25.4h, v25.4s\n" + "sqxtn2 v19.8h, v20.4s\n" + "ld1 {v20.4s}, [x10]\n" + "sqxtn2 v25.8h, v26.4s\n" + "ld1 {v26.4s}, [x10]\n" + "sqxtun v19.8b, v19.8h\n" + "sqxtun v25.8b, v25.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v19.8b}, [x7], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v25.8b}, [x7], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v19.4s}, [%[bias_ptr]]\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v25.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w14, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + // Handle last 2 columns if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v13.8b}, [x12]\n" + "add x12, x15, %[input_row_size]\n" + "smlal v23.4s, v0.4h, v11.4h\n" + "ld1 {v17.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v11.8h\n" + "ld1 {v18.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v1.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v1.8h, v12.8h\n" + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v23.4s, v2.4h, v13.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v24.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x15]\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "ld1 {v17.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v5.4h, v18.4h\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "smlal2 v24.4s, v5.8h, v18.8h\n" + "ld1 {v18.8b}, [x12]\n" + + "smlal v21.4s, v6.4h, v9.4h\n" + "smlal2 v22.4s, v6.8h, v9.8h\n" + "smlal v19.4s, v0.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v20.4s, v0.8h, v9.8h\n" + "ld1 {v9.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v6.4h, v11.4h\n" + "smlal2 v24.4s, v6.8h, v11.8h\n" + "smlal v21.4s, v7.4h, v10.4h\n" + "smlal2 v22.4s, v7.8h, v10.8h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal v19.4s, v1.4h, v10.4h\n" + "smlal2 v20.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v7.4h, v12.4h\n" + "smlal2 v24.4s, v7.8h, v12.8h\n" + "smlal v25.4s, v1.4h, v12.4h\n" + "smlal2 v26.4s, v1.8h, v12.8h\n" + "smlal v21.4s, v8.4h, v11.4h\n" + "smlal2 v22.4s, v8.8h, v11.8h\n" + "smlal v19.4s, v2.4h, v11.4h\n" + "smlal2 v20.4s, v2.8h, v11.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal v25.4s, v0.4h, v11.4h\n" + "smlal2 v26.4s, v0.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v8.4h, v13.4h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v13.8h\n" + "smlal v25.4s, v2.4h, v13.4h\n" + "smlal2 v26.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "dup v28.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v27.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v23.8b}, [x6]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + + "smlal v19.4s, v6.4h, v9.4h\n" + "smlal2 v20.4s, v6.8h, v9.8h\n" + "smlal v25.4s, v6.4h, v11.4h\n" + "smlal2 v26.4s, v6.8h, v11.8h\n" + "smlal v19.4s, v7.4h, v10.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v20.4s, v7.8h, v10.8h\n" + "smlal v25.4s, v7.4h, v12.4h\n" + "smlal2 v26.4s, v7.8h, v12.8h\n" + "smlal v19.4s, v8.4h, v11.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v20.4s, v8.8h, v11.8h\n" + "smlal v25.4s, v8.4h, v13.4h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal2 v26.4s, v8.8h, v13.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v19.4s, v3.4h, v14.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v20.4s, v3.8h, v14.8h\n" + "smlal v25.4s, v3.4h, v16.4h\n" + "smlal2 v26.4s, v3.8h, v16.8h\n" + "smlal v19.4s, v4.4h, v15.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v20.4s, v4.8h, v15.8h\n" + "smlal v25.4s, v4.4h, v17.4h\n" + "smlal2 v26.4s, v4.8h, v17.8h\n" + "smlal v19.4s, v5.4h, v16.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v20.4s, v5.8h, v16.8h\n" + "smlal v25.4s, v5.4h, v18.4h\n" + "smlal2 v26.4s, v5.8h, v18.8h\n" + + "dup v28.4s, w9\n" + "sqrdmulh v19.4s, v19.4s, v27.4s\n" + "sqrdmulh v20.4s, v20.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "sqrdmulh v26.4s, v26.4s, v27.4s\n" + "and v27.16b, v19.16b, v28.16b\n" + "and v29.16b, v20.16b, v28.16b\n" + "and v30.16b, v25.16b, v28.16b\n" + "and v31.16b, v26.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v19.4s, v19.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v20.4s, v20.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v25.4s, v25.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v26.4s, v26.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v19.4s, v19.4s, v28.4s\n" + "srshl v20.4s, v20.4s, v28.4s\n" + "srshl v25.4s, v25.4s, v28.4s\n" + "srshl v26.4s, v26.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v19.4s, v19.4s, v29.4s\n" + "add v20.4s, v20.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "add v26.4s, v26.4s, v29.4s\n" + "smax v19.4s, v19.4s, v30.4s\n" + "smax v20.4s, v20.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smin v19.4s, v19.4s, v31.4s\n" + "smin v20.4s, v20.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "sqxtn v19.4h, v19.4s\n" + "sqxtn v25.4h, v25.4s\n" + "sqxtn2 v19.8h, v20.4s\n" + "sqxtn2 v25.8h, v26.4s\n" + "sqxtun v19.8b, v19.8h\n" + "sqxtun v25.8b, v25.8h\n" + "st1 {v19.8b}, [x7], x5\n" + "st1 {v25.8b}, [x7]\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n" + + // Handle last column if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n" + // Registers v9, v10, v11, v14, v15, and v16 have already been loaded + // with the correct values at this point. This corresponds to the + // first two input rows of the top left output. Now load the last + // input row for this output. Once these inputs are no longer needed, + // load the input rows for the bottom left output. + "add x12, x15, %[input_row_size]\n" + "add x13, x12, %[input_row_size]\n" + + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v13.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v17.8b}, [x15]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x12]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "ld1 {v16.8b}, [x13]\n" + + "smlal v21.4s, v6.4h, v12.4h\n" + "smlal2 v22.4s, v6.8h, v12.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v7.4h, v13.4h\n" + "smlal2 v22.4s, v7.8h, v13.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v2.4h, v17.4h\n" + "smlal2 v24.4s, v2.8h, v17.8h\n" + + "dup v26.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "and v18.16b, v21.16b, v26.16b\n" + "and v19.16b, v22.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v21.4s, v21.4s, v18.4s\n" + "sqadd v22.4s, v22.4s, v19.4s\n" + "srshl v21.4s, v21.4s, v26.4s\n" + "srshl v22.4s, v22.4s, v26.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtun v21.8b, v21.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + + "smlal v23.4s, v3.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v24.4s, v3.8h, v9.8h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal v23.4s, v4.4h, v10.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v24.4s, v4.8h, v10.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v23.4s, v5.4h, v11.4h\n" + "smlal2 v24.4s, v5.8h, v11.8h\n" + + "smlal v23.4s, v6.4h, v14.4h\n" + "smlal2 v24.4s, v6.8h, v14.8h\n" + "smlal v23.4s, v7.4h, v15.4h\n" + "smlal2 v24.4s, v7.8h, v15.8h\n" + "smlal v23.4s, v8.4h, v16.4h\n" + "smlal2 v24.4s, v8.8h, v16.8h\n" + + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v18.16b, v23.16b, v26.16b\n" + "and v19.16b, v24.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v23.4s, v23.4s, v18.4s\n" + "sqadd v24.4s, v24.4s, v19.4s\n" + "srshl v23.4s, v23.4s, v26.4s\n" + "srshl v24.4s, v24.4s, v26.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v23.8b}, [x7]\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n" + "subs %w[output_window_height], %w[output_window_height], #2\n" + "add %[input_ptr], %[input_ptr], %[input_height_increment]\n" + "cmp %w[output_window_height], #2\n" + "add %[output_ptr], %[output_ptr], %[output_height_increment]\n" + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n" + "cmp %w[output_window_height], #1\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_1 ":\n" + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "add x13, x12, %[input_row_size]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x15, x13, %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "mov w14, %w[output_window_width]\n" + // The height 1 / width 2 loop loads an extra 1x1 output in anticipation + // for the next iteration. Make sure |output_window_width| is large + // enough to handle the additional load, otherwise jump to the + // appropriate label to handle smaller widths. + "cmp w14, #2\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "ld1 {v15.8b}, [x15], %[input_depth]\n" + "ld1 {v16.8b}, [x15], %[input_depth]\n" + "ld1 {v17.8b}, [x15], %[input_depth]\n" + + "uaddw v9.8h, v28.8h, v9.8b\n" + "ld1 {v24.4s}, [%[bias_ptr]]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "ld1 {v25.4s}, [x10]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v26.4s}, [%[bias_ptr]]\n" + "ld1 {v27.4s}, [x10]\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n" + "cmp w14, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n" + "smlal v24.4s, v0.4h, v9.4h\n" + "ld1 {v18.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "ld1 {v19.8b}, [x12]\n" + "smlal v26.4s, v0.4h, v11.4h\n" + "ld1 {v20.8b}, [x13], %[input_depth]\n" + "smlal2 v27.4s, v0.8h, v11.8h\n" + "ld1 {v21.8b}, [x13]\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "ld1 {v22.8b}, [x15], %[input_depth]\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "ld1 {v23.8b}, [x15]\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "subs w14, w14, #2\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "cmp w14, #3\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "mov x12, x11\n" + "smlal v26.4s, v3.4h, v14.4h\n" + "add x13, x12, %[input_row_size]\n" + "smlal2 v27.4s, v3.8h, v14.8h\n" + "add x15, x13, %[input_row_size]\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v26.4s, v6.4h, v17.4h\n" + "ld1 {v15.8b}, [x15], %[input_depth]\n" + "smlal2 v27.4s, v6.8h, v17.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "ld1 {v16.8b}, [x15], %[input_depth]\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + "ld1 {v17.8b}, [x15], %[input_depth]\n" + "uaddw v19.8h, v28.8h, v19.8b\n" + + "smlal v26.4s, v1.4h, v18.4h\n" + "uaddw v20.8h, v28.8h, v20.8b\n" + "smlal2 v27.4s, v1.8h, v18.8h\n" + "smlal v26.4s, v2.4h, v19.4h\n" + "uaddw v21.8h, v28.8h, v21.8b\n" + "smlal2 v27.4s, v2.8h, v19.8h\n" + "smlal v26.4s, v4.4h, v20.4h\n" + "smlal v26.4s, v5.4h, v21.4h\n" + "smlal2 v27.4s, v4.8h, v20.8h\n" + "uaddw v22.8h, v28.8h, v22.8b\n" + "smlal2 v27.4s, v5.8h, v21.8h\n" + "uaddw v23.8h, v28.8h, v23.8b\n" + "smlal v26.4s, v7.4h, v22.4h\n" + "smlal2 v27.4s, v7.8h, v22.8h\n" + "smlal v26.4s, v8.4h, v23.4h\n" + "smlal2 v27.4s, v8.8h, v23.8h\n" + + "dup v28.4s, w1\n" + "dup v29.4s, w9\n" + "sqrdmulh v24.4s, v24.4s, v28.4s\n" + "sqrdmulh v25.4s, v25.4s, v28.4s\n" + "sqrdmulh v26.4s, v26.4s, v28.4s\n" + "sqrdmulh v27.4s, v27.4s, v28.4s\n" + "dup v28.4s, w2\n" + "and v30.16b, v24.16b, v29.16b\n" + "and v31.16b, v25.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v24.4s, v24.4s, v30.4s\n" + "sqadd v25.4s, v25.4s, v31.4s\n" + "and v30.16b, v26.16b, v29.16b\n" + "and v31.16b, v27.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v26.4s, v26.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v27.4s, v27.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v24.4s, v24.4s, v29.4s\n" + "srshl v25.4s, v25.4s, v29.4s\n" + "srshl v26.4s, v26.4s, v29.4s\n" + "srshl v27.4s, v27.4s, v29.4s\n" + "add v24.4s, v24.4s, v28.4s\n" + "add v25.4s, v25.4s, v28.4s\n" + "add v26.4s, v26.4s, v28.4s\n" + "add v27.4s, v27.4s, v28.4s\n" + "dup v28.8h, w0\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smax v27.4s, v27.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "smin v27.4s, v27.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn v26.4h, v26.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "ld1 {v25.4s}, [x10]\n" + "sqxtn2 v26.8h, v27.4s\n" + "ld1 {v27.4s}, [x10]\n" + "sqxtun v24.8b, v24.8h\n" + "sqxtun v26.8b, v26.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v24.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v26.8b}, [x6], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v24.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "ld1 {v26.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w14, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + // Handle last two horizontal outputs if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n" + "smlal v24.4s, v0.4h, v9.4h\n" + "ld1 {v18.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "ld1 {v19.8b}, [x12]\n" + "smlal v26.4s, v0.4h, v11.4h\n" + "ld1 {v20.8b}, [x13], %[input_depth]\n" + "smlal2 v27.4s, v0.8h, v11.8h\n" + "ld1 {v21.8b}, [x13]\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "ld1 {v22.8b}, [x15], %[input_depth]\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "ld1 {v23.8b}, [x15]\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "smlal v26.4s, v3.4h, v14.4h\n" + "smlal2 v27.4s, v3.8h, v14.8h\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "smlal v26.4s, v6.4h, v17.4h\n" + "smlal2 v27.4s, v6.8h, v17.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + "uaddw v19.8h, v28.8h, v19.8b\n" + + "smlal v26.4s, v1.4h, v18.4h\n" + "uaddw v20.8h, v28.8h, v20.8b\n" + "smlal2 v27.4s, v1.8h, v18.8h\n" + "smlal v26.4s, v2.4h, v19.4h\n" + "uaddw v21.8h, v28.8h, v21.8b\n" + "smlal2 v27.4s, v2.8h, v19.8h\n" + "smlal v26.4s, v4.4h, v20.4h\n" + "smlal v26.4s, v5.4h, v21.4h\n" + "smlal2 v27.4s, v4.8h, v20.8h\n" + "uaddw v22.8h, v28.8h, v22.8b\n" + "smlal2 v27.4s, v5.8h, v21.8h\n" + "uaddw v23.8h, v28.8h, v23.8b\n" + "smlal v26.4s, v7.4h, v22.4h\n" + "smlal2 v27.4s, v7.8h, v22.8h\n" + "smlal v26.4s, v8.4h, v23.4h\n" + "smlal2 v27.4s, v8.8h, v23.8h\n" + + "dup v28.4s, w1\n" + "dup v29.4s, w9\n" + "sqrdmulh v24.4s, v24.4s, v28.4s\n" + "sqrdmulh v25.4s, v25.4s, v28.4s\n" + "sqrdmulh v26.4s, v26.4s, v28.4s\n" + "sqrdmulh v27.4s, v27.4s, v28.4s\n" + "dup v28.4s, w2\n" + "and v30.16b, v24.16b, v29.16b\n" + "and v31.16b, v25.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v24.4s, v24.4s, v30.4s\n" + "sqadd v25.4s, v25.4s, v31.4s\n" + "and v30.16b, v26.16b, v29.16b\n" + "and v31.16b, v27.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v26.4s, v26.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v27.4s, v27.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v24.4s, v24.4s, v29.4s\n" + "srshl v25.4s, v25.4s, v29.4s\n" + "srshl v26.4s, v26.4s, v29.4s\n" + "srshl v27.4s, v27.4s, v29.4s\n" + "add v24.4s, v24.4s, v28.4s\n" + "add v25.4s, v25.4s, v28.4s\n" + "add v26.4s, v26.4s, v28.4s\n" + "add v27.4s, v27.4s, v28.4s\n" + "dup v28.8h, w0\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smax v27.4s, v27.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "smin v27.4s, v27.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn v26.4h, v26.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "sqxtn2 v26.8h, v27.4s\n" + "sqxtun v24.8b, v24.8h\n" + "sqxtun v26.8b, v26.8h\n" + "st1 {v24.8b}, [x6], x5\n" + "st1 {v26.8b}, [x6]\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + // Handle bottom right output if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n" + "dup v26.4s, w9\n" + "dup v27.4s, w1\n" + "dup v29.4s, w2\n" + + "smlal v24.4s, v0.4h, v9.4h\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "and v18.16b, v24.16b, v26.16b\n" + "and v19.16b, v25.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v24.4s, v24.4s, v18.4s\n" + "sqadd v25.4s, v25.4s, v19.4s\n" + "srshl v24.4s, v24.4s, v26.4s\n" + "srshl v25.4s, v25.4s, v26.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "sqxtun v24.8b, v24.8h\n" + "st1 {v24.8b}, [x6]\n" + + DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), + [output_window_height] "+r"(output_window_height) + : + // Inputs. + [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size), + [input_depth] "r"(input_depth), + [output_window_width] "r"(output_window_width), + [input_width_increment] "r"(input_width_increment), + [input_height_increment] "r"(input_height_increment), + [output_height_increment] "r"(output_height_increment), + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x19", "x20"); +#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1 +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_END } }; -template <> -struct ConvKernel3x3FilterDepth8<4, 1, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - DotProductAndStore( - filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3, - input_4, input_5, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Third output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - - DotProductAndStore( - filter, input_3, input_4, input_5, input_6, input_7, input_8, input_0, - input_1, input_2, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Fourth output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1, acc_2, acc_3; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_2.low = vld1q_s32(bias_ptr); - acc_3.low = vld1q_s32(bias_ptr); - - bias_ptr += 4; - acc_0.high = vld1q_s32(bias_ptr); - acc_1.high = vld1q_s32(bias_ptr); - acc_2.high = vld1q_s32(bias_ptr); - acc_3.high = vld1q_s32(bias_ptr); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - // Add scope for input registers to help the compiler know that it is - // not needed. - { - // To process 2x2 outputs using a 3x3 filter at stride 2, we require - // 5x5 inputs. We load the first 5x2 inputs at a time. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9; - - const uint8* ptr = input_ptr; - - // Load inputs. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next inputs. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - // Moving onto the two bottom outputs. - acc_2 = MultiplyAccumulateRow(acc_2, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_2 = MultiplyAccumulateRow(acc_2, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load last input row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; +enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter }; - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - } - - acc_2 = MultiplyAccumulateRow(acc_2, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - } - - DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset, - output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - } -}; +template +struct DepthwiseConvPartial {}; template <> -struct ConvKernel3x3FilterDepth8<2, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - // Reuse 2x2 kernel twice. - ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth, - output_width); - - ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run( - input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr + 2 * output_depth, output_depth, output_width); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 1x1 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the 1x1 input and filter values. + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w10\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "cmp x11, #16\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w10, w10\n" + "dup v29.4s, w10\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w10\n" + "dup v25.8h, w9\n" + + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x11, x11, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x11, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v8", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", + // We use these general-purpose registers. + "x9", "x10", "x11"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<2, 1, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - DotProductAndStore( - filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3, - input_4, input_5, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x2 input and + // filter values. + + // Load input and filter values. + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr x9, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "cmp x15, #16\n" + "add x12, %[input_ptr], x15\n" + "add x13, %[input_ptr], x9\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "add x14, x13, x15\n" + "ld1 {v9.8b}, [x12], #8\n" + "ldr x6, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + + "add x9, %[filter_ptr], x15\n" + "ld1 {v10.8b}, [x13], #8\n" + "add x10, %[filter_ptr], x6\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "add x11, x10, x15\n" + "ld1 {v1.8b}, [x9], #8\n" + "ld1 {v2.8b}, [x10], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + // Load constants. + "ldr w6, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w7, w7\n" + "dup v29.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w7\n" + "dup v25.8h, w6\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x15, x15, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x15, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], #8\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "ld1 {v1.8b}, [x9], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], #8\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v2.8b}, [x10], #8\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + // We use these general-purpose registers. + "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<1, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_depth; - - ptr = input_ptr + 3 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - DotProductAndStore( - filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8, - input_6, input_7, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x3 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x3 input and + // filter values. + + // Load input and filter values. + "ldr x7, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x9, %[filter_ptr]\n" + "ldr x14, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + + "ld1 {v8.8b}, [x12], x7\n" + "add x10, x9, x14\n" + "ld1 {v9.8b}, [x12], x7\n" + "cmp x15, #16\n" + "ld1 {v10.8b}, [x12]\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13], x7\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x13], x7\n" + "ld1 {v13.8b}, [x13]\n" + + "ld1 {v0.8b}, [x9], x7\n" + "ld1 {v1.8b}, [x9], x7\n" + "ld1 {v2.8b}, [x9]\n" + "ld1 {v3.8b}, [x10], x7\n" + "ld1 {v4.8b}, [x10], x7\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x9, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x7\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x10, x9, x14\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], x7\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x12]\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x9], x7\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], x7\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x9], x7\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], x7\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9]\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x10], x7\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x7\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<1, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_depth; - - ptr = input_ptr + 3 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - DotProductAndStore( - filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8, - input_6, input_7, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Third output. - output_ptr += output_depth; - - ptr = input_ptr + 5 * input_depth; - temp_2 = vld1_u8(ptr); - temp_0 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_5 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_8 = vld1_u8(ptr); - temp_6 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - - DotProductAndStore( - filter, input_1, input_2, input_0, input_4, input_5, input_3, input_7, - input_8, input_6, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Fourth output. - output_ptr += output_depth; - - ptr = input_ptr + 7 * input_depth; - temp_1 = vld1_u8(ptr); - temp_2 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_7 = vld1_u8(ptr); - temp_8 = vld1_u8(ptr + input_depth); - - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 3x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 3x2 input and + // filter values. + + // Load input and filter values. + "ldr x6, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x7, %[filter_ptr]\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "add x14, x13, x11\n" + + "ld1 {v8.8b}, [x12], x6\n" + "add x9, x7, x5\n" + "ld1 {v9.8b}, [x12]\n" + "cmp x15, #16\n" + "add x10, x9, x5\n" + "ld1 {v10.8b}, [x13], x6\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13]\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x14], x6\n" + "ld1 {v13.8b}, [x14]\n" + + "ld1 {v0.8b}, [x7], x6\n" + "ld1 {v1.8b}, [x7]\n" + "ld1 {v2.8b}, [x9], x6\n" + "ld1 {v3.8b}, [x9]\n" + "ld1 {v4.8b}, [x10], x6\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add x14, x13, x11\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x7, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x6\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x9, x7, x5\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "add x10, x9, x5\n" + "ld1 {v9.8b}, [x12]\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], x6\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x7], x6\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13]\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x7]\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x14], x6\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9], x6\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x14]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x9]\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x6\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x5", "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; -template -struct ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - - uint8x8_t temp_0 = vld1_u8(input_ptr); - uint8x8_t temp_1 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_2 = vld1_u8(input_ptr + 2 * input_depth); - - input_ptr += input_row_size; - uint8x8_t temp_3 = vld1_u8(input_ptr); - uint8x8_t temp_4 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_5 = vld1_u8(input_ptr + 2 * input_depth); - - input_ptr += input_row_size; - uint8x8_t temp_6 = vld1_u8(input_ptr); - uint8x8_t temp_7 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_8 = vld1_u8(input_ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - } -}; - -inline void ShuffleInput(const uint8* input_ptr, int input_depth, - int input_width, int input_height, int output_depth, - int output_width, int output_height, - uint8* output_ptr) { - const int input_row_size = input_depth * input_width; - - for (int y = 0; y < output_height; y++) { +#undef OFFSET_INPUT_DEPTH +#undef OFFSET_INPUT_ROW_SIZE +#undef OFFSET_OUTPUT_DEPTH +#undef OFFSET_OUTPUT_ROW_SIZE +#undef OFFSET_INPUT_OFFSET +#undef OFFSET_OUTPUT_OFFSET +#undef OFFSET_FILTER_OFFSET +#undef OFFSET_OUTPUT_MULTIPLIER +#undef OFFSET_OUTPUT_ACTIVATION_MIN +#undef OFFSET_OUTPUT_ACTIVATION_MAX +#undef OFFSET_OUTPUT_SHIFT +#undef OFFSET_INPUT_WIDTH +#undef OFFSET_INPUT_HEIGHT +#undef OFFSET_OUTPUT_WIDTH +#undef OFFSET_OUTPUT_HEIGHT +#undef STR +#undef STR_UNEXPANDED + +// Copies a subset of the input designated by |input_ptr| into |output_ptr| +// with the specified output dimensions. Supports output depths of 64 only as +// this is the cache line size. +inline void ShuffleInput(const uint8* input_ptr, int64_t input_depth, + int32 input_width, int32 input_height, + int64_t output_depth, int32 output_width, + int32 output_height, uint8* output_ptr) { + const int64_t input_row_size = input_depth * input_width; + for (int32 y = 0; y < output_height; y++) { const uint8* ptr = input_ptr; - for (int x = 0; x < output_width; x++) { + for (int32 x = 0; x < output_width; x++) { memcpy(output_ptr, ptr, output_depth); output_ptr += output_depth; ptr += input_depth; @@ -3873,561 +2937,262 @@ inline void ShuffleInput(const uint8* input_ptr, int input_depth, } } -template -struct ConvRow3x3FilterDepth8 {}; - -template -struct ConvRow3x3FilterDepth8<1, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 1x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<1, 4, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * kFixedStrideWidth * input_depth; - output_data += 4 * output_depth; - } - - // 1x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } +// Calculates the input size depending on stride and output. +inline int32 get_shuffle_input_size(int32 stride, int32 output) { + return stride * (output - 1) + 3; +} - input_data += kFixedStrideWidth * input_depth; - output_data += output_depth; - } +// Indicates the input and output dimensions used when shuffling input +// activations. +struct ShuffleParams { + int32 output_width; + int32 output_height; + int32 input_width; + int32 input_height; + + ShuffleParams() = default; + ShuffleParams(int32 output_width, int32 output_height, int32 stride_width, + int32 stride_height) + : output_width(output_width) + , output_height(output_height) + , input_width(get_shuffle_input_size(stride_width, output_width)) + , input_height(get_shuffle_input_size(stride_height, output_height)) { } }; -template -struct ConvRow3x3FilterDepth8<2, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 2x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 4, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * kFixedStrideWidth * input_depth; - output_data += 4 * output_depth; - } - - // 2x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 2, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * kFixedStrideWidth * input_depth; - output_data += 2 * output_depth; - } - - // 2x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 1, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += kFixedStrideWidth * input_depth; - output_data += output_depth; +template +struct DepthwiseConvThroughDepth { + // Runs the DepthwiseConvWindow kernels through the depth dimension from + // |start_depth| to |end_depth|. Keep this not inlined to maintain a small + // binary size. We use a DepthwiseConvParams struct for read only params + // to minimize call overhead. + static __attribute__((noinline)) void Run(const uint8* input_ptr, + const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr, + int64_t start_depth, int64_t end_depth, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, const DepthwiseConvParams& params) { + for (; start_depth <= end_depth - 8; start_depth += 8) { + DepthwiseConvWindow<8, kStrideWidth, kStrideHeight>::Run( + input_ptr, filter_ptr, bias_ptr, output_ptr, input_depth, + input_row_size, output_window_height, output_window_width, ¶ms); + input_ptr += 8; + output_ptr += 8; + filter_ptr += 8; + bias_ptr += 8; } } }; -template <> -struct ConvRow3x3FilterDepth8<4, 1, 1> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 4x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } +template +struct DepthwiseConvMultiRow { + using ConvKernel = DepthwiseConvThroughDepth; - input_data += 4 * input_depth; - output_data += 4 * output_depth; - } - - // Handle the rest of the right side. - // 4x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 2, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * input_depth; - output_data += 2 * output_depth; - } - - // 4x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 1, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += input_depth; - output_data += output_depth; - } - } -}; - -template <> -struct ConvRow3x3FilterDepth8<4, 2, 2> { - // The buffer size of the shuffled input. - static inline constexpr int ShuffleWorkspaceSize() { return 64 * 9 * 9; } - - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, + static inline void Run(const uint8* input_data, int32 start_x, int32 end_x, + const uint8* filter_data, const int32* bias_data, + uint8* output_data, const DepthwiseConvParams& params, + const ShuffleParams& shuffle_params, uint8* shuffle_workspace) { - // Branch and cache misses increase substantially with stride 2 kernels. - // Adding prefetching reduces latency by as much as 2x. - const int i0 = 0; - const int i1 = input_depth; - const int i2 = 2 * input_depth; - const int i3 = 3 * input_depth; - const int i4 = 4 * input_depth; - const int i5 = 5 * input_depth; - const int i6 = 6 * input_depth; - const int i7 = 7 * input_depth; - const int i8 = 8 * input_depth; - -#define DEPTHWISECONV_PRELOAD_ROW(input_ptr, i) \ - preload_l1_keep(input_ptr + i * input_row_size + i0); \ - preload_l1_keep(input_ptr + i * input_row_size + i1); \ - preload_l1_keep(input_ptr + i * input_row_size + i2); \ - preload_l1_keep(input_ptr + i * input_row_size + i3); \ - preload_l1_keep(input_ptr + i * input_row_size + i4); \ - preload_l1_keep(input_ptr + i * input_row_size + i5); \ - preload_l1_keep(input_ptr + i * input_row_size + i6); \ - preload_l1_keep(input_ptr + i * input_row_size + i7); \ - preload_l1_keep(input_ptr + i * input_row_size + i8); - - int out_x = start_x; - // 4x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - int depth = 0; - for (; depth <= output_depth - 64; depth += 64) { - // Preload 9x9 input. - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8); - - // For a large input window (64x9x9) that is small enough to fit in L1 - // cache, copy the input into a separate buffer and run the kernel on - // this new buffer. This reduces the likelihood of cache misses when - // the kernel is loading input data. If this size is ever changed, - // update the ShuffleWorkspaceSize() function to return the new size. - ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 9, - 9, shuffle_workspace); - const uint8* shuffled_ptr = &shuffle_workspace[0]; - - for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run( - shuffled_ptr, 64, input_offset, 64 * 9, filter_ptr, filter_offset, - bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, - output_depth, output_width); - - shuffled_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; + TFLITE_DCHECK(shuffle_params.input_height == + get_shuffle_input_size(kStrideHeight, shuffle_params.output_height)); + TFLITE_DCHECK(shuffle_params.input_width == + get_shuffle_input_size(kStrideWidth, shuffle_params.output_width)); + TFLITE_DCHECK(64 * shuffle_params.input_width * shuffle_params.input_height + <= DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE); + + int32 out_x = start_x; + + // Run shuffling on inputs with sufficiently large depth and width. When + // these parameters are large enough, more time is taken to load inputs + // from memory. At this point, it becomes useful to prefetch and + // preshuffle the input data to maximize locality. + if (params.output_depth > 64 || + (params.output_depth <= 64 && params.input_width > 150)) { + for (; out_x <= (end_x - shuffle_params.output_width); + out_x += shuffle_params.output_width) { + const uint8* input_ptr = input_data; + const int32* bias_ptr = bias_data; + const uint8* filter_ptr = filter_data; + uint8* output_ptr = output_data; + int64_t depth = 0; + const int64_t shuffle_row_size = 64 * shuffle_params.input_width; + + for (; depth <= params.output_depth - 64; depth += 64) { + // Preload. + const uint8* h_ptr = input_ptr; + for (int32 i = 0; i < shuffle_params.input_height; i++) { + const uint8* ptr = h_ptr; + for (int32 j = 0; j < shuffle_params.input_width; j++) { + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += params.input_depth; + } + h_ptr += params.input_row_size; + } + + // For a large enough input, shuffle into buckets. + ShuffleInput(input_ptr, params.input_depth, params.input_width, + params.input_height, 64, shuffle_params.input_width, + shuffle_params.input_height, shuffle_workspace); + ConvKernel::Run(shuffle_workspace, filter_ptr, bias_ptr, output_ptr, + 0, 64, 64, shuffle_row_size, + shuffle_params.output_height, + shuffle_params.output_width, params); + input_ptr += 64; + output_ptr += 64; + filter_ptr += 64; + bias_ptr += 64; } - input_ptr += 64; - } - - // Preload 9x9 input one more time for the rest of the depth. - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8); - - for (; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * 2 * input_depth; - output_data += 4 * output_depth; - } - -#undef DEPTHWISECONV_PRELOAD_ROW - - // Handle the rest of the right side. - // 4x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; + // Preload. + const uint8* h_ptr = input_ptr; + for (int32 i = 0; i < shuffle_params.input_height; i++) { + const uint8* ptr = h_ptr; + for (int32 j = 0; j < shuffle_params.input_width; j++) { + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += params.input_depth; + } + h_ptr += params.input_row_size; + } - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); + // Handle leftover depth. + ConvKernel::Run(input_ptr, filter_ptr, bias_ptr, output_ptr, + depth, params.output_depth, params.input_depth, + params.input_row_size, shuffle_params.output_height, + shuffle_params.output_width, params); - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; + input_data += + shuffle_params.output_width * kStrideWidth * params.input_depth; + output_data += shuffle_params.output_width * params.output_depth; } - - input_data += 2 * 2 * input_depth; - output_data += 2 * output_depth; } - // 4x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 1, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * input_depth; - output_data += output_depth; + const int32 output_leftover_width = end_x - out_x; + if (output_leftover_width > 0) { + ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0, + params.output_depth, params.input_depth, + params.input_row_size, shuffle_params.output_height, + output_leftover_width, params); } } }; -template <> -struct ConvRow3x3FilterDepth8<8, 2, 2> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - // Reuse 4 row kernels twice. - ConvRow3x3FilterDepth8<4, 2, 2>::Run( - input_data, start_x, start_y, input_depth, input_width, input_height, - input_row_size, input_offset, filter_data, filter_offset, bias_data, - output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_depth, output_width, - shuffle_workspace); - - ConvRow3x3FilterDepth8<4, 2, 2>::Run( - input_data + 2 * 4 * input_row_size, start_x, start_y + 4, input_depth, - input_width, input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data + 4 * output_depth * output_width, output_depth, - output_width, shuffle_workspace); +// Processes the borders of the input for pad_width and pad_height = 1. +// Calls 4 asm kernels: +// * 1x1 input shape. +// * Corner edges. +// * Horizontal edges. +// * Vertical edges. +inline void DepthwiseConvHandlePadding(const uint8* input_data, + const uint8* filter_data, const int32* bias_data, uint8* output_data, + const DepthwiseConvParams& params) { + if (params.input_width == 1 && params.input_height == 1) { + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + DepthwiseConvPartial::Run(input_data, filter_ptr, + bias_data, output_data, ¶ms); + return; } -}; -template <> -struct ConvRow3x3FilterDepth8<8, 1, 1> { - // The buffer size of the shuffled input. - static inline constexpr int ShuffleWorkspaceSize() { return 64 * 10 * 10; } - - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - // 8x8 at a time. - for (; out_x <= output_width - 8; out_x += 8) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - int depth = 0; - for (; depth <= output_depth - 64; depth += 64) { - // For a large input window (64x10x10) that is small enough to fit in L1 - // cache, copy the input into a separate buffer and run the kernel on - // this new buffer. This reduces the likelihood of cache misses when - // the kernel is loading input data. If the size of the input window - // changes, update the function ShuffleWorkspaceSize() with the new - // size. - ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 10, - 10, shuffle_workspace); - const uint8* shuffled_ptr = shuffle_workspace; - - for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) { - ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run( - shuffled_ptr, 64, input_offset, 64 * 10, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - shuffled_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - input_ptr += 64; - } + const int32 out_x_start_corner = 0; + const int32 out_x_end_corner = params.output_width - 1; + const int32 out_y_start_corner = 0; + const int32 out_y_end_corner = params.output_height - 1; + + // Handle top row. + const uint8* input_ptr = input_data; + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + uint8* output_ptr = output_data; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width - 1) * params.input_depth; + filter_ptr = filter_data + params.filter_row_size; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } - for (; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); - input_data += 8 * input_depth; - output_data += 8 * output_depth; - } + // Handle left side. + input_ptr = input_data + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data + params.input_depth; + output_ptr = output_data + params.output_row_size; - // Handle the rest of the right side by re-using 4 row kernels twice. - ConvRow3x3FilterDepth8<4, 1, 1>::Run( - input_data, out_x, start_y, input_depth, input_width, input_height, - input_row_size, input_offset, filter_data, filter_offset, bias_data, - output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_depth, output_width, - shuffle_workspace); - - ConvRow3x3FilterDepth8<4, 1, 1>::Run( - input_data + 4 * input_row_size, out_x, start_y + 4, input_depth, - input_width, input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data + 4 * output_depth * output_width, output_depth, - output_width, shuffle_workspace); + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; } -}; -inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims, - const Dims<4>& filter_dims, - int stride_width, int stride_height, - int pad_width, int pad_height, - int depth_multiplier, - const Dims<4>& output_dims) { - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - - bool supported = filter_width == 3 && filter_height == 3 && - depth_multiplier == 1 && - (stride_width == 1 || stride_width == 2) && - (stride_height == 1 || stride_height == 2) && - (stride_width == stride_height) && pad_width == 0 && - pad_height == 0 && (input_depth % 8) == 0; + // Handle right side. + input_ptr = input_data + (params.input_width - 2) * params.input_depth + + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data; + output_ptr = output_data + params.output_row_size + + (params.output_width - 1) * params.output_depth; + + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; + } + + // Handle bottom row. + input_ptr = input_data + (params.input_height - 2) * params.input_row_size; + filter_ptr = filter_data + params.output_depth; + output_ptr = output_data + + (params.output_height - 1) * params.output_row_size; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width == 1) ? 0 : params.input_depth; + filter_ptr = filter_data; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); +} + +inline bool Fast3x3FilterKernelSupported( + const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width, + int32 stride_height, int32 pad_width, int32 pad_height, + int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) { + const int32 input_height = ArraySize(input_dims, 2); + const int32 input_width = ArraySize(input_dims, 1); + const int32 input_depth = ArraySize(input_dims, 0); + const int32 filter_height = ArraySize(filter_dims, 2); + const int32 filter_width = ArraySize(filter_dims, 1); + const int32 output_height = ArraySize(output_dims, 2); + const int32 output_width = ArraySize(output_dims, 1); + + bool supported = + filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && + (stride_width == 1 || stride_width == 2) && + (stride_height == 1 || stride_height == 2) && + (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) && + (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) && + (input_depth % 8) == 0 && (output_shift > 0); if (!supported) { return false; @@ -4436,145 +3201,194 @@ inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims, // Handle case where padding is zero but padding type is not kValid. // This would require special boundary case handling that is not supported. - const int out_x = output_width - 1; - const int out_y = output_height - 1; + const int32 out_x = output_width - 1; + const int32 out_y = output_height - 1; - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int32 in_x_origin = (out_x * stride_width) - pad_width; + const int32 in_y_origin = (out_y * stride_height) - pad_height; - const int in_x_end = in_x_origin + filter_width; - const int in_y_end = in_y_origin + filter_height; + const int32 in_x_end = in_x_origin + filter_width; + const int32 in_y_end = in_y_origin + filter_height; // Supported only if filter on the right and bottom boundary lies completely - // within the input. - return in_x_end <= input_width && in_y_end <= input_height; + // within the input if padding is zero. + if (pad_width == 0 && pad_height == 0) { + return in_x_end <= input_width && in_y_end <= input_height; + } + + // Else if padding is 1, supported if bottom right filter lies +1 past input + // width and height. + supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1); + + if (!supported) { + return false; + } + + // Shapes with width 1 and height > 1, and vice versa are not supported yet. + if (input_width == 1) { + supported = (input_width == input_height); + } else if (input_height == 1) { + supported = (input_width == input_height); + } + return supported; } inline void DepthwiseConv3x3Filter( const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride_width, - int stride_height, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - - // Algorithm assumes below constraints. It is optimized for depth multiplier - // of 1, 3x3 filter, no padding and strides 1 and 2. - TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width, + int32 stride_height, int32 pad_width, int32 pad_height, + int32 depth_multiplier, int32 output_offset, int32 output_multiplier, + int32 output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); + DepthwiseConvParams params; + params.input_depth = ArraySize(input_dims, 0); + params.input_width = ArraySize(input_dims, 1); + params.input_height = ArraySize(input_dims, 2); + params.input_row_size = params.input_depth * params.input_width; + params.input_offset = input_offset; + params.stride_width = stride_width; + params.stride_height = stride_height; + params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + params.output_width = ArraySize(output_dims, 1); + params.output_height = ArraySize(output_dims, 2); + params.output_row_size = params.output_depth * params.output_width; + params.output_offset = output_offset; + params.filter_offset = filter_offset; + params.output_multiplier = output_multiplier; + params.output_shift = output_shift; + params.output_activation_min = output_activation_min; + params.output_activation_max = output_activation_max; + + const int32 filter_height = ArraySize(filter_dims, 2); + const int32 filter_width = ArraySize(filter_dims, 1); + params.filter_row_size = params.output_depth * filter_width; + + // Algorithm assumes below constraints. It is optimized for depth + // multiplier of 1, 3x3 filter, no padding and strides 1 and 2. + TFLITE_DCHECK(params.output_depth == params.input_depth * depth_multiplier); TFLITE_DCHECK(depth_multiplier == 1); TFLITE_DCHECK(filter_height == 3); TFLITE_DCHECK(filter_width == 3); - TFLITE_DCHECK(pad_height == 0); - TFLITE_DCHECK(pad_width == 0); TFLITE_DCHECK(stride_height == 1 || stride_height == 2); TFLITE_DCHECK(stride_width == 1 || stride_width == 2); TFLITE_DCHECK(stride_width == stride_height); + TFLITE_DCHECK(pad_height == 0 || pad_height == 1); + TFLITE_DCHECK(pad_width == 0 || pad_width == 1); + TFLITE_DCHECK(pad_width == pad_height); + + const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int64_t input_batch_size = params.input_row_size * params.input_height; + const int64_t output_batch_size = + params.output_row_size * params.output_height; + + ShuffleParams one_row_shuffle_params, two_row_shuffle_params, + four_row_shuffle_params, eight_row_shuffle_params; + if (stride_width == 1) { + one_row_shuffle_params = ShuffleParams(30, 1, 1, 1); + two_row_shuffle_params = ShuffleParams(22, 2, 1, 1); + four_row_shuffle_params = ShuffleParams(14, 4, 1, 1); + eight_row_shuffle_params = ShuffleParams(8, 8, 1, 1); + } else { + one_row_shuffle_params = ShuffleParams(14, 1, 2, 2); + two_row_shuffle_params = ShuffleParams(8, 2, 2, 2); + four_row_shuffle_params = ShuffleParams(4, 4, 2, 2); + eight_row_shuffle_params = ShuffleParams(2, 8, 2, 2); + } - const int input_row_size = input_depth * (input_width + 2 * pad_width); - const int output_row_size = output_depth * output_width; - const int input_batch_size = input_row_size * (input_height + 2 * pad_height); - const int output_batch_size = output_depth * output_width * output_height; - - using conv_row_func_t = decltype(&ConvRow3x3FilterDepth8<1, 1, 1>::Run); - conv_row_func_t conv_1_output_row = ConvRow3x3FilterDepth8<1, 1, 1>::Run; - conv_row_func_t conv_2_output_rows = ConvRow3x3FilterDepth8<2, 1, 1>::Run; - conv_row_func_t conv_4_output_rows = ConvRow3x3FilterDepth8<4, 1, 1>::Run; - conv_row_func_t conv_8_output_rows = ConvRow3x3FilterDepth8<8, 1, 1>::Run; - + using conv_multirow_func_t = decltype(&DepthwiseConvMultiRow<1, 1>::Run); + conv_multirow_func_t conv_multirow_func = DepthwiseConvMultiRow<1, 1>::Run; if (stride_width == 2) { - conv_1_output_row = ConvRow3x3FilterDepth8<1, 2, 2>::Run; - conv_2_output_rows = ConvRow3x3FilterDepth8<2, 2, 2>::Run; - conv_4_output_rows = ConvRow3x3FilterDepth8<4, 2, 2>::Run; - conv_8_output_rows = ConvRow3x3FilterDepth8<8, 2, 2>::Run; + conv_multirow_func = DepthwiseConvMultiRow<2, 2>::Run; } // Allocate maximum memory needed for shuffled input. // TODO(mariewhite): The size of this workspace is small enough to be // allocated on the stack. Eventually we will want to move it to the heap - // and have it allocated outside of this function, like the im2col_array used - // in gemmlowp. -#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 + // and have it allocated outside of this function, like the im2col_array + // used in gemmlowp. uint8 shuffle_workspace[DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE]; - // Make sure the kernels using this buffer will not run out of bounds. - static_assert(ConvRow3x3FilterDepth8<8, 1, 1>::ShuffleWorkspaceSize() <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE, - "Shuffle workspace size is too small."); - static_assert(ConvRow3x3FilterDepth8<4, 2, 2>::ShuffleWorkspaceSize() <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE, - "Shuffle workspace size is too small."); - -#undef DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE - - for (int b = 0; b < batches; ++b) { + for (int32 b = 0; b < batches; ++b) { const uint8* input_ptr = input_data + b * input_batch_size; uint8* output_ptr = output_data + b * output_batch_size; - int out_y = 0; + int32 out_x = 0; + int32 out_y = 0; + int32 end_x = params.output_width; + int32 end_y = params.output_height; + + if (pad_width == 1 && pad_height == 1) { + DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr, + params); + + // Update extents now that the edges have been handled. + out_x = 1; + end_x = params.output_width - 1; + out_y = 1; + end_y = params.output_height - 1; + const int in_x = (out_x * stride_width) - pad_width; + const int in_y = (out_y * stride_height) - pad_height; + input_ptr += in_y * params.input_row_size + in_x * params.input_depth; + output_ptr += out_y * params.output_row_size + + out_x * params.output_depth; + } + + // Shuffling shapes that maximize width over the shuffle workspace size + // perform better since the inputs are closer together, minimizing + // shuffling time. + // + // If the input shape has width large enough for the 2 row kernels, + // we prefer to use this. The innermost loop of the kernels handle + // 2 height x 2 width so this is the fastest path. + // + // If the input shape has smaller width but larger height, shuffling is + // still useful and can benefit from kernels 4 row and 8 row kernels. // Handle 8 rows at a time. - for (; out_y <= output_height - 8; out_y += 8) { - conv_8_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 8 * stride_height * input_row_size; - output_ptr += 8 * output_row_size; + if (params.input_width < four_row_shuffle_params.input_width) { + for (; out_y <= end_y - 8; out_y += 8) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, eight_row_shuffle_params, + shuffle_workspace); + input_ptr += 8 * stride_height * params.input_row_size; + output_ptr += 8 * params.output_row_size; + } } // Handle 4 rows at a time. - for (; out_y <= output_height - 4; out_y += 4) { - conv_4_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 4 * stride_height * input_row_size; - output_ptr += 4 * output_row_size; + if (params.input_width < two_row_shuffle_params.input_width) { + for (; out_y <= end_y - 4; out_y += 4) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, four_row_shuffle_params, + shuffle_workspace); + input_ptr += 4 * stride_height * params.input_row_size; + output_ptr += 4 * params.output_row_size; + } } // Handle 2 rows at a time. - for (; out_y <= output_height - 2; out_y += 2) { - conv_2_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 2 * stride_height * input_row_size; - output_ptr += 2 * output_row_size; + for (; out_y <= end_y - 2; out_y += 2) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, two_row_shuffle_params, + shuffle_workspace); + input_ptr += 2 * stride_height * params.input_row_size; + output_ptr += 2 * params.output_row_size; } // Handle one row at a time. - for (; out_y < output_height; out_y++) { - conv_1_output_row(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += stride_height * input_row_size; - output_ptr += output_row_size; + for (; out_y < end_y; out_y++) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, one_row_shuffle_params, + shuffle_workspace); + input_ptr += stride_height * params.input_row_size; + output_ptr += params.output_row_size; } } } +// clang-format on #endif // __aarch64__ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..7816752132761d9523ffc1f45b3740c0817ed402 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -0,0 +1,324 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Unoptimized reference ops: +using reference_ops::Relu1; +using reference_ops::Relu6; + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); +} + +} // namespace optimized_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 08f7cfa5a5f9453cd187164078898e754126da52..5ba7e2af9b8f2beeee151e219997b68f5c7a6bce 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -162,7 +162,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int batch, row, col; for (batch = 0; batch < n_batch; ++batch) { - const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; + const float batch_scaling_factor = scaling_factors[batch]; // Copy the vector data to an aligned vector. memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols); // Compute dot-product for every column. @@ -232,7 +232,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int32 neon_sum = vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1); - *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv); + *result += ((neon_sum + postable_sum) * batch_scaling_factor); } // for row } // for batch @@ -352,6 +352,30 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) { } } +bool NeonIsZeroVector(const float* vector, int v_size) { + // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot + // use the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + const float32x4_t zero_x4_float = vmovq_n_f32(0.0f); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + const float32x4_t i_x4_float = vld1q_f32(vector + v); + uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float); + if (vgetq_lane_u32(cmp_result, 0) == 0) return false; + if (vgetq_lane_u32(cmp_result, 1) == 0) return false; + if (vgetq_lane_u32(cmp_result, 2) == 0) return false; + if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + } + + // Postamble loop + for (int v = postamble_start; v < v_size; ++v) { + if (vector[v] != 0.0) return false; + } + return true; +} + void NeonClipVector(const float* vector, int v_size, float abs_limit, float* result) { // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main @@ -394,13 +418,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, *scaling_factor = 1; return; } - *scaling_factor = kScale / range; + *scaling_factor = range / kScale; + const float scaling_factor_inv = 1.0f / *scaling_factor; const int postamble_start = size - (size & (2 * kFloatWeightsPerNeonLane - 1)); // Vectorized constants. - const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor); + const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv); const float32x4_t point5_f32x4 = vmovq_n_f32(0.5); const float32x4_t zero_f32x4 = vmovq_n_f32(0.0); const int32x4_t scale_i32x4 = vmovq_n_s32(kScale); @@ -452,7 +477,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, for (int i = postamble_start; i < size; ++i) { const int32 quantized_value = - static_cast(TfLiteRound(*scaling_factor * values[i])); + static_cast(TfLiteRound(scaling_factor_inv * values[i])); quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); } } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 9e60d0657b49ed5ee1f2e999acb86ebee0eab972..7a5a8fc54123946229963abd1720030d0bb358bf 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -100,6 +100,11 @@ void ZeroVector(float* vector, int v_size) { float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } +// Check if all entries of a vector are zero. +bool IsZeroVector(const float* vector, int v_size) { + return NEON_OR_PORTABLE(IsZeroVector, vector, v_size); +} + void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 5c67066311e62cac4ffbfb271368112be7534580..8597707b24325588b1b4dc4f4ac68ccfa9cecd36 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,16 +40,37 @@ namespace tflite { namespace optimized_ops { // Unoptimized reference ops: +using reference_ops::ArgMax; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; +using reference_ops::Concatenation; +using reference_ops::DepthConcatenation; +using reference_ops::Dequantize; +using reference_ops::Div; +using reference_ops::FakeQuant; +using reference_ops::Gather; using reference_ops::Greater; using reference_ops::GreaterEqual; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::Mean; using reference_ops::RankOneSelect; +using reference_ops::Relu1; +using reference_ops::Relu6; +using reference_ops::ReluX; using reference_ops::Select; +using reference_ops::SpaceToBatchND; +using reference_ops::StridedSlice; +using reference_ops::Transpose; + +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to @@ -65,6 +86,12 @@ using VectorMap = typename std::conditional< Eigen::Dynamic, 1>>, Eigen::Map>>::type; +template +VectorMap MapAsVector(Scalar* data, const RuntimeShape& shape) { + const int size = shape.FlatSize(); + return VectorMap(data, size, 1); +} + template VectorMap MapAsVector(Scalar* data, const Dims& dims) { const int size = FlatSize(dims); @@ -81,6 +108,23 @@ using MatrixMap = typename std::conditional< Eigen::Dynamic, Eigen::Dynamic>>, Eigen::Map>>::type; +template +MatrixMap MapAsMatrixWithLastDimAsRows(Scalar* data, + const RuntimeShape& shape) { + const int dims_count = shape.DimensionsCount(); + const int rows = shape.Dims(dims_count - 1); + const int cols = FlatSizeSkipDim(shape, dims_count - 1); + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithFirstDimAsCols(Scalar* data, + const RuntimeShape& shape) { + const int cols = shape.Dims(0); + const int rows = FlatSizeSkipDim(shape, 0); + return MatrixMap(data, rows, cols); +} + template MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, const Dims& dims) { @@ -127,19 +171,51 @@ template MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, const Dims& dims, int rows) { - int cols = 1; - bool matched_rows = false; - for (int d = 0; d < N; d++) { - cols *= dims.sizes[d]; - if (cols == rows) { - matched_rows = true; - cols = 1; - } - } - TFLITE_DCHECK(matched_rows); + const int flatsize = FlatSize(dims); + TFLITE_DCHECK((flatsize % rows) == 0); + const int cols = flatsize / rows; return MatrixMap(data, rows, cols); } +// This is like the template-parameter version, except that the power-of-two is +// passed as a function parameter. The template version is to be preferred, +// since some target hardware optimizations depend on the range of the exponent. +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// This is like the template-parameter version, except that the power-of-two is +// passed as a function parameter. See raw-integer version for further comments. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -1036,10 +1112,10 @@ struct GemmlowpOutputPipeline { gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; - static Pipeline Make(const int32* bias_data, int output_rows, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max) { + static Pipeline MakeExp(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_left_shift, int32 output_activation_min, + int32 output_activation_max) { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; @@ -1047,7 +1123,7 @@ struct GemmlowpOutputPipeline { quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = output_shift; + quantize_down_stage.result_shift = -output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -1100,8 +1176,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, batches, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, batches, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -1210,11 +1286,11 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } // Internal function doing the actual arithmetic work for -// ExperimentalShuffledFullyConnected. +// ShuffledFullyConnected. // May be called either directly by it (single-threaded case) or may be used // as the 'task' for worker threads to run (multi-threaded case, see -// ExperimentalShuffledFullyConnectedWorkerTask below). -inline void ExperimentalShuffledFullyConnectedWorkerImpl( +// ShuffledFullyConnectedWorkerTask below). +inline void ShuffledFullyConnectedWorkerImpl( const uint8* shuffled_input_workspace_data, const int8* shuffled_weights_data, int batches, int output_depth, int output_stride, int accum_depth, const int32* bias_data, @@ -1488,14 +1564,16 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl( #endif } -// Wraps ExperimentalShuffledFullyConnectedWorkerImpl into a Task class +// Wraps ShuffledFullyConnectedWorkerImpl into a Task class // to allow using gemmlowp's threadpool. -struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { - ExperimentalShuffledFullyConnectedWorkerTask( - const uint8* input_data, const int8* shuffled_weights_data, int batches, - int output_depth, int output_stride, int accum_depth, - const int32* bias_data, int32 output_multiplier, int output_shift, - int16* output_data) +struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task { + ShuffledFullyConnectedWorkerTask(const uint8* input_data, + const int8* shuffled_weights_data, + int batches, int output_depth, + int output_stride, int accum_depth, + const int32* bias_data, + int32 output_multiplier, int output_shift, + int16* output_data) : input_data_(input_data), shuffled_weights_data_(shuffled_weights_data), batches_(batches), @@ -1508,7 +1586,7 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { output_data_(output_data) {} void Run() override { - ExperimentalShuffledFullyConnectedWorkerImpl( + ShuffledFullyConnectedWorkerImpl( input_data_, shuffled_weights_data_, batches_, output_depth_, output_stride_, accum_depth_, bias_data_, output_multiplier_, output_shift_, output_data_); @@ -1526,15 +1604,14 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { int16* output_data_; }; -inline void ExperimentalShuffledFullyConnected( +inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_activation_max, int16* output_data, const Dims<4>& output_dims, uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { - gemmlowp::ScopedProfilingLabel label( - "ExperimentalShuffledFullyConnected/8bit"); + gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit"); (void)gemm_context; // only used in optimized code. TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_max, 32767); @@ -1618,7 +1695,7 @@ inline void ExperimentalShuffledFullyConnected( if (thread_count == 1) { // Single-thread case: do the computation on the current thread, don't // use a threadpool - ExperimentalShuffledFullyConnectedWorkerImpl( + ShuffledFullyConnectedWorkerImpl( shuffled_input_workspace_data, int8_shuffled_weights_data, batches, output_depth, output_depth, accum_depth, bias_data, output_multiplier, output_shift, output_data); @@ -1633,7 +1710,7 @@ inline void ExperimentalShuffledFullyConnected( int row_start = 0; for (int i = 0; i < thread_count; i++) { int row_end = std::min(output_depth, row_start + kRowsPerWorker); - tasks[i] = new ExperimentalShuffledFullyConnectedWorkerTask( + tasks[i] = new ShuffledFullyConnectedWorkerTask( shuffled_input_workspace_data, int8_shuffled_weights_data + row_start * accum_depth, batches, row_end - row_start, output_depth, accum_depth, bias_data + row_start, @@ -1730,6 +1807,100 @@ inline void ExtractPatchIntoBufferColumn( } } +template +void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 byte_zero, + T* im2col_data) { + // For dilated convolution, the input pixels are not contiguous therefore we + // can't use the same opitimizations as Im2Col(). Though note this code would + // work fine for the non-dilated case too (though likely a bit slower). + gemmlowp::ScopedProfilingLabel label("DilatedIm2col"); + TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 3); + + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); + + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Loop through the output rows (B x H x W) + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + // Each im2col row is an output pixel. Arrange the input data in this + // row in an order we can conveniently multiply with the filter data. + int row_offset = Offset(row_dims, out_x, out_y, batch, 0); + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Loop through all the pixels of the filter (Kh x Kw) + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + if ((in_y >= 0) && (in_y < input_height)) { + // Filter row is within the input data. + // Loop through all the filter pixels in this row. + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0); + T* dst = im2col_data + + Offset(im2col_dims, col_offset, row_offset, 0, 0); + if ((in_x >= 0) && (in_x < input_width)) { + // Filter pixel is within the input, copy the input data. + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + memcpy(dst, src, input_depth * sizeof(T)); + } else { + // Filter pixel is outside the input, zero it out. + memset(dst, byte_zero, input_depth * sizeof(T)); + } + } + } else { + // Filter row is outside the input, zero out the entire filter row. + int col_offset = Offset(col_dims, 0, 0, filter_y, 0); + T* dst = + im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); + memset(dst, byte_zero, filter_width * input_depth * sizeof(T)); + } + } + } + } + } +} + template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kheight, @@ -1770,74 +1941,6 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, kwidth, byte_zero, output_data, output_dims); } -inline void DilatedConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, - int dilation_width_factor, int dilation_height_factor, - int pad_width, int pad_height, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { - gemmlowp::ScopedProfilingLabel label("DilatedConv"); - // This is a copy of the reference Conv implementation. We do not currently - // have an optimized path for dilation. - (void)im2col_data; // only used in optimized code. - (void)im2col_dims; // only used in optimized code. - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); - if (bias_data) { - TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); - } - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; - float total = 0.f; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // If the location is outside the bounds of the input image, - // use zero as a default value. - if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, in_channel, filter_x, - filter_y, out_channel)]; - total += (input_value * filter_value); - } - } - } - } - float bias_value = 0.0f; - if (bias_data) { - bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; - } - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(total + bias_value, - output_activation_min, - output_activation_max); - } - } - } - } -} - inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -1846,29 +1949,32 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { - if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) { - return DilatedConv(input_data, input_dims, filter_data, filter_dims, - bias_data, bias_dims, stride_width, stride_height, - dilation_width_factor, dilation_height_factor, pad_width, - pad_height, output_activation_min, output_activation_max, - output_data, output_dims, im2col_data, im2col_dims); - } - (void)im2col_data; (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); + // NB: static_cast(0x00000000h) == 0.0f + const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; const int filter_width = ArraySize(filter_dims, 1); const int filter_height = ArraySize(filter_dims, 2); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; - if (need_im2col) { + if (need_dilated_im2col) { + DilatedIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, dilation_width_factor, dilation_height_factor, + pad_width, pad_height, output_dims, float_zero_byte, + im2col_data); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else if (need_im2col) { TFLITE_DCHECK(im2col_data); Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, 0, im2col_data, - im2col_dims); + pad_height, filter_height, filter_width, float_zero_byte, + im2col_data, im2col_dims); gemm_input_data = im2col_data; gemm_input_dims = &im2col_dims; } else { @@ -1979,11 +2085,23 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, } const int gemm_input_rows = gemm_input_dims->sizes[0]; - const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0); + // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784). + // The root cause has not yet been identified though. Same applies below for + // the other calls commented out. This is a partial rollback of cl/196819423. + // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0); + const int gemm_input_cols = gemm_input_dims->sizes[1] * + gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; const int filter_rows = filter_dims.sizes[3]; - const int filter_cols = FlatSizeSkipDim(filter_dims, 3); + // See b/79927784. + // const int filter_cols = FlatSizeSkipDim(filter_dims, 3); + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; const int output_rows = output_dims.sizes[0]; - const int output_cols = FlatSizeSkipDim(output_dims, 0); + // See b/79927784. + // const int output_cols = FlatSizeSkipDim(output_dims, 0); + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; TFLITE_DCHECK_EQ(output_rows, filter_rows); TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); @@ -1997,8 +2115,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, gemm_input_data, gemm_input_rows, gemm_input_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2155,8 +2273,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, output_cols, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2243,52 +2361,29 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); - const auto input = MapAsVector(input_data, input_dims); - auto output = MapAsVector(output_data, output_dims); + const auto input = MapAsVector(input_data, input_shape); + auto output = MapAsVector(output_data, output_shape); output = input.cwiseMax(0.0f); } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 1; - const float lower = -1; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 6; - const float lower = 0; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization"); static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { - const float val = input_data[depth * i + c]; + const float val = input_data[c]; squared_l2_norm += val * val; } const float l2_norm = std::sqrt(squared_l2_norm); @@ -2300,8 +2395,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -2343,34 +2439,41 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - TFLITE_DCHECK_EQ(outer_size, 1); - int32 square_l2_norm = 0; - for (int i = 0; i < depth; i++) { - int32 diff = input_data[i] - input_zero_point; - square_l2_norm += diff * diff; - } - int32 inv_l2norm_multiplier; - int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - - for (int i = 0; i < depth; i++) { - int32 diff = input_data[i] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[i] = static_cast(output_val); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + for (int i = 0; i < outer_size; ++i) { + int32 square_l2_norm = 0; + for (int c = 0; c < depth; c++) { + // Note that input_data advances by depth in the second pass below. + int32 diff = input_data[c] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int c = 0; c < depth; c++) { + int32 diff = *input_data - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + *output_data = static_cast(output_val); + ++input_data; + ++output_data; + } } } @@ -2506,14 +2609,19 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + - output_offset; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + + output_offset; const int32 clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = static_cast(clamped_output); @@ -2568,25 +2676,13 @@ inline void Add(int left_shift, const uint8* input1_data, output_activation_max, output_data); } -template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, const Dims<4>& input2_dims, int input2_shift, int16 output_activation_min, int16 output_activation_max, int16* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Add/Int16"); - // This is a copy of the reference implementation. We do not currently have a - // properly optimized version. - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); @@ -2612,6 +2708,42 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } +inline void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32 output_activation_min, int32 output_activation_max, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] + input2_data[i], output_activation_min, + output_activation_max); + } +} + +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, + input2_shift, output_activation_min, output_activation_max, output_data, + output_dims); +} + template void Add(const int32* input1_data, const Dims<4>& input1_dims, const int32* input2_data, const Dims<4>& input2_dims, @@ -2732,15 +2864,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -3081,9 +3215,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -3110,19 +3244,6 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } -// TODO(aselle): This is not actually optimized yet. -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - for (int i = 0; i < flat_size; i++) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] / input2_data[i], output_activation_min, - output_activation_max); - } -} - // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -3265,15 +3386,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -3286,105 +3409,6 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } -template -void Concatenation(int concat_dim, const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Concatenation"); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - // for now we dont have a model with a Concatenation - // with fused activation function. - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - Scalar* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - memcpy(output_ptr, input_data[i] + k * copy_size, - copy_size * sizeof(Scalar)); - output_ptr += copy_size; - } - } -} - -// TODO(prabhumk): This is the same as the reference implementation. -// TODO(prabhumk): The quantized implementation of concatentation isn't fully -// quantized as it takes scale as a floating point value. This should be fixed -// when optimizng this routine further. -inline void Concatenation(int concat_dim, const uint8* const* input_data, - const Dims<4>* const* input_dims, - const int32* input_zeropoint, - const float* input_scale, int inputs_count, - uint8* output_data, const Dims<4>& output_dims, - const int32 output_zeropoint, - const float output_scale) { - // The arguments input_zeropoint and input_scale are expected to be an array - // that have the quantization parameters for all the inputs to the concat - // operator. - gemmlowp::ScopedProfilingLabel label("Concatenation"); - TFLITE_DCHECK_GT(inputs_count, 1); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - const float inverse_output_scale = 1.f / output_scale; - uint8* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - const uint8* input_ptr = input_data[i] + k * copy_size; - if (input_zeropoint[i] == output_zeropoint && - input_scale[i] == output_scale) { - memcpy(output_ptr, input_ptr, copy_size); - } else { - const float scale = input_scale[i] * inverse_output_scale; - const float bias = -input_zeropoint[i] * scale; - for (int j = 0; j < copy_size; ++j) { - const int32_t value = - static_cast(round(input_ptr[j] * scale + bias)) + - output_zeropoint; - output_ptr[j] = - static_cast(std::max(std::min(255, value), 0)); - } - } - output_ptr += copy_size; - } - } -} - -template -void DepthConcatenation(const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - Concatenation(0, input_data, input_dims, inputs_count, - output_data, output_dims); -} - inline void LstmCell(const float* input_data, const Dims<4>& input_dims, const float* prev_activ_data, const Dims<4>& prev_activ_dims, const float* weights_data, @@ -3747,23 +3771,25 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, - float output_activation_min, +inline void AveragePool(const float* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("AveragePool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); // TODO(benoitjacob) make this a proper reference impl without Eigen! - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // TODO(benoitjacob) get rid of the dynamic memory allocation here! Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -3801,9 +3827,9 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, for (int y = 0; y < output_height; ++y) { for (int x = 0; x < output_width; ++x) { for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], + output_data[Offset(output_shape, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -3811,44 +3837,23 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const uint8* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -3868,11 +3873,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, uint16 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -3903,7 +3909,7 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ @@ -3944,54 +3950,23 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, +inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("MaxPool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Prefill the output to minimum representable float value out_mat.setConstant(std::numeric_limits::lowest()); for (int b = 0; b < batches; ++b) { @@ -4024,9 +3999,9 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, for (int y = 0; y < output_height; ++y) { for (int x = 0; x < output_width; ++x) { for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], + output_data[Offset(output_shape, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -4034,41 +4009,21 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -4086,11 +4041,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, uint8 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4116,7 +4072,7 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON for (; channel <= depth - 16; channel += 16) { @@ -4143,53 +4099,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, +inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Pool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); // Actually carry out L2 Pool. Code is written in forward mode: we go through // the input values once, and write to all the pooled regions that it maps to. - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); Eigen::VectorXf in_square(in_mat.rows()); Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -4231,28 +4157,6 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -4298,14 +4202,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Softmax"); - MatchingFlatSize(input_dims, output_dims); + MatchingFlatSize(input_shape, output_shape); - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Compute the exponential first, removing the max coefficient for numerical // stability. out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; @@ -4317,10 +4221,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, out_mat.array().rowwise() *= scale; } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -4334,8 +4238,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPoint0 = gemmlowp::FixedPoint; gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int b = 0; b < outer_size; ++b) { const uint8* input_data_ptr = input_data + b * depth; @@ -4525,11 +4432,14 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, // TODO(myenik): This is the same as the reference implementation, not actually // optimized yet. -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const float* block_input_data = input_data + i * depth; @@ -4556,12 +4466,125 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, } } +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); + // assert(__builtin_clz(0u) <= std::numeric_limits::digits); + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = __builtin_clz(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = __builtin_clz(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + // Currently just a copy of the reference code. -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8"); // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as @@ -4576,8 +4599,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const uint8* block_input_data = input_data + i * depth; @@ -4601,13 +4627,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } - // TODO(b/77858996): Implement fixed-point log(). - // Not a fully-quantized implementation: floating-point log(). - const float float_log_sum_of_exps = - std::log(static_cast(sum_of_exps.raw()) / - (1 << (31 - kAccumulationIntegerBits))); - const int32 fixed_log_sum_of_exps = static_cast(TfLiteRound( - float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits)))); + const int32 fixed_log_sum_of_exps = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps) + .raw(); // rescaled_diff_min is smallest representable in // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the @@ -4618,9 +4641,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = static_cast(block_input_data[c]) - max_in_row; @@ -4644,21 +4667,21 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Uint8"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; #ifdef USE_NEON @@ -4790,10 +4813,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { } @@ -4850,21 +4873,21 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().tanh(); } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // Note that this is almost the exact same code as in Logistic(). gemmlowp::ScopedProfilingLabel label("Tanh"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; int32_t output_zero_point = 128; @@ -5005,16 +5028,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh/Int16"); // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); int c = 0; const int16* input_data_ptr = input_data; @@ -5105,49 +5128,6 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims, } } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Dequantize"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - int32 val = input_data[i]; - float result = static_cast(scale * (val - zero_point)); - output_data[i] = result; - } -} - -inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, - float rmin, float rmax, int num_bits, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("FakeQuant"); - - // 0 should always be a representable value. Let's assume that the initial - // min,max range contains 0. - TFLITE_DCHECK_LE(rmin, 0.0f); - TFLITE_DCHECK_GE(rmax, 0.0f); - TFLITE_DCHECK_LT(rmin, rmax); - - // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor. - int quant_min = 0; - int quant_max = (1 << num_bits) - 1; - float nudged_min, nudged_max, nudged_scale; - NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, - &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } -} - template inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data, const Dims<4>& output_dims) { @@ -5165,26 +5145,6 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, output_map.array() = Eigen::floor(input_map.array()); } -template -inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, - const Dims<4>& coords_dims, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Gather"); - - TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); - int stride = input_dims.strides[input_rank - 1]; - T* out = output_data; - - for (int i = 0; i < coords_dims.sizes[0]; i++) { - TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); - const T* in = input_data + coords_data[i] * stride; - memcpy(out, in, sizeof(T) * stride); - out += stride; - } -} - #ifdef USE_NEON inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, float scale, float* output_ptr) { @@ -5513,6 +5473,46 @@ inline void ResizeBilinearGeneric(const float* input_data, } } +template +inline void ResizeBilinearGenericSmallChannel( + const T* input_data, const Dims<4>& input_dims, T* output_data, + const Dims<4>& output_dims, int32 batches, int32 input_height, + int32 input_width, int32 depth, int32 output_height, int32 output_width, + float height_scale, float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(T)); + + T* output_ptr = &output_data[0]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + + int32 input_offset[4] = { + Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b), + Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)}; + float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), + (1 - (input_y - y0)) * (input_x - x0), + (input_y - y0) * (1 - (input_x - x0)), + (input_y - y0) * (input_x - x0)}; + + for (int d = 0; d < depth; d++) { + const T* input_ptr = &input_data[d]; + *output_ptr++ = static_cast(input_ptr[input_offset[0]] * scale[0] + + input_ptr[input_offset[1]] * scale[1] + + input_ptr[input_offset[2]] * scale[2] + + input_ptr[input_offset[3]] * scale[3]); + } + } + } + } +} + inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, @@ -5553,6 +5553,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8 +// or int16 arithmetic. +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims, bool align_corners) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + float height_scale = + (align_corners && output_height > 1) + ? (static_cast(input_height - 1) / (output_height - 1)) + : (static_cast(input_height) / output_height); + + float width_scale = + (align_corners && output_width > 1) + ? (static_cast(input_width - 1) / (output_width - 1)) + : (static_cast(input_width) / output_width); + + ResizeBilinearGenericSmallChannel( + input_data, input_dims, output_data, output_dims, batches, input_height, + input_width, depth, output_height, output_width, height_scale, + width_scale); +} + // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, @@ -5562,53 +5597,13 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } -template -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, const Dims<4>& output_dims) { - // Unoptimized - Straight copy from reference ops. - gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); - - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); - const int block_shape_height = block_shape_data[0]; - const int block_shape_width = block_shape_data[1]; - const int padding_top = paddings_data[0]; - const int padding_left = paddings_data[2]; - - for (int out_b = 0; out_b < output_batch_size; ++out_b) { - int input_batch = out_b % input_batch_size; - int shift_w = (out_b / input_batch_size) % block_shape_width; - int shift_h = (out_b / input_batch_size) / block_shape_width; - for (int out_h = 0; out_h < output_height; ++out_h) { - for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height + shift_h < padding_top || - out_h * block_shape_height + shift_h >= - padding_top + input_height || - out_w * block_shape_width + shift_w < padding_left || - out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); - } else { - const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, - (out_h * block_shape_height + shift_h) - padding_top, - input_batch); - memcpy(out, in, depth * sizeof(T)); - } - } - } - } + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); } // Helper methods for BatchToSpaceND. @@ -5813,54 +5808,6 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, output_dims, 0); } -// UNOPTIMIZED COPY of StridedSlice from reference_ops.h. -template -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, - const std::vector& start_indices, - const std::vector& stop_indices, - const std::vector& strides, T* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); - - T* out_ptr = output_data; - for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { - for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { - for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; - } - } - } - } -} - template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, @@ -5895,41 +5842,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } -template -inline void Mean(const T* input_data, const Dims<4>& input_dims, - const std::vector& reduction_indices, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Mean"); - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); - - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - - // The current implementation only supports simultaneous reduction over - // width and height. - TFLITE_DCHECK_EQ(reduction_indices.size(), 2); - TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || - (reduction_indices[0] == 2 && reduction_indices[1] == 1)); - TFLITE_DCHECK_EQ(output_height, 1); - TFLITE_DCHECK_EQ(output_width, 1); - - for (int out_b = 0; out_b < output_batch; ++out_b) { - for (int out_d = 0; out_d < output_depth; ++out_d) { - float value = 0; - for (int in_h = 0; in_h < input_height; ++in_h) { - for (int in_w = 0; in_w < input_width; ++in_w) { - value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; - } - } - output_data[Offset(output_dims, out_d, 0, 0, out_b)] = - value / (input_width * input_height); - } - } -} - template void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, @@ -6009,130 +5921,84 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, output_map.array() = input1_map.array().max(max_value); } -template -void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("ArgMax"); - - // The current ArgMax implemention can only determine the index of the maximum - // value in the last dimension. So the axis argument is ignored. - - // For ArgMax, the number of output dimensions = (number of input dimensions - - // 1). For the sake of simplicity, the output dimensions are equal to the - // input dimensions here. We enforce the constraint that the last dimension - // must always be 1. - TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = ArraySize(input_dims, 0); - for (int i = 0; i < outer_size; ++i) { - auto max_value = *input_data; - ++input_data; - int max_index = 0; - for (int d = 1; d < depth; ++d) { - const auto& curr_value = *input_data; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; - } - ++input_data; - } - *output_data = max_index; - ++output_data; - } -} - template -void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, const int* permuted_axes) { - int out_sizes[4]; - // Compute the inverse permutation array so we can do an output centered - // transpose. Also, check to make sure output_dims is matching input_dims. - for (int k = 0; k < 4; k++) { - out_sizes[k] = - MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); - } - - // Naive transpose loop (iterate on output index and compute input index). - int o[4]; // loop index (on output). - int i[4]; - for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { - i[permuted_axes[3]] = o[3]; - for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { - i[permuted_axes[2]] = o[2]; - for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { - i[permuted_axes[1]] = o[1]; - for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { - i[permuted_axes[0]] = o[0]; - output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; - } - } - } - } -} +void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 zero_byte, + T* im2col_data) { + gemmlowp::ScopedProfilingLabel label("TransposeIm2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); -inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("TransposeConv"); - // THIS FUNCTION IS A COPY FROM reference_ops.h. - // To optimize, start by using the conv code with transposed weights for the - // case of stride_height = stride_width = 1. const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); const int filter_height = ArraySize(filter_dims, 2); const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); - - // Although transpose convolution simplifies to convolution with transposed - // weights for strides of 1, non-unitary striding complicates matters. To - // keep this reference implementation as clear as possible, we use a "scatter" - // access pattern, where we loop through all the input elements, computing - // their influence on the output, rather than looping through the output - // elements in the typical "gather" access pattern of a conv. We therefore - // must initialize the output array to zero. - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - 0.0f; - } - } - } - } - - // Loop through input elements one at a time. + MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth + + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); + + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Build the im2col matrix by looping through all the input pixels, + // computing their influence on the output, rather than looping through all + // the output pixels. We therefore must initialize the im2col array to zero. + // This is potentially inefficient because we subsequently overwrite bytes + // set here. However, in practice memset is very fast and costs negligible. + memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T)); + + // Loop through the output batches for (int batch = 0; batch < batches; ++batch) { + // Loop through input pixels one at a time. for (int in_y = 0; in_y < input_height; ++in_y) { for (int in_x = 0; in_x < input_width; ++in_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - // Loop through the output elements it will influence - const int out_x_origin = (in_x * stride_width) - pad_width; - const int out_y_origin = (in_y * stride_height) - pad_height; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + // Loop through the output pixels it will influence + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int out_y = out_y_origin + filter_y; + // Is output pixel within height bounds? + if ((out_y >= 0) && (out_y < output_height)) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int out_channel = 0; out_channel < output_depth; - ++out_channel) { - // Compute output element location - const int out_x = out_x_origin + filter_x; - const int out_y = out_y_origin + filter_y; - // We cannot accumulate out of bounds - if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && - (out_y < output_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; - output_data[Offset(output_dims, out_channel, out_x, out_y, - batch)] += input_value * filter_value; - } + const int out_x = out_x_origin + filter_x; + // Is output pixel within width bounds? + if ((out_x >= 0) && (out_x < output_width)) { + // Copy the input elements of this pixel + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + T* dst = im2col_data + + Offset(im2col_dims, + Offset(col_dims, 0, filter_x, filter_y, 0), + Offset(row_dims, out_x, out_y, batch, 0), 0, 0); + memcpy(dst, src, input_depth * sizeof(T)); } } } @@ -6142,6 +6008,31 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("TransposeConv"); + + // Note we could use transposed weights with forward conv for unstrided + // cases. But we are already getting good performance with this code as-is. + TFLITE_DCHECK(im2col_data); + TransposeIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, pad_width, pad_height, output_dims, 0, + im2col_data); + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index d570dadd86b4dc7c3abe341a4955320367330b9c..f14667090f5c3867c7992211272063239f3b92aa 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -127,6 +127,10 @@ void PortableZeroVector(float* vector, int v_size); // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); +// Check if all entries of a vector are zero. +bool PortableIsZeroVector(const float* vector, int v_size); +bool NeonIsZeroVector(const float* vector, int v_size); + // Symmetric quantizer. void PortableSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index b0951aac8cbb98a181d9dcaef88770fadfc74f62..e224980493aa11f642da103ee7d7377b6c4b1da0 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include #include #include @@ -48,15 +49,15 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, TFLITE_CHECK_GE(*left_shift, 0); } -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { TFLITE_CHECK_LT(double_multiplier, 1.); TFLITE_CHECK_GT(double_multiplier, 0.); int shift; QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); TFLITE_CHECK_LE(shift, 0); - *right_shift = -shift; + *left_shift = shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, @@ -78,20 +79,21 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, quantized_multiplier, left_shift); } -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift) { +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift) { PreprocessSoftmaxScaling(beta, input_scale, input_integer_bits, quantized_multiplier, left_shift); // Also calculate what amounts to the inverse scaling factor for the input. const double real_reverse_scaling_divisor = (1 << (31 - *left_shift)) / static_cast(*quantized_multiplier); - tflite::QuantizeMultiplierSmallerThanOne(real_reverse_scaling_divisor, - reverse_scaling_divisor, - reverse_scaling_right_shift); + tflite::QuantizeMultiplierSmallerThanOneExp(real_reverse_scaling_divisor, + reverse_scaling_divisor, + reverse_scaling_left_shift); } int CalculateInputRadius(int input_integer_bits, int input_left_shift) { @@ -125,4 +127,16 @@ void NudgeQuantizationRange(const float min, const float max, *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); } +bool CheckedLog2(const float x, int* log2_result) { + // Using TfLiteRound instead of std::round and std::log instead of + // std::log2 to work around these fuctions being missing in a toolchain + // used in some TensorFlow tests as of May 2018. + const float x_log2 = std::log(x) * (1.0f / std::log(2.0f)); + const float x_log2_rounded = TfLiteRound(x_log2); + const float x_log2_fracpart = x_log2 - x_log2_rounded; + + *log2_result = static_cast(x_log2_rounded); + return std::abs(x_log2_fracpart) < 1e-3; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 4a217515f142b2451ebd61e423871b95cdc09748..525857a2e6f73276d0a6e64770947169033c7667 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -167,9 +167,9 @@ IntOut SafeCast(FloatIn x) { // this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift); +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of its exponent. @@ -197,11 +197,12 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, int32_t* quantized_multiplier, int* left_shift); // Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated. -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift); +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift); // Calculate the largest input that will result in a within-bounds intermediate // result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, // it must not overflow before we reduce the value by multiplication by the @@ -217,6 +218,11 @@ void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, float* nudged_min, float* nudged_max, float* scale); +// If x is approximately a power of two (with any positive or negative +// exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise +// returns false. +bool CheckedLog2(const float x, int* log2_result); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 2d74b3d3849812a2dc95fabcd680aa280c99ca55..94773b47d3817d7ed7240f74545ad04e7fa4bd52 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -196,21 +196,21 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); } -TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) { auto quantize = [](double d) { int32_t q; int s; - QuantizeMultiplierSmallerThanOne(d, &q, &s); + QuantizeMultiplierSmallerThanOneExp(d, &q, &s); return std::pair{q, s}; }; EXPECT_DEATH(quantize(-0.1), ""); EXPECT_DEATH(quantize(0.0), ""); - EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, -1)); // Around 0.5 we can see the change in exponent and how we try hard to // void hitting max int32. - EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1)); EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h index e9b6baeaee87d22aef238410bc9f447509a81c47..d57739279f44ad2a9fff2bd4dca21047b8147f2a 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -76,8 +76,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + -output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..878b2441b4f2828a014673f5bd80fb8aa29514db --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu1(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu6(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); +} + +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index 2607adc0c18aeaa8dc2061e0e95a307205700a08..ccf112c990f3b5cba755a9b29aadd5aa82104849 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -29,9 +29,18 @@ float PortableClip(float f, float abs_limit) { return result; } +bool PortableIsZeroVector(const float* vector, int v_size) { + for (int i = 0; i < v_size; ++i) { + if (*vector++ != 0.0f) return false; + } + return true; +} + void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, - float* max, float* scaling_factor) { + int8_t* quantized_values, + float* __restrict__ min, + float* __restrict__ max, + float* __restrict__ scaling_factor) { auto minmax = std::minmax_element(values, values + size); *min = *minmax.first; *max = *minmax.second; @@ -42,10 +51,11 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size, *scaling_factor = 1; return; } - *scaling_factor = kScale / range; + *scaling_factor = range / kScale; + const float scaling_factor_inv = 1.0f / *scaling_factor; for (int i = 0; i < size; ++i) { const int32_t quantized_value = - static_cast(TfLiteRound(*scaling_factor * values[i])); + static_cast(TfLiteRound(values[i] * scaling_factor_inv)); // Clamp: just in case some odd numeric offset. quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); } @@ -71,13 +81,14 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, int result_stride) { + const int8_t* __restrict__ vectors, + const float* __restrict__ scaling_factors, int n_batch, + float* __restrict__ result, int result_stride) { int batch, row, col; for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { - const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; + const float batch_scaling_factor = scaling_factors[batch]; // Get the address of the first row. - int8_t* row_ptr = (int8_t*)matrix; // NOLINT + const int8_t* row_ptr = matrix; for (row = 0; row < m_rows; ++row, result += result_stride) { // Initialize the dot product sum for the row to 0. int32_t dotprod = 0; @@ -88,7 +99,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate( for (col = 0; col < m_cols; ++col, ++row_ptr) { dotprod += (*row_ptr) * (vectors[col]); } // for col - *result += (dotprod * batch_scaling_factor_inv); + *result += (dotprod * batch_scaling_factor); } // for row } // for batch } diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index 1757a9f5e5299401fb9fef1d38870bd4e63ba3c1..d2e1fecd25cf3d11d3daffcc566dc1d5df97128c 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -25,6 +25,8 @@ namespace tensor_utils { // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); +bool PortableIsZeroVector(const float* vector, int v_size); + void PortableSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor); @@ -112,6 +114,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } +bool IsZeroVector(const float* vector, int v_size) { + return PortableIsZeroVector(vector, v_size); +} + void SymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 36a92a21a03044624e348f0ba1066d2d2ec14e8f..9357e7407eb83fe8ea3486dfdde8742fc6323ee9 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -33,8 +33,131 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { + +// TODO(b/77858996): Add these to gemmlowp. +template +IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + sum))); +} + +template +gemmlowp::FixedPoint SaturatingAddNonGemmlowp( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingAddNonGemmlowp(a.raw(), b.raw())); +} + +template +IntegerType SaturatingSub(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t diff = a32 - b32; + return static_cast(std::min(32767, std::max(-32768, diff))); +} + +template <> +inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t diff = a64 - b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + diff))); +} + +template +gemmlowp::FixedPoint SaturatingSub( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingSub(a.raw(), b.raw())); +} +// End section to be moved to gemmlowp. + namespace reference_ops { +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + if (integer_input == 0) { + return std::numeric_limits::digits; + } + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -291,8 +414,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -515,8 +638,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, - output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -574,7 +697,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } } -inline void ExperimentalShuffledFullyConnected( +inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, @@ -791,9 +914,9 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float lower = 0; @@ -802,9 +925,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu1(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 1; @@ -814,9 +938,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu6(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 6; @@ -826,12 +951,28 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, } } +inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, + const RuntimeShape& input_shape, uint8* output_data, + const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const uint8 val = input_data[i]; + const uint8 clamped = + val > max_value ? max_value : val < min_value ? min_value : val; + output_data[i] = clamped; + } +} + template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -845,8 +986,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -888,39 +1030,45 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - TFLITE_DCHECK_EQ(outer_size, 1); - int32 square_l2_norm = 0; - for (int i = 0; i < depth; i++) { - int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; - square_l2_norm += diff * diff; - } - int32 inv_l2norm_multiplier; - int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - - for (int i = 0; i < depth; i++) { - int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[Offset(output_dims, i, 0, 0, 0)] = - static_cast(output_val); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + for (int i = 0; i < outer_size; ++i) { + int32 square_l2_norm = 0; + for (int c = 0; c < depth; c++) { + int32 diff = input_data[depth * i + c] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int c = 0; c < depth; c++) { + int32 diff = input_data[depth * i + c] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[depth * i + c] = static_cast(output_val); + } } } -inline void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +template +inline void Add(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( @@ -979,15 +1127,17 @@ inline void Add(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1000,22 +1150,12 @@ inline void Add(int left_shift, const uint8* input1_data, } } -template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, const Dims<4>& input2_dims, int input2_shift, int16 output_activation_min, int16 output_activation_max, int16* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); @@ -1041,6 +1181,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, + input2_shift, output_activation_min, output_activation_max, output_data, + output_dims); +} + // TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -1133,15 +1295,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1186,15 +1350,17 @@ inline void BroadcastAddFivefold( const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1374,9 +1540,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -1590,15 +1756,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1615,7 +1783,6 @@ template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, Scalar* output_data, const Dims<4>& output_dims) { - TFLITE_DCHECK_GT(inputs_count, 1); int concat_size = 0; for (int i = 0; i < inputs_count; i++) { for (int j = 0; j < 4; j++) { @@ -1626,7 +1793,9 @@ void Concatenation(int concat_dim, const Scalar* const* input_data, concat_size += ArraySize(*input_dims[i], concat_dim); } TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // For now we don't have a model with a Concatenation with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); int outer_size = 1; for (int i = concat_dim + 1; i < 4; i++) { outer_size *= output_dims.sizes[i]; @@ -1793,7 +1962,7 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, // The quantization of the input, output arrays is as follows: // - The input activations are quantized as uint8 on the interval // [-1, 127/128]. -// The rationale for that is that that is the natural interval for output +// The rationale for that is that is the natural interval for output // activations (see next point) and these need to be concatenated together. // We could accommodate different ranges by re-scaling, but we empirically // found that setting the input activations range to be [-1, 127/128] in the @@ -1858,7 +2027,7 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, // However, for a fixed-point implementation in 16-bit integers, using 5 // integer bits to represent the [-16, 16] range would leave only 11 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive -// representable values. Notice that that is higher than the +// representable values. Notice that is higher than the // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic. // Using [-8, 8] thus seems like the better compromise overall, enjoying // an increment of 2.4e-4 between representable values and a worst-case @@ -2104,18 +2273,21 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const float* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float output_activation_min, float output_activation_max, float* output_data, - const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2139,12 +2311,12 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; total += - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } const float average = total / filter_count; - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(average, output_activation_min, output_activation_max); } @@ -2153,42 +2325,22 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const uint8* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2211,14 +2363,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, ++filter_x) { const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; - acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + acc += + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } acc = (acc + filter_count / 2) / filter_count; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(acc); } } @@ -2226,50 +2379,19 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, +inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + float* output_data, const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2293,13 +2415,13 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; const float val = - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; sum_squares += val * val; filter_count++; } } const float l2pool_result = std::sqrt(sum_squares / filter_count); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(l2pool_result, output_activation_min, output_activation_max); } @@ -2308,40 +2430,19 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, +inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + float* output_data, const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2365,10 +2466,10 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(max, output_activation_min, output_activation_max); } @@ -2377,42 +2478,22 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LE(output_activation_max, 255); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2436,12 +2517,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } max = std::max(max, output_activation_min); max = std::min(max, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(max); } } @@ -2449,38 +2530,6 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -2504,11 +2553,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2533,10 +2585,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, } } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2549,8 +2601,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2611,10 +2666,13 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2639,11 +2697,126 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, } } -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +// Although currently the name of this function says that it cannot handle +// values less than 1, in practice it can handle as low as 1/x_max, where +// x_max is the largest representable input. In other words, the output range +// is symmetric. +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2657,8 +2830,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2681,13 +2857,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } - // TODO(b/77858996): Implement fixed-point log(). - // Not a fully-quantized implementation: floating-point log(). - const float float_log_sum_of_exps = - std::log(static_cast(sum_of_exps.raw()) / - (1 << (31 - kAccumulationIntegerBits))); - const int32 fixed_log_sum_of_exps = static_cast(TfLiteRound( - float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits)))); + const int32 fixed_log_sum_of_exps = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps) + .raw(); // rescaled_diff_min is smallest representable in // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the @@ -2698,9 +2871,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = @@ -2725,9 +2898,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -2736,11 +2909,11 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); + uint8* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -2774,9 +2947,9 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -2792,9 +2965,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -2803,12 +2976,12 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { const int32 output_zero_point = 128; - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -2843,15 +3016,15 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); // F0 uses 0 integer bits, range [-1, 1]. // This is the return type of math functions such as tanh, logistic, @@ -2956,9 +3129,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, } } -inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, +template +inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, const int32* output_size_data, - const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_size_dims, T* output_data, const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -2990,15 +3164,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 x0 = static_cast(std::floor(input_x)); int32 x1 = std::min(x0 + 1, input_width - 1); for (int c = 0; c < depth; ++c) { - float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * - (1 - (input_y - y0)) * - (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x0, y1, b)] * - (input_y - y0) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x1, y0, b)] * - (1 - (input_y - y0)) * (input_x - x0) + - input_data[Offset(input_dims, c, x1, y1, b)] * - (input_y - y0) * (input_x - x0); + T interpolation = + static_cast(input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0)); output_data[Offset(output_dims, c, x, y, b)] = interpolation; } } @@ -3011,8 +3185,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { - ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, - output_data, output_dims, /*align_corners=*/false); + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); +} + +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); } template @@ -3172,7 +3356,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, + int begin_mask, int end_mask, int shrink_axis_mask, const std::vector& start_indices, const std::vector& stop_indices, const std::vector& strides, T* output_data, @@ -3184,20 +3368,24 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, TFLITE_DCHECK_EQ(strides.size(), 4); const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); + const int stop_b = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 3, start_b); const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); + const int stop_h = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 2, start_h); const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); + const int stop_w = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 1, start_w); const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); + const int stop_d = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 0, start_d); T* out_ptr = output_data; for (int in_b = start_b; @@ -3259,63 +3447,152 @@ inline void Exp(const T* input_data, const size_t num_elements, } } +// A generic reduce method that can be used for reduce_sum, reduce_mean, etc. +// This method iterates through input data and reduce elements along the +// dimensions given in axis. +template +inline bool Reduce(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out reducer(Out current, const In in), Out* output_data) { + // Reset input iterator. + TFLITE_DCHECK(input_num_dims > 0); + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = + ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); + size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, + input_iter, num_axis, axis); + output_data[output_offset] = + reducer(output_data[output_offset], input_data[input_offset]); + } while (NextIndex(input_num_dims, input_dims, input_iter)); + return true; +} + +inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis, + int* out_axis, int* out_num_axis) { + *out_num_axis = 0; // Just in case. + // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 + for (int idx = 0; idx < num_axis; ++idx) { + // Handle negative index. + int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; + TFLITE_DCHECK(current >= 0 && current < num_dims); + bool is_dup = false; + for (int j = 0; j < *out_num_axis; ++j) { + if (out_axis[j] == current) { + is_dup = true; + break; + } + } + if (!is_dup) { + out_axis[*out_num_axis] = current; + *out_num_axis += 1; + } + } + return true; +} + +// This method expects that output_data has been initialized. +template +inline bool ReduceSumImpl(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out* output_data) { + auto reducer = [](Out current, const In in) -> Out { + const Out actual_in = static_cast(in); + return current + actual_in; + }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, axis, num_axis, input_iter, reducer, + output_data); +} + +// Computes the sum of elements across dimensions given in axis. +template +inline bool Sum(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis) { + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = T(); + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + return ReduceSumImpl(input_data, input_dims, output_dims, + input_num_dims, output_num_dims, resolved_axis, + num_resolved_axis, temp_index, output_data); +} + +// Computes the mean of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis. template inline bool Mean(const T* input_data, const int* input_dims, const int input_num_dims, T* output_data, const int* output_dims, const int output_num_dims, const int* axis, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum) { - // resets output data. + // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { - num_outputs *= static_cast(output_dims[idx]); + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; } for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = T(); temp_sum[idx] = U(); } - // resets temp index. - for (int idx = 0; idx < input_num_dims; ++idx) { - temp_index[idx] = 0; - } - // resolves axis. + + // Resolve axis. int num_resolved_axis = 0; - for (int idx = 0; idx < num_axis_dimensions; ++idx) { - int current = axis[idx]; - TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0); - if (current < 0) { - current += input_num_dims; - } - bool is_dup = false; - for (int j = 0; j < num_resolved_axis; ++j) { - if (resolved_axis[j] == current) { - is_dup = true; - break; - } - } - if (!is_dup) { - resolved_axis[num_resolved_axis++] = current; - } + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; } - // iterates through input_data. - for (bool has_next = true; has_next; - has_next = NextIndex(input_num_dims, input_dims, temp_index)) { - size_t input_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr); - size_t output_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, - num_resolved_axis, resolved_axis); - temp_sum[output_offset] += static_cast(input_data[input_offset]); - } - // takes average by num of elements added to get mean. - size_t num_elements_in_axis = 1; + + if (!ReduceSumImpl(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, temp_sum)) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + U num_elements_in_axis = 1; for (int idx = 0; idx < num_resolved_axis; ++idx) { size_t current = static_cast(input_dims[resolved_axis[idx]]); + // Overflow prevention. if (current > (std::numeric_limits::max() / num_elements_in_axis)) { return false; } num_elements_in_axis *= current; } + if (num_elements_in_axis > 0) { for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = @@ -3470,7 +3747,7 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, template void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, int* permuted_axes) { + const Dims<4>& output_dims, const int* permuted_axes) { int out_sizes[4]; // Compute the inverse permutation array so we can do an output centered // transpose. Also, check to make sure output_dims is matching input_dims. @@ -3501,10 +3778,11 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, int stride_width, int stride_height, int pad_width, int pad_height, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, float* /*im2col_data*/, + const Dims<4>& /*im2col_dims*/) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -3519,7 +3797,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // computing their influence on the output, rather than looping through the // output elements in the typical "gather" access pattern of a conv. We // therefore must initialize the output array to zero. - for (int i = 0; i < FlatSize(output_dims); i++) { + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; i++) { output_data[i] = 0.0f; } @@ -3544,8 +3823,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } @@ -3558,6 +3837,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +template +inline bool EqualFn(T lhs, T rhs) { + return lhs == rhs; +} + +template +inline bool NotEqualFn(T lhs, T rhs) { + return lhs != rhs; +} + template inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; @@ -3604,10 +3893,14 @@ inline void Comparison(int left_shift, const T* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[i] = F(scaled_input1_val, scaled_input2_val); } } @@ -3656,11 +3949,13 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[Offset(output_dims, c, x, y, b)] = F(scaled_input1_val, scaled_input2_val); } @@ -3715,6 +4010,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, input2_offset, input2_multiplier, \ input2_shift, output_data, output_dims); \ } +TFLITE_COMPARISON_OP(Equal); +TFLITE_COMPARISON_OP(NotEqual); TFLITE_COMPARISON_OP(Greater); TFLITE_COMPARISON_OP(GreaterEqual); TFLITE_COMPARISON_OP(Less); @@ -3754,6 +4051,72 @@ inline void RankOneSelect(const D* input_condition_data, } } +// For easy implementation, the indices is always a vector of size-4 vectors. +template +inline void SparseToDense(const std::vector>& indices, + const T* values, T default_value, T* output_data, + const Dims<4>& output_dims, bool value_is_scalar) { + const int value_count = indices.size(); + + // First fill the output_data with default value. + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; ++i) { + output_data[i] = default_value; + } + + // Special handle for value is scalar case to avoid checking the boolean + // condition within the loop every time. + if (value_is_scalar) { + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = *values; // just use the first value. + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } + return; + } + + // Go through the values and indices to fill the sparse values. + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = values[i]; + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } +} + +template +inline void Pow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = std::pow(input1_data[i], input2_data[i]); + } +} + +template +inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d8765f11b2941ef5871c7db8e3582e506713aa6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace { +template +void TestOneResizeBilinear(int batch, int depth, int input_width, + int input_height, int output_width, + int output_height, float error_threshold) { + Dims<4> input_dims_inference = + MakeDimsForInference(depth, input_width, input_height, batch); + Dims<4> output_dims_inference = + MakeDimsForInference(depth, output_width, output_height, batch); + + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int output_buffer_size = + RequiredBufferSizeForDims(output_dims_inference); + + std::vector input_data(input_buffer_size, 0); + std::vector reference_output_data(output_buffer_size, 0); + // Initialize the output data with something other than zero, so we can catch + // issue with kernels failing to initialize the output. + std::vector output_data(output_buffer_size, 3); + + const T min_amplitude = static_cast(0); + const T max_amplitude = static_cast(255); + FillRandom(&input_data, min_amplitude, max_amplitude); + + Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1); + std::vector output_size_data = {output_height, output_width}; + + reference_ops::ResizeBilinear( + input_data.data(), input_dims_inference, output_size_data.data(), + output_size_dims, reference_output_data.data(), output_dims_inference); + optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference, + output_size_data.data(), output_size_dims, + output_data.data(), output_dims_inference); + + double sum_diff = 0; + float max_abs_val = 0; + for (int i = 0; i < output_buffer_size; i++) { + sum_diff += std::abs(static_cast(output_data[i]) - + static_cast(reference_output_data[i])); + max_abs_val = std::max( + max_abs_val, std::abs(static_cast(reference_output_data[i]))); + } + + if (sum_diff != 0.f) { + const float mean_diff = static_cast(sum_diff / output_buffer_size); + const float relative_error = std::abs(mean_diff) / max_abs_val; + ASSERT_LT(relative_error, error_threshold); + } +} + +TEST(ResizeBilinear, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 0.025); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} + +TEST(ResizeBilinear, TestResizeBilinear) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7dad3c14e60fac9da9c0bcfd5d1d4c8f10b71c7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +namespace tflite { +namespace { + +void RunSoftmaxFloatReference(const uint8* input_data, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float Softmax. + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta, + reference_output_float_data.data(), shape_common); + // Work with quantized scaling for Softmax, under which 256 represents 1, but + // we limit this to 255. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::min( + 255, + static_cast(std::round(256.0f * reference_output_float_data[i]))); + } +} + +void CheckOutputData(const uint8* test_output, const uint8* reference_output, + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); + // While calculating some metrics in floating point, we work with quantized + // scaling. + std::vector diff(buffer_size); + int64_t sum_diff = 0; + int64_t sum_abs_diff = 0; + for (int i = 0; i < buffer_size; i++) { + diff[i] = static_cast(test_output[i]) - reference_output[i]; + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / buffer_size; + const float mean_abs_diff = static_cast(sum_abs_diff) / buffer_size; + // We either check for bit exactness (against the reference quantized version) + // or for general accuracy, allowing off-by-one (against the float reference). + if (be_exacting) { + ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0); + } else { + // For small numbers of samples, the estimates of the means vary more. + // Rather than widen the tolerances, we skip the smaller tests. + ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) || + buffer_size < 10000) && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1); + } +} + +// Runs the Softmax and compares against the float reference implementation and +// the quantized reference implementation. +void RunOneSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); + std::vector optimized_softmax_output(buffer_size); + std::vector reference_float_softmax_output(buffer_size); + std::vector reference_quant_softmax_output(buffer_size); + + RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, + stride, beta, reference_float_softmax_output.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits, + &input_beta_multiplier, + &input_beta_left_shift); + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier, + input_beta_left_shift, diff_min, + optimized_softmax_output.data(), shape_common); + reference_ops::Softmax(input_data, shape_common, input_beta_multiplier, + input_beta_left_shift, diff_min, + reference_quant_softmax_output.data(), shape_common); + + CheckOutputData(optimized_softmax_output.data(), + reference_float_softmax_output.data(), shape_common, + "Optimized vs float reference", false); + CheckOutputData(optimized_softmax_output.data(), + reference_quant_softmax_output.data(), shape_common, + "Optimized vs quant reference", true); + CheckOutputData(reference_quant_softmax_output.data(), + reference_float_softmax_output.data(), shape_common, + "Quant reference vs float reference", false); +} + +// This function picks some random Softmax params, which are checked for +// desirability. If not acceptable, it returns false. If they're OK, +// it runs the Softmax test and returns true. This allows the caller +// to loop until a test has been run. +// +// Currently we do not reject for any reason. +bool TryOneUniformSoftmax() { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // Softmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); + + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); + + std::vector input_data(buffer_size); + FillRandom(&input_data); + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, + stride, beta); + return true; +} + +// See TryOneUniformSoftmax() for a general description. +// +// Tests with "skyscraper" input patterns are included for two reasons. (a) +// Bimodal distributions are potentially challenging and perhaps more +// realistic than simple uniform random inputs. (b) Some implementations of +// Softmax may adapt as they traverse the depth, and so we test handling of +// cases where relatively small values are encountered at the beginning and end. +bool TryOneSkyscraperSoftmax(bool small_depth) { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // Softmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = small_depth + ? ExponentialRandomPositiveInt(0.75f, 40, 500) + : ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); + // Extra parameters for skyscraper input patterns. + const double middle_proportion = + ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0); + const int middle_min = UniformRandomInt(0, 255); + const int sides_max = UniformRandomInt(0, middle_min); + + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); + + std::vector input_data(buffer_size); + FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, + sides_max); + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, + stride, beta); + return true; +} + +TEST(TestQuantizedSoftmax, UniformSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformSoftmax()) { + } + } +} + +TEST(TestQuantizedSoftmax, SkyscraperSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperSoftmax(false)) { + } + } +} + +TEST(TestQuantizedSoftmax, SmallSkyscraperSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperSoftmax(true)) { + } + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h index ef77371bf65cc975dfa35275c8daa32de112a249..5994fad5c73df1dde6e33ba46dbd6e0802ea61be 100644 --- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h @@ -74,12 +74,22 @@ inline int StartForAxis(int begin_mask, // size 4, this function would return 4 as the stop, because it is one past the // "real" indices of 0, 1, 2 & 3. template -inline int StopForAxis(int end_mask, std::vector const& stop_indices, +inline int StopForAxis(int end_mask, int shrink_axis_mask, + std::vector const& stop_indices, std::vector const& strides, - int const* input_shape, int axis) { + int const* input_shape, int axis, int start_for_axis) { // Begin with the specified index + const bool shrink_axis = shrink_axis_mask & (1 << axis); int stop = stop_indices[axis]; + // When shrinking an axis, the end position does not matter (and can be + // incorrect when negative indexing is used, see Issue #19260). Always use + // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has + // already been adjusted for negative indices. + if (shrink_axis) { + stop = start_for_axis + 1; + } + // end_mask override if (end_mask & (1 << axis)) { if (strides[axis] > 0) { @@ -93,7 +103,7 @@ inline int StopForAxis(int end_mask, std::vector const& stop_indices, } // Handle negative indices - int axis_size = input_shape[axis]; + const int axis_size = input_shape[axis]; if (stop < 0) { stop += axis_size; } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ce887cea8b794b4b0cfd31722581cf9327be625e..ee2af5b46046c9e8bdc5816d5b6e9e9100cdc240 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#include #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" @@ -34,6 +35,11 @@ inline uint8_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline int16_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline int32_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -49,6 +55,13 @@ inline bool* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline std::complex* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast*>(tensor->data.c64) + : nullptr; +} + template inline const T* GetTensorData(const TfLiteTensor* tensor); @@ -62,6 +75,11 @@ inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -77,6 +95,13 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline const std::complex* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast*>(tensor->data.c64) + : nullptr; +} + inline int RemapDim(int max_dimensions, int d) { return max_dimensions - d - 1; } @@ -114,6 +139,19 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +inline RuntimeShape GetTensorShape(std::vector data) { + return RuntimeShape(data.size(), data.data()); +} + +inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return RuntimeShape(); + } + + auto* dims = tensor->dims; + return RuntimeShape(dims->size, dims->data); +} + // A list of tensors in a format that can be used by kernels like split and // concatenation. template diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index e1c9ccd84b09fdca0241192ef4dff62ded096433..5160e22307ae0894fabd0e9c4f7b9cd38b00840e 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -23,6 +23,9 @@ namespace tensor_utils { // Limit a float input f between +abs_limit and -abs_limit. float Clip(float f, float abs_limit); +// Checks if all entries of vector are zero. +bool IsZeroVector(const float* vector, int v_size); + // Quantizes a buffer of floating point values using a symmetric quantization // (i.e. linear quantization without an offset) to 8-bit signed integers. // It also outputs the range (min, max) of the floating point buffer, and the diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 3d8a2eada0c30132b5646327da5cd9fd80ccd39f..aa0d49ae4db6b4952b5864166f4a13459763cf44 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -32,6 +32,25 @@ TEST(uKernels, ClipTest) { {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); } +TEST(uKernels, IsZeroTest) { + constexpr int kVectorSize = 21; + static float zeros[kVectorSize] = {0.0}; + EXPECT_TRUE(IsZeroVector(zeros, kVectorSize)); + + static float nonzeros[kVectorSize] = { + 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, + 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, 1e-19, + 1e-20, 1e-21, 1e-22, 1e-23, 1e-24, 1e-25, 1e-26}; + EXPECT_FALSE(IsZeroVector(nonzeros, kVectorSize)); +} + +TEST(uKernels, GeneratedIsZeroTest) { + constexpr int kVectorSize = 39; + std::vector input(kVectorSize); + ZeroVector(input.data(), kVectorSize); + EXPECT_TRUE(IsZeroVector(input.data(), kVectorSize)); +} + TEST(uKernels, SymmetricQuantizeFloatsTest) { constexpr int kVectorSize = 9; static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0, @@ -44,7 +63,8 @@ TEST(uKernels, SymmetricQuantizeFloatsTest) { EXPECT_EQ(min, -640); EXPECT_EQ(max, 1000); - EXPECT_NEAR(scaling_factor, 0.127, 1e-6); // EQ won't work due to fpoint. + // EQ won't work due to fpoint. + EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6); EXPECT_THAT(output, testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127})); } @@ -76,7 +96,7 @@ TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) { EXPECT_NEAR(min, -9e-05, 1e-6); EXPECT_NEAR(max, 0.0002, 1e-6); - EXPECT_EQ(scaling_factor, 635000); + EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6); EXPECT_THAT(output, testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0})); } diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b1fd9b344d99103ee2a1b5b95fd697ccf4a64d0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +#include +#include + +namespace tflite { + +Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) { + Dims<4> result; + int cum_prod = 1; + + result.sizes[0] = depth; + result.strides[0] = cum_prod; + cum_prod *= result.sizes[0]; + + result.sizes[1] = width; + result.strides[1] = cum_prod; + cum_prod *= result.sizes[1]; + + result.sizes[2] = height; + result.strides[2] = cum_prod; + cum_prod *= result.sizes[2]; + + result.sizes[3] = batch; + result.strides[3] = cum_prod; + + return result; +} + +// this is a copied from an internal function in propagate_fixed_sizes.cc +bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, + int filter_height, int stride, PaddingType padding_type, + Dims<4>* output_dims, int* pad_width, int* pad_height) { + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int batch = ArraySize(input_dims, 3); + + int output_height = 0; + int output_width = 0; + if (padding_type == PaddingType::kValid) { + output_height = (input_height + stride - filter_height) / stride; + output_width = (input_width + stride - filter_width) / stride; + } else if (padding_type == PaddingType::kSame) { + output_height = (input_height + stride - 1) / stride; + output_width = (input_width + stride - 1) / stride; + } else { + return false; + } + + if (output_width <= 0 || output_height <= 0) { + return false; + } + + *pad_height = + ((output_height - 1) * stride + filter_height - input_height) / 2; + *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2; + *output_dims = + MakeDimsForInference(output_depth, output_width, output_height, batch); + return true; +} + +std::mt19937& RandomEngine() { + static std::mt19937 engine; + return engine; +} + +int UniformRandomInt(int min, int max) { + std::uniform_int_distribution dist(min, max); + return dist(RandomEngine()); +} + +float UniformRandomFloat(float min, float max) { + std::uniform_real_distribution dist(min, max); + return dist(RandomEngine()); +} + +int ExponentialRandomPositiveInt(float percentile, int percentile_val, + int max_val) { + const float lambda = + -std::log(1.f - percentile) / static_cast(percentile_val); + std::exponential_distribution dist(lambda); + float val; + do { + val = dist(RandomEngine()); + } while (!val || !std::isfinite(val) || val > max_val); + return static_cast(std::ceil(val)); +} + +float ExponentialRandomPositiveFloat(float percentile, float percentile_val, + float max_val) { + const float lambda = + -std::log(1.f - percentile) / static_cast(percentile_val); + std::exponential_distribution dist(lambda); + float val; + do { + val = dist(RandomEngine()); + } while (!std::isfinite(val) || val > max_val); + return val; +} + +void FillRandom(std::vector* vec, float min, float max) { + std::uniform_real_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(std::begin(*vec), std::end(*vec), gen); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..26078cef49a7868c64ff0095898eebe5e4de8751 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/test_util.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +// Creates a Dims struct from a set of dimensions. +Dims<4> MakeDimsForInference(int depth, int width, int height, int batch); + +// Computes output and padding dimensions. +bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, + int filter_height, int stride, PaddingType padding_type, + Dims<4>* output_dims, int* pad_width, int* pad_height); + +// Returns a mt19937 random engine. +std::mt19937& RandomEngine(); + +// Returns a random integer uniformly distributed between |min| and |max|. +int UniformRandomInt(int min, int max); + +// Returns a random float uniformly distributed between |min| and |max|. +float UniformRandomFloat(float min, float max); + +// Returns a random element in |v|. +template +const T& RandomElement(const std::vector& v) { + return v[UniformRandomInt(0, v.size() - 1)]; +} + +// Returns a random exponentially distributed integer. +int ExponentialRandomPositiveInt(float percentile, int percentile_val, + int max_val); + +// Returns a random exponentially distributed float. +float ExponentialRandomPositiveFloat(float percentile, float percentile_val, + float max_val); + +// Fills a vector with random floats between |min| and |max|. +void FillRandom(std::vector* vec, float min, float max); + +// Fills a vector with random numbers between |min| and |max|. +template +void FillRandom(std::vector* vec, T min, T max) { + std::uniform_int_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(std::begin(*vec), std::end(*vec), gen); +} + +// Fills a vector with random numbers. +template +void FillRandom(std::vector* vec) { + FillRandom(vec, std::numeric_limits::min(), std::numeric_limits::max()); +} + +template +void FillRandom(typename std::vector::iterator begin_it, + typename std::vector::iterator end_it, T min, T max) { + std::uniform_int_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(begin_it, end_it, gen); +} + +// Fill with a "skyscraper" pattern, in which there is a central section (across +// the depth) with higher values than the surround. +template +void FillRandomSkyscraper(std::vector* vec, int depth, + double middle_proportion, uint8 middle_min, + uint8 sides_max) { + for (auto base_it = std::begin(*vec); base_it != std::end(*vec); + base_it += depth) { + auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion)); + auto right_it = + base_it + std::ceil(0.5 * depth * (1.0 + middle_proportion)); + FillRandom(base_it, left_it, std::numeric_limits::min(), sides_max); + FillRandom(left_it, right_it, middle_min, std::numeric_limits::max()); + FillRandom(right_it, base_it + depth, std::numeric_limits::min(), + sides_max); + } +} + +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 43c68832785ac87f51c298370b50dc722167dc7f..fa2420713fea4faa3596251a95c2ed9606878b98 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,76 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#include +#include + #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; +enum class PaddingType { kNone, kSame, kValid }; + +// This enumeration allows for non-default formats for the weights array +// of a fully-connected operator, allowing the use of special optimized +// runtime paths. +enum class FullyConnectedWeightsFormat : uint8 { + // Default format (flat 2D layout, the inner contiguous dimension + // is input_depth, the outer non-contiguous dimension is output_depth) + kDefault, + // Summary: optimized layout for fast CPU runtime implementation, + // aimed specifically at ARM CPUs at the moment, and specialized for + // 8-bit quantized layers. + // + // The use case we're concerned with here is: 8-bit quantization, + // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in + // a key application that drove this), very small batch size (e.g. 1 -- 4). + // + // Even with 8-bit quantization of weights, the performance of memory + // accesses to the weights can become the dominant issue when + // the batch size is small, so each weight value is used in only a few + // arithmetic ops, i.e. the fully-connected node has a low arithmetic + // intensity. The specific issues that arise are of three kinds: + // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory + // bound. That's the "good" issue to run into. + // (2) One may run into sub-optimal pre-fetching: the data hasn't been + // prefetched into the cache by the time we need it. + // (3) One may run into cache aliasing: multiple values that are + // pre-fetched, alias each other in the L1 cache (which typically + // has only 4-way set associativity in ARM CPUs) and thus evict + // each other before we get to using them. + // + // The point of this shuffling is to avoid issues (2) and (3) so that + // we get as fast as possible given only the hard constraint (1). + // This is achieved by turning the difficulty into a solution: the + // difficulty, that each value loaded from memory is used only in + // one kernel iteration, making this operation memory-intensive, hints at + // the solution, of shuffling the weights so that they are stored in the + // exact order as the kernel needs to load them, so that the memory + // accesses made by the kernel are trivial. This solves (2) because the + // trivial memory access pattern allows the CPU's automatic prefetching + // to perform very well (no need even for preload instructions), and this + // solves (3) because the values being loaded concurrently are now + // contiguous in the address space, thus don't alias each other in the cache. + // + // On ARM, we typically want our kernel to process a 4x16 block of weights + // at a time, because: + // - 16 is the number of bytes in a NEON register. + // - 4 is how many rows we need to handle concurrently in the kernel in + // order to have sufficient mutual independence of instructions to + // maximize arithmetic throughput. + // + // Finally, the 'Int8' part in the name refers to the fact that this + // weights format has each weights value encoded as a signed int8 value, + // even if the data type of the weights buffer is uint8. This is intended + // to save runtime kernels the effort to have to XOR the top bit of these + // bytes before using them in signed arithmetic, see this file for more + // explanations on the 'signed int8 trick' in matrix multiplication kernels: + // + // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc + // + kShuffled4x16Int8, +}; // Quantization parameters, determining the mapping of quantized values // to real values (i.e. determining how quantized values are mathematically @@ -43,6 +108,125 @@ struct Dims { int strides[N]; }; +class RuntimeShape { + public: + // Shapes with dimensions up to 4 are stored directly in the structure, while + // larger shapes are separately allocated. + static constexpr int kMaxSmallSize = 4; + + RuntimeShape() : size_(0) {} + + explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) { + ReplaceWith(dimensions_count, dims_data); + } + + RuntimeShape(const std::initializer_list init_list) : size_(0) { + BuildFrom(init_list); + } + + ~RuntimeShape() { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + } + + inline int32 DimensionsCount() const { return size_; } + inline int32 Dims(int i) const { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i]; + } + inline void SetDim(int i, int32 val) { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + if (size_ > kMaxSmallSize) { + dims_pointer_[i] = val; + } else { + dims_[i] = val; + } + } + inline int32* DimsData() { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + inline const int32* DimsData() const { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + + inline void Resize(int dimensions_count) { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + size_ = dimensions_count; + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + inline void ReplaceWith(int dimensions_count, const int32* dims_data) { + Resize(dimensions_count); + int32* dst_dims = DimsData(); + std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32)); + } + + template + inline void BuildFrom(const T& src_iterable) { + const int dimensions_count = + std::distance(src_iterable.begin(), src_iterable.end()); + Resize(dimensions_count); + int32* data = DimsData(); + for (auto it : src_iterable) { + *data = it; + ++data; + } + } + + inline void BuildFrom(const std::initializer_list init_list) { + BuildFrom>(init_list); + } + + // Returns the total count of elements, that is the size when flattened into a + // vector. + inline int FlatSize() const { + int buffer_size = 1; + const int* dims_data = DimsData(); + for (int i = 0; i < size_; i++) { + const int dim = dims_data[i]; + TFLITE_DCHECK_GE(dim, 1); + buffer_size *= dim; + } + return buffer_size; + } + + private: + int32 size_; + union { + int32 dims_[kMaxSmallSize]; + int32* dims_pointer_; + }; +}; + +// Converts inference-style shape to legacy tflite::Dims<4>. +inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { + tflite::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_CHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); @@ -95,6 +279,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims, return offset; } +inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0)); + TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1)); + TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2)); + TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3)); + const int* dims_data = shape.DimsData(); + return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3; +} + inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); @@ -109,6 +302,9 @@ inline int Offset(const Dims<4>& dims, int* index) { } // Get array size, DCHECKing that the dim index is in range. +// +// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims() +// already performs this check. template int ArraySize(const Dims& array, int index) { TFLITE_DCHECK(index >= 0 && index < N); @@ -130,6 +326,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1, return MatchingArraySize(array1, index1, args...); } +// Get common shape dim, DCHECKing that they all agree. +inline int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return shape1.Dims(index1); +} + +template +int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2, Args... args) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return MatchingDim(shape1, index1, args...); +} + +// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize(). template inline int FlatSize(const Dims& dims) { int flat_size = 1; @@ -144,6 +355,50 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { return FlatSize(dims); } +// Flat size calculation, checking that dimensions match with one or more other +// arrays. +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return shape.FlatSize(); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3); +} + // Flat size calculation, checking that dimensions match with one or more other // arrays. template @@ -170,7 +425,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2); + return MatchingFlatSize(dims, check_dims_1, check_dims_2); } template @@ -181,7 +436,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3); + return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3); } // Data is required to be contiguous, and so many operators can use either the @@ -249,6 +504,72 @@ inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, check_dims_3); } +// Data is required to be contiguous, and so many operators can use either the +// full array flat size or the flat size with one dimension skipped (commonly +// the depth). +inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) { + const int dims_count = shape.DimensionsCount(); + TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count); + const auto* dims_data = shape.DimsData(); + int flat_size = 1; + for (int i = 0; i < dims_count; ++i) { + flat_size *= (i == skip_dim) ? 1 : dims_data[i]; + } + return flat_size; +} + +// A combination of MatchingFlatSize() and FlatSizeSkipDim(). +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return FlatSizeSkipDim(shape, skip_dim); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2, + check_shape_3); +} + template bool IsPackedWithoutStrides(const Dims& dims) { int expected_stride = 1; @@ -259,6 +580,14 @@ bool IsPackedWithoutStrides(const Dims& dims) { return true; } +template +void ComputeStrides(Dims* dims) { + dims->strides[0] = 1; + for (int d = 1; d < N; d++) { + dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1]; + } +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index 239b533a17efaa25d632c140c17e9f35d01a80ef..08f942c933552aa6ca7369550c928efba9e2e93e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -37,19 +37,17 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= 1e-6 * std::min(input_product_scale, bias_scale)); TF_LITE_ENSURE(context, input_product_scale >= 0); - TF_LITE_ENSURE(context, input_product_scale < output_scale); *multiplier = input_product_scale / output_scale; return kTfLiteOk; } -void CalculateActivationRangeUint8(TfLiteFusedActivation activation, - TfLiteTensor* output, int32_t* act_min, - int32_t* act_max) { - const int32_t qmin = std::numeric_limits::min(); - const int32_t qmax = std::numeric_limits::max(); - +namespace { +void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation, + int32_t qmin, int32_t qmax, + TfLiteTensor* output, + int32_t* act_min, int32_t* act_max) { const auto scale = output->params.scale; const auto zero_point = output->params.zero_point; @@ -71,23 +69,38 @@ void CalculateActivationRangeUint8(TfLiteFusedActivation activation, *act_max = qmax; } } - -void CalculateActivationRangeFloat(TfLiteFusedActivation activation, - float* activation_min, - float* activation_max) { - if (activation == kTfLiteActRelu) { - *activation_min = 0.f; - *activation_max = std::numeric_limits::max(); - } else if (activation == kTfLiteActRelu6) { - *activation_min = 0.f; - *activation_max = 6.f; - } else if (activation == kTfLiteActRelu1) { - *activation_min = -1.f; - *activation_max = 1.f; +} // namespace + +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max) { + int32_t qmin = 0; + int32_t qmax = 0; + if (output->type == kTfLiteUInt8) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } else if (output->type == kTfLiteInt16) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); } else { - *activation_min = std::numeric_limits::lowest(); - *activation_max = std::numeric_limits::max(); + TF_LITE_ENSURE(context, false); } + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); } bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) { diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 82cded36f2ed2777daccafee5890f47c0d7254e8..c8ce3c917d5bf66e01fbae95c18dfe97b3c84bae 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#include + #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" @@ -86,14 +88,35 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, TfLiteTensor* output, double* multiplier); -// Calculates the useful range of an activation layer given its activation -// tensor. +// Calculates the useful quantized range of an activation layer given its +// activation tensor. +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max); void CalculateActivationRangeUint8(TfLiteFusedActivation activation, TfLiteTensor* output, int32_t* act_min, int32_t* act_max); -void CalculateActivationRangeFloat(TfLiteFusedActivation activation, - float* activation_min, - float* activation_max); +// Calculates the useful range of an activation layer given its activation +// tensor.a +template +void CalculateActivationRange(TfLiteFusedActivation activation, + T* activation_min, T* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0; + *activation_max = 6; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1; + *activation_max = 1; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} // Return true if the given tensors have the same shape. bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2); diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index 7cea63da871219276a9daa551d890f28c839f21d..a7b54c6b842332feb2d9e7179e79ae054bd23bb9 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -70,8 +70,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { #define TF_LITE_L2NORM(type) \ type::L2Normalization( \ - GetTensorData(input), GetTensorDims(input), \ - GetTensorData(output), GetTensorDims(output)) + GetTensorData(input), GetTensorShape(input), \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); @@ -81,10 +81,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_L2NORM } else if (output->type == kTfLiteUInt8) { -#define TF_LITE_L2NORM(type) \ - type::L2Normalization(GetTensorData(input), GetTensorDims(input), \ - input->params.zero_point, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_L2NORM(type) \ + type::L2Normalization(GetTensorData(input), GetTensorShape(input), \ + input->params.zero_point, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); @@ -94,7 +94,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_L2NORM } else { - context->ReportError(context, "Inputs and outputs not all float types."); + context->ReportError(context, "Output type is %d, requires float.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index 042314ccf55cb6de12c743448fbe040f35e7baab..070ed60040997f18f7e8053acc9532adc2377400 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel { int output_; }; -TEST(L2NormOpTest, SimpleTest) { +TEST(L2NormOpTest, SimpleFloatTest) { L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); @@ -76,6 +76,23 @@ TEST(L2NormOpTest, SimpleTest) { ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); } +TEST(L2NormOpTest, MultipleBatchFloatTest) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32, + ActivationFunctionType_NONE); + m.SetInput({ + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + })); +} + TEST(L2NormOpTest, SimpleUint8Test) { L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); @@ -88,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) { ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); } +TEST(L2NormOpTest, MultipleBatchUint8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + 58, 166, 173, 205, 83, 134, // batch 1 + 58, 166, 173, 205, 83, 134, // batch 2 + 58, 166, 173, 205, 83, 134, // batch 3 + })); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc index c15a5170b85e180911d98d20205579bc0b9e8558..36dca299d0e07a84af60a13dfeb50b0f8fe38ee2 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -77,7 +77,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_LOCAL_RESPONSE_NORM } else { - context->ReportError(context, "Inputs and outputs not all float types."); + context->ReportError(context, "Output type is %d, requires float.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc index 62820a2f5113cb6ae252386aaf3842135383b79f..9a8d35e82cbc3a7e55246e6c06599b2838d1ee67 100644 --- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc @@ -90,10 +90,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::LogSoftmax(input_buffer, input_dims, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::LogSoftmax(input_buffer, input_shape, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 990b3da0554ebcb13f995fa281ed04f8c7c6d7ea..3577ae6caa1e02ce2e5db2e8054ba9c2fccbe93e 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -24,7 +24,10 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -34,6 +37,20 @@ namespace ops { namespace builtin { namespace lstm { +struct OpData { + // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // (5 inputs). + TfLiteLSTMKernelType kernel_type; + + // These fields are only used by full kernel. + int activation_state_tensor_index; + int cell_state_tensor_index; + int scratch_tensor_index; +}; + +// For full inputs kernel (18 or 20 inputs). +namespace full { + // Input Tensors of size {n_batch, n_input} constexpr int kInputTensor = 0; @@ -65,26 +82,33 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional +// If the node has 20 inputs, the following 2 tensors are used as state tensors. +// These are defined as variable tensors, and will be modified by this op. +constexpr int kInputActivationStateTensor = 18; +constexpr int kInputCellStateTensor = 19; + // Output tensors. +// * If the node has 18 inputs, these 2 tensors are used as state tensors. +// * If the node has 20 inputs, these 2 tensors are ignored. +// TODO(ycling): Make the 2 output state tensors optional, and propagate the +// state to output tensors when the 2 tensors present. constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMFullKernel; + context->AddTensors(context, /*tensors_to_add=*/7, + &op_data->scratch_tensor_index); + return op_data; } // Check that input tensor dimensions matches with each other. TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -94,7 +118,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - if (input_to_input_weights) { + if (input_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); @@ -114,7 +138,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - if (recurrent_to_input_weights) { + if (recurrent_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], n_cell); @@ -204,7 +228,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - if (projection_weights) { + if (projection_weights != nullptr) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); @@ -212,7 +236,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); - if (projection_bias) { + if (projection_bias != nullptr) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); } @@ -233,15 +257,37 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); + // True if the node is using input variable state tensors. It means: + // * The state tensors are defined as inputs. In this case it would be the + // 19th and 20th input tensors. + // * Otherwise, the output tensors are used to store states. + bool use_input_variable_states; + if (node->inputs->size == 20) { + use_input_variable_states = true; + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = + node->inputs->data[kInputCellStateTensor]; + } else if (node->inputs->size == 18) { + use_input_variable_states = false; + op_data->activation_state_tensor_index = + node->outputs->data[kOutputStateTensor]; + op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -262,110 +308,185 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that input tensor dimensions matches with each other. CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); - // Get the pointer to output, output_state and cell_state tensors. + // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - // Resize the output, output_state and cell_state tensors. + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + if (use_input_variable_states) { + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), + n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + } else { + // If the state tensors are outputs, this function takes the + // responsibility to resize the state tensors. + TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); + activation_state_size->data[0] = n_batch; + activation_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, + activation_state_size)); + + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + // Mark state tensors as persistent tensors. + activation_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + } + + // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); output_size->data[0] = n_batch; output_size->data[1] = n_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); - TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); - output_state_size->data[0] = n_batch; - output_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, output_state, output_state_size)); + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); + TfLiteIntArrayFree(node->temporaries); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = op_data->scratch_tensor_index; // Create a scratch buffer tensor. - TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // activation_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -377,9 +498,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -428,7 +546,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; @@ -441,12 +559,493 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + activation_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* activation_state_ptr = activation_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast(activation_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(mirkov): add a check that weights are all uint8s or all floats. + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, activation_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, activation_state_quantized, cell_state_quantized, + activation_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace full + +// For basic kernel (5-inputs). +namespace basic { + +enum InputTensor { + kInputData = 0, + kInputPrevActivation = 1, + kInputWeights = 2, + kInputBiases = 3, + kInputPrevState = 4, + kInputNum = 5, +}; + +enum OutputTensor { + kOutputActivation = 0, + kOutputState = 1, + kOutputConcatTemp = 2, + kOutputActivationTemp = 3, + kOutputNum = 4, +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMBasicKernel; + // `scratch_tensor_index` is unused in this kernel. + op_data->scratch_tensor_index = -1; + return op_data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, node->inputs->size == kInputNum); + TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); + + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TF_LITE_ENSURE_EQ(context, input->dims->size, 2); + const int num_batches = input->dims->data[0]; + const int input_depth = input->dims->data[1]; + + TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + const int activation_depth = prev_activation->dims->data[1]; + const int total_depth = input_depth + activation_depth; + + TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth); + TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth); + + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth); + + TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + TF_LITE_ENSURE_OK(context, context->ResizeTensor( + context, activation_out, + TfLiteIntArrayCopy(prev_activation->dims))); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, state_out, + TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); + concat_temp_size->data[0] = num_batches; + concat_temp_size->data[1] = total_depth; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, concat_temp, concat_temp_size)); + TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); + activation_temp_size->data[0] = num_batches; + activation_temp_size->data[1] = 4 * activation_depth; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, + activation_temp_size)); + + // Set the state tensors as persistent. + for (auto index : {kInputPrevActivation, kInputPrevState}) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + tensor->allocation_type = kTfLiteArenaRwPersistent; + } return kTfLiteOk; } +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + if (input->type == kTfLiteFloat32 && + prev_activation->type == kTfLiteFloat32 && + weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 && + prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 && + activation_out->type == kTfLiteFloat32 && + concat_temp->type == kTfLiteFloat32 && + activation_temp->type == kTfLiteFloat32) { + optimized_ops::LstmCell( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp)); + } else if (input->type == kTfLiteUInt8 && + prev_activation->type == kTfLiteUInt8 && + weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && + prev_state->type == kTfLiteInt16 && + state_out->type == kTfLiteInt16 && + activation_out->type == kTfLiteUInt8 && + concat_temp->type == kTfLiteUInt8 && + activation_temp->type == kTfLiteInt16) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + int state_scale_log2_rounded; + if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { + context->ReportError( + context, + "The internal state of a LSTM cell must have a power-of-two scale."); + return kTfLiteError; + } + const int state_integer_bits = 15 + state_scale_log2_rounded; + if (state_integer_bits != 4) { + context->ReportError(context, + "The only case of quantized LstmCell currently " + "supported is with StateIntegerBits==4"); + return kTfLiteError; + } + + double real_accum_multiplier = 4096 * bias->params.scale; + int32 accum_multiplier; + int accum_shift; + tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier, + &accum_shift); + optimized_ops::LstmCell<4>( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp), + weights->params.zero_point, accum_multiplier, accum_shift, + gemm_context); + } else { + context->ReportError(context, + "Unsupported combination of data types for LstmCell"); + return kTfLiteError; + } + + // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs + // LSTM kernel. + memcpy(prev_activation->data.raw, activation_out->data.raw, + activation_out->bytes); + memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); + + return kTfLiteOk; +} + +} // namespace basic + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + gemm_support::IncrementUsageCounter(context); + + const auto* params = reinterpret_cast(buffer); + switch (params->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Init(context, buffer, length); + case kTfLiteLSTMBasicKernel: + return basic::Init(context, buffer, length); + } +} +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Prepare(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Prepare(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Eval(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Eval(context, node); + } +} + } // namespace lstm TfLiteRegistration* Register_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index d81220d8d30793616444c03e8647b0877a39a4d9..0b7c56133e3cbb3d85f75657b6141620a8019e61 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. -#include #include #include @@ -35,7 +34,8 @@ class LSTMOpModel : public SingleOpModel { LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -45,31 +45,31 @@ class LSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weight_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weight_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weight_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -86,7 +86,7 @@ class LSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weight_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -97,6 +97,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -192,8 +198,9 @@ class LSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -203,7 +210,7 @@ class LSTMOpModel : public SingleOpModel { int num_cells() { return n_cell_; } int num_batches() { return n_batch_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -226,6 +233,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; @@ -237,7 +246,174 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +class HybridLSTMOpModel : public LSTMOpModel { + public: + HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole, + use_projection_weights, use_projection_bias, cell_clip, + proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -257,10 +433,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_output_weight tensor - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor {0}, // cell_to_input_weight tensor {0}, // cell_to_forget_weight tensor @@ -275,79 +451,137 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); - - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); - - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); - - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); - - lstm.SetInputGateBias({0., 0., 0., 0.}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); - - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -385,74 +619,689 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); - - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); - - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); - - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetOutputGateBias({0., 0., 0., 0.}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); +} - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -489,588 +1338,98 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(0, batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); - lstm.Invoke(); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc index 0752aa1804722accb1f88910fe013ffd632a4503..fd4d5367c5a6369b5ffeeea30a910262bc0796a9 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc +++ b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc @@ -126,10 +126,10 @@ TEST(MaximumOpTest, FloatWithBroadcastTest) { TEST(MaximumOpTest, Int32WithBroadcastTest) { std::initializer_list data1 = {1, 0, -1, -2, 3, 11}; std::initializer_list data2 = {2}; - TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}}, + TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}}, data1, data2, {2, 2, 2, 2, 3, 11}); - TestModel(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}}, + TestModel(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}}, data1, data2, {1, 0, -1, -2, 2, 2}); } diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 6c4c3a1edc43af5f97d8a233ed717a730878b7b4..1f72f3a3c7af4f9e042c9b2ac09252fab5de1a4f 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // Parameters used in the quantized paths where the output is 8bit + int32 output_activation_min; + int32 output_activation_max; + + // Parameters used in all quantized paths + int32_t output_multiplier; + int output_shift; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); - output->type = input2->type; data->requires_broadcast = !HaveSameShapes(input1, input2); @@ -74,6 +82,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + } + return context->ResizeTensor(context, output, output_size); } @@ -83,8 +105,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_MUL(type, opname) \ type::opname(GetTensorData(input1), GetTensorDims(input1), \ GetTensorData(input2), GetTensorDims(input2), \ @@ -107,41 +129,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, } template -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - - int32_t output_multiplier; - int output_shift; - - double real_multiplier = - input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, - &output_shift); - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_MUL(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { + if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + -input1->params.zero_point, GetTensorData(input2), \ + GetTensorDims(input2), -input2->params.zero_point, \ + output->params.zero_point, data->output_multiplier, \ + data->output_shift, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ GetTensorDims(output)); - // The quantized version of Mul doesn't support activations, so we - // always use BroadcastMul. - if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops, BroadcastMul); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, BroadcastMul); + } else { + TF_LITE_MUL(optimized_ops, BroadcastMul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteInt16) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + GetTensorData(output), GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output->params.zero_point, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL } else { - TF_LITE_MUL(optimized_ops, BroadcastMul); + context->ReportError( + context, "Unsupported combination of input and output types in Mul."); + return kTfLiteError; } -#undef TF_LITE_MUL + return kTfLiteOk; } template @@ -155,12 +196,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalFloat(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, data, input1, input2, - output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK( + context, EvalQuantized(context, node, params, data, input1, + input2, output)); } else { - context->ReportError(context, - "Mul only supports FLOAT32 and quantized UINT8 now."); + context->ReportError( + context, + "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index f1a30f82634631ba8320421d5b36ffe446f443fa..43d56e50d2686ff2624f36a0c5d8e43279a572cc 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -58,6 +58,9 @@ class FloatMulOpModel : public BaseMulOpModel { const float kQuantizedStep = 2.0 / 255.0; const float kQuantizedTolerance = 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; +const float kQuantizedStepInt16 = 2.0 / 32767.0; +const float kQuantizedToleranceInt16 = + 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16; class QuantizedMulOpModel : public BaseMulOpModel { public: @@ -67,6 +70,11 @@ class QuantizedMulOpModel : public BaseMulOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; TEST(FloatMulOpTest, NoActivation) { @@ -138,6 +146,38 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +TEST(QuantizedMulOpTest, NoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedToleranceInt16))); +} + +TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) { + const float kMinInt16 = -1.f; + const float kMaxInt16 = 32767.f / 32768.f; + const float kMinUint8 = -1.f; + const float kMaxUint8 = 127.f / 128.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_UINT8, {}, kMinUint8, kMaxUint8}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + // for quantized Mul, the error shouldn't exceed 2*step float GetTolerance(int min, int max) { float kQuantizedStep = (max - min) / 255.0; diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc index b8b53f340234a29e3c234d4d0fbb5496f6894dce..4124c05388cca180c2b417603e6d239f1f97b5bf 100644 --- a/tensorflow/contrib/lite/kernels/neg.cc +++ b/tensorflow/contrib/lite/kernels/neg.cc @@ -59,7 +59,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError( - context, "Neg only currently supports int64, int32, and float32.", + context, + "Neg only currently supports int64, int32, and float32, got %d.", input->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc index 3c95ac8cc2727fdeff5f39aa2fe30eb6129a6022..3d3594c60bbe1684dff7b1816f5f8a715b1abc60 100644 --- a/tensorflow/contrib/lite/kernels/neg_test.cc +++ b/tensorflow/contrib/lite/kernels/neg_test.cc @@ -58,9 +58,9 @@ TEST(NegOpModel, NegFloat) { TEST(NegOpModel, NegInt32) { NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}}); - m.SetInput({-2, -1, 0, 1, 2, 3}); + m.SetInput({-2, -1, 0, 1, 2, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); } TEST(NegOpModel, NegInt64) { diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index bcad58406af1cdd466e410a06011641692194be4..1c728a473326564a85a5e7d3d72718265979e29a 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index ecac2dd5e3c5fa2c4ef6f81e997f6d1515af6e43..4be8c243c17c533e8c7d5aa7bb50c9d790b06995 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -128,7 +128,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(nupurgarg): Change kernel implementation to use padding arrays in // forward order (depth, width, height, batch). // Build paddings in order of int[] = {batch, height, width, depth} to match - // kernel implementation of Pad in referenced_ops.h and optimized_ops.h. + // kernel implementation of Pad in reference_ops.h and optimized_ops.h. for (int idx = op_context.dims - 1; idx >= 0; --idx) { before_padding.push_back(paddings_data[idx * 2]); after_padding.push_back(paddings_data[idx * 2 + 1]); @@ -199,7 +199,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } break; default: - context->ReportError(context, "Type is currently not supported by Pad."); + context->ReportError(context, + "Type %d is currently not supported by Pad.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_PAD diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 645d9f4008600227b6f0a615d5d9ebfd1b7b69f7..7240fe04ccdadfb7b9703c3f2775c4b3502bd1d9 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -80,24 +80,24 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto computeOutSize = [padding](int imageSize, int filterSize, - int stride) -> int { + auto compute_out_size = [padding](int image_size, int filter_size, + int stride) -> int { return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - filterSize + stride) / stride + ? (image_size - filter_size + stride) / stride : 0; }; - int outWidth = - computeOutSize(width, params->filter_width, params->stride_width); - int outHeight = - computeOutSize(height, params->filter_height, params->stride_height); + int out_width = + compute_out_size(width, params->filter_width, params->stride_width); + int out_height = + compute_out_size(height, params->filter_height, params->stride_height); data->padding.height = ComputePadding(params->stride_height, 1, height, - params->filter_height, outHeight); + params->filter_height, out_height); data->padding.width = ComputePadding(params->stride_width, 1, width, - params->filter_width, outWidth); + params->filter_width, out_width); if (input->type == kTfLiteUInt8) { if (pool_type == kAverage || pool_type == kMax) { @@ -111,12 +111,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { } } - TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); - outputSize->data[0] = batches; - outputSize->data[1] = outHeight; - outputSize->data[2] = outWidth; - outputSize->data[3] = channels_out; - return context->ResizeTensor(context, output, outputSize); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = out_height; + output_size->data[2] = out_width; + output_size->data[3] = channels_out; + return context->ResizeTensor(context, output, output_size); } template @@ -124,14 +124,15 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -148,13 +149,13 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool(GetTensorData(input), GetTensorDims(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, \ - activation_min, activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -168,14 +169,15 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_MAX_POOL(type) \ - type::MaxPool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -193,12 +195,12 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); #define TF_LITE_MAX_POOL(type) \ - type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + type::MaxPool(GetTensorData(input), GetTensorShape(input), \ params->stride_width, params->stride_height, \ data->padding.width, data->padding.height, \ params->filter_width, params->filter_height, activation_min, \ activation_max, GetTensorData(output), \ - GetTensorDims(output)) + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -212,14 +214,15 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_L2_POOL(type) \ - type::L2Pool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + type::L2Pool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2_POOL(reference_ops); } else { @@ -246,7 +249,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; @@ -267,7 +271,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { MaxEvalQuantized(context, node, params, data, input, output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; @@ -288,7 +293,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { // We don't have a quantized implementation, so just fall through to the // 'default' case. default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a539c47a8fbe392e0e6542ab8ffb9065b550485 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow.cc @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pow { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for pow op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteInt32 && type != kTfLiteFloat32) { + context->ReportError(context, "Unsupported data type %d.", type); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output, bool requires_broadcast) { + if (requires_broadcast) { + reference_ops::BroadcastPow(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), + GetTensorDims(output)); + } else { + reference_ops::Pow(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output)); + } +} + +TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) { + const int64_t num_elements = NumElements(input); + const int32_t* data = GetTensorData(input); + for (int i = 0; i < num_elements; ++i) { + if (data[i] < 0) { + context->ReportError(context, + "POW does not support negative value for int32."); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (output->type) { + case kTfLiteInt32: { + // TensorFlow does not support negative for int32. + TF_LITE_ENSURE_OK(context, CheckValue(context, input2)); + PowImpl(input1, input2, output, data->requires_broadcast); + break; + } + case kTfLiteFloat32: { + PowImpl(input1, input2, output, data->requires_broadcast); + break; + } + default: { + context->ReportError(context, "Unsupported data type: %d", output->type); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +} // namespace +} // namespace pow + +TfLiteRegistration* Register_POW() { + static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..474d323bc3a1a0f224aa0575a5bbd35394aa2f53 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class PowOpModel : public SingleOpModel { + public: + PowOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions, + CreatePowOptions(builder_).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(PowOpModel, Simple) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8)); +} + +TEST(PowOpModel, NegativeAndZeroValue) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {0, 2, -7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 0}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1)); +} + +TEST(PowOpModel, Float) { + PowOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor(model.input2(), {0.5, 2.7, 3.1, 3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3))); +} + +TEST(PowOpModel, NegativeFloatTest) { + PowOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor(model.input2(), {0.5, -2.7, 3.1, -3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3))); +} + +TEST(PowOpModel, BroadcastTest) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/reduce.cc similarity index 72% rename from tensorflow/contrib/lite/kernels/mean.cc rename to tensorflow/contrib/lite/kernels/reduce.cc index 03e5db24de3f3c2d4e17df21bc0b592a02078d6b..31c331a8c61ded203af9ff2ae127cb6f985e2932 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -25,21 +25,21 @@ limitations under the License. namespace tflite { namespace ops { namespace builtin { -namespace mean { +namespace reduce { -// This file has reference implementation of Mean. +// This file has reference implementation of reduce_* operators. enum KernelType { kReference, }; -struct MeanContext { - MeanContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); axis = GetInput(context, node, 1); output = GetOutput(context, node, 0); } - TfLiteMeanParams* params; + TfLiteReducerParams* params; const TfLiteTensor* input; const TfLiteTensor* axis; TfLiteTensor* output; @@ -58,7 +58,7 @@ void Free(TfLiteContext* context, void* buffer) { } // Resizes the temp tensor that stores resolved axis. -TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context, TfLiteTensor* resolved_axis) { TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); axis_size->data[0] = static_cast(NumElements(op_context->axis)); @@ -66,7 +66,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, } // Resizes the temp tensor that stores temp sum of reduced elements. -TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context, TfLiteTensor* temp_sum) { TfLiteIntArray* size = TfLiteIntArrayCreate(1); size->data[0] = static_cast(NumElements(op_context->output)); @@ -74,8 +74,7 @@ TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, } // Resizes output array based on the input size and resolved axis. -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, - MeanContext* op_context) { +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { size_t num_axis = NumElements(op_context->axis); const TfLiteIntArray* input_dims = op_context->input->dims; int input_num_dims = NumDimensions(op_context->input); @@ -140,7 +139,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, // Initializes temp tensors to store index and resolved axis. TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, - MeanContext* op_context) { + OpContext* op_context) { // Creates a temp index to iterate through input data. int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); @@ -180,33 +179,44 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - MeanContext op_context(context, node); + OpContext op_context(context, node); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); // Leaves work to Eval if axis is not constant; else resizes output. if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(op_context.output); SetTensorToDynamic(resolved_axis); - SetTensorToDynamic(temp_sum); return kTfLiteOk; } resolved_axis->allocation_type = kTfLiteArenaRw; TF_LITE_ENSURE_OK(context, ResizeTempAxis(context, &op_context, resolved_axis)); TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + return kTfLiteOk; +} + +TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); + + // reduce_mean requires a buffer to store intermediate sum result. + OpContext op_context(context, node); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(temp_sum); + return kTfLiteOk; + } temp_sum->allocation_type = kTfLiteArenaRw; return ResizeTempSum(context, &op_context, temp_sum); } template -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - MeanContext op_context(context, node); +TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); @@ -255,16 +265,75 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #undef TF_LITE_MEAN return kTfLiteOk; } -} // namespace mean + +template +TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_SUM(kernel_type, data_type) \ + kernel_type::Sum<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, + op_context.output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, + op_context.output->params.zero_point); + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t)); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_SUM + return kTfLiteOk; +} + +} // namespace reduce TfLiteRegistration* Register_MEAN_REF() { - static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare, - mean::Eval}; + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareMean, + reduce::EvalMean}; + return &r; +} + +TfLiteRegistration* Register_SUM_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalSum}; return &r; } // TODO(kanlig): add optimized implementation of Mean. TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } +TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc similarity index 53% rename from tensorflow/contrib/lite/kernels/mean_test.cc rename to tensorflow/contrib/lite/kernels/reduce_test.cc index 79c9957f76fdb994be0a71f2e90b883435de4815..9e946822c686f6f20505d60b6161239624c94696 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -23,7 +23,7 @@ namespace { using ::testing::ElementsAreArray; -class BaseMeanOpModel : public SingleOpModel { +class BaseOpModel : public SingleOpModel { public: void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } @@ -53,7 +53,7 @@ class BaseMeanOpModel : public SingleOpModel { }; // Model for the tests case where axis is a const tensor. -class MeanOpConstModel : public BaseMeanOpModel { +class MeanOpConstModel : public BaseOpModel { public: MeanOpConstModel(const TensorData& input, const TensorData& output, std::initializer_list axis_shape, @@ -61,26 +61,59 @@ class MeanOpConstModel : public BaseMeanOpModel { input_ = AddInput(input); axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); BuildInterpreter({GetShape(input_)}); } }; // Model for the tests case where axis is a dynamic tensor. -class MeanOpDynamicModel : public BaseMeanOpModel { +class MeanOpDynamicModel : public BaseOpModel { public: MeanOpDynamicModel(const TensorData& input, const TensorData& output, const TensorData& axis, bool keep_dims) { input_ = AddInput(input); axis_ = AddInput(axis); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); BuildInterpreter({GetShape(input_)}); } }; +// Model for the tests case where axis is a const tensor. +class SumOpConstModel : public BaseOpModel { + public: + SumOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class SumOpDynamicModel : public BaseOpModel { + public: + SumOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// for quantized Add, the error shouldn't exceed step +float GetTolerance(int min, int max) { return (max - min) / 255.0; } + +// Tests for reduce_mean TEST(ConstFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, @@ -149,8 +182,6 @@ TEST(DynamicFloatMeanOpTest, Scale) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); } -// for quantized Add, the error shouldn't exceed step -float GetTolerance(int min, int max) { return (max - min) / 255.0; } TEST(ConstUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); @@ -209,6 +240,135 @@ TEST(DynamicUint8MeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); } +// Tests for reduce_sum + +TEST(ConstFloatSumOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(ConstFloatSumOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(DynamicFloatSumOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, Scale) { + std::initializer_list data = {9.527}; + SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance))); +} + +TEST(ConstUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177}, + kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::initializer_list data = {1.3, -4.8, -3.6, 0.24}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::initializer_list axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::initializer_list data = {11.14, -0.14, 7.423, 0.879}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 21cc185e9fbec42fe86dd65d3308a0011175c869..0ca08cd8f38216549b4383ebaacbf4c54442cd97 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -22,6 +22,7 @@ namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); TfLiteRegistration* Register_MFCC(); +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); } // namespace custom @@ -73,6 +74,7 @@ TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -85,11 +87,21 @@ TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); +TfLiteRegistration* Register_SUM(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); +TfLiteRegistration* Register_EXPAND_DIMS(); +TfLiteRegistration* Register_SPARSE_TO_DENSE(); +TfLiteRegistration* Register_EQUAL(); +TfLiteRegistration* Register_NOT_EQUAL(); +TfLiteRegistration* Register_SQRT(); +TfLiteRegistration* Register_RSQRT(); +TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_POW(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -111,7 +123,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); - AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); @@ -123,7 +137,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); - AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, Register_BIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, @@ -144,6 +159,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); @@ -161,12 +177,24 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_SUM, Register_SUM()); + AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); + AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); + AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); + AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); + AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); + AddBuiltin(BuiltinOperator_POW, Register_POW()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index b928f1b302580d52f708bbf85dfcfc0f79ff1e69..940718d67e70b7206227b891ea529cb9e9619161 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -32,4 +32,4 @@ class BuiltinOpResolver : public MutableOpResolver { } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index e4bd0f5b85d50c38b0a96f8f67f5be92f098d215..86c4cd3ee88013ca4174f444d0388bc036d9cde6 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); - // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); // ResizeBilinear creates a float tensor even when the input is made of // integers. - output->type = kTfLiteFloat32; + output->type = input->type; if (!IsConstantTensor(size)) { SetTensorToDynamic(output); @@ -90,21 +88,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output), \ +#define TF_LITE_RESIZE_BILINEAR(type, datatype) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ params->align_corners) if (kernel_type == kReference) { - TF_LITE_RESIZE_BILINEAR(reference_ops); + TF_LITE_RESIZE_BILINEAR(reference_ops, float); } if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { - TF_LITE_RESIZE_BILINEAR(optimized_ops); + TF_LITE_RESIZE_BILINEAR(optimized_ops, float); + } + } else if (output->type == kTfLiteUInt8) { + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t); } #undef TF_LITE_RESIZE_BILINEAR } else { - context->ReportError(context, "Inputs and outputs not all float types."); + context->ReportError(context, "Output type is %d, requires float.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 4e03f3820a5c14ee1692c553db61e385716b1723..10caffea03ebcec7862df1627541ac3d076b04e4 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using uint8 = std::uint8_t; class ResizeBilinearOpModel : public SingleOpModel { public: @@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel { } else { size_ = AddInput({TensorType_INT32, {2}}); } - output_ = AddOutput(TensorType_FLOAT32); // Always float. + output_ = AddOutput(input.type); SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); @@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } private: int input_; @@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel { TEST(ResizeBilinearOpTest, HorizontalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); - m.SetInput({3, 6}); + m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); - const_m.SetInput({3, 6}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); - m.SetInput({3, 9}); + m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); - const_m.SetInput({3, 9}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12 // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12, // 4, 10, // @@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); - m.SetInput({ + m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); } +TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc index 9bc8a1a34a0fc57aed0aff952b0b200d938ac1f9..9b6cee3cb55bf93b987fa8e59bdf9c591f5c0372 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -97,7 +97,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { break; \ default: \ context->ReportError(context, \ - "Does not support type other than bool|float|int"); \ + "Does not support type other than bool|float|int, " \ + "got %d", \ + type); \ return kTfLiteError; \ } diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc index cfe24a5fc92765747d1c75bc3e6964b959e2205d..4664b9acb444747167f991944ddc120e9941ccd6 100644 --- a/tensorflow/contrib/lite/kernels/select_test.cc +++ b/tensorflow/contrib/lite/kernels/select_test.cc @@ -88,11 +88,11 @@ TEST(SelectOpTest, SelectUInt8) { TensorType_UINT8); model.PopulateTensor(model.input1(), {false, true, false, false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } @@ -101,11 +101,11 @@ TEST(SelectOpTest, SelectInt32) { TensorType_INT32); model.PopulateTensor(model.input1(), {false, true, false, false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } @@ -113,11 +113,11 @@ TEST(SelectOpTest, RankOneSelectInt32) { SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32); model.PopulateTensor(model.input1(), {false, true}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1})); } @@ -125,11 +125,11 @@ TEST(SelectOpTest, RankZeroSelectInt32) { SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32); model.PopulateTensor(model.input1(), {false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); } diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbcd2ef004f490f00193153be7a2cfda83e73c24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace shape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template +void ExtractShape(const TfLiteTensor* input, OutType* output_data) { + for (int i = 0; i < NumDimensions(input); ++i) { + output_data[i] = SizeOfDimension(input, i); + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + auto* params = reinterpret_cast(node->builtin_data); + switch (params->out_type) { + case kTfLiteInt32: + output->type = kTfLiteInt32; + break; + case kTfLiteInt64: + output->type = kTfLiteInt64; + break; + default: + context->ReportError(context, "Unknown shape output data type: %d", + params->out_type); + return kTfLiteError; + } + + // Shape always produces a 1-dimensional output tensor, where each output + // element is the length of the corresponding input tensor's dimension. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); + output_size->data[0] = NumDimensions(input); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TFLITE_DCHECK_EQ(NumDimensions(output), 1); + TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); + + switch (output->type) { + case kTfLiteInt32: + ExtractShape(input, GetTensorData(output)); + break; + case kTfLiteInt64: + ExtractShape(input, GetTensorData(output)); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace shape + +TfLiteRegistration* Register_SHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27b48f4e992a8f02d56815bd1bd9074f5b41f400 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class ShapeOpModel : public SingleOpModel { + public: + ShapeOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type) { + input_ = AddInput(input_type); + output_ = AddOutput(output_type); + SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions, + CreateShapeOptions(builder_, output_type).Union()); + BuildInterpreter({input_shape}); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int input() { return input_; } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ShapeOpTest, OutTypeInt) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, OutTypeInt64) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT64); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, ScalarTensor) { + ShapeOpModel model({}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_EQ(model.GetOutputSize(), 0); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(ShapeOpTest, EmptyTensor) { + ShapeOpModel model({1, 0}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc index b28934e2f7418136fbeb65b4b9c322bd67f5319b..6a20e802a99cdf23a005a8cd9f1fd97b03c8070a 100644 --- a/tensorflow/contrib/lite/kernels/slice.cc +++ b/tensorflow/contrib/lite/kernels/slice.cc @@ -85,7 +85,8 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context, TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector( context, input, begin, size, &output_shape_vector)); } else { - context->ReportError(context, "Type is currently not supported by Slice."); + context->ReportError( + context, "Type %d is currently not supported by Slice.", begin->type); return kTfLiteError; } @@ -148,7 +149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetBeginAndSizeVectors(NumDimensions(input), begin, size, &begins, &sizes); } else { - context->ReportError(context, "Type is currently not supported by Slice."); + context->ReportError( + context, "Type %d is currently not supported by Slice.", begin->type); return kTfLiteError; } @@ -179,8 +181,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SLICE(bool); break; default: - context->ReportError(context, - "Type is currently not supported by Slice."); + context->ReportError( + context, "Type %d is currently not supported by Slice.", input->type); return kTfLiteError; } #undef TF_LITE_SLICE diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index 6c5338ff0fd26337c9adc8e0b94a0a88edfde37f..727822f6beaa8a63ca8f1b57ba4993d2e59f7e0b 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -92,10 +92,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), @@ -120,10 +119,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index 1e35869958a77907cbd9bbd228315b43903750a2..c9269599e58f095ded4788e2ab064583ae0a708c 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -152,8 +152,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; default: - context->ReportError(context, - "Type is currently not supported by SpaceToBatch."); + context->ReportError( + context, "Type %d is currently not supported by SpaceToBatch.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_SPACE_TO_BATCH_ND diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc index aafce895123cc21500a1e7edbed1e7552f8a3783..9dbe9b9edaccc3ea75f1997378aba5a218ee3030 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -113,7 +113,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } #undef TF_LITE_SPACE_TO_DEPTH diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc new file mode 100644 index 0000000000000000000000000000000000000000..404c32ad9ca8b9f1e467b747708ccb451f2a5118 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -0,0 +1,275 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sparse_to_dense { + +constexpr int kIndicesTensor = 0; +constexpr int kOutputShapeTensor = 1; +constexpr int kValueInputTensor = 2; +constexpr int kDefaultValueTensor = 3; +constexpr int kOutputTensor = 0; + +constexpr int kMaxDimensions = 4; + +template +TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape, + TfLiteTensor* output) { + const int output_dimensions = NumElements(output_shape); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); + for (int i = 0; i < output_dimensions; ++i) { + output_shape_array->data[i] = GetTensorData(output_shape)[i]; + } + + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus CheckDimensionsMatch(TfLiteContext* context, + const TfLiteTensor* indices, + const TfLiteTensor* output_shape, + const TfLiteTensor* values) { + switch (NumDimensions(indices)) { + case 0: + case 1: { + if (NumDimensions(values) == 0) { + TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values)); + } + TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1); + break; + } + case 2: { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1), + NumElements(output_shape)); + if (NumDimensions(values) == 0) + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + NumElements(values)); + break; + } + default: + context->ReportError( + context, "Wrong indices dimensions %d, should be less than 3.", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Convert indices into a vector of 4-d vectors. +// TODO(renjieliu): Revisit here to improve the performance, since multiple +// allocations of std::vectors will be quite slow on phones. +template +TfLiteStatus GetIndicesVector(TfLiteContext* context, + const TfLiteTensor* indices, + const int num_indices, + std::vector>* indices_vector) { + // Note because TfLite will reverse the dimensions, so pad zeros upfront. + switch (NumDimensions(indices)) { + case 0: + case 1: { + const auto indices_data = GetTensorData(indices); + for (int i = 0; i < num_indices; ++i) { + std::vector index({0, 0, 0, indices_data[i]}); + indices_vector->push_back(index); + } + break; + } + case 2: { + const int true_dimensions = SizeOfDimension(indices, 1); + TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions); + for (int i = 0; i < num_indices; ++i) { + std::vector index; + index.reserve(kMaxDimensions); + // Fill the index with 1 up to kMaxDimensions - true_dimensions to + // satisfy the needs for 4-dimension index. + for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) { + index.push_back(0); + } + for (int j = 0; j < true_dimensions; ++j) { + index.push_back(GetTensorData(indices)[i * true_dimensions + j]); + } + + indices_vector->push_back(index); + } + break; + } + default: + context->ReportError(context, + "Indices dimensions problem, got %d dimensions", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { + if (output_shape->type == kTfLiteInt32) { + return Resize(context, output_shape, output); + } else if (output_shape->type == kTfLiteInt64) { + return Resize(context, output_shape, output); + } else { + context->ReportError(context, "Dense shape type %d not supported.", + output_shape->type); + return kTfLiteError; + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + + // TODO(renjieliu): Handle validate_indices. + + // Indices can be 0-D, 1-D or 2-D. + TF_LITE_ASSERT(NumDimensions(indices) >= 0); + TF_LITE_ENSURE(context, NumDimensions(indices) < 3); + TF_LITE_ASSERT(NumDimensions(output_shape) >= 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + // Values can be 0-D or 1-D. + TF_LITE_ASSERT(NumDimensions(values) >= 0); + TF_LITE_ENSURE(context, NumDimensions(values) < 2); + + TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1); + + TF_LITE_ENSURE( + context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || + output_shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, values->type, default_value->type); + + // Ensure dimensions match. + TF_LITE_ENSURE_OK( + context, CheckDimensionsMatch(context, indices, output_shape, values)); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputShape(context, output_shape, output); +} + +template +TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, output_shape, output)); + } + + const int num_indices = SizeOfDimension(indices, 0); + const bool value_is_scalar = NumDimensions(values) == 0; + std::vector> indices_vector; + indices_vector.reserve(num_indices); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); + reference_ops::SparseToDense(indices_vector, GetTensorData(values), + *GetTensorData(default_value), + GetTensorData(output), GetTensorDims(output), + value_is_scalar); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + + // Currently only supports float32 and int32. + switch (values->type) { + case kTfLiteFloat32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + case kTfLiteInt32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + values->type); + return kTfLiteError; + } +} + +} // namespace sparse_to_dense + +TfLiteRegistration* Register_SPARSE_TO_DENSE() { + static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare, + sparse_to_dense::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a51ec17afcefd791680d7aa42cef467f481f6dbc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SparseToDenseOpModel : public SingleOpModel { + public: + SparseToDenseOpModel(std::initializer_list indices_shape, + std::initializer_list output_shape_shape, + std::initializer_list values_shape, T default_value, + TensorType tensor_index_type, + TensorType tensor_input_type) { + indices_ = AddInput(tensor_index_type); + output_shape_ = AddInput(TensorType_INT32); + values_ = AddInput(tensor_input_type); + default_value_ = AddInput(tensor_input_type); + output_ = AddOutput(tensor_input_type); + + SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOptions_SparseToDenseOptions, + CreateSparseToDenseOptions(builder_, false).Union()); + BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}}); + + PopulateTensor(default_value_, {default_value}); + } + + int indices() { return indices_; } + int output_shape() { return output_shape_; } + int values() { return values_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_shape_; + int values_; + int default_value_; + int output_; +}; + +TEST(SparseToDenseOpModelTest, ZeroDimensionTest) { + SparseToDenseOpModel m({1}, {1}, {1}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {3}); + m.PopulateTensor(m.output_shape(), {5}); + m.PopulateTensor(m.values(), {7}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(SparseToDenseOpModelTest, OneDimensionTest) { + SparseToDenseOpModel m({3}, {1}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {1, 3, 5}); + m.PopulateTensor(m.output_shape(), {7}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7})); +} + +TEST(SparseToDenseOpModelTest, TwoDimensionsTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, DefaultValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, IntegerValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_INT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, Int64IndexTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index c6b94c25be3a4e471c24c91b770c105d668b2df6..b14448604123253bac9c50c21f047891721ab122 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); auto input_type = op_context.input->type; - TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || + input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16); for (int i = 0; i < NumOutputs(node); ++i) { GetOutput(context, node, i)->type = input_type; } @@ -137,9 +138,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(uint8_t); break; } + case kTfLiteInt16: { + TF_LITE_SPLIT(int16_t); + break; + } default: - context->ReportError(context, - "Only float32 and uint8 are currently supported."); + context->ReportError( + context, + "Only float32, uint8 and int16 are currently supported, got %d.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_SPLIT diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 9417be32b3b83ea5539c89f8a0d2b6e4e20e69c4..bed2117f9ae3a64e963478eb03b46f0547f4c05f 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -121,10 +121,19 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, int32_t begin = GetBeginValueAtIndex(op_context, idx); int32_t end = GetEndValueAtIndex(op_context, idx); + // When shrinking an axis, the end position does not matter (and can be + // incorrect when negative indexing is used, see Issue #19260). Always use + // begin + 1 to generate a length 1 slice, since begin has + // already been adjusted for negative indices by GetBeginValueAtIndex. + const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx); + if (shrink_axis) { + end = begin + 1; + } + // This is valid for both positive and negative strides int32_t dim_shape = ceil((end - begin) / static_cast(stride)); dim_shape = dim_shape < 0 ? 0 : dim_shape; - if (!(op_context->params->shrink_axis_mask & (1 << idx))) { + if (!shrink_axis) { output_shape_vector.push_back(dim_shape); } } @@ -204,13 +213,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); + int shrink_axis_mask = + ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), begin_mask, \ - end_mask, starts, stops, strides, \ - GetTensorData(op_context.output), \ - GetTensorDims(op_context.output)) +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: @@ -235,8 +246,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Type is currently not supported " - "by StridedSlice."); + "Type %d is currently not supported " + "by StridedSlice.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_STRIDED_SLICE diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index cc39179bc705aa1083e74b06f8f7f3fb45e9f616..c5d4f9affb46c82b4dec15bc0653d7315d132335 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. namespace tflite { namespace { -using ::int32; using ::testing::ElementsAreArray; template data) { PopulateTensor(input_, data); } - void SetBegin(std::initializer_list data) { - PopulateTensor(begin_, data); + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); } - void SetEnd(std::initializer_list data) { - PopulateTensor(end_, data); + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); } - void SetStrides(std::initializer_list data) { - PopulateTensor(strides_, data); + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); } std::vector GetOutput() { @@ -384,6 +383,45 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); } +TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { + // This is equivalent to tf.range(4)[-1]. + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-1}); + m.SetEnd({0}); + m.SetStrides({1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-2, -1}); + m.SetEnd({-1, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][:, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({0, -1}); + m.SetEnd({0, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3})); +} + TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); @@ -395,17 +433,6 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } -TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); - m.SetInput({1, 2, 3, 4}); - m.SetBegin({-2}); - m.SetEnd({-3}); - m.SetStrides({-1}); - m.Invoke(); - EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); -} - TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -538,7 +565,7 @@ TEST(StridedSliceOpTest, RunTwice) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 9531ecba98991a0f7d8da0ec018f39150e7df9cf..1247525d416e8166a9e2e1d67c7907c00b0f6723 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -83,8 +83,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_SUB(type, opname) \ type::opname(GetTensorData(input1), GetTensorDims(input1), \ GetTensorData(input2), GetTensorDims(input2), \ @@ -126,16 +126,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32 input1_multiplier; int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &input1_multiplier, &input1_shift); + input1_shift *= -1; int32 input2_multiplier; int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &input2_multiplier, &input2_shift); + input2_shift *= -1; int32 output_multiplier; int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &output_multiplier, &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, @@ -174,8 +177,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { EvalQuantized(context, node, params, data, input1, input2, output); } else { - context->ReportError(context, - "Inputs and outputs not all float|uint8 types."); + context->ReportError( + context, "output type %d is not supported, requires float|uint8 types.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 308860c299e9d74729d35b760e0f605437872c92..22eebdd4ceb16aeabc5e799c708f7236b3e2be37 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +// SVDF op that compresses a fully connected op via low-rank matrix +// factorization. See https://research.google.com/pubs/archive/43813.pdf for +// details. #include #include #include @@ -32,6 +36,67 @@ namespace ops { namespace builtin { namespace svdf { +namespace { + +struct OpData { + int scratch_tensor_index; + bool float_weights_time_initialized; +}; + +static inline void ApplyTimeWeightsBiasAndActivation( + int batch_size, int memory_size, int num_filters, int num_units, int rank, + const TfLiteTensor* weights_time, const TfLiteTensor* bias, + TfLiteFusedActivation activation, TfLiteTensor* state, + TfLiteTensor* scratch, TfLiteTensor* output) { + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + activation, output_ptr_batch); + } + + // Left shift the state to make room for next cycle's activation. + // TODO(alanchiao): explore collapsing this into a single loop. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; ++f) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } +} + +} // namespace + constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; @@ -40,29 +105,34 @@ constexpr int kStateTensor = 0; constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData; + op_data->float_weights_time_initialized = false; + context->AddTensors(context, /*tensors_to_add=*/4, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - int* scratch_tensor_index = reinterpret_cast(node->user_data); + const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); const TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + // Check all the parameters of tensor match within themselves and match the // input configuration. const int rank = params->rank; @@ -103,10 +173,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = + (input->type == kTfLiteFloat32 && weights_feature->type == kTfLiteUInt8); + // Resize scratch. TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(4); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = scratch_tensor_index; TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); scratch_size_array->data[0] = batch_size; @@ -118,24 +196,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, scratch_size_array)); - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* weights_feature = - GetInput(context, node, kWeightsFeatureTensor); - const TfLiteTensor* weights_time = - GetInput(context, node, kWeightsTimeTensor); + if (is_hybrid_op) { + // Tell interpreter to allocate temporary tensors to store quantized values + // of input tensors. + node->temporaries->data[1] = scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } - TfLiteTensor* state = GetOutput(context, node, kStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + // Tell interpreter to allocate temporary tensors to store scaling factors. + node->temporaries->data[2] = scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + // Used to store dequantized weights_time matrix for hybrid computation + // of matmul(state, weights_time), which occurs in floating point. + node->temporaries->data[3] = scratch_tensor_index + 3; + TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); + float_weights_time->type = kTfLiteFloat32; + // Persistent so that we can compute the dequantized weights only once. + float_weights_time->allocation_type = kTfLiteArenaRwPersistent; + if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) { + TfLiteIntArray* float_weights_time_size = + TfLiteIntArrayCopy(weights_time->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, float_weights_time, + float_weights_time_size)); + } + } + return kTfLiteOk; +} +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* input, + const TfLiteTensor* weights_feature, + const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* state, + TfLiteTensor* output) { const int rank = params->rank; const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; @@ -146,67 +256,151 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Clear the activation (state left most column). // TODO(ghodrat): Add a test which initialize state with invalid values in // left most column and make sure it passes. - for (int b = 0; b < batch_size; b++) { + for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - for (int c = 0; c < num_filters; c++) { + for (int c = 0; c < num_filters; ++c) { float* state_ptr = state_ptr_batch + c * memory_size; state_ptr[memory_size - 1] = 0.0; } } // Compute conv1d(inputs, weights_feature). - // The state left most column is used to save current cycle activation. This + // The state right most column is used to save current cycle activation. This // is achieved by starting at state->data.f[memory_size - 1] and having the // stride equal to memory_size. tensor_utils::MatrixBatchVectorMultiplyAccumulate( weights_feature->data.f, num_filters, input_size, input->data.f, batch_size, &state->data.f[memory_size - 1], memory_size); - // Compute matmul(state, weights_time). - // The right most column is used to save temporary output (with the size of - // num_filters). This is achieved by starting at state->data.f and having the - // stride equal to memory_size. - for (int b = 0; b < batch_size; b++) { + ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters, + num_units, rank, weights_time, bias, + params->activation, state, scratch, output); + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, + const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) { + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Initialize the pointer to input. + const float* input_ptr_batch = input->data.f; + + // Initialize the pointer to storage for quantized values and + // scaling factors. + int8_t* quantized_input_ptr_batch = + reinterpret_cast(input_quantized->data.uint8); + + float* scaling_factors_ptr = scaling_factors->data.f; + + // Other initializations. + const int8_t* weights_feature_ptr = + reinterpret_cast(weights_feature->data.uint8); + const float weights_feature_scale = weights_feature->params.scale; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - float* scratch_ptr_batch = scratch->data.f + b * num_filters; - tensor_utils::BatchVectorBatchVectorDotProduct( - weights_time->data.f, state_ptr_batch, memory_size, num_filters, - scratch_ptr_batch, /*result_stride=*/1); + for (int c = 0; c < num_filters; ++c) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } } - // Initialize output with bias if provided. - if (bias) { - tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, - output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, batch_size * num_units); - } + if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) { + // Quantize input from float to int8. + float unused_min, unused_max; + for (int b = 0; b < batch_size; ++b) { + const int offset = b * input_size; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors_ptr[b]); + scaling_factors_ptr[b] *= weights_feature_scale; + } - // Reduction sum - for (int b = 0; b < batch_size; b++) { - float* output_ptr_batch = output->data.f + b * num_units; - float* scratch_ptr_batch = scratch->data.f + b * num_filters; - tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, - num_units, rank); + // Compute conv1d(inputs, weights_feature). + // The state right most column is used to save current cycle activation. + // This is achieved by starting at state->data.f[memory_size - 1] and having + // the stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch, + scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1], + memory_size); } - // Apply activation. - for (int b = 0; b < batch_size; b++) { - float* output_ptr_batch = output->data.f + b * num_units; - tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, - params->activation, output_ptr_batch); - } + // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying + // time weights so that the inner loop multiplies eight elements at a time. + ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters, + num_units, rank, weights_time, bias, + params->activation, state, scratch, output); + return kTfLiteOk; +} - // Right shift the state. - for (int b = 0; b < batch_size; b++) { - float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - for (int f = 0; f < num_filters; f++) { - tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, - /*shift_value=*/0.0); - state_ptr_batch += memory_size; +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (weights_feature->type) { + case kTfLiteFloat32: { + return EvalFloat(context, node, input, weights_feature, weights_time, + bias, params, scratch, state, output); + break; } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + TfLiteTensor* float_weights_time = + GetTemporary(context, node, /*index=*/3); + + // Dequantize weights time. + // TODO(alanchiao): this dequantization initialization only needs to + // happen once per model and should theoretically be placed in either Init + // or Prepare. However, TFLite doesn't allocate float_weights_time until + // the Eval function. + // TODO(alanchiao): refactor logic out into dequantize function. + if (!op_data->float_weights_time_initialized) { + const float dequantization_scale = weights_time->params.scale; + const int8_t* weights_time_ptr = + reinterpret_cast(weights_time->data.uint8); + for (int i = 0; i < NumElements(float_weights_time); ++i) { + float_weights_time->data.f[i] = + weights_time_ptr[i] * dequantization_scale; + } + op_data->float_weights_time_initialized = true; + } + return EvalHybrid(context, node, input, weights_feature, + float_weights_time, bias, params, scratch, + scaling_factors, input_quantized, state, output); + break; + } + default: + context->ReportError(context, "Type %d not currently supported.", + weights_feature->type); + return kTfLiteError; } - return kTfLiteOk; } } // namespace svdf diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 0f166dc69b95f3459388135b3a6c4d9b73a31cb4..5af3ff85004ce43c5b75c6f12761f121c0d8deca 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -126,17 +126,20 @@ static float svdf_golden_output_rank_2[] = { }; // Derived class of SingleOpModel, which is used to test SVDF TFLite op. -class SVDFOpModel : public SingleOpModel { +class BaseSVDFOpModel : public SingleOpModel { public: - SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(input_size), memory_size_(memory_size), rank_(rank) { input_ = AddInput(TensorType_FLOAT32); - weights_feature_ = AddInput(TensorType_FLOAT32); - weights_time_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); bias_ = AddNullInput(); state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -182,7 +185,7 @@ class SVDFOpModel : public SingleOpModel { int num_units() { return units_; } int num_batches() { return batches_; } - private: + protected: int input_; int weights_feature_; int weights_time_; @@ -197,7 +200,61 @@ class SVDFOpModel : public SingleOpModel { int rank_; }; -TEST(SVDFOpTest, BlackBoxTestRank1) { +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; +}; + +class HybridSVDFOpModel : public BaseSVDFOpModel { + public: + HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank) + : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, + TensorType_UINT8, TensorType_UINT8) {} + + void SetWeightsFeature(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_feature_, f); + } + + void SetWeightsTime(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_time_, f); + } +}; + +class SVDFOpTest : public ::testing::Test { + protected: + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, BaseSVDFOpModel* svdf, + float tolerance = 1e-5) { + const int svdf_num_batches = svdf->num_batches(); + const int svdf_input_size = svdf->input_size(); + const int svdf_num_units = svdf->num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf->SetInput(0, batch_start, batch_end); + + svdf->Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST_F(SVDFOpTest, BlackBoxTestRank1) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/1); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, @@ -218,31 +275,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) { -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf); } -TEST(SVDFOpTest, BlackBoxTestRank2) { +TEST_F(SVDFOpTest, BlackBoxTestRank2) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/2); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, @@ -278,28 +315,75 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.002945); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.00625109); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 1a01ee093626c08badd65858fc16ad44e69e4912..9156917140b5af6c0f38c878ab77fef7f93b049a 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -32,8 +32,8 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t, {}); +int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { + int id = AddTensor(t, {}, is_variable); inputs_.push_back(id); return id; } @@ -112,8 +112,15 @@ void SingleOpModel::BuildInterpreter( if (shape.empty()) continue; CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); } + + // Modify delegate with function. + if (apply_delegate_fn_) { + apply_delegate_fn_(interpreter_.get()); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; + interpreter_->ResetVariableTensorsToZero(); } void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 55edc97d19fa75bedb6c0928fcf9c7be5f434522..bedbe93ae65662647f6a0fb0c9c6a6a921e148bb 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -114,13 +114,22 @@ class SingleOpModel { SingleOpModel() {} ~SingleOpModel() {} + // Set a function callback that is run right after graph is prepared + // that allows applying external delegates. This is useful for testing + // other runtimes like NN API or GPU. + void SetApplyDelegate(std::function apply_delegate_fn) { + apply_delegate_fn_ = apply_delegate_fn; + } + // Copying or assignment is disallowed to simplify ownership semantics. SingleOpModel(const SingleOpModel&) = delete; SingleOpModel& operator=(const SingleOpModel&) = delete; // Add a TensorType input tensor and return its index. - int AddInput(TensorType type) { return AddInput(TensorData{type}); } - int AddInput(const TensorData& t); + int AddInput(TensorType type, bool is_variable = false) { + return AddInput(TensorData{type}, is_variable); + } + int AddInput(const TensorData& t, bool is_variable = false); // Templated version of AddConstInput(). template @@ -139,20 +148,18 @@ class SingleOpModel { int AddOutput(const TensorData& t); template - void QuantizeAndPopulate(int index, std::initializer_list data) { + void QuantizeAndPopulate(int index, const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); auto q = Quantize(data, t->params.scale, t->params.zero_point); PopulateTensor(index, 0, q.data(), q.data() + q.size()); } - void SymmetricQuantizeAndPopulate(int index, - std::initializer_list data) { + void SymmetricQuantizeAndPopulate(int index, const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); - std::vector values(data); - const int length = values.size(); + const int length = data.size(); std::vector q(length); float min, max, scaling_factor; - tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min, + tensor_utils::SymmetricQuantizeFloats(data.data(), length, q.data(), &min, &max, &scaling_factor); // Update quantization params. t->params.scale = scaling_factor; @@ -189,8 +196,22 @@ class SingleOpModel { } // Populate the tensor given its index. + // TODO(b/110696148) clean up and merge with vector-taking variant below. + template + void PopulateTensor(int index, const std::initializer_list& data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Populate the tensor given its index. + // TODO(b/110696148) clean up and merge with initializer_list-taking variant + // above. template - void PopulateTensor(int index, std::initializer_list data) { + void PopulateTensor(int index, const std::vector& data) { T* v = interpreter_->typed_tensor(index); CHECK(v) << "No tensor with index '" << index << "'."; for (T f : data) { @@ -253,7 +274,8 @@ class SingleOpModel { } template - int AddTensor(TensorData t, std::initializer_list data) { + int AddTensor(TensorData t, std::initializer_list data, + bool is_variable = false) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -270,6 +292,9 @@ class SingleOpModel { } else if (t.type == TensorType_INT32) { std::tie(t.scale, t.zero_point) = QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT16) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); } else { LOG(FATAL) << "No support for the requested quantized type"; } @@ -302,7 +327,7 @@ class SingleOpModel { tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); + /*name=*/0, q_params, is_variable)); tensor_data_[id] = t; @@ -317,6 +342,9 @@ class SingleOpModel { std::vector> operators_; std::vector> buffers_; std::map> custom_registrations_; + // A function pointer that gets called after the interpreter is created but + // before evaluation happens. This is useful for applying a delegate. + std::function apply_delegate_fn_; }; // Base class for single op unit tests. diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/contrib/lite/kernels/test_util_test.cc index 1e10e89061213b6fcabd404310893dd97a51d83f..236580347254d336609a3081736f54e069b5cb5a 100644 --- a/tensorflow/contrib/lite/kernels/test_util_test.cc +++ b/tensorflow/contrib/lite/kernels/test_util_test.cc @@ -22,22 +22,22 @@ using ::testing::ElementsAreArray; TEST(TestUtilTest, QuantizeVector) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 1, 1, 255}; + auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 1, 1, 255}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } TEST(TestUtilTest, QuantizeVectorScalingDown) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 0, 0, 100}; + auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 0, 0, 100}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } TEST(TestUtilTest, QuantizeVectorScalingUp) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 5, 10, 255}; + auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 5, 10, 255}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc new file mode 100644 index 0000000000000000000000000000000000000000..af77f074742eb3fef10a74616ff679255911fbb2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace tile { + +constexpr int kInputTensor = 0; +constexpr int kInputMultipliers = 1; +constexpr int kOutputTensor = 0; + +namespace { +template +TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape, + const TfLiteTensor* multipliers, + int num_dimensions) { + const T* multipliers_v = GetTensorData(multipliers); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape->data[i] = shape.data[i] * multipliers_v[i]; + } + return output_shape; +} + +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + const int num_dimensions = NumDimensions(input); + const int num_multipliers = NumElements(multipliers); + TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers); + switch (multipliers->type) { + case kTfLiteInt32: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + case kTfLiteInt64: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + default: + context->ReportError(context, "Tile not supported multiply tensor type."); + return kTfLiteError; + } +} + +template +void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, + T* out_data) { + for (int i = 0; i < multiplier; ++i) { + const T* in_end = in_data + in_size; + T* new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, + const T* in_data, const M* multipliers, + T* out_data, int dimension) { + const int dimension_size = in_dimensions.data[dimension]; + if (dimension == in_dimensions.size - 1) { + CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], + out_data); + return std::make_pair(dimension_size, + dimension_size * multipliers[dimension]); + } + int total_stride_size = 0, total_tiled_stride_size = 0; + const T* copy_from_data = in_data; + T* copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multipliers, + copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, + multipliers[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + total_tiled_stride_size * multipliers[dimension]); +} + +template +void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, + const TfLiteTensor* multipliers, TfLiteTensor* out_data) { + // Doing recursively tiling from top to down dimension. + switch (multipliers->type) { + case kTfLiteInt32: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + case kTfLiteInt64: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + default: + break; + } +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + // Only int32 and int64 multipliers type is supported. + TF_LITE_ENSURE_MSG(context, + (multipliers->type == kTfLiteInt32) || + (multipliers->type == kTfLiteInt64), + "Tile only supports int32 and int64 mutlipliers."); + + if (IsConstantTensor(multipliers)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + SetTensorToDynamic(output); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } + + switch (output->type) { + case kTfLiteFloat32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteUInt8: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt64: + Tile(*(input->dims), input, multipliers, output); + break; + default: + context->ReportError(context, "Type is currently not supported by Tile."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace tile +TfLiteRegistration* Register_TILE() { + static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f78c224e54f0c71bc6622134a1c8e4142c22daa --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -0,0 +1,256 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +class TileOpModel : public SingleOpModel { + public: + TileOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType multiply_type) { + input_ = AddInput(input_type); + multipliers_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_TILE, BuiltinOptions_TileOptions, 0); + BuildInterpreter({input_shape, {static_cast(input_shape.size())}}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt64(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetMultipliers(std::initializer_list data) { + PopulateTensor(multipliers_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + + std::vector GetOutputUInt8() { return ExtractVector(output_); } + + std::vector GetOutputInt32() { return ExtractVector(output_); } + + std::vector GetOutputInt64() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int multipliers_; + int output_; +}; + +TEST(TileTest, Float32Vector) { + TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({1.f, 2.f, 3.f}); + m.SetMultipliers({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f})); +} + +TEST(TileTest, Float32Matrix) { + TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Float32HighDimension) { + TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 3, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, + 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 6, 3})); +} + +TEST(TileTest, Uint8Matrix) { + TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32); + m.SetInputUInt8({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int32Matrix) { + TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32); + m.SetInputInt32({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix64Multipliers) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index 0feb42b85bbae695294211f82698211ca89eae04..2dd760bbfebd1faa8b7ff9158bc1a1b1d4647525 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -56,11 +56,13 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { output_values_shape->data[num_dimensions - 1] = k; TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + // Force output types. + output_indexes->type = kTfLiteInt32; + output_values->type = input->type; auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, TfLiteIntArray* delete_on_error) { TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); if (status != kTfLiteOk) { - TfLiteIntArrayFree(new_size); if (delete_on_error != nullptr) { TfLiteIntArrayFree(delete_on_error); } @@ -214,7 +216,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_values->data.i64); break; default: - context->ReportError(context, "Type is currently not supported by TopK."); + context->ReportError(context, + "Type %d is currently not supported by TopK.", + output_values->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc index 212f8acc76d4afba56933029175f69b34ea87a3e..2abb89b617742b33b9280b15ad379422c5c9b207 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -42,32 +42,32 @@ class TopKV2OpModel : public SingleOpModel { PopulateTensor(input_, data); } - void SetInputUInt8(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); } - void SetInputInt32(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); } void SetInputInt64(std::initializer_list data) { PopulateTensor(input_, data); } - std::vector GetIndexes() { - return ExtractVector(output_indexes_); + std::vector GetIndexes() { + return ExtractVector(output_indexes_); } std::vector GetValuesFloat() { return ExtractVector(output_values_); } - std::vector GetValuesUInt8() { - return ExtractVector(output_values_); + std::vector GetValuesUInt8() { + return ExtractVector(output_values_); } - std::vector GetValuesInt32() { - return ExtractVector(output_values_); + std::vector GetValuesInt32() { + return ExtractVector(output_values_); } std::vector GetValuesInt64() { @@ -119,7 +119,7 @@ TEST(TopKV2OpTest, VectorFloat) { EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2}))); } -// Check that uint8 works. +// Check that uint8_t works. TEST(TopKV2OpTest, TypeUint8) { TopKV2OpModel m({2, 3}, TensorType_UINT8, 2); m.SetInputUInt8({1, 2, 3, 251, 250, 249}); @@ -128,7 +128,7 @@ TEST(TopKV2OpTest, TypeUint8) { EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); } -// Check that int32 works. +// Check that int32_t works. TEST(TopKV2OpTest, TypeInt32) { TopKV2OpModel m({2, 3}, TensorType_INT32, 2); m.SetInputInt32({1, 2, 3, 10251, 10250, 10249}); diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc index 8316a23c18dea4f036773ab093b483fa45babd45..800b0563d7ee6126d65005ff4ef61219db9eebb5 100644 --- a/tensorflow/contrib/lite/kernels/transpose.cc +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -136,7 +136,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Type is currently not supported by Transpose."); + "Type %d is currently not supported by Transpose.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_TRANSPOSE diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 3c99661029ed1ac881536f83519dcec355c60d50..7182374a6f2ec39c670e02e6fda9b967ae0a5b43 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" @@ -38,9 +39,35 @@ constexpr int kWeightsTensor = 1; constexpr int kDataInputTensor = 2; constexpr int kOutputTensor = 0; -TfLiteStatus ResizeOutputShape(TfLiteContext* context, - const TfLiteTensor* output_shape, - TfLiteTensor* output) { +const int kTensorNotAllocated = -1; + +struct OpData { + // IDs are the arbitrary identifiers used by TF Lite to identify and access + // memory buffers. + int im2col_id = kTensorNotAllocated; + + // im2col is the only temporary currently tracked, therefore always index 0. + // If more temporaries are added, they should be properly tracked. + int32_t im2col_index = 0; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to use as scratch space for im2col, and + // to carry information from Prepare() to Eval(). + auto* data = new OpData; + eigen_support::IncrementUsageCounter(context); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + eigen_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { // Currently only support int32 for output shape. if (output_shape->type != kTfLiteInt32) { context->ReportError(context, "Output shape is %d, not int32.", @@ -56,15 +83,60 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context, return context->ResizeTensor(context, output, output_shape_array); } +// Allocate temporary im2col tensor. +static TfLiteStatus AllocateIm2colTensor(TfLiteContext* context, + TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + if (data->im2col_id == kTensorNotAllocated) { + context->AddTensors(context, 1, &data->im2col_id); + } + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[data->im2col_index] = data->im2col_id; + + return kTfLiteOk; +} + +TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context, + const TfLiteTensor* output_shape, + const TfLiteTensor* weights, + const TfLiteTensor* input, + TfLiteTensor* im2col) { + if (output_shape->type != kTfLiteInt32) { + context->ReportError(context, "im2col shape is %d, not int32.", + output_shape->type); + return kTfLiteError; + } + TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4); + TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4); + im2col_shape_array->data[0] = output_shape->data.i32[0]; + im2col_shape_array->data[1] = output_shape->data.i32[1]; + im2col_shape_array->data[2] = output_shape->data.i32[2]; + const int input_depth = SizeOfDimension(input, 3); + const int filter_width = SizeOfDimension(weights, 1); + const int filter_height = SizeOfDimension(weights, 2); + im2col_shape_array->data[3] = input_depth * filter_height * filter_width; + + im2col->type = input->type; + im2col->allocation_type = kTfLiteArenaRw; + return context->ResizeTensor(context, im2col, im2col_shape_array); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TF_LITE_ENSURE_STATUS(AllocateIm2colTensor(context, node)); + const TfLiteTensor* output_shape = GetInput(context, node, kOutputShapeTensor); const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + OpData* user_data = reinterpret_cast(node->user_data); + TfLiteTensor* im2col = + &context->tensors[node->temporaries->data[user_data->im2col_index]]; TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); @@ -79,13 +151,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), - SizeOfDimension(weights, 0)); - - if (!IsConstantTensor(output_shape)) { + SizeOfDimension(weights, 3)); + + if (IsConstantTensor(output_shape)) { + TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output)); + TF_LITE_ENSURE_STATUS( + ResizeIm2ColTensor(context, output_shape, weights, input, im2col)); + } else { + // Defer resizing until Eval(). SetTensorToDynamic(output); - return kTfLiteOk; } - return ResizeOutputShape(context, output_shape, output); + return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -94,13 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - + OpData* user_data = reinterpret_cast(node->user_data); + TfLiteTensor* im2col = + &context->tensors[node->temporaries->data[user_data->im2col_index]]; const auto* params = reinterpret_cast(node->builtin_data); if (IsDynamicTensor(output)) { TF_LITE_ENSURE_OK(context, - ResizeOutputShape(context, output_shape, output)); + ResizeOutputTensor(context, output_shape, output)); + } + if (IsDynamicTensor(im2col)) { + TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape, + weights, input, im2col)); } // Get height and width of the output image. @@ -123,7 +205,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(input), GetTensorDims(input), GetTensorData(weights), GetTensorDims(weights), stride_width, stride_height, padding_size.width, padding_size.height, - GetTensorData(output), GetTensorDims(output)); + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); break; default: context->ReportError(context, "Type %d, not currently supported.", @@ -136,8 +219,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace transpose_conv TfLiteRegistration* Register_TRANSPOSE_CONV() { - static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare, - transpose_conv::Eval}; + static TfLiteRegistration r = {transpose_conv::Init, transpose_conv::Free, + transpose_conv::Prepare, transpose_conv::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc index 52be08934997f484337e4a3592bc7af832601695..c741df19dee09b140954d0c110800cbd849c2f11 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" @@ -24,9 +25,49 @@ namespace { using ::testing::ElementsAreArray; +class ConstTransposeConvOpModel : public SingleOpModel { + // Just to be extra confusing, transpose_conv has an _input_ named + // "output_shape". This input sets the shape of the output tensor of the op. + // In this version of the test class, "output_shape" is a constant that must + // be specified in the constructor. + public: + ConstTransposeConvOpModel(TfLiteRegistration* registration, + std::initializer_list input_shape, + std::initializer_list filter_shape, + std::initializer_list output_shape_data, + Padding padding, int stride_w, int stride_h) { + output_shape_ = AddConstInput(TensorType_INT32, output_shape_data, + {static_cast(output_shape_data.size())}); + filter_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions, + CreateTransposeConvOptions(builder_, padding, stride_w, stride_h) + .Union()); + resolver_ = absl::make_unique( + BuiltinOperator_TRANSPOSE_CONV, registration); + BuildInterpreter({{4}, filter_shape, input_shape}); + } + + int output_shape() { return output_shape_; } + int filter() { return filter_; } + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int output_shape_; + int filter_; + int input_; + int output_; +}; + class TransposeConvOpModel : public SingleOpModel { public: - TransposeConvOpModel(std::initializer_list input_shape, + TransposeConvOpModel(TfLiteRegistration* registration, + std::initializer_list input_shape, std::initializer_list filter_shape, Padding padding, int stride_w, int stride_h) { output_shape_ = AddInput(TensorType_INT32); @@ -37,6 +78,8 @@ class TransposeConvOpModel : public SingleOpModel { BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions, CreateTransposeConvOptions(builder_, padding, stride_w, stride_h) .Union()); + resolver_ = absl::make_unique( + BuiltinOperator_TRANSPOSE_CONV, registration); BuildInterpreter({{4}, filter_shape, input_shape}); } @@ -54,6 +97,15 @@ class TransposeConvOpModel : public SingleOpModel { int output_; }; +const auto kKernelMap = new std::map({}); + +class TransposeConvOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMap; + } +}; + // Test case: // output = tf.nn.conv2d_backprop_input( // tf.constant([ 1, 4, 4, 1 ]), @@ -61,8 +113,9 @@ class TransposeConvOpModel : public SingleOpModel { // tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32), // [1, 1, 1, 1 ], // "SAME") -TEST(TransposeConvOpModelTest, SimpleTest) { - TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1); +TEST_P(TransposeConvOpTest, SimpleTest) { + TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 3, 3, 1}, + Padding_SAME, 1, 1); m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9}); m.PopulateTensor( @@ -75,6 +128,21 @@ TEST(TransposeConvOpModelTest, SimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } +// Test case: Same as above, but with a const "output_shape" +TEST_P(TransposeConvOpTest, ConstSimpleTest) { + ConstTransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 4, 4, 1}, + {1, 3, 3, 1}, Padding_SAME, 1, 1); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9}); + m.PopulateTensor( + m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372, + 417, 330, 263, 446, 485, 365})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + // Test case: // filter = tf.constant(np.arange(1, 19), // shape=[ 3, 3, 1, 2 ], @@ -87,11 +155,12 @@ TEST(TransposeConvOpModelTest, SimpleTest) { // "SAME") // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) -TEST(TransposeConvOpModelTest, TwoFiltersTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1); +TEST_P(TransposeConvOpTest, TwoFiltersTest) { + TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2}, + Padding_SAME, 1, 1); m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -116,11 +185,12 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) { // "VALID") // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) -TEST(TransposeConvOpModelTest, PaddingValidTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1); +TEST_P(TransposeConvOpTest, PaddingValidTest) { + TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2}, + Padding_VALID, 1, 1); m.PopulateTensor(m.output_shape(), {1, 6, 6, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -146,8 +216,9 @@ TEST(TransposeConvOpModelTest, PaddingValidTest) { // tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32), // [1, 2, 2, 1 ], // "VALID") -TEST(TransposeConvOpModelTest, StrideValidTest) { - TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2); +TEST_P(TransposeConvOpTest, StrideValidTest) { + TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {1, 3, 3, 1}, + Padding_VALID, 2, 2); m.PopulateTensor(m.output_shape(), {1, 5, 5, 1}); m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9}); m.PopulateTensor(m.input(), {1, 2, 3, 4}); @@ -170,11 +241,30 @@ TEST(TransposeConvOpModelTest, StrideValidTest) { // tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32), // [1, 2, 2, 1 ], // "VALID") -TEST(TransposeConvOpModelTest, MultiChannelTest) { - TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2); +TEST_P(TransposeConvOpTest, MultiChannelTest) { + TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1}, + Padding_VALID, 2, 2); m.PopulateTensor(m.output_shape(), {1, 5, 5, 2}); - m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18}); + m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, + 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.input(), {1, 2, 3, 4}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9, + 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72, + 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44, + 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); +} + +// Test case: Same as above, but with a const "output_shape" +TEST_P(TransposeConvOpTest, ConstMultiChannelTest) { + ConstTransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1}, + {1, 5, 5, 2}, Padding_VALID, 2, 2); + m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, + 8, 10, 12, 14, 16, 18}); m.PopulateTensor(m.input(), {1, 2, 3, 4}); m.Invoke(); @@ -199,8 +289,9 @@ TEST(TransposeConvOpModelTest, MultiChannelTest) { // "SAME") // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1]) -TEST(TransposeConvOpModelTest, AccuracyTest) { - TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3); +TEST_P(TransposeConvOpTest, AccuracyTest) { + TransposeConvOpModel m(GetRegistration(), {1, 1, 2, 1}, {1, 3, 3, 1}, + Padding_SAME, 3, 3); m.PopulateTensor(m.output_shape(), {1, 3, 4, 1}); m.PopulateTensor(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4}); m.PopulateTensor(m.input(), {323, 521}); @@ -212,6 +303,10 @@ TEST(TransposeConvOpModelTest, AccuracyTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1})); } +INSTANTIATE_TEST_CASE_P( + TransposeConvOpTest, TransposeConvOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 1c28123a24edd9886476bf8e9ea3ba4c692baa2b..32daf2bb02d5f63391cc5ba45654acd4acfbfe56 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -70,9 +70,21 @@ constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; +// Temporary tensors +enum TemporaryTensor { + kScratchBuffer = 0, + kInputQuantized = 1, + kOutputStateQuantized = 2, + kCellStateQuantized = 3, + kScalingFactors = 4, + kProductScalingFactors = 5, + kRecoveredCellWeights = 6, + kNumTemporaryTensors = 7 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -84,7 +96,7 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -242,6 +254,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; @@ -288,86 +301,156 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + // Create a scratch buffer tensor. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[kOutputStateQuantized] = + *scratch_tensor_index + kOutputStateQuantized; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, kOutputStateQuantized); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[kCellStateQuantized] = + *scratch_tensor_index + kCellStateQuantized; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, kCellStateQuantized); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -380,8 +463,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -432,6 +513,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_state_ptr = output_state->data.f; float* cell_state_ptr = cell_state->data.f; + // Feed the sequence into the LSTM step-by-step. for (int t = 0; t < max_time; t++) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_batch = output->data.f + t * n_batch * n_output; @@ -452,6 +534,262 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_batch = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, + input_to_input_weights_scale, input_to_forget_weights_ptr, + input_to_forget_weights_scale, input_to_cell_weights_ptr, + input_to_cell_weights_scale, input_to_output_weights_ptr, + input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, + n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr_batch); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} } // namespace unidirectional_sequence_lstm TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc index 5881ced7c7a616ef2c24db60892cbbf9eec7c42e..de38bdef6fd1b019c7790a664b29cd45d29e5dcc 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite Sequential LSTM op. -#include #include #include @@ -37,7 +36,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weights_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -48,31 +48,31 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weights_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weights_type); + input_to_cell_weights_ = AddInput(weights_type); + input_to_output_weights_ = AddInput(weights_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weights_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weights_type); + recurrent_to_cell_weights_ = AddInput(weights_type); + recurrent_to_output_weights_ = AddInput(weights_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weights_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weights_type); + cell_to_output_weights_ = AddInput(weights_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -89,7 +89,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weights_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -196,8 +196,9 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -208,7 +209,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int num_batches() { return n_batch_; } int sequence_length() { return sequence_length_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -243,7 +244,183 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int sequence_length_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +// The hybrid model has quantized weights. +class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { + public: + HybridUnidirectionalLSTMOpModel( + int n_batch, int n_input, int n_cell, int n_output, int sequence_length, + bool use_cifg, bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : UnidirectionalLSTMOpModel( + n_batch, n_input, n_cell, n_output, sequence_length, use_cifg, + use_peephole, use_projection_weights, use_projection_bias, + cell_clip, proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + // Feed the whole sequence as input. + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(), + batch_start, batch_end); + } + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + EXPECT_GT(num_outputs, 0); + std::vector expected; + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + } + + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -252,9 +429,11 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int sequence_length = 3; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, - /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -281,77 +460,138 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; - lstm.SetInputGateBias({0., 0., 0., 0.}); + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - lstm.SetCellBias({0., 0., 0., 0.}); + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - lstm.SetForgetGateBias({1., 1., 1., 1.}); + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor - lstm.SetOutputGateBias({0., 0., 0., 0.}); + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - // Input should have n_input * sequence_length many values. - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - float* batch0_start = lstm_input; - float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output; - float* golden_end = - golden_start + lstm.num_outputs() * lstm.sequence_length(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); -} + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; + } +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -360,9 +600,11 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { const int sequence_length = 3; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, - /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -389,71 +631,690 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - lstm.SetCellBias({0., 0., 0., 0.}); + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor - lstm.SetForgetGateBias({1., 1., 1., 1.}); + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor - lstm.SetOutputGateBias({0., 0., 0., 0.}); + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - float* batch0_start = lstm_input; - float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); - - float* golden_start = lstm_golden_output; - float* golden_end = - golden_start + lstm.num_outputs() * lstm.sequence_length(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + } +}; + +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -461,8 +1322,9 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { const int sequence_length = 4; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, - /*use_peephole=*/true, /*use_projection_weights=*/true, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, { @@ -491,588 +1353,99 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - for (int i = 0; i < lstm.sequence_length(); i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + const int sequence_length = 4; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); - } + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - lstm.Invoke(); + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - std::vector expected; - for (int i = 0; i < lstm.sequence_length(); i++) { - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - } - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 22c80df19c5dd7ebbaf8065a93e9e24527267302..164a0cbd08d6ce82a413f12ba6b1703087a30aba 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -41,7 +41,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -125,6 +125,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; } @@ -187,14 +197,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteSequenceRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -218,6 +226,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; if (time_major) { // Initialize the pointer to hidden state. @@ -233,7 +242,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -252,7 +262,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, /*batch_size=*/1, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, - hidden_state_ptr_batch, output_ptr_batch); + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); } } } @@ -278,12 +288,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input_weights->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 80fcb28bc7f6c09c7b979fcefcbc25deef583a00..c448fb71db204494042192d6a75ac4d600467e47 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -45,6 +45,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_FLOAT32: *type = kTfLiteFloat32; break; + case TensorType_INT16: + *type = kTfLiteInt16; + break; case TensorType_INT32: *type = kTfLiteInt32; break; @@ -60,6 +63,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_BOOL: *type = kTfLiteBool; break; + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + break; default: error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", EnumNameTensorType(tensor_type), tensor_type); @@ -322,12 +328,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = nullptr; switch (op_type) { - case BuiltinOperator_CALL: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - break; - case BuiltinOperator_CUSTOM: - break; case BuiltinOperator_CONV_2D: { TfLiteConvParams* params = MallocPOD(); if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { @@ -343,21 +343,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_TANH: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RELU6: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_EXP: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_PRELU: - case BuiltinOperator_FLOOR: - case BuiltinOperator_NEG: - case BuiltinOperator_SIN: - break; case BuiltinOperator_CAST: { TfLiteCastParams* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_CastOptions()) { @@ -445,9 +430,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_EMBEDDING_LOOKUP: - // no-op. - break; case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = MallocPOD(); @@ -465,6 +447,18 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, op->builtin_options_as_FullyConnectedOptions()) { params->activation = parse_activation( fully_connected_params->fused_activation_function()); + switch (fully_connected_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + error_reporter->Report("Unhandled fully-connected weights format."); + return kTfLiteError; + } } *builtin_data = reinterpret_cast(params); break; @@ -558,6 +552,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, parse_activation(lstm_params->fused_activation_function()); params->cell_clip = lstm_params->cell_clip(); params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + } } *builtin_data = reinterpret_cast(params); break; @@ -571,12 +573,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_PAD: { - break; - } - case BuiltinOperator_PADV2: { - break; - } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { @@ -616,18 +612,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_SPACE_TO_BATCH_ND: { - break; - } - case BuiltinOperator_BATCH_TO_SPACE_ND: { - break; - } - case BuiltinOperator_TRANSPOSE: { - break; - } - case BuiltinOperator_MEAN: { - auto* params = MallocPOD(); - if (auto* schema_params = op->builtin_options_as_MeanOptions()) { + case BuiltinOperator_MEAN: + case BuiltinOperator_SUM: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); } *builtin_data = reinterpret_cast(params); @@ -664,10 +652,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: { - break; - } case BuiltinOperator_ARG_MAX: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { @@ -677,16 +661,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_SELECT: { - break; - } - case BuiltinOperator_SLICE: { - break; - } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD(); @@ -699,11 +673,73 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + ConvertTensorType(schema_params->out_type(), ¶ms->out_type, + error_reporter); + } + *builtin_data = static_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RSQRT: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_SQRT: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: + break; } return kTfLiteOk; } @@ -725,7 +761,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( } const TfLiteRegistration* registration = - flatbuffer_op_index_to_registration_[op->opcode_index()]; + flatbuffer_op_index_to_registration_[index]; if (registration == nullptr) { error_reporter_->Report("Skipping op for opcode_index %d\n", index); status = kTfLiteError; @@ -844,7 +880,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const char* buffer_ptr; TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + bool is_variable = tensor->is_variable(); if (buffer_ptr) { + if (is_variable) { + error_reporter_->Report( + "Tensor %d is a variable tensor with buffer. " + "It's not supported now.\n", + i); + status = kTfLiteError; + } + if (interpreter->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_) != kTfLiteOk) { @@ -853,8 +898,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (interpreter->SetTensorParametersReadWrite( - i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor), + dims, quantization, + is_variable) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; @@ -938,6 +984,15 @@ TfLiteStatus InterpreterBuilder::operator()( if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) return cleanup_and_error(); + std::vector variables; + for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { + auto* tensor = (*interpreter)->tensor(i); + if (tensor->is_variable) { + variables.push_back(i); + } + } + (**interpreter).SetVariables(std::move(variables)); + return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD index f8767b443a2aa64b666c3b6bfb7db30cc0be62ea..f18a2ca07a5f66b760e96a6d9a57db8d6c26b7b9 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc index ceef8e6a29c4fb03fc3be5e7e6fb062d144f9250..5d6c47dce8d90192d35a3a51fe6d0beb11f3b23f 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc @@ -104,11 +104,11 @@ void GetSegmentPredictions( }); // Add backoff response. - for (const string& backoff : config.backoff_responses) { + for (const auto& backoff : config.backoff_responses) { if (predictor_responses->size() >= config.num_response) { break; } - predictor_responses->push_back({backoff, config.backoff_confidence}); + predictor_responses->emplace_back(backoff, config.backoff_confidence); } } diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc index e6c8d966f1aff5a867f9469f8fcdec526df84763..c7e08814fdf502f1ecfea60af3385fc7aa6055fa 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -35,8 +35,8 @@ const char kModelName[] = "smartreply_ondevice_model.bin"; const char kSamples[] = "smartreply_samples.tsv"; string TestDataPath() { - return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", - "contrib/lite/models/testdata/")); + return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", + "contrib/lite/models/testdata/")); } MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { @@ -55,7 +55,7 @@ class PredictorTest : public ::testing::Test { protected: PredictorTest() { model_ = tflite::FlatBufferModel::BuildFromFile( - StrCat(TestDataPath(), "/", kModelName).c_str()); + absl::StrCat(TestDataPath(), "/", kModelName).c_str()); CHECK(model_); } ~PredictorTest() override {} @@ -121,7 +121,7 @@ TEST_F(PredictorTest, BatchTest) { int total_triggers = 0; string line; - std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + std::ifstream fin(absl::StrCat(TestDataPath(), "/", kSamples)); while (std::getline(fin, line)) { const std::vector fields = absl::StrSplit(line, '\t'); if (fields.empty()) { diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 107c84e66607aa7144f60b26f9402677a96ac9ee..905c0919cb690012c2feba2cca821aa43fb2ddff 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -29,27 +29,46 @@ limitations under the License. namespace tflite { -// TODO(aselle): FATAL leaves resources hanging. -void FATAL(const char* format, ...) { +void logError(const char* format, ...) { + // TODO(mikie): use android logging, stderr is not captured for Java + // applications va_list args; va_start(args, format); vfprintf(stderr, format, args); va_end(args); + fprintf(stderr, "\n"); fflush(stderr); - exit(1); } +#define FATAL(...) \ + logError(__VA_ARGS__); \ + exit(1); + // TODO(aselle): Change the error model to use status codes. -#define CHECK_TFLITE_SUCCESS(x) \ - if (x != kTfLiteOk) { \ - FATAL("Aborting since tflite returned failure."); \ +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ } -#define CHECK_NN(x) \ - if (x != ANEURALNETWORKS_NO_ERROR) { \ - FATAL("Aborting since tflite returned failure."); \ +#define CHECK_NN(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + FATAL("Aborting since NNAPI returned failure nnapi_delegate.cc:%d", \ + __LINE__); \ } +#define RETURN_ERROR_IF_NN_FAILED(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + logError( \ + "Returning error since NNAPI returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ + return kTfLiteError; \ + } + +// Tracking of NNAPI operand ids +static const int64_t kOperandIdNotSet = -1; +static const int64_t kOperandNotNeeded = -2; + namespace { int32_t GetAndroidSdkVersion() { @@ -104,21 +123,16 @@ NNAPIDelegate::~NNAPIDelegate() { } // Adds the tensors of the interpreter to the NN API model. -// Returns the number of operands added. -uint32_t addTensorOperands(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, - const std::vector& skip_list) { +TfLiteStatus addTensorOperands(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model, + uint32_t* no_of_operands_added, + std::vector* nnapi_ids) { uint32_t next_id = 0; for (size_t i = 0; i < interpreter->tensors_size(); i++) { - // skip temporaries tensors. - bool shouldSkip = false; - for (auto skip_idx : skip_list) { - if (i == skip_idx) { - shouldSkip = true; - break; - } - } - if (shouldSkip) continue; + // Skip temporaries and RNN back-edges. + if ((*nnapi_ids)[i] == kOperandNotNeeded) continue; + + (*nnapi_ids)[i] = int64_t(next_id); int32_t nn_type = 0; // NNAPI requires 32-bit float scale to be zero, tflite doesn't care @@ -144,7 +158,18 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, zeroPoint = tensor->params.zero_point; break; default: - FATAL("Unsupported type."); + logError("Unsupported tensor type %d", tensor->type); + return kTfLiteError; + } + if (tensor->dims->size == 0) { + logError("NNAPI doesn't support tensors with rank 0 (index %d name %s)", + i, tensor->name); + return kTfLiteError; + } + if (tensor->dims->size > 4) { + logError("NNAPI doesn't support tensors with rank > 4 (index %d name %s)", + i, tensor->name); + return kTfLiteError; } // TODO(aselle): Note, many of these are intermediate results. Do I need // to ever specify these sizes. I am currently below doing setValue @@ -154,30 +179,53 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, ANeuralNetworksOperandType operand_type{ nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); - + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_addOperand(nn_model, &operand_type)); // TODO(aselle): Based on Michael's suggestion, limiting this to read // only memory if (tensor->allocation_type == kTfLiteMmapRo) { if (const NNAPIAllocation* alloc = dynamic_cast( static_cast(tensor->allocation))) { - CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory( - nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw), - tensor->bytes)); + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_setOperandValueFromMemory( + nn_model, next_id, alloc->memory(), + alloc->offset(tensor->data.raw), tensor->bytes)); } else { - CHECK_NN(ANeuralNetworksModel_setOperandValue( + RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue( nn_model, next_id, tensor->data.raw, tensor->bytes)); } + } else if (tensor->bytes == 0) { + // These size 0 tensors are optional tensors reserved. + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0)); } + ++next_id; } - return next_id; + *no_of_operands_added = next_id; + return kTfLiteOk; +} + +void MapAndAddTensorIds(const int* from_ids_buf, size_t from_ids_count, + std::vector* into, + const std::vector& map) { + for (size_t i = 0; i < from_ids_count; i++) { + int from_id = from_ids_buf[i]; + if (from_id == kOptionalTensor) { + into->push_back(from_id); + } else { + into->push_back(map[from_id]); + } + } } // Adds the operations and their parameters to the NN API model. // 'next-id' is the operand ID of the next operand of the model. -void AddOpsAndParams(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, uint32_t next_id) { +TfLiteStatus AddOpsAndParams( + tflite::Interpreter* interpreter, ANeuralNetworksModel* nn_model, + uint32_t next_id, std::vector* model_state_inputs, + std::vector* model_state_outputs, + const std::vector& tensor_id_to_nnapi_id) { for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; @@ -186,8 +234,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, static_cast(registration.builtin_code); // Add the parameters. - std::vector augmented_inputs( - node.inputs->data, node.inputs->data + node.inputs->size); + std::vector augmented_inputs, augmented_outputs; + MapAndAddTensorIds(node.inputs->data, node.inputs->size, &augmented_inputs, + tensor_id_to_nnapi_id); + MapAndAddTensorIds(node.outputs->data, node.outputs->size, + &augmented_outputs, tensor_id_to_nnapi_id); auto add_scalar_int32 = [&nn_model, &augmented_inputs, &next_id](int value) { @@ -207,46 +258,83 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; + auto add_vector_int32 = [&](const int* values, uint32_t num_values) { + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_INT32, + .dimensionCount = 1, + .dimensions = &num_values}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, next_id, values, sizeof(int32_t) * num_values)); + augmented_inputs.push_back(next_id++); + }; + + // Handle state tensors of RNN, LSTM, SVDF. + // For each state_out tensor, a corresponding state_in operand needs to be + // created for NNAPI. auto duplicate_state_tensor_float32 = - [interpreter, &nn_model, &augmented_inputs](int tensor_id) { + [interpreter, &nn_model, &next_id, &augmented_inputs, + &model_state_inputs, &model_state_outputs](int tensor_id) { const TfLiteTensor* tensor = interpreter->tensor(tensor_id); - CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, tensor_id, tensor->data.raw, tensor->bytes)); - augmented_inputs.push_back(tensor_id); + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), + tensor->params.scale, tensor->params.zero_point}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + augmented_inputs.push_back(next_id); + model_state_inputs->push_back(next_id); + model_state_outputs->push_back(tensor_id); + next_id++; }; + auto check_and_add_activation = [&add_scalar_int32](int activation) { + if (activation > kTfLiteActRelu6) { + FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + } + add_scalar_int32(activation); + }; - auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + auto add_add_params = [&add_scalar_int32](void* data) { + auto* builtin = reinterpret_cast(data); + if (builtin->activation > kTfLiteActRelu6) { + FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + } + add_scalar_int32(builtin->activation); + }; - auto add_pooling_params = [&add_scalar_int32](void* data) { + auto add_pooling_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->filter_width); add_scalar_int32(builtin->filter_height); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_convolution_params = [&add_scalar_int32](void* data) { + auto add_convolution_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_depthwise_conv_params = [&add_scalar_int32](void* data) { + auto add_depthwise_conv_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->depth_multiplier); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_fully_connected_params = [&add_scalar_int32](void* data) { + auto add_fully_connected_params = [&check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; auto add_concatenation_params = [&add_scalar_int32](void* data) { @@ -275,39 +363,71 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_float32(builtin->proj_clip); }; + // LSTM in NNAPI requires scratch tensor as an output operand. + auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model, + &next_id, &augmented_outputs]() { + if (node.temporaries->size == 0) return; + int scratch_buffer_index = node.temporaries->data[0]; + const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index); + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), tensor->params.scale, + tensor->params.zero_point}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + augmented_outputs.insert(augmented_outputs.begin(), next_id++); + }; + auto add_mean_params = [&add_scalar_int32](void* data) { - auto builtin = reinterpret_cast(data); + auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->keep_dims); }; -#if 0 - auto add_reshape_params = [&](void* data) { - auto builtin = reinterpret_cast(data); - uint32_t tensor_size_shape = builtin->num_dimensions; - ANeuralNetworksOperandType operand_type{ - ANEURALNETWORKS_TENSOR_INT32, - {static_cast(1), - reinterpret_cast(&tensor_size_shape)}, - 0, - 0}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, next_id, builtin->shape, - sizeof(int) * builtin->num_dimensions)); - augmented_inputs.push_back(next_id++); + auto add_svdf_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->rank); + add_scalar_int32(builtin->activation); }; -#endif + + auto add_rnn_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; + + auto add_squeeze_params = [&](void* data) { + const auto* builtin = reinterpret_cast(data); + // Note that we add the squeeze dimensions even if the dimensions were + // unspecified (empty), as NNAPI requires the operand. + add_vector_int32(builtin->squeeze_dims, + static_cast(builtin->num_squeeze_dims)); + }; + + // Handle optional input tensors. + auto add_optional_tensors = [&nn_model, &augmented_inputs, + &next_id](int nn_type) { + for (size_t idx = 0; idx < augmented_inputs.size(); idx++) { + if (augmented_inputs[idx] == kOptionalTensor) { + const std::vector dim = {0, 0}; + ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, + nullptr, 0)) + augmented_inputs[idx] = next_id++; + } + } + }; + int nnapi_version = 10; ANeuralNetworksOperationType nn_op_type; switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_MUL: nn_op_type = ANEURALNETWORKS_MUL; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: add_pooling_params(node.builtin_data); @@ -321,7 +441,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_pooling_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_L2_POOL_2D; break; - case tflite::BuiltinOperator_CONV_2D: + case tflite::BuiltinOperator_CONV_2D: { + auto builtin = reinterpret_cast(node.builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node.inputs->size != 3) { + logError("NNAPI does not support dilated Conv2D."); + return kTfLiteError; + } + } add_convolution_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_CONV_2D; break; @@ -365,14 +492,36 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nn_op_type = ANEURALNETWORKS_SPACE_TO_DEPTH; break; case tflite::BuiltinOperator_LSTM: { + if (node.inputs->size + /* no of params */ 3 != 21) { + logError("NNAPI only supports 21-input LSTMs"); + return kTfLiteError; + } duplicate_state_tensor_float32( - node.outputs->data[/*kOutputStateTensor*/ 1]); + node.outputs->data[/*kOutputStateTensor*/ 0]); duplicate_state_tensor_float32( - node.outputs->data[/*kCellStateTensor*/ 2]); + node.outputs->data[/*kCellStateTensor*/ 1]); add_lstm_params(node.builtin_data); + add_lstm_scratch_tensor_float32(); + add_optional_tensors(ANEURALNETWORKS_TENSOR_FLOAT32); nn_op_type = ANEURALNETWORKS_LSTM; break; } + case tflite::BuiltinOperator_SVDF: { + duplicate_state_tensor_float32(node.outputs->data[/*kStateTensor*/ 0]); + add_svdf_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SVDF; + break; + } + case tflite::BuiltinOperator_RNN: { + duplicate_state_tensor_float32( + node.outputs->data[/*kHiddenStateTensor*/ 0]); + add_rnn_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_RNN; + break; + } + case tflite::BuiltinOperator_EMBEDDING_LOOKUP: + nn_op_type = ANEURALNETWORKS_EMBEDDING_LOOKUP; + break; case tflite::BuiltinOperator_PAD: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_PAD; @@ -385,19 +534,25 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_DIV: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_DIV; + check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation); break; case tflite::BuiltinOperator_SUB: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_SUB; + check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation); + break; + case tflite::BuiltinOperator_SQUEEZE: + nnapi_version = 11; // requires NNAPI 1.1 + add_squeeze_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SQUEEZE; break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: - case tflite::BuiltinOperator_SVDF: case tflite::BuiltinOperator_HASHTABLE_LOOKUP: - case tflite::BuiltinOperator_RNN: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: - case tflite::BuiltinOperator_EMBEDDING_LOOKUP: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: @@ -414,7 +569,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_SPLIT: - case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: case tflite::BuiltinOperator_LOG_SOFTMAX: @@ -433,13 +587,24 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SELECT: case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: + case tflite::BuiltinOperator_LOG: case tflite::BuiltinOperator_TRANSPOSE_CONV: - FATAL("Op code %d is currently not delegated to NNAPI", builtin); - nn_op_type = -1; // set to invalid + case tflite::BuiltinOperator_TILE: + case tflite::BuiltinOperator_EXPAND_DIMS: + case tflite::BuiltinOperator_SPARSE_TO_DENSE: + case tflite::BuiltinOperator_EQUAL: + case tflite::BuiltinOperator_NOT_EQUAL: + case tflite::BuiltinOperator_SUM: + case tflite::BuiltinOperator_SQRT: + case tflite::BuiltinOperator_RSQRT: + case tflite::BuiltinOperator_SHAPE: + case tflite::BuiltinOperator_POW: + logError("Op code %d is currently not delegated to NNAPI", builtin); + return kTfLiteError; break; case tflite::BuiltinOperator_CUSTOM: - FATAL("Custom operations are not supported when using NNAPI."); - nn_op_type = -1; // set to invalid + logError("Custom operations are not supported when using NNAPI."); + return kTfLiteError; break; } @@ -448,39 +613,76 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, } // Add the operation. - CHECK_NN(ANeuralNetworksModel_addOperation( + RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast(augmented_inputs.size()), - augmented_inputs.data(), static_cast(node.outputs->size), - reinterpret_cast(node.outputs->data))); + augmented_inputs.data(), + static_cast(augmented_outputs.size()), + reinterpret_cast(augmented_outputs.data()))); } + return kTfLiteOk; } TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { - // TODO(aselle): This is not correct. need to handle resize invalidation. - if (nn_model_ && nn_compiled_model_) return kTfLiteOk; + if (nn_model_ && nn_compiled_model_) return model_status_; + // TODO(aselle): This is not correct. need to handle resize invalidation. if (!nn_model_) { CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); - // Find all the temporary tensors and put them in a skip_list. - std::vector skip_list; + // Find which tensors should be added to NNAPI. TFLite has temporaries + // and RNN back-edges which are are not valid for NNAPI. We look through all + // inputs and outputs and mark the mapping in tensor_id_to_nnapi_id with + // kOperandIdNotSet. addTensorOperands will replace those with the + // corresponding NNAPI operand ids and skip kOperandNotNeeded entries. + std::vector tensor_id_to_nnapi_id(interpreter->tensors_size(), + kOperandNotNeeded); + auto set_ids_to_not_set = [&tensor_id_to_nnapi_id](const int* buf, + size_t count) { + for (int j = 0; j < count; j++) { + auto tensor_id = buf[j]; + if (tensor_id != kOptionalTensor) { + tensor_id_to_nnapi_id[tensor_id] = kOperandIdNotSet; + } + } + }; for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; - if (node.temporaries != nullptr) { - for (int j = 0; j < node.temporaries->size; j++) { - skip_list.push_back(static_cast(node.temporaries->data[j])); - } - } + set_ids_to_not_set(node.inputs->data, node.inputs->size); + set_ids_to_not_set(node.outputs->data, node.outputs->size); } + set_ids_to_not_set(interpreter->inputs().data(), + interpreter->inputs().size()); + set_ids_to_not_set(interpreter->outputs().data(), + interpreter->outputs().size()); + + uint32_t next_id = 0; + RETURN_ERROR_IF_NN_FAILED(addTensorOperands( + interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id)); + RETURN_ERROR_IF_NN_FAILED( + AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, + &model_states_outputs_, tensor_id_to_nnapi_id)); + + std::vector augmented_inputs; + MapAndAddTensorIds(interpreter->inputs().data(), + interpreter->inputs().size(), &augmented_inputs, + tensor_id_to_nnapi_id); + augmented_inputs.insert(augmented_inputs.end(), + model_states_inputs_.begin(), + model_states_inputs_.end()); + std::vector augmented_outputs; + MapAndAddTensorIds(interpreter->outputs().data(), + interpreter->outputs().size(), &augmented_outputs, + tensor_id_to_nnapi_id); + MapAndAddTensorIds(model_states_outputs_.data(), + model_states_outputs_.size(), &augmented_outputs, + tensor_id_to_nnapi_id); - uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list); - AddOpsAndParams(interpreter, nn_model_, next_id); CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( - nn_model_, static_cast(interpreter->inputs().size()), - reinterpret_cast(interpreter->inputs().data()), - static_cast(interpreter->outputs().size()), - reinterpret_cast(interpreter->outputs().data()))); + nn_model_, static_cast(augmented_inputs.size()), + reinterpret_cast(augmented_inputs.data()), + static_cast(augmented_outputs.size()), + reinterpret_cast(augmented_outputs.data()))); CHECK_NN(ANeuralNetworksModel_finish(nn_model_)); } if (!nn_compiled_model_) { @@ -492,7 +694,13 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { if (!nn_model_) { - TF_LITE_ENSURE_STATUS(BuildGraph(interpreter)); + model_status_ = BuildGraph(interpreter); + if (model_status_ != kTfLiteOk) { + logError("Failed to build graph for NNAPI"); + } + } + if (model_status_ != kTfLiteOk) { + return model_status_; } ANeuralNetworksExecution* execution = nullptr; @@ -507,6 +715,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { CHECK_NN(ANeuralNetworksExecution_setInput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } + // Tell nn api where to place final data. for (size_t i = 0; i < interpreter->outputs().size(); i++) { int output = interpreter->outputs()[i]; @@ -514,6 +723,24 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { CHECK_NN(ANeuralNetworksExecution_setOutput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } + + // The state_out of previous invocation need to be mapped to state_in of + // current invocation. + for (size_t i = 0; i < model_states_outputs_.size(); i++) { + int state_tensor_idx = model_states_outputs_[i]; + TfLiteTensor* tensor = interpreter->tensor(state_tensor_idx); + // Here we are using a deep copy for state_in tensors so that we are not + // reading and writing into the same buffer during a invocation. + // TODO(miaowang): using double shared buffer to minimize the copies. + CHECK_NN(ANeuralNetworksExecution_setInput( + execution, i + interpreter->inputs().size(), nullptr, tensor->data.raw, + tensor->bytes)); + // Tell NNAPI where to output the state_out. + CHECK_NN(ANeuralNetworksExecution_setOutput( + execution, i + interpreter->outputs().size(), nullptr, tensor->data.raw, + tensor->bytes)); + } + // Currently use blocking compute. ANeuralNetworksEvent* event = nullptr; CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event)); diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index e98000929a1168c786f6c18f498f9d1d72311ada..8dc7d38a303f51b7ccefefd8c9d2990b443e6827 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -59,6 +59,16 @@ class NNAPIDelegate { ANeuralNetworksModel* nn_model_ = nullptr; // The NN API compilation handle ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; + // Model status + TfLiteStatus model_status_ = kTfLiteOk; + + // List of state tensors for LSTM, RNN, SVDF. + // NN API does not allow ops to maintain states across multiple + // invocations. We need to manually create state input tensors from + // corresponding state output tensors of TFLite operations, and map them + // correctly. + std::vector model_states_inputs_; // holds NNAPI operand ids + std::vector model_states_outputs_; // holds TFLite tensor ids }; } // namespace tflite diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h index 38a27069421586f28a5fbe4c7880a28f80548b98..9d7e3f20854a3596181ffa885cc17cfdbd16356e 100644 --- a/tensorflow/contrib/lite/op_resolver.h +++ b/tensorflow/contrib/lite/op_resolver.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/util.h" namespace tflite { @@ -55,8 +56,7 @@ struct OperatorKeyHasher { size_t operator()(const T& x) const { size_t a = ValueHasher()(x.first); size_t b = ValueHasher()(x.second); - // Hash combinator used by TensorFlow core. - return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4)); + return CombineHashes({a, b}); } }; } // namespace op_resolver_hasher diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index dfdd80ea8a42af683632be1d7e8ab0062847077d..f1f025f777c987c5ee47bdea457a973896b9bb82 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -50,6 +50,10 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteString"; case kTfLiteBool: return "kTfLiteBool"; + case kTfLiteInt16: + return "kTfLiteInt16"; + case kTfLiteComplex64: + return "kTfLiteComplex64"; } return "(invalid)"; } @@ -82,13 +86,13 @@ void PrintInterpreterState(Interpreter* interpreter) { for (int tensor_index = 0; tensor_index < interpreter->tensors_size(); tensor_index++) { TfLiteTensor* tensor = interpreter->tensor(tensor_index); - printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, - TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type), - tensor->bytes, float(tensor->bytes) / float(1 << 20)); + printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, + tensor->name, TensorTypeName(tensor->type), + AllocTypeName(tensor->allocation_type), tensor->bytes, + (static_cast(tensor->bytes) / (1 << 20))); PrintTfLiteIntVector(tensor->dims); - printf("\n"); } - + printf("\n"); for (int node_index = 0; node_index < interpreter->nodes_size(); node_index++) { const std::pair* node_and_reg = @@ -104,7 +108,4 @@ void PrintInterpreterState(Interpreter* interpreter) { } } -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h index 1b6998cda382782b974bea3d18ffb6217e8f780c..7fb4b8d8b7ae87cc6e8dd8503c8a4ce0cef2ce8d 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.h +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -24,9 +24,6 @@ namespace tflite { // Prints a dump of what tensors and what nodes are in the interpreter. void PrintInterpreterState(Interpreter* interpreter); -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index 15999e5d4188db1e191936ae6d84faf8cce5ca6e..a162b87b8f98576ec7c3b2623d1d34f2baef6cce 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + common_copts = [ "-Wall", -] +] + tflite_copts() cc_library( name = "profiler", @@ -29,6 +31,43 @@ cc_library( name = "profile_buffer", hdrs = ["profile_buffer.h"], copts = common_copts, + deps = [":time"], +) + +cc_library( + name = "time", + srcs = ["time.cc"], + hdrs = ["time.h"], + copts = common_copts, +) + +cc_library( + name = "profile_summarizer", + srcs = ["profile_summarizer.cc"], + hdrs = ["profile_summarizer.h"], + copts = common_copts, + deps = [ + ":profiler", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:stats_calculator_portable", + ], +) + +cc_test( + name = "profile_summarizer_test", + srcs = ["profile_summarizer_test.cc"], + copts = common_copts, + deps = [ + ":profile_summarizer", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], ) cc_test( diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h index 299b2a9cad161ce05ba68f39cf612f9866a0b656..65d86dce47f397c7dad6cc2beb8ffa1f95b29d45 100644 --- a/tensorflow/contrib/lite/profiling/profile_buffer.h +++ b/tensorflow/contrib/lite/profiling/profile_buffer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/profiling/time.h" + namespace tflite { namespace profiling { @@ -74,7 +76,7 @@ class ProfileBuffer { if (!enabled_) { return kInvalidEventHandle; } - uint64_t timestamp = NowMicros(); + uint64_t timestamp = time::NowMicros(); int index = current_index_ % event_buffer_.size(); event_buffer_[index].tag = tag; event_buffer_[index].event_type = event_type; @@ -103,7 +105,7 @@ class ProfileBuffer { } int event_index = event_handle % max_size; - event_buffer_[event_index].end_timestamp_us = NowMicros(); + event_buffer_[event_index].end_timestamp_us = time::NowMicros(); } // Returns the size of the buffer. @@ -134,12 +136,6 @@ class ProfileBuffer { } private: - static uint64_t NowMicros() { - // TODO(shashishekhar): Refactor this to a separate file. - struct timeval tv; - gettimeofday(&tv, nullptr); - return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; - } bool enabled_; uint32_t current_index_; std::vector event_buffer_; diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..c37a0965884a803e82da536f73a8f32a28691651 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" + +#include + +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace profiling { +namespace { + +using Detail = tensorflow::StatsCalculator::Detail; + +struct OperatorDetails { + std::string name; + std::vector inputs; + std::vector outputs; +}; + +std::string GetTensorName(const tflite::Interpreter& interpreter, + int tensor_index) { + const auto tensor = interpreter.tensor(tensor_index); + if (tensor == nullptr || tensor->name == nullptr) { + return "Unknown"; + } + return tensor->name; +} +std::vector GetTensorNames(const tflite::Interpreter& interpreter, + const TfLiteIntArray* tensor_indices) { + std::vector tensors; + tensors.reserve(tensor_indices->size); + for (int i = 0; i < tensor_indices->size; i++) { + tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i])); + } + return tensors; +} + +std::string ToString(const std::vector& str_vector) { + std::stringstream stream; + stream << "["; + bool first = true; + for (const auto& s : str_vector) { + if (!first) { + stream << ", "; + } else { + first = false; + } + stream << s; + } + stream << "]"; + return stream.str(); +} + +OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, + int node_index) { + auto node_reg = interpreter.node_and_registration(node_index); + auto inputs = node_reg->first.inputs; + auto outputs = node_reg->first.outputs; + int code = node_reg->second.builtin_code; + const char* op_name = nullptr; + if (code == tflite::BuiltinOperator_CUSTOM) { + const char* custom_name = node_reg->second.custom_name; + op_name = custom_name ? custom_name : "UnknownCustomOp"; + } else { + op_name = tflite::EnumNamesBuiltinOperator()[code]; + } + const char* profiling_string = + interpreter.OpProfilingString(node_reg->second, &node_reg->first); + OperatorDetails details; + details.name = op_name; + if (profiling_string) { + details.name += ":" + string(profiling_string); + } + details.inputs = GetTensorNames(interpreter, inputs); + details.outputs = GetTensorNames(interpreter, outputs); + return details; +} + +tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() { + auto options = tensorflow::StatSummarizerOptions(); + options.show_summary = true; + options.show_memory = false; + return options; +} + +} // namespace + +ProfileSummarizer::ProfileSummarizer() + : stats_calculator_( + new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {} + +void ProfileSummarizer::ProcessProfiles( + const std::vector& profile_stats, + const tflite::Interpreter& interpreter) { + std::vector events; + std::copy_if(profile_stats.begin(), profile_stats.end(), + std::back_inserter(events), [](const ProfileEvent* e) { + return e->event_type == + ProfileEvent::EventType::OPERATOR_INVOKE_EVENT && + e->end_timestamp_us >= e->begin_timestamp_us; + }); + // Sort with begin_time. + std::sort(events.begin(), events.end(), + [](const ProfileEvent* const& a, const ProfileEvent* const& b) { + return a->begin_timestamp_us < b->begin_timestamp_us; + }); + if (events.empty()) { + return; + } + + int64_t base_start_us = events[0]->begin_timestamp_us; + int node_num = 0; + int64_t curr_total_us = 0; + std::map details; + for (auto event : events) { + auto op_details = GetOperatorDetails(interpreter, event->event_metadata); + auto node_name = ToString(op_details.outputs); + auto result = details.emplace(node_name, Detail()); + Detail* detail = &(result.first->second); + detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us); + int64_t node_exec_time = + event->end_timestamp_us - event->begin_timestamp_us; + detail->rel_end_us.UpdateStat(node_exec_time); + curr_total_us += node_exec_time; + ++node_num; + + if (result.second) { + detail->name = node_name; + detail->type = op_details.name; + detail->run_order = node_num; + detail->times_called = 0; + } + ++detail->times_called; + } + stats_calculator_->UpdateDetails(details); + stats_calculator_->UpdateRunTotalUs(curr_total_us); +} +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h new file mode 100644 index 0000000000000000000000000000000000000000..a529ff87428d70d002241311d7f70f185521020f --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ + +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/profiling/profiler.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace profiling { + +// Creates a summary of operator invocations in the interpreter. +class ProfileSummarizer { + public: + ProfileSummarizer(); + virtual ~ProfileSummarizer() {} + + // Process profile events to update statistics for operator invocations. + void ProcessProfiles(const std::vector& profile_stats, + const tflite::Interpreter& interpreter); + + // Returns a string detailing the accumulated runtime stats in a tab-separated + // format which can be pasted into a spreadsheet for further analysis. + std::string GetOutputString() const { + return stats_calculator_->GetOutputString(); + } + + std::string GetShortSummary() const { + return stats_calculator_->GetShortSummary(); + } + + private: + std::unique_ptr stats_calculator_; +}; + +} // namespace profiling +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..67a5eecfa05379c7a721e7d669fcd02602e5e369 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc @@ -0,0 +1,156 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { +namespace profiling { + +namespace { + +#ifdef TFLITE_PROFILING_ENABLED +TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0); + const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1); + + TfLiteTensor* output = GetOutput(context, node, /*index=*/0); + + int32_t* output_data = output->data.i32; + *output_data = *(input1->data.i32) + *(input2->data.i32); + return kTfLiteOk; +} + +const char* SimpleOpProfilingString(const TfLiteContext* context, + const TfLiteNode* node) { + return "Profile"; +} + +TfLiteRegistration* RegisterSimpleOp() { + static TfLiteRegistration registration = { + nullptr, nullptr, nullptr, + SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", 1}; + return ®istration; +} + +TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() { + static TfLiteRegistration registration = {nullptr, + nullptr, + nullptr, + SimpleOpEval, + SimpleOpProfilingString, + tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", + 1}; + return ®istration; +} +#endif + +class SimpleOpModel : public SingleOpModel { + public: + void Init(const std::function& registration); + tflite::Interpreter* GetInterpreter() { return interpreter_.get(); } + void SetInputs(int32_t x, int32_t y) { + PopulateTensor(inputs_[0], {x}); + PopulateTensor(inputs_[1], {y}); + } + int32_t GetOutput() { return ExtractVector(output_)[0]; } + + private: + int inputs_[2]; + int output_; +}; + +void SimpleOpModel::Init( + const std::function& registration) { + inputs_[0] = AddInput({TensorType_INT32, {1}}); + inputs_[1] = AddInput({TensorType_INT32, {1}}); + output_ = AddOutput({TensorType_INT32, {}}); + SetCustomOp("SimpleAdd", {}, registration); + BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])}); +} + +TEST(ProfileSummarizerTest, Empty) { + ProfileSummarizer summarizer; + std::string output = summarizer.GetOutputString(); + EXPECT_GT(output.size(), 0); +} + +#ifdef TFLITE_PROFILING_ENABLED +TEST(ProfileSummarizerTest, Interpreter) { + Profiler profiler; + SimpleOpModel m; + m.Init(RegisterSimpleOp); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output; +} + +TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) { + Profiler profiler; + SimpleOpModel m; + m.Init(RegisterSimpleOpWithProfilingDetails); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos) + << output; +} + +#endif + +} // namespace +} // namespace profiling +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc new file mode 100644 index 0000000000000000000000000000000000000000..446660bb747cd6e3b694669b64ac1d95cf415fbe --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/profiling/time.h" + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +} +} // namespace time +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/time.h b/tensorflow/contrib/lite/profiling/time.h new file mode 100644 index 0000000000000000000000000000000000000000..cc2ec319b8a95b3efa0aab0ac9f97a88bf7b5536 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros(); +} // namespace time +} // namespace profiling +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 4920e83970d1cb7f60a38f95ea05986d52b0bbe7..27909a9458f6b09f96cb556a5254f01e54f46e05 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -36,6 +36,16 @@ py_test( ], ) +py_binary( + name = "tflite_convert", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ], +) + py_library( name = "lite", srcs = ["lite.py"], @@ -45,7 +55,23 @@ py_library( ":convert", ":convert_saved_model", ":interpreter", + ":lite_constants", ":op_hint", + "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/tools:freeze_graph_lib", + ], +) + +py_test( + name = "lite_test", + srcs = ["lite_test.py"], + data = [":interpreter_test_data"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":lite", ], ) @@ -111,9 +137,9 @@ py_library( visibility = ["//visibility:public"], deps = [ ":convert", - ":lite_constants", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", + "//tensorflow/python:platform", "//tensorflow/python/tools:freeze_graph_lib", ], ) @@ -150,20 +176,3 @@ py_test( "//tensorflow/python/saved_model", ], ) - -py_binary( - name = "convert_saved_model_to_frozen_graph", - srcs = ["convert_saved_model_to_frozen_graph.py"], - srcs_version = "PY2AND3", - deps = [ - ":convert_saved_model", - ], -) - -# Transitive dependencies of this target will be included in the pip package. -py_library( - name = "tf_lite_py_pip", - deps = [ - ":convert_saved_model", - ], -) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c4200c879ba0e17b3bd183f4004eb75ebdd2f5ee..0ea2630f711727787332f207bdff6383aac8097c 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -25,7 +25,6 @@ import tempfile as _tempfile from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util.lazy_loader import LazyLoader @@ -111,37 +110,75 @@ def tensor_name(x): return x.name.split(":")[0] -def toco_convert(input_data, - input_tensors, - output_tensors, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True): - """Convert a model using TOCO from `input_format` to `output_format`. +def build_toco_convert_protos(input_tensors, + output_tensors, + inference_type=lite_constants.FLOAT, + inference_input_type=None, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + default_ranges_stats=None, + drop_control_dependency=True, + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False, + quantize_weights=False, + dump_graphviz_dir=None, + dump_graphviz_video=False): + """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which case the default `input_format` and `output_format` are sufficient. Args: - input_data: Input data (i.e. often `sess.graph_def`). input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + input_format: Type of data to read Currently must be + `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: List of tuples of integers representing the mean and + standard deviation. Each tuple maps to the corresponding input tensor. + Only need if `inference_type` is `QUANTIZED_UINT8`. (default None) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. + model_flags, toco_flags: two protocol buffers describing the conversion + process. Raises: ValueError: If the input tensor type is unknown @@ -151,37 +188,60 @@ def toco_convert(input_data, toco = _toco_flags_pb2.TocoFlags() toco.input_format = input_format toco.output_format = output_format + toco.inference_type = inference_type + if inference_input_type: + toco.inference_input_type = inference_input_type toco.drop_control_dependency = drop_control_dependency + toco.reorder_across_fake_quant = reorder_across_fake_quant + toco.allow_custom_ops = allow_custom_ops + toco.quantize_weights = quantize_weights + if default_ranges_stats: + toco.default_ranges_min = default_ranges_stats[0] + toco.default_ranges_max = default_ranges_stats[1] + if dump_graphviz_dir: + toco.dump_graphviz_dir = dump_graphviz_dir + toco.dump_graphviz_include_video = dump_graphviz_video + model = _model_flags_pb2.ModelFlags() - toco.inference_type = inference_type + model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): - if input_tensor.dtype == _dtypes.float32: - tflite_input_type = lite_constants.FLOAT - elif input_tensor.dtype == _dtypes.int32: - tflite_input_type = lite_constants.INT32 - elif input_tensor.dtype == _dtypes.int64: - tflite_input_type = lite_constants.INT64 - # TODO(aselle): Insert strings when they are available - else: - raise ValueError("Tensors %s not known type %r" % (input_tensor.name, - input_tensor.dtype)) - input_array = model.input_arrays.add() - if inference_type == lite_constants.QUANTIZED_UINT8: - if tflite_input_type == lite_constants.FLOAT: - tflite_input_type = lite_constants.QUANTIZED_UINT8 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] - input_array.name = tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) + return model, toco + + +def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): + """"Convert a model using TOCO. + + Typically this function is used to convert from TensorFlow GraphDef to TFLite. + Conversion can be customized by providing arguments that are forwarded to + `build_toco_convert_protos` (see documentation for details). - # TODO(aselle): Consider handling the case of allowing quantized - # inputs to be converted to float (via the toco.inference_input_type field). - data = toco_convert_protos(model.SerializeToString(), - toco.SerializeToString(), + Args: + input_data: Input data (i.e. often `sess.graph_def`), + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + *args: See `build_toco_convert_protos`, + **kwargs: See `build_toco_convert_protos`. + + Returns: + The converted data. For example if TFLite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + Defined in `build_toco_convert_protos`. + """ + model_flags, toco_flags = build_toco_convert_protos(input_tensors, + output_tensors, + *args, **kwargs) + data = toco_convert_protos(model_flags.SerializeToString(), + toco_flags.SerializeToString(), input_data.SerializeToString()) return data diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index a7eddf3408f54dff5fa49ff6fa7b61cd0b8a22e4..1553464b9fe30f596c151bcc67efe891bb913ba3 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -18,34 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python import convert -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.toco import model_flags_pb2 -from tensorflow.contrib.saved_model.python.saved_model import reader -from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants - - -def _write_and_flush_file(file_path, data_str): - """Writes data to file path. - - Args: - file_path: Full path of the file to store data in. - data_str: Data represented as a string. - - Returns: None. - """ - with gfile.Open(file_path, "wb") as data_file: - data_file.write(data_str) - data_file.flush() def _log_tensor_details(tensor_info): @@ -77,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set): Raises: ValueError: No valid MetaGraphDef for given tag_set. """ - saved_model = reader.read_saved_model(saved_model_dir) - tag_sets = [] - result_meta_graph_def = None - for meta_graph_def in saved_model.meta_graphs: - meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) - tag_sets.append(meta_graph_tag_set) - if meta_graph_tag_set == tag_set: - result_meta_graph_def = meta_graph_def - logging.info("The given saved_model contains the following tags: %s", - tag_sets) - if result_meta_graph_def is not None: - return result_meta_graph_def - else: - raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " - "values are '{}'. ".format(tag_set, tag_sets)) + with session.Session(graph=ops.Graph()) as sess: + return loader.load(sess, tag_set, saved_model_dir) def _get_signature_def(meta_graph, signature_key): @@ -110,15 +77,13 @@ def _get_signature_def(meta_graph, signature_key): signature_def_map = meta_graph.signature_def signature_def_keys = set(signature_def_map.keys()) logging.info( - "The given saved_model MetaGraphDef contains SignatureDefs with the " + "The given SavedModel MetaGraphDef contains SignatureDefs with the " "following keys: %s", signature_def_keys) if signature_key not in signature_def_keys: - raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible " - "values are '{}'. ".format(signature_key, - signature_def_keys)) - signature_def = signature_def_utils.get_signature_def_by_key( - meta_graph, signature_key) - return signature_def + raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " + "values are '{}'.".format(signature_key, + ",".join(signature_def_keys))) + return signature_def_map[signature_key] def _get_inputs_outputs(signature_def): @@ -170,29 +135,10 @@ def _get_tensors(graph, signature_def_tensor_names=None, """ tensors = [] if user_tensor_names: - # Get the list of all of the tensors with and without the tensor index. - all_tensor_names = [ - tensor.name for op in graph.get_operations() for tensor in op.outputs - ] - all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names] - # Sort the tensor names. user_tensor_names = sorted(user_tensor_names) - # Get the tensors associated with the tensor names. - tensors = [] - invalid_tensors = [] - for name in user_tensor_names: - if name not in all_tensor_names_only: - invalid_tensors.append(name) - else: - idx = all_tensor_names_only.index(name) - tensors.append(graph.get_tensor_by_name(all_tensor_names[idx])) - - # Throw ValueError if any user input names are not valid tensors. - if invalid_tensors: - raise ValueError("Invalid tensors '{}' were found.".format( - ",".join(invalid_tensors))) + tensors = get_tensors_from_tensor_names(graph, user_tensor_names) elif signature_def_tensor_names: tensors = [ graph.get_tensor_by_name(name) @@ -207,25 +153,74 @@ def _get_tensors(graph, signature_def_tensor_names=None, return tensors -def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, - output_arrays, tag_set, signature_key, batch_size): +def get_tensors_from_tensor_names(graph, tensor_names): + """Gets the Tensors associated with the `tensor_names` in the provided graph. + + Args: + graph: TensorFlow Graph. + tensor_names: List of strings that represent names of tensors in the graph. + + Returns: + A list of Tensor objects in the same order the names are provided. + + Raises: + ValueError: + tensor_names contains an invalid tensor name. + """ + # Get the list of all of the tensors. + tensor_name_to_tensor = { + tensor_name(tensor): tensor for op in graph.get_operations() + for tensor in op.values() + } + + # Get the tensors associated with tensor_names. + tensors = [] + invalid_tensors = [] + for name in tensor_names: + tensor = tensor_name_to_tensor.get(name) + if tensor is None: + invalid_tensors.append(name) + else: + tensors.append(tensor) + + # Throw ValueError if any user input names are not valid tensors. + if invalid_tensors: + raise ValueError("Invalid tensors '{}' were found.".format( + ",".join(invalid_tensors))) + return tensors + + +def set_tensor_shapes(tensors, shapes): + """Sets Tensor shape for each tensor if the shape is defined. + + Args: + tensors: TensorFlow ops.Tensor. + shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + """ + if shapes: + for tensor in tensors: + shape = shapes.get(tensor_name(tensor)) + if shape is not None: + tensor.set_shape(shape) + + +def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key): """Converts a SavedModel to a frozen graph. Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of + from SignatureDef when none are provided. + input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) + arrays from SignatureDef when none are provided. tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) Returns: frozen_graph_def: Frozen GraphDef. @@ -236,186 +231,32 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, ValueError: SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. + assets/ directory is in the MetaGraphDef. input_shapes does not match the length of input_arrays. - input_shapes has a None value after the 1st dimension. input_arrays or output_arrays are not valid. - Unable to load Session. """ - # Set default values for inputs if they are set to None. - if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if tag_set is None: - tag_set = set([tag_constants.SERVING]) - if batch_size is None: - batch_size = 1 - # Read SignatureDef. meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) signature_def = _get_signature_def(meta_graph, signature_key) inputs, outputs = _get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + graph = ops.Graph() with session.Session(graph=graph) as sess: - # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) # Gets input and output tensors. # TODO(zhixianyan): Use TFLite supported Op list to filter outputs. in_tensors = _get_tensors(graph, inputs, input_arrays) out_tensors = _get_tensors(graph, outputs, output_arrays) - - # Gets fully defined tensor shape. An input tensor with None in the first - # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size. - # Shapes with None after the first dimension result in a ValueError. - # TODO(zhixianyan): Add supports for input tensor with more None in shape. - for tensor in in_tensors: - if (input_shapes and tensor.name in input_shapes and - input_shapes[tensor.name] is not None): - shape = input_shapes[tensor.name] - else: - shape = tensor.get_shape().as_list() - - if None in shape[1:]: - raise ValueError( - "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor.name, shape)) - elif shape[0] is None: - shape[0] = batch_size - tensor.set_shape(shape) + set_tensor_shapes(in_tensors, input_shapes) output_names = [node.split(":")[0] for node in outputs] frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_names) return frozen_graph_def, in_tensors, out_tensors - raise ValueError("Unable to load Session.") - - -def saved_model_to_frozen_graphdef( - saved_model_dir, - output_file_model, - output_file_flags, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1): - """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. - - Stores frozen graph and command line flags in the tmp directory. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file_model: Full file path to save frozen graph. - output_file_flags: Full file path to save ModelFlags. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - - Returns: None. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - # Initialize model flags. - model = model_flags_pb2.ModelFlags() - - for input_tensor in in_tensors: - input_array = model.input_arrays.add() - input_array.name = convert.tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) - - for output_tensor in out_tensors: - model.output_arrays.append(convert.tensor_name(output_tensor)) - - # Write model and ModelFlags to file. ModelFlags contain input array and - # output array information that is parsed from the SignatureDef and used for - # analysis by TOCO. - _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) - _write_and_flush_file(output_file_flags, model.SerializeToString()) - - -def tflite_from_saved_model( - saved_model_dir, - output_file=None, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True): - """Converts a SavedModel to TFLite FlatBuffer. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file: File path to write result TFLite FlatBuffer. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. - - Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - result = convert.toco_convert( - input_data=frozen_graph_def, - input_tensors=in_tensors, - output_tensors=out_tensors, - inference_type=inference_type, - input_format=input_format, - output_format=output_format, - quantized_input_stats=quantized_input_stats, - drop_control_dependency=drop_control_dependency) - - if output_file is not None: - with gfile.Open(output_file, "wb") as f: - f.write(result) - logging.info("Successfully converted to: %s", output_file) - - return result diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index db95fc8ad7f94b52d33c72f6ec5819bdfe8cf05f..92c4ebb2465c2abaa1cefd020e69b2f7ad6a54a5 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -25,12 +25,12 @@ from __future__ import print_function import os from tensorflow.contrib.lite.python import convert_saved_model -from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops @@ -38,13 +38,68 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import training as train -class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): +class TensorFunctionsTest(test_util.TensorFlowTestCase): + + def testGetTensorsValid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + tensors = convert_saved_model.get_tensors_from_tensor_names( + sess.graph, ["Placeholder"]) + self.assertEqual("Placeholder:0", tensors[0].name) + + def testGetTensorsInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + with self.assertRaises(ValueError) as error: + convert_saved_model.get_tensors_from_tensor_names(sess.graph, + ["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + def testSetTensorShapeValid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) + self.assertEqual([5, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + self.assertEqual(None, tensor.shape) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeInvalid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"invalid-input": [5, 3, 5]}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeEmpty(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + +class FreezeSavedModelTest(test_util.TensorFlowTestCase): def _createSimpleSavedModel(self, shape): """Create a simple SavedModel on the fly.""" @@ -57,82 +112,167 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): saved_model.simple_save(sess, saved_model_dir, inputs, outputs) return saved_model_dir + def _createSavedModelTwoInputArrays(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputB") + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputA") + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {"x": in_tensor_1, "y": in_tensor_2} + outputs = {"z": out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def _getArrayNames(self, tensors): + return [tensor.name for tensor in tensors] + + def _getArrayShapes(self, tensors): + dims = [] + for tensor in tensors: + dim_tensor = [] + for dim in tensor.shape: + if isinstance(dim, tensor_shape.Dimension): + dim_tensor.append(dim.value) + else: + dim_tensor.append(dim) + dims.append(dim_tensor) + return dims + + def _convertSavedModel(self, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( + saved_model_dir=saved_model_dir, + input_arrays=input_arrays, + input_shapes=input_shapes, + output_arrays=output_arrays, + tag_set=tag_set, + signature_key=signature_key) + return graph_def, in_tensors, out_tensors + def testSimpleSavedModel(self): - """Test a simple SavedModel created on the fly.""" - # Create a simple SavedModel + """Test a SavedModel.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite - result = convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) - self.assertTrue(result) + _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) def testSimpleSavedModelWithNoneBatchSizeInShape(self): - """Test a simple SavedModel, with None in input tensor's shape.""" + """Test a SavedModel with None in input tensor's shape.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) - self.assertTrue(result) + _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) - def testSimpleSavedModelWithMoreNoneInShape(self): - """Test a simple SavedModel, fail as more None in input shape.""" - saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3]) - # Convert to tflite: this should raise ValueError, as 3rd dim is None. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) - def testSimpleSavedModelWithWrongSignatureKey(self): - """Test a simple SavedModel, fail as given signature is invalid.""" + def testSimpleSavedModelWithInvalidSignatureKey(self): + """Test a SavedModel that fails due to an invalid signature_key.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite: this should raise ValueError, as - # signature_key does not exit in the saved_model. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, signature_key="wrong-key") - - def testSimpleSavedModelWithWrongOutputArray(self): - """Test a simple SavedModel, fail as given output_arrays is invalid.""" - # Create a simple SavedModel + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, signature_key="invalid-key") + self.assertEqual( + "No 'invalid-key' in the SavedModel's SignatureDefs. " + "Possible values are 'serving_default'.", str(error.exception)) + + def testSimpleSavedModelWithInvalidOutputArray(self): + """Test a SavedModel that fails due to invalid output arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite: this should raise ValueError, as - # output_arrays is not valid for the saved_model. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, output_arrays=["wrong-output"]) + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"]) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) def testSimpleSavedModelWithWrongInputArrays(self): - """Test a simple SavedModel, fail as given input_arrays is invalid.""" + """Test a SavedModel that fails due to invalid input arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Checks invalid input_arrays. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, input_arrays=["wrong-input"]) - # Checks valid and invalid input_arrays. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, - input_arrays=["Placeholder", "wrong-input"]) + + # Check invalid input_arrays. + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Check valid and invalid input_arrays. + with self.assertRaises(ValueError) as error: + self._convertSavedModel( + saved_model_dir, input_arrays=["Placeholder", "invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) def testSimpleSavedModelWithCorrectArrays(self): - """Test a simple SavedModel, with correct input_arrays and output_arrays.""" + """Test a SavedModel with correct input_arrays and output_arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, input_arrays=["Placeholder"], output_arrays=["add"]) - self.assertTrue(result) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) def testSimpleSavedModelWithCorrectInputArrays(self): - """Test a simple SavedModel, with correct input_arrays and input_shapes.""" + """Test a SavedModel with correct input_arrays and input_shapes.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, input_arrays=["Placeholder"], input_shapes={"Placeholder": [1, 16, 16, 3]}) - self.assertTrue(result) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) + + def testTwoInputArrays(self): + """Test a simple SavedModel.""" + saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) + + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"]) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"]) + self.assertEqual( + self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]]) + + def testSubsetInputArrays(self): + """Test a SavedModel with a subset of the input array names of the model.""" + saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) + + # Check case where input shape is given. + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, + input_arrays=["inputA"], + input_shapes={"inputA": [1, 16, 16, 3]}) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) + + # Check case where input shape is None. + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, input_arrays=["inputA"]) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) def testMultipleMetaGraphDef(self): - """Test saved model with multiple MetaGraphDef.""" + """Test saved model with multiple MetaGraphDefs.""" saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") builder = saved_model.builder.SavedModelBuilder(saved_model_dir) with session.Session(graph=ops.Graph()) as sess: @@ -161,91 +301,13 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): builder.save(True) # Convert to tflite - convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) - -class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase): - - def _createSimpleSavedModel(self, shape): - """Create a simple SavedModel.""" - saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") - with session.Session() as sess: - in_tensor_1 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name="inputB") - in_tensor_2 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name="inputA") - out_tensor = in_tensor_1 + in_tensor_2 - inputs = {"x": in_tensor_1, "y": in_tensor_2} - outputs = {"z": out_tensor} - saved_model.simple_save(sess, saved_model_dir, inputs, outputs) - return saved_model_dir - - def _getInputArrayNames(self, model_proto): - return [data.name for data in model_proto.input_arrays] - - def _getInputArrayShapes(self, model_proto): - return [ - [dim for dim in data.shape.dims] for data in model_proto.input_arrays - ] - - def _get_model_flags_proto_from_file(self, filename): - proto = _model_flags_pb2.ModelFlags() - with gfile.Open(filename, "rb") as output_file: - proto.ParseFromString(output_file.read()) - output_file.close() - return proto - - def testSimpleSavedModel(self): - """Test a simple SavedModel.""" - saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - output_file_model = os.path.join(self.get_temp_dir(), "model.pb") - output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") - - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputB", "inputA"]) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"]) - self.assertEqual( - self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]]) - - def testSimpleSavedModelWithDifferentInputNames(self): - """Test a simple SavedModel.""" - saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - output_file_model = os.path.join(self.get_temp_dir(), "model.pb") - output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") - - # Check case where input shape is given. - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputA"], - input_shapes={"inputA": [1, 16, 16, 3]}) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) - self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) - - # Check case where input shape is None. - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputA"], - input_shapes={"inputA": None}) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) - self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]]) class Model(keras.Model): @@ -354,7 +416,7 @@ def dummy_input_fn(): return image, labels -class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): +class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase): def testTrainedMnistSavedModel(self): """Test mnist SavedModel, trained with dummy data and small steps.""" @@ -379,13 +441,16 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): # Convert to tflite and test output saved_model_name = os.listdir(saved_model_dir)[0] saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) - output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite") + # TODO(zhixianyan): no need to limit output_arrays to `Softmax' # once b/74205001 fixed and argmax implemented in tflite. - result = convert_saved_model.tflite_from_saved_model( + result = convert_saved_model.freeze_saved_model( saved_model_dir=saved_model_final_dir, + input_arrays=None, + input_shapes=None, output_arrays=["Softmax"], - output_file=output_file) + tag_set=set([tag_constants.SERVING]), + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) self.assertTrue(result) diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py deleted file mode 100644 index 4d9782f4a6a9e853c3afdbd97d4264a818937e63..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python console command for generating frozen models from SavedModels. - -This exists to add SavedModel compatibility to TOCO. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys -from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef -from tensorflow.python.platform import app - -FLAGS = None - - -def execute(unused_args): - """Calls function to convert the SavedModel to a frozen graph.""" - # Error handling. - if FLAGS.input_shapes and not FLAGS.input_arrays: - raise ValueError("Input shapes requires input arrays to be specified.") - - # Calls saved_model_to_frozen_graphdef function to generate frozen graph. - input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None) - input_shapes = None - if FLAGS.input_shapes: - input_shapes = { - input_arrays[idx]: shape.split(",") - for idx, shape in enumerate(FLAGS.input_shapes.split(":")) - } - output_arrays = ( - FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None) - tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None - - saved_model_to_frozen_graphdef( - saved_model_dir=FLAGS.saved_model_directory, - output_file_model=FLAGS.output_file_model, - output_file_flags=FLAGS.output_file_flags, - input_arrays=input_arrays, - input_shapes=input_shapes, - output_arrays=output_arrays, - tag_set=tag_set, - signature_key=FLAGS.signature_key, - batch_size=FLAGS.batch_size) - - -def main(): - global FLAGS - # Parses flags. - parser = argparse.ArgumentParser( - description="Invoke SavedModel to frozen model converter.") - parser.add_argument( - "saved_model_directory", - type=str, - help="Full path to directory containing the SavedModel.") - parser.add_argument( - "output_file_model", - type=str, - help="Full file path to save frozen graph.") - parser.add_argument( - "output_file_flags", type=str, help="Full file path to save ModelFlags.") - parser.add_argument( - "--input_arrays", - type=str, - help="Name of the input arrays, comma-separated.") - parser.add_argument( - "--input_shapes", - type=str, - help="Shapes corresponding to --input_arrays, colon-separated.") - parser.add_argument( - "--output_arrays", - type=str, - help="Name of the output arrays, comma-separated.") - parser.add_argument( - "--tag_set", type=str, help="Name of output arrays, comma-separated.") - parser.add_argument( - "--signature_key", - type=str, - help="Key identifying SignatureDef containing inputs and outputs.") - parser.add_argument( - "--batch_size", - type=int, - help="Batch size for the model. Replaces the first dimension of an " - "input size array if undefined.") - - FLAGS, unparsed = parser.parse_known_args() - - app.run(main=execute, argv=[sys.argv[0]] + unparsed) - - -if __name__ == "__main__": - main() diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 5fbc55145217dd8a4e3eec4108e18d1c2be5c883..fd908234254185e0a0639618e936ca8ff58631da 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys from tensorflow.python.util.lazy_loader import LazyLoader # Lazy load since some of the performance benchmark skylark rules @@ -54,7 +55,7 @@ class Interpreter(object): elif model_content and not model_path: self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( - model_content, len(model_content))) + model_content)) if not self._interpreter: raise ValueError( 'Failed to create model from {} bytes'.format(len(model_content))) @@ -64,9 +65,38 @@ class Interpreter(object): raise ValueError('Can\'t both provide `model_path` and `model_content`') def allocate_tensors(self): + self._ensure_safe() if not self._interpreter.AllocateTensors(): raise ValueError('Failed to allocate tensors') + def _safe_to_run(self): + """Returns true if there exist no numpy array buffers. + + This means it is safe to run tflite calls that may destroy internally + allocated memory. This works, because in the wrapper.cc we have made + the numpy base be the self._interpreter. + """ + # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. + # If this environment is the only _interpreter, then the ref count should be + # 2 (1 in self and 1 in temporary of sys.getrefcount). + return sys.getrefcount(self._interpreter) == 2 + + def _ensure_safe(self): + """Makes sure no numpy arrays pointing to internal buffers are active. + + This should be called from any function that will call a function on + _interpreter that may reallocate memory e.g. invoke(), ... + + Raises: + RuntimeError: If there exist numpy objects pointing to internal memory + then we throw. + """ + if not self._safe_to_run(): + raise RuntimeError("""There is at least 1 reference to internal data + in the interpreter in the form of a numpy array or slice. Be sure to + only hold the function returned from tensor() if you are using raw + data access.""") + def _get_tensor_details(self, tensor_index): """Gets tensor details. @@ -109,7 +139,10 @@ class Interpreter(object): ] def set_tensor(self, tensor_index, value): - """Sets the value of the input. + """Sets the value of the input tensor. Note this copies data in `value`. + + If you want to avoid copying, you can use the `tensor()` function to get a + numpy buffer pointing to the input buffer in the tflite interpreter. Args: tensor_index: Tensor index of tensor to set. This value can be gotten from @@ -133,6 +166,7 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not resize the input tensor. """ + self._ensure_safe() if not self._interpreter.ResizeInputTensor(input_index, tensor_size): raise ValueError('Failed to resize input') @@ -147,7 +181,7 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Sets the value of the input. + """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -158,6 +192,60 @@ class Interpreter(object): """ return self._interpreter.GetTensor(tensor_index) + def tensor(self, tensor_index): + """Returns function that gives a numpy view of the current tensor buffer. + + This allows reading and writing to this tensors w/o copies. This more + closely mirrors the C++ Interpreter class interface's tensor() member, hence + the name. Be careful to not hold these output references through calls + to `allocate_tensors()` and `invoke()`. + + Usage: + + interpreter.allocate_tensors() + input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) + output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) + for i in range(10): + input().fill(3.) + interpreter.invoke() + print("inference %s" % output) + + Notice how this function avoids making a numpy array directly. This is + because it is important to not hold actual numpy views to the data longer + than necessary. If you do, then the interpreter can no longer be invoked, + because it is possible the interpreter would resize and invalidate the + referenced tensors. The NumPy API doesn't allow any mutability of the + the underlying buffers. + + WRONG: + + input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() + output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() + interpreter.allocate_tensors() # This will throw RuntimeError + for i in range(10): + input.fill(3.) + interpreter.invoke() # this will throw RuntimeError since input,output + + Args: + tensor_index: Tensor index of tensor to get. This value can be gotten from + the 'index' field in get_output_details. + + Returns: + A function that can return a new numpy array pointing to the internal + TFLite tensor state at any point. It is safe to hold the function forever, + but it is not safe to hold the numpy array forever. + """ + return lambda: self._interpreter.tensor(self._interpreter, tensor_index) + def invoke(self): + """Invoke the interpreter. + + Be sure to set the input sizes, allocate tensors and fill values before + calling this. + + Raises: + ValueError: When the underlying interpreter fails raise ValueError. + """ + self._ensure_safe() if not self._interpreter.Invoke(): raise ValueError('Failed to invoke TFLite model') diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index f802edf020db8a9d4e7bb890aadaae7e34e983a8..5f1fa26c3b7f76309a6f1f80aa3c1e4889781528 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -91,5 +91,61 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) +class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + self.interpreter.allocate_tensors() + self.input0 = self.interpreter.get_input_details()[0]['index'] + self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32) + + def testTensorAccessor(self): + """Check that tensor returns a reference.""" + array_ref = self.interpreter.tensor(self.input0) + np.copyto(array_ref(), self.initial_data) + self.assertAllEqual(array_ref(), self.initial_data) + self.assertAllEqual( + self.interpreter.get_tensor(self.input0), self.initial_data) + + def testGetTensorAccessor(self): + """Check that get_tensor returns a copy.""" + self.interpreter.set_tensor(self.input0, self.initial_data) + array_initial_copy = self.interpreter.get_tensor(self.input0) + new_value = np.add(1., array_initial_copy) + self.interpreter.set_tensor(self.input0, new_value) + self.assertAllEqual(array_initial_copy, self.initial_data) + self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value) + + def testBase(self): + self.assertTrue(self.interpreter._safe_to_run()) + _ = self.interpreter.tensor(self.input0) + self.assertTrue(self.interpreter._safe_to_run()) + in0 = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + in0b = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + # Now get rid of the buffers so that we can evaluate. + del in0 + del in0b + self.assertTrue(self.interpreter._safe_to_run()) + + def testBaseProtectsFunctions(self): + in0 = self.interpreter.tensor(self.input0)() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.allocate_tensors() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.invoke() + # Now test that we can run + del in0 # this is our only buffer reference, so now it is safe to change + in0safe = self.interpreter.tensor(self.input0) + _ = self.interpreter.allocate_tensors() + del in0safe # make sure in0Safe is held but lint doesn't complain + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 453eda6e7345762666917fd501b69c7181c349e8..634c2a1e1f5005208b4eea5c853a43cccf4d244c 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -14,8 +14,8 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/core:lib", - "//tensorflow/python:numpy_lib", - "//util/python:python_headers", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", ], ) @@ -27,6 +27,6 @@ tf_py_wrap_cc( ], deps = [ ":interpreter_wrapper_lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 16f4f30b94313453c3a4f0496a6ac5649847f076..5554d08fa08fdc6ddcb042d12f979164a144e337 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -21,7 +21,14 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/python/lib/core/numpy.h" + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" #if PY_MAJOR_VERSION >= 3 #define PY_TO_CPPSTRING PyBytes_AsStringAndSize @@ -35,6 +42,13 @@ namespace tflite { namespace interpreter_wrapper { namespace { + +// Calls PyArray's initialization to initialize all the API pointers. Note that +// this usage implies only this translation unit can use the pointers. See +// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend +// this further. +void ImportNumpy() { import_array1(); } + std::unique_ptr CreateInterpreter( const tflite::FlatBufferModel* model, const tflite::ops::builtin::BuiltinOpResolver& resolver) { @@ -42,6 +56,8 @@ std::unique_ptr CreateInterpreter( return nullptr; } + ImportNumpy(); + std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); if (interpreter) { @@ -66,6 +82,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT32; case kTfLiteInt32: return NPY_INT32; + case kTfLiteInt16: + return NPY_INT16; case kTfLiteUInt8: return NPY_UINT8; case kTfLiteInt64: @@ -74,6 +92,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_OBJECT; case kTfLiteBool: return NPY_BOOL; + case kTfLiteComplex64: + return NPY_COMPLEX64; case kTfLiteNoType: return -1; } @@ -88,6 +108,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteFloat32; case NPY_INT32: return kTfLiteInt32; + case NPY_INT16: + return kTfLiteInt16; case NPY_UINT8: return kTfLiteUInt8; case NPY_INT64: @@ -98,6 +120,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { case NPY_STRING: case NPY_UNICODE: return kTfLiteString; + case NPY_COMPLEX64: + return kTfLiteComplex64; } LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; @@ -282,47 +306,93 @@ bool InterpreterWrapper::SetTensor(int i, PyObject* value) { return true; } -PyObject* InterpreterWrapper::GetTensor(int i) const { - if (!interpreter_) { +namespace { + +PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index, + TfLiteTensor** tensor, int* type_num) { + if (!interpreter) { LOG(ERROR) << "Invalid interpreter."; Py_INCREF(Py_None); return Py_None; } - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->inputs().size(); + if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) { + LOG(ERROR) << "Invalid tensor index: " << tensor_index + << " exceeds max tensor index " << interpreter->inputs().size(); Py_INCREF(Py_None); return Py_None; } - const TfLiteTensor* output_tensor = interpreter_->tensor(i); - const int tensor_size = output_tensor->bytes; - if (tensor_size <= 0) { + *tensor = interpreter->tensor(tensor_index); + if ((*tensor)->bytes == 0) { LOG(ERROR) << "Invalid tensor size"; Py_INCREF(Py_None); return Py_None; } - int type_num = TfLiteTypeToPyArrayType(output_tensor->type); - if (type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << output_tensor->type; + *type_num = TfLiteTypeToPyArrayType((*tensor)->type); + if (*type_num == -1) { + LOG(ERROR) << "Unknown tensor type " << (*tensor)->type; Py_INCREF(Py_None); return Py_None; } - void* data = malloc(tensor_size); - memcpy(data, output_tensor->data.raw, tensor_size); + if (!(*tensor)->data.raw) { + LOG(ERROR) << "Tensor data is null."; + Py_INCREF(Py_None); + return Py_None; + } - const TfLiteIntArray* output_dims = output_tensor->dims; - std::vector dims(output_dims->data, - output_dims->data + output_dims->size); + return nullptr; +} + +} // namespace + +PyObject* InterpreterWrapper::GetTensor(int i) const { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + if (PyObject* pynone_or_nullptr = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { + return pynone_or_nullptr; + } + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + // Make a buffer copy but we must tell Numpy It owns that data or else + // it will leak. + void* data = malloc(tensor->bytes); + if (!data) { + LOG(ERROR) << "Malloc to copy tensor failed."; + Py_INCREF(Py_None); + return Py_None; + } + memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); - + PyArray_ENABLEFLAGS(reinterpret_cast(np_array), + NPY_ARRAY_OWNDATA); return PyArray_Return(reinterpret_cast(np_array)); } +PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + if (PyObject* pynone_or_nullptr = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { + return pynone_or_nullptr; + } + + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + PyArrayObject* np_array = + reinterpret_cast(PyArray_SimpleNewFromData( + dims.size(), dims.data(), type_num, tensor->data.raw)); + Py_INCREF(base_object); // SetBaseObject steals, so we need to add. + PyArray_SetBaseObject(np_array, base_object); + return PyArray_Return(np_array); +} + InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( const char* model_path) { std::unique_ptr model = @@ -331,9 +401,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - const char* data, size_t len) { + PyObject* data) { + char * buf = nullptr; + Py_ssize_t length; + if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { + return nullptr; + } std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(data, len); + tflite::FlatBufferModel::BuildFromBuffer(buf, length); return model ? new InterpreterWrapper(std::move(model)) : nullptr; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 0972c572595f5044a305a81afaccbea5f131247c..681448be20cfc013a0c4d02a6aa549744b976077 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -19,7 +19,9 @@ limitations under the License. #include #include +// Place `` before to avoid build failures in macOS. #include +#include // We forward declare TFLite classes here to avoid exposing them to SWIG. namespace tflite { @@ -40,8 +42,7 @@ class InterpreterWrapper { static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(const char* data, - size_t len); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); ~InterpreterWrapper(); bool AllocateTensors(); @@ -57,6 +58,9 @@ class InterpreterWrapper { PyObject* TensorQuantization(int i) const; bool SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + // Returns a reference to tensor index i as a numpy array. The base_object + // should be the interpreter object providing the memory. + PyObject* tensor(PyObject* base_object, int i); private: InterpreterWrapper(std::unique_ptr model); diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 86b25e68acaf5d74e3dd11784446e7bda3d329ee..29a1487c1f468055dde85ef6c2657a50f3d2f32b 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -16,23 +16,423 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. +@@TocoConverter @@toco_convert @@toco_convert_protos -@@tflite_from_saved_model @@Interpreter @@OpHint @@convert_op_hints_to_stubs +@@build_toco_convert_protos + +@@FLOAT +@@QUANTIZED_UINT8 +@@TFLITE +@@GRAPHVIZ_DOT """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import +from six import PY3 + +from google.protobuf import text_format as _text_format +from google.protobuf.message import DecodeError +from tensorflow.contrib.lite.python import lite_constants as constants +from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert -from tensorflow.contrib.lite.python.convert import toco_convert_protos -from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model -from tensorflow.contrib.lite.python.interpreter import Interpreter -from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs -from tensorflow.contrib.lite.python.op_hint import OpHint -# pylint: enable=unused-import +from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes +from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python import keras as _keras +from tensorflow.python.client import session as _session +from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework.importer import import_graph_def +from tensorflow.python.ops.variables import global_variables_initializer +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants +# from tensorflow.python.util.all_util import remove_undocumented + + +class TocoConverter(object): + """Convert a TensorFlow model into `output_format` using TOCO. + + This is used to convert from a TensorFlow GraphDef or SavedModel into either a + TFLite FlatBuffer or graph visualization. + + Attributes: + + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default {}) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) + + Example usage: + + # Converting a GraphDef from session. + converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + + # Converting a GraphDef from file. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + + # Converting a SavedModel. + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + """ + + def __init__(self, graph_def, input_tensors, output_tensors): + """Constructor for TocoConverter. + + Args: + + graph_def: Frozen TensorFlow GraphDef. + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + """ + self._graph_def = graph_def + self._input_tensors = input_tensors + self._output_tensors = output_tensors + self.inference_type = constants.FLOAT + self.inference_input_type = None + self.output_format = constants.TFLITE + self.quantized_input_stats = {} + self.default_ranges_stats = None + self.drop_control_dependency = True + self.reorder_across_fake_quant = False + self.change_concat_input_ranges = False + self.allow_custom_ops = False + self.quantize_weights = False + self.dump_graphviz_dir = None + self.dump_graphviz_video = False + + @classmethod + def from_session(cls, sess, input_tensors, output_tensors): + """Creates a TocoConverter class from a TensorFlow Session. + + Args: + sess: TensorFlow Session. + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + TocoConverter class. + """ + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + + @classmethod + def from_frozen_graph(cls, + graph_def_file, + input_arrays, + output_arrays, + input_shapes=None): + """Creates a TocoConverter class from a file containing a frozen GraphDef. + + Args: + graph_def_file: Full filepath of file containing frozen GraphDef. + input_arrays: List of input tensors to freeze graph with. + output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + + Returns: + TocoConverter class. + + Raises: + ValueError: + Unable to parse input file. + The graph is not frozen. + input_arrays or output_arrays contains an invalid tensor name. + """ + with _session.Session() as sess: + sess.run(global_variables_initializer()) + + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() + try: + graph_def.ParseFromString(file_content) + except (_text_format.ParseError, DecodeError): + try: + print("Ignore 'tcmalloc: large alloc' warnings.") + + if not isinstance(file_content, str): + if PY3: + file_content = file_content.decode('utf-8') + else: + file_content = file_content.encode('utf-8') + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + sess.graph.as_default() + import_graph_def(graph_def, name="") + + # Get input and output tensors. + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + set_tensor_shapes(input_tensors, input_shapes) + + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py.") + + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) + + @classmethod + def from_saved_model(cls, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): + """Creates a TocoConverter class from a SavedModel. + + Args: + saved_model_dir: SavedModel directory to convert. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. (default set("serve")) + signature_key: Key identifying SignatureDef containing inputs and outputs. + (default DEFAULT_SERVING_SIGNATURE_DEF_KEY) + + Returns: + TocoConverter class. + """ + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key) + return cls( + graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + + @classmethod + def from_keras_model_file(cls, + model_file, + input_arrays=None, + input_shapes=None, + output_arrays=None): + """Creates a TocoConverter class from a tf.keras model file. + + Args: + model_file: Full filepath of HDF5 file containing the tf.keras model. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + + Returns: + TocoConverter class. + """ + _keras.backend.clear_session() + _keras.backend.set_learning_phase(False) + keras_model = _keras.models.load_model(model_file) + sess = _keras.backend.get_session() + + # Get input and output tensors. + if input_arrays: + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + else: + input_tensors = keras_model.inputs + + if output_arrays: + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + else: + output_tensors = keras_model.outputs + set_tensor_shapes(input_tensors, input_shapes) + + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + + def convert(self): + """Converts a TensorFlow GraphDef based on instance variables. + + Returns: + The converted data in serialized format. Either a TFLite Flatbuffer or a + Graphviz graph depending on value in `output_format`. + + Raises: + ValueError: + Input shape is not specified. + None value for dimension in input_tensor. + """ + # Checks dimensions in input tensor. + for tensor in self._input_tensors: + if not tensor.get_shape(): + raise ValueError("Provide an input shape for input array '{0}'.".format( + tensor_name(tensor))) + shape = tensor.get_shape().as_list() + if None in shape[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format(tensor_name(tensor), shape)) + elif shape[0] is None: + self._set_batch_size(batch_size=1) + + # Get quantization stats. Ensures there is one stat per name if the stats + # are specified. + if self.quantized_input_stats: + quantized_stats = [] + invalid_stats = [] + for tensor in self._input_tensors: + name = tensor_name(tensor) + if name in self.quantized_input_stats: + quantized_stats.append(self.quantized_input_stats[name]) + else: + invalid_stats.append(name) + + if invalid_stats: + raise ValueError("Quantization input stats are not available for input " + "tensors '{0}'.".format(",".join(invalid_stats))) + else: + quantized_stats = None + + # Converts model. + result = toco_convert( + input_data=self._graph_def, + input_tensors=self._input_tensors, + output_tensors=self._output_tensors, + inference_type=self.inference_type, + inference_input_type=self.inference_input_type, + input_format=constants.TENSORFLOW_GRAPHDEF, + output_format=self.output_format, + quantized_input_stats=quantized_stats, + default_ranges_stats=self.default_ranges_stats, + drop_control_dependency=self.drop_control_dependency, + reorder_across_fake_quant=self.reorder_across_fake_quant, + change_concat_input_ranges=self.change_concat_input_ranges, + allow_custom_ops=self.allow_custom_ops, + quantize_weights=self.quantize_weights, + dump_graphviz_dir=self.dump_graphviz_dir, + dump_graphviz_video=self.dump_graphviz_video) + return result + + def get_input_arrays(self): + """Returns a list of the names of the input tensors. + + Returns: + List of strings. + """ + return [tensor_name(tensor) for tensor in self._input_tensors] + + def _set_batch_size(self, batch_size): + """Sets the first dimension of the input tensor to `batch_size`. + + Args: + batch_size: Batch size for the model. Replaces the first dimension of an + input size array if undefined. (default 1) + """ + for tensor in self._input_tensors: + shape = tensor.get_shape().as_list() + shape[0] = batch_size + tensor.set_shape(shape) + + +def _is_frozen_graph(sess): + """Determines if the graph is frozen. + + Determines if a graph has previously been frozen by checking for any + operations of type Variable*. If variables are found, the graph is not frozen. + + Args: + sess: TensorFlow Session. + + Returns: + Bool. + """ + for op in sess.graph.get_operations(): + if op.type.startswith("Variable") or op.type.endswith("VariableOp"): + return False + return True + + +def _freeze_graph(sess, output_tensors): + """Returns a frozen GraphDef. + + Freezes a graph with Variables in it. Otherwise the existing GraphDef is + returned. + + Args: + sess: TensorFlow Session. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + Frozen GraphDef. + """ + if not _is_frozen_graph(sess): + sess.run(global_variables_initializer()) + output_arrays = [tensor_name(tensor) for tensor in output_tensors] + return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, + output_arrays) + else: + return sess.graph_def + +# remove_undocumented(__name__) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2af5aaed3ee4f4fce5f0d31eaa61df0e11f364 --- /dev/null +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -0,0 +1,898 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lite.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import numpy as np + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python import keras +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.training_util import write_graph + + +class FromSessionTest(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testQuantization(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = { + 'inputA': (0., 1.), + 'inputB': (0., 1.) + } # mean, std_dev + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), + input_details[0]['quantization']) # scale, zero_point + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.uint8, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((1., 0.), + input_details[1]['quantization']) # scale, zero_point + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + def testQuantizationInvalid(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'Quantization input stats are not available for input tensors ' + '\'inputB\'.', str(error.exception)) + + def testSizeNoneInvalid(self): + in_tensor = array_ops.placeholder(dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual('Provide an input shape for input array \'Placeholder\'.', + str(error.exception)) + + def testBatchSizeInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, None, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'None is only supported in the 1st dimension. Tensor ' + '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', + str(error.exception)) + + def testBatchSizeValid(self): + in_tensor = array_ops.placeholder( + shape=[None, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + var + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.output_format = lite_constants.GRAPHVIZ_DOT + graphviz_output = converter.convert() + self.assertTrue(graphviz_output) + + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testDumpGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure interpreter is able to allocate and check graphviz data. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + num_items_graphviz = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + converter.dump_graphviz_video = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure graphviz folder has more data after using video flag. + num_items_graphviz_video = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz_video > num_items_graphviz) + + def testInferenceInputType(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + + def testDefaultRangesStats(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev + converter.default_ranges_stats = (0, 6) # min, max + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + def testQuantizeWeights(self): + np.random.seed(0) + # We need the tensor to have more than 1024 elements for quantize_weights + # to kick in. Thus, the [33, 33] shape. + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + sess = session.Session() + + # Convert float model. + float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1], + [out_tensor]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + # Convert quantized weights model. + quantized_weights_converter = lite.TocoConverter.from_session( + sess, [in_tensor_1], [out_tensor]) + quantized_weights_converter.quantize_weights = True + quantized_weights_tflite = quantized_weights_converter.convert() + self.assertTrue(quantized_weights_tflite) + + # Ensure that the quantized weights tflite model is smaller. + self.assertTrue(len(quantized_weights_tflite) < len(float_tflite)) + + +class FromFrozenGraphFile(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFloatWithShapesArray(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, ['Placeholder'], ['add'], + input_shapes={'Placeholder': [1, 16, 16, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + var + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Ensure the graph with variables cannot be converted. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual('Please freeze the graph using freeze_graph.py.', + str(error.exception)) + + def testPbtxt(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') + write_graph(sess.graph_def, '', graph_def_file, True) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testInvalidFile(self): + graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') + with gfile.Open(graph_def_file, 'wb') as temp_file: + temp_file.write('bad data') + temp_file.flush() + + # Attempts to convert the invalid model. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual( + 'Unable to parse input file \'{}\'.'.format(graph_def_file), + str(error.exception)) + + +class FromSavedModelTest(test_util.TensorFlowTestCase): + + def _createSavedModel(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name='inputB') + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name='inputA') + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {'x': in_tensor_1, 'y': in_tensor_2} + outputs = {'z': out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def testSimpleModel(self): + """Test a SavedModel.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testNoneBatchSize(self): + """Test a SavedModel, with None in input tensor's shape.""" + saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) + + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testOrderInputArrays(self): + """Test a SavedModel ordering of input arrays.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, input_arrays=['inputB', 'inputA']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSubsetInputArrays(self): + """Test a SavedModel with a subset of the input array names of the model.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + # Check case where input shape is given. + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, + input_arrays=['inputA'], + input_shapes={'inputA': [1, 16, 16, 3]}) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check case where input shape is None. + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + +class FromKerasFile(test_util.TensorFlowTestCase): + + def setUp(self): + keras.backend.clear_session() + + def _getSequentialModel(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + try: + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + return keras_file + + def testSequentialModel(self): + """Test a Sequential tf.keras model with default inputs.""" + keras_file = self._getSequentialModel() + + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSequentialModelInputArray(self): + """Test a Sequential tf.keras model testing input arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid input array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['invalid-input']) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['dense_input']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testSequentialModelInputShape(self): + """Test a Sequential tf.keras model testing input shapes argument.""" + keras_file = self._getSequentialModel() + + # Passing in shape of invalid input array has no impact as long as all input + # arrays have a shape. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'invalid-input': [2, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Passing in shape of valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'dense_input': [2, 3]}) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + # Check input shape from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertTrue(([2, 3] == input_details[0]['shape']).all()) + + def testSequentialModelOutputArray(self): + """Test a Sequential tf.keras model testing output arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid output array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['invalid-output']) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) + + # Valid output array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['time_distributed/Reshape_1']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testFunctionalModel(self): + """Test a Functional tf.keras model with default inputs.""" + inputs = keras.layers.Input(shape=(3,), name='input') + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFunctionalModelMultipleInputs(self): + """Test a Functional tf.keras model with multiple inputs and outputs.""" + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.mae], + loss_weights=[1., 0.5]) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + model.predict([input_a_np, input_b_np], batch_size=5) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input_a', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('input_b', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(2, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + self.assertEqual('dropout/Identity', output_details[1]['name']) + self.assertEqual(np.float32, output_details[1]['dtype']) + self.assertTrue(([1, 4] == output_details[1]['shape']).all()) + self.assertEqual((0., 0.), output_details[1]['quantization']) + + def testFunctionalSequentialModel(self): + """Test a Functional tf.keras model containing a Sequential model.""" + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model = keras.models.Model(model.input, model.output) + + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd1f4f76ee693414a8515a5bd2567001b53e2ea --- /dev/null +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -0,0 +1,374 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python command line interface for running TOCO.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.platform import app + + +def _parse_array(values, type_fn=str): + if values: + return [type_fn(val) for val in values.split(",") if val] + + +def _parse_set(values): + if values: + return set(values.split(",")) + + +def _get_toco_converter(flags): + """Makes a TocoConverter object based on the flags provided. + + Args: + flags: argparse.Namespace object containing TFLite flags. + + Returns: + TocoConverter object. + """ + # Parse input and output arrays. + input_arrays = _parse_array(flags.input_arrays) + input_shapes = None + if flags.input_shapes: + input_shapes_list = [ + _parse_array(shape, type_fn=int) + for shape in flags.input_shapes.split(":") + ] + input_shapes = dict(zip(input_arrays, input_shapes_list)) + output_arrays = _parse_array(flags.output_arrays) + + converter_kwargs = { + "input_arrays": input_arrays, + "input_shapes": input_shapes, + "output_arrays": output_arrays + } + + # Create TocoConverter. + if flags.graph_def_file: + converter_fn = lite.TocoConverter.from_frozen_graph + converter_kwargs["graph_def_file"] = flags.graph_def_file + elif flags.saved_model_dir: + converter_fn = lite.TocoConverter.from_saved_model + converter_kwargs["saved_model_dir"] = flags.saved_model_dir + converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) + converter_kwargs["signature_key"] = flags.saved_model_signature_key + elif flags.keras_model_file: + converter_fn = lite.TocoConverter.from_keras_model_file + converter_kwargs["model_file"] = flags.keras_model_file + + return converter_fn(**converter_kwargs) + + +def _convert_model(flags): + """Calls function to convert the TensorFlow model into a TFLite model. + + Args: + flags: argparse.Namespace object. + + Raises: + ValueError: Invalid flags. + """ + # Create converter. + converter = _get_toco_converter(flags) + if flags.inference_type: + converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.inference_input_type: + converter.inference_input_type = _types_pb2.IODataType.Value( + flags.inference_input_type) + if flags.output_format: + converter.output_format = _toco_flags_pb2.FileFormat.Value( + flags.output_format) + + if flags.mean_values and flags.std_dev_values: + input_arrays = converter.get_input_arrays() + std_dev_values = _parse_array(flags.std_dev_values, type_fn=int) + mean_values = _parse_array(flags.mean_values, type_fn=int) + quant_stats = list(zip(mean_values, std_dev_values)) + if ((not flags.input_arrays and len(input_arrays) > 1) or + (len(input_arrays) != len(quant_stats))): + raise ValueError("Mismatching --input_arrays, --std_dev_values, and " + "--mean_values. The flags must have the same number of " + "items. The current input arrays are '{0}'. " + "--input_arrays must be present when specifying " + "--std_dev_values and --mean_values with multiple input " + "tensors in order to map between names and " + "values.".format(",".join(input_arrays))) + converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + if (flags.default_ranges_min is not None) and (flags.default_ranges_max is + not None): + converter.default_ranges_stats = (flags.default_ranges_min, + flags.default_ranges_max) + + if flags.drop_control_dependency: + converter.drop_control_dependency = flags.drop_control_dependency + if flags.reorder_across_fake_quant: + converter.reorder_across_fake_quant = flags.reorder_across_fake_quant + if flags.change_concat_input_ranges: + converter.change_concat_input_ranges = flags.change_concat_input_ranges + if flags.allow_custom_ops: + converter.allow_custom_ops = flags.allow_custom_ops + if flags.quantize_weights: + if flags.inference_type == lite_constants.QUANTIZED_UINT8: + raise ValueError("--quantized_weights is not supported with " + "--inference_type=QUANTIZED_UINT8") + converter.quantize_weights = flags.quantize_weights + if flags.dump_graphviz_dir: + converter.dump_graphviz_dir = flags.dump_graphviz_dir + if flags.dump_graphviz_video: + converter.dump_graphviz_vode = flags.dump_graphviz_video + + # Convert model. + output_data = converter.convert() + with open(flags.output_file, "wb") as f: + f.write(output_data) + + +def _check_flags(flags, unparsed): + """Checks the parsed and unparsed flags to ensure they are valid. + + Raises an error if previously support unparsed flags are found. Raises an + error for parsed flags that don't meet the required conditions. + + Args: + flags: argparse.Namespace object containing TFLite flags. + unparsed: List of unparsed flags. + + Raises: + ValueError: Invalid flags. + """ + + # Check unparsed flags for common mistakes based on previous TOCO. + def _get_message_unparsed(flag, orig_flag, new_flag): + if flag.startswith(orig_flag): + return "\n Use {0} instead of {1}".format(new_flag, orig_flag) + return "" + + if unparsed: + output = "" + for flag in unparsed: + output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--savedmodel_directory", + "--saved_model_dir") + output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") + output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + output += _get_message_unparsed(flag, "--dump_graphviz", + "--dump_graphviz_dir") + if output: + raise ValueError(output) + + # Check that flags are valid. + if flags.graph_def_file and (not flags.input_arrays or + not flags.output_arrays): + raise ValueError("--input_arrays and --output_arrays are required with " + "--graph_def_file") + + if flags.input_shapes: + if not flags.input_arrays: + raise ValueError("--input_shapes must be used with --input_arrays") + if flags.input_shapes.count(":") != flags.input_arrays.count(","): + raise ValueError("--input_shapes and --input_arrays must have the same " + "number of items") + + if flags.std_dev_values or flags.mean_values: + if bool(flags.std_dev_values) != bool(flags.mean_values): + raise ValueError("--std_dev_values and --mean_values must be used " + "together") + if flags.std_dev_values.count(",") != flags.mean_values.count(","): + raise ValueError("--std_dev_values, --mean_values must have the same " + "number of items") + + if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): + raise ValueError("--default_ranges_min and --default_ranges_max must be " + "used together") + + if flags.dump_graphviz_video and not flags.dump_graphviz: + raise ValueError("--dump_graphviz_video must be used with --dump_graphviz") + + +def run_main(_): + """Main in toco_convert.py.""" + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Optimizing " + "Converter (TOCO).")) + + # Output file flag. + parser.add_argument( + "--output_file", + type=str, + help="Full filepath of the output file.", + required=True) + + # Input file flags. + input_file_group = parser.add_mutually_exclusive_group(required=True) + input_file_group.add_argument( + "--graph_def_file", + type=str, + help="Full filepath of file containing frozen TensorFlow GraphDef.") + input_file_group.add_argument( + "--saved_model_dir", + type=str, + help="Full filepath of directory containing the SavedModel.") + input_file_group.add_argument( + "--keras_model_file", + type=str, + help="Full filepath of HDF5 file containing tf.Keras model.") + + # Model format flags. + parser.add_argument( + "--output_format", + type=str.upper, + choices=["TFLITE", "GRAPHVIZ_DOT"], + help="Output file format.") + parser.add_argument( + "--inference_type", + type=str.upper, + choices=["FLOAT", "QUANTIZED_UINT8"], + help="Target data type of real-number arrays in the output file.") + parser.add_argument( + "--inference_input_type", + type=str.upper, + choices=["FLOAT", "QUANTIZED_UINT8"], + help=("Target data type of real-number input arrays. Allows for a " + "different type for input arrays in the case of quantization.")) + + # Input and output arrays flags. + parser.add_argument( + "--input_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + + # SavedModel related flags. + parser.add_argument( + "--saved_model_tag_set", + type=str, + help=("Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags must be present. " + "(default \"serve\")")) + parser.add_argument( + "--saved_model_signature_key", + type=str, + help=("Key identifying the SignatureDef containing inputs and outputs. " + "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) + + # Quantization flags. + parser.add_argument( + "--std_dev_values", + type=str, + help=("Standard deviation of training data for each input tensor, " + "comma-separated integers. Used for quantization. (default None)")) + parser.add_argument( + "--mean_values", + type=str, + help=("Mean of training data for each input tensor, comma-separated " + "integers. Used for quantization. (default None)")) + parser.add_argument( + "--default_ranges_min", + type=int, + help=("Default value for min bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--default_ranges_max", + type=int, + help=("Default value for max bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--quantize_weights", + type=bool, + help=("Store float weights as quantized weights followed by dequantize " + "operations. Inference is still done in FLOAT, but reduces model " + "size (at the cost of accuracy and latency).")) + + # Graph manipulation flags. + parser.add_argument( + "--drop_control_dependency", + action="store_true", + help=("Boolean indicating whether to drop control dependencies silently. " + "This is due to TensorFlow not supporting control dependencies. " + "(default True)")) + parser.add_argument( + "--reorder_across_fake_quant", + action="store_true", + help=("Boolean indicating whether to reorder FakeQuant nodes in " + "unexpected locations. Used when the location of the FakeQuant " + "nodes is preventing graph transformations necessary to convert " + "the graph. Results in a graph that differs from the quantized " + "training graph, potentially causing differing arithmetic " + "behavior. (default False)")) + parser.add_argument( + "--change_concat_input_ranges", + action="store_true", + help=("Boolean to change behavior of min/max ranges for inputs and " + "outputs of the concat operator for quantized models. Changes the " + "ranges of concat operator overlap when true. (default False)")) + parser.add_argument( + "--allow_custom_ops", + action="store_true", + help=("Boolean indicating whether to allow custom operations. When false " + "any unknown operation is an error. When true, custom ops are " + "created for any op that is unknown. The developer will need to " + "provide these to the TensorFlow Lite runtime with a custom " + "resolver. (default False)")) + + # Logging flags. + parser.add_argument( + "--dump_graphviz_dir", + type=str, + help=("Full filepath of folder to dump the graphs at various stages of " + "processing GraphViz .dot files. Preferred over --output_format=" + "GRAPHVIZ_DOT in order to keep the requirements of the output " + "file.")) + parser.add_argument( + "--dump_graphviz_video", + action="store_true", + help=("Boolean indicating whether to dump the graph after every graph " + "transformation")) + + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) + try: + _check_flags(tflite_flags, unparsed) + except ValueError as e: + parser.print_usage() + file_name = os.path.basename(sys.argv[0]) + sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) + sys.exit(1) + _convert_model(tflite_flags) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 9717a4a1a496b888348514584888e62c4e3703b4..f095151cae835aa202ff4c9f43e175246f54f1cf 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -65,6 +65,7 @@ cc_test( ], tags = [ "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ "//tensorflow/core:lib_platform", diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc index ac408d2f94b98d505afe4c951d7cc2ff960606fb..9dc8daa227dd68ccde2efa4013ac4465a72e6bb0 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -39,7 +39,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -57,7 +57,6 @@ const char* kFileFooter = } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} )"; } // anonymous namespace diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 8bdeb035f5a778fa3b0d85da36d6b8d6721445ea..15fb8bbdb8f100201750faf706eb45b697319dfb 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -34,6 +34,8 @@ enum TensorType : byte { INT64 = 4, STRING = 5, BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, } // Parameters for converting a quantized tensor back to float. Given a @@ -63,6 +65,8 @@ table Tensor { buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; } // A list of builtin operators. Builtin operators are slightly faster than custom @@ -145,6 +149,17 @@ enum BuiltinOperator : byte { SLICE = 65, SIN = 66, TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM=74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, } // Options for the builtin operators. @@ -175,7 +190,7 @@ union BuiltinOptions { BatchToSpaceNDOptions, SpaceToBatchNDOptions, TransposeOptions, - MeanOptions, + ReducerOptions, SubOptions, DivOptions, SqueezeOptions, @@ -198,6 +213,13 @@ union BuiltinOptions { SelectOptions, SliceOptions, TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, } enum Padding : byte { SAME, VALID } @@ -275,9 +297,18 @@ table BidirectionalSequenceRNNOptions { fused_activation_function:ActivationFunctionType; } +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; } table SoftmaxOptions { @@ -309,11 +340,23 @@ table LocalResponseNormalizationOptions { beta:float; } +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { + // Parameters for LSTM version 1 or above. fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; } table ResizeBilinearOptions { @@ -385,7 +428,7 @@ table TransposeOptions { table ExpOptions { } -table MeanOptions { +table ReducerOptions { keep_dims: bool; } @@ -419,6 +462,9 @@ table DequantizeOptions { table MaximumMinimumOptions { } +table TileOptions { +} + table ArgMaxOptions { output_type : TensorType; } @@ -450,6 +496,27 @@ table TransposeConvOptions { stride_h:int; } +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table PowOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { @@ -481,6 +548,16 @@ table Operator { builtin_options:BuiltinOptions; custom_options:[ubyte]; custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; } // The root type, defining a subgraph, which typically represents an entire diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 35c34f53a6bf9716941f623b43f238c681252747..fe0ff9a7a5ba0764475f4a7c14cd875b3cdb2aa8 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -127,8 +127,8 @@ struct TransposeOptionsT; struct ExpOptions; struct ExpOptionsT; -struct MeanOptions; -struct MeanOptionsT; +struct ReducerOptions; +struct ReducerOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; @@ -151,6 +151,9 @@ struct DequantizeOptionsT; struct MaximumMinimumOptions; struct MaximumMinimumOptionsT; +struct TileOptions; +struct TileOptionsT; + struct ArgMaxOptions; struct ArgMaxOptionsT; @@ -178,6 +181,24 @@ struct SliceOptionsT; struct TransposeConvOptions; struct TransposeConvOptionsT; +struct ExpandDimsOptions; +struct ExpandDimsOptionsT; + +struct SparseToDenseOptions; +struct SparseToDenseOptionsT; + +struct EqualOptions; +struct EqualOptionsT; + +struct NotEqualOptions; +struct NotEqualOptionsT; + +struct ShapeOptions; +struct ShapeOptionsT; + +struct PowOptions; +struct PowOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -201,11 +222,13 @@ enum TensorType { TensorType_INT64 = 4, TensorType_STRING = 5, TensorType_BOOL = 6, + TensorType_INT16 = 7, + TensorType_COMPLEX64 = 8, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_BOOL + TensorType_MAX = TensorType_COMPLEX64 }; -inline TensorType (&EnumValuesTensorType())[7] { +inline TensorType (&EnumValuesTensorType())[9] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -213,7 +236,9 @@ inline TensorType (&EnumValuesTensorType())[7] { TensorType_UINT8, TensorType_INT64, TensorType_STRING, - TensorType_BOOL + TensorType_BOOL, + TensorType_INT16, + TensorType_COMPLEX64 }; return values; } @@ -227,6 +252,8 @@ inline const char **EnumNamesTensorType() { "INT64", "STRING", "BOOL", + "INT16", + "COMPLEX64", nullptr }; return names; @@ -305,11 +332,22 @@ enum BuiltinOperator { BuiltinOperator_SLICE = 65, BuiltinOperator_SIN = 66, BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, + BuiltinOperator_TILE = 69, + BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, + BuiltinOperator_SUM = 74, + BuiltinOperator_SQRT = 75, + BuiltinOperator_RSQRT = 76, + BuiltinOperator_SHAPE = 77, + BuiltinOperator_POW = 78, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_MAX = BuiltinOperator_POW }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -377,7 +415,18 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { BuiltinOperator_SELECT, BuiltinOperator_SLICE, BuiltinOperator_SIN, - BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOperator_TILE, + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG, + BuiltinOperator_SUM, + BuiltinOperator_SQRT, + BuiltinOperator_RSQRT, + BuiltinOperator_SHAPE, + BuiltinOperator_POW }; return values; } @@ -452,6 +501,17 @@ inline const char **EnumNamesBuiltinOperator() { "SLICE", "SIN", "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", + "TILE", + "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", + "LOG", + "SUM", + "SQRT", + "RSQRT", + "SHAPE", + "POW", nullptr }; return names; @@ -490,7 +550,7 @@ enum BuiltinOptions { BuiltinOptions_BatchToSpaceNDOptions = 24, BuiltinOptions_SpaceToBatchNDOptions = 25, BuiltinOptions_TransposeOptions = 26, - BuiltinOptions_MeanOptions = 27, + BuiltinOptions_ReducerOptions = 27, BuiltinOptions_SubOptions = 28, BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, @@ -513,11 +573,18 @@ enum BuiltinOptions { BuiltinOptions_SelectOptions = 47, BuiltinOptions_SliceOptions = 48, BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, + BuiltinOptions_TileOptions = 51, + BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, + BuiltinOptions_ShapeOptions = 55, + BuiltinOptions_PowOptions = 56, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_TransposeConvOptions + BuiltinOptions_MAX = BuiltinOptions_PowOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -546,7 +613,7 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { BuiltinOptions_BatchToSpaceNDOptions, BuiltinOptions_SpaceToBatchNDOptions, BuiltinOptions_TransposeOptions, - BuiltinOptions_MeanOptions, + BuiltinOptions_ReducerOptions, BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, BuiltinOptions_SqueezeOptions, @@ -568,7 +635,14 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { BuiltinOptions_LessEqualOptions, BuiltinOptions_SelectOptions, BuiltinOptions_SliceOptions, - BuiltinOptions_TransposeConvOptions + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions, + BuiltinOptions_TileOptions, + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions, + BuiltinOptions_ShapeOptions, + BuiltinOptions_PowOptions }; return values; } @@ -602,7 +676,7 @@ inline const char **EnumNamesBuiltinOptions() { "BatchToSpaceNDOptions", "SpaceToBatchNDOptions", "TransposeOptions", - "MeanOptions", + "ReducerOptions", "SubOptions", "DivOptions", "SqueezeOptions", @@ -625,6 +699,13 @@ inline const char **EnumNamesBuiltinOptions() { "SelectOptions", "SliceOptions", "TransposeConvOptions", + "SparseToDenseOptions", + "TileOptions", + "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", + "ShapeOptions", + "PowOptions", nullptr }; return names; @@ -743,8 +824,8 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template<> struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; }; template<> struct BuiltinOptionsTraits { @@ -835,6 +916,34 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1074,13 +1183,13 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeOptions ? reinterpret_cast(value) : nullptr; } - MeanOptionsT *AsMeanOptions() { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + ReducerOptionsT *AsReducerOptions() { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } - const MeanOptionsT *AsMeanOptions() const { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + const ReducerOptionsT *AsReducerOptions() const { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } SubOptionsT *AsSubOptions() { return type == BuiltinOptions_SubOptions ? @@ -1258,6 +1367,62 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeConvOptions ? reinterpret_cast(value) : nullptr; } + SparseToDenseOptionsT *AsSparseToDenseOptions() { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + const SparseToDenseOptionsT *AsSparseToDenseOptions() const { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + TileOptionsT *AsTileOptions() { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + const TileOptionsT *AsTileOptions() const { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + ExpandDimsOptionsT *AsExpandDimsOptions() { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + const ExpandDimsOptionsT *AsExpandDimsOptions() const { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + EqualOptionsT *AsEqualOptions() { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + const EqualOptionsT *AsEqualOptions() const { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + NotEqualOptionsT *AsNotEqualOptions() { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const NotEqualOptionsT *AsNotEqualOptions() const { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + ShapeOptionsT *AsShapeOptions() { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + const ShapeOptionsT *AsShapeOptions() const { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + PowOptionsT *AsPowOptions() { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } + const PowOptionsT *AsPowOptions() const { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -1365,6 +1530,64 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum FullyConnectedOptionsWeightsFormat { + FullyConnectedOptionsWeightsFormat_DEFAULT = 0, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1, + FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 +}; + +inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] { + static FullyConnectedOptionsWeightsFormat values[] = { + FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 + }; + return values; +} + +inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() { + static const char *names[] = { + "DEFAULT", + "SHUFFLED4x16INT8", + nullptr + }; + return names; +} + +inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) { + const size_t index = static_cast(e); + return EnumNamesFullyConnectedOptionsWeightsFormat()[index]; +} + +enum LSTMKernelType { + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] { + static LSTMKernelType values[] = { + LSTMKernelType_FULL, + LSTMKernelType_BASIC + }; + return values; +} + +inline const char **EnumNamesLSTMKernelType() { + static const char *names[] = { + "FULL", + "BASIC", + nullptr + }; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { + const size_t index = static_cast(e); + return EnumNamesLSTMKernelType()[index]; +} + enum CombinerType { CombinerType_SUM = 0, CombinerType_MEAN = 1, @@ -1534,9 +1757,11 @@ struct TensorT : public flatbuffers::NativeTable { uint32_t buffer; std::string name; std::unique_ptr quantization; + bool is_variable; TensorT() : type(TensorType_FLOAT32), - buffer(0) { + buffer(0), + is_variable(false) { } }; @@ -1547,7 +1772,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_TYPE = 6, VT_BUFFER = 8, VT_NAME = 10, - VT_QUANTIZATION = 12 + VT_QUANTIZATION = 12, + VT_IS_VARIABLE = 14 }; const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -1564,6 +1790,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const QuantizationParameters *quantization() const { return GetPointer(VT_QUANTIZATION); } + bool is_variable() const { + return GetField(VT_IS_VARIABLE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -1574,6 +1803,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.Verify(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && verifier.VerifyTable(quantization()) && + VerifyField(verifier, VT_IS_VARIABLE) && verifier.EndTable(); } TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -1599,6 +1829,9 @@ struct TensorBuilder { void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } + void add_is_variable(bool is_variable) { + fbb_.AddElement(Tensor::VT_IS_VARIABLE, static_cast(is_variable), 0); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1617,12 +1850,14 @@ inline flatbuffers::Offset CreateTensor( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, flatbuffers::Offset name = 0, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { TensorBuilder builder_(_fbb); builder_.add_quantization(quantization); builder_.add_name(name); builder_.add_buffer(buffer); builder_.add_shape(shape); + builder_.add_is_variable(is_variable); builder_.add_type(type); return builder_.Finish(); } @@ -1633,14 +1868,16 @@ inline flatbuffers::Offset CreateTensorDirect( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, const char *name = nullptr, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { return tflite::CreateTensor( _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, name ? _fbb.CreateString(name) : 0, - quantization); + quantization, + is_variable); } flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -2374,22 +2611,29 @@ flatbuffers::Offset CreateBidirectionalSequence struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; ActivationFunctionType fused_activation_function; + FullyConnectedOptionsWeightsFormat weights_format; FullyConnectedOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(ActivationFunctionType_NONE), + weights_format(FullyConnectedOptionsWeightsFormat_DEFAULT) { } }; struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef FullyConnectedOptionsT NativeTableType; enum { - VT_FUSED_ACTIVATION_FUNCTION = 4 + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_WEIGHTS_FORMAT = 6 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + FullyConnectedOptionsWeightsFormat weights_format() const { + return static_cast(GetField(VT_WEIGHTS_FORMAT, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_WEIGHTS_FORMAT) && verifier.EndTable(); } FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2403,6 +2647,9 @@ struct FullyConnectedOptionsBuilder { void add_fused_activation_function(ActivationFunctionType fused_activation_function) { fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_weights_format(FullyConnectedOptionsWeightsFormat weights_format) { + fbb_.AddElement(FullyConnectedOptions::VT_WEIGHTS_FORMAT, static_cast(weights_format), 0); + } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2417,8 +2664,10 @@ struct FullyConnectedOptionsBuilder { inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT) { FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_weights_format(weights_format); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -2802,10 +3051,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; + LSTMKernelType kernel_type; LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) { + proj_clip(0.0f), + kernel_type(LSTMKernelType_FULL) { } }; @@ -2814,7 +3065,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, - VT_PROJ_CLIP = 8 + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -2825,11 +3077,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } + LSTMKernelType kernel_type() const { + return static_cast(GetField(VT_KERNEL_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && + VerifyField(verifier, VT_KERNEL_TYPE) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2849,6 +3105,9 @@ struct LSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } + void add_kernel_type(LSTMKernelType kernel_type) { + fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2865,10 +3124,12 @@ inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, float cell_clip = 0.0f, - float proj_clip = 0.0f) { + float proj_clip = 0.0f, + LSTMKernelType kernel_type = LSTMKernelType_FULL) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -3673,16 +3934,16 @@ inline flatbuffers::Offset CreateExpOptions( flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct MeanOptionsT : public flatbuffers::NativeTable { - typedef MeanOptions TableType; +struct ReducerOptionsT : public flatbuffers::NativeTable { + typedef ReducerOptions TableType; bool keep_dims; - MeanOptionsT() + ReducerOptionsT() : keep_dims(false) { } }; -struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MeanOptionsT NativeTableType; +struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReducerOptionsT NativeTableType; enum { VT_KEEP_DIMS = 4 }; @@ -3694,38 +3955,38 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); } - MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ReducerOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -struct MeanOptionsBuilder { +struct ReducerOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_keep_dims(bool keep_dims) { - fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); + fbb_.AddElement(ReducerOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); } - explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit ReducerOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - MeanOptionsBuilder &operator=(const MeanOptionsBuilder &); - flatbuffers::Offset Finish() { + ReducerOptionsBuilder &operator=(const ReducerOptionsBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateMeanOptions( +inline flatbuffers::Offset CreateReducerOptions( flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) { - MeanOptionsBuilder builder_(_fbb); + ReducerOptionsBuilder builder_(_fbb); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SqueezeOptionsT : public flatbuffers::NativeTable { typedef SqueezeOptions TableType; @@ -4131,6 +4392,46 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions( flatbuffers::Offset CreateMaximumMinimumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct TileOptionsT : public flatbuffers::NativeTable { + typedef TileOptions TableType; + TileOptionsT() { + } +}; + +struct TileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TileOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TileOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit TileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TileOptionsBuilder &operator=(const TileOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTileOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + TileOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ArgMaxOptionsT : public flatbuffers::NativeTable { typedef ArgMaxOptions TableType; TensorType output_type; @@ -4543,6 +4844,274 @@ inline flatbuffers::Offset CreateTransposeConvOptions( flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ExpandDimsOptionsT : public flatbuffers::NativeTable { + typedef ExpandDimsOptions TableType; + ExpandDimsOptionsT() { + } +}; + +struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExpandDimsOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpandDimsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpandDimsOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ExpandDimsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ExpandDimsOptionsBuilder &operator=(const ExpandDimsOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExpandDimsOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + ExpandDimsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SparseToDenseOptionsT : public flatbuffers::NativeTable { + typedef SparseToDenseOptions TableType; + bool validate_indices; + SparseToDenseOptionsT() + : validate_indices(false) { + } +}; + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SparseToDenseOptionsT NativeTableType; + enum { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { + return GetField(VT_VALIDATE_INDICES, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALIDATE_INDICES) && + verifier.EndTable(); + } + SparseToDenseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparseToDenseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) { + fbb_.AddElement(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SparseToDenseOptionsBuilder &operator=(const SparseToDenseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSparseToDenseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool validate_indices = false) { + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EqualOptionsT : public flatbuffers::NativeTable { + typedef EqualOptions TableType; + EqualOptionsT() { + } +}; + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EqualOptionsBuilder &operator=(const EqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NotEqualOptionsT : public flatbuffers::NativeTable { + typedef NotEqualOptions TableType; + NotEqualOptionsT() { + } +}; + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NotEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NotEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNotEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ShapeOptionsT : public flatbuffers::NativeTable { + typedef ShapeOptions TableType; + TensorType out_type; + ShapeOptionsT() + : out_type(TensorType_FLOAT32) { + } +}; + +struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ShapeOptionsT NativeTableType; + enum { + VT_OUT_TYPE = 4 + }; + TensorType out_type() const { + return static_cast(GetField(VT_OUT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUT_TYPE) && + verifier.EndTable(); + } + ShapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ShapeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_out_type(TensorType out_type) { + fbb_.AddElement(ShapeOptions::VT_OUT_TYPE, static_cast(out_type), 0); + } + explicit ShapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ShapeOptionsBuilder &operator=(const ShapeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateShapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, + TensorType out_type = TensorType_FLOAT32) { + ShapeOptionsBuilder builder_(_fbb); + builder_.add_out_type(out_type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PowOptionsT : public flatbuffers::NativeTable { + typedef PowOptions TableType; + PowOptionsT() { + } +}; + +struct PowOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PowOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PowOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PowOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit PowOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PowOptionsBuilder &operator=(const PowOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePowOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + PowOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4641,6 +5210,7 @@ struct OperatorT : public flatbuffers::NativeTable { BuiltinOptionsUnion builtin_options; std::vector custom_options; CustomOptionsFormat custom_options_format; + std::vector mutating_variable_inputs; OperatorT() : opcode_index(0), custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { @@ -4656,7 +5226,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_BUILTIN_OPTIONS_TYPE = 10, VT_BUILTIN_OPTIONS = 12, VT_CUSTOM_OPTIONS = 14, - VT_CUSTOM_OPTIONS_FORMAT = 16 + VT_CUSTOM_OPTIONS_FORMAT = 16, + VT_MUTATING_VARIABLE_INPUTS = 18 }; uint32_t opcode_index() const { return GetField(VT_OPCODE_INDEX, 0); @@ -4752,8 +5323,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeOptions *builtin_options_as_TransposeOptions() const { return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } - const MeanOptions *builtin_options_as_MeanOptions() const { - return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast(builtin_options()) : nullptr; + const ReducerOptions *builtin_options_as_ReducerOptions() const { + return builtin_options_type() == BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; } const SubOptions *builtin_options_as_SubOptions() const { return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; @@ -4821,12 +5392,36 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; } + const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + } + const TileOptions *builtin_options_as_TileOptions() const { + return builtin_options_type() == BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; + } + const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { + return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; + } + const EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; + } + const NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; + } + const ShapeOptions *builtin_options_as_ShapeOptions() const { + return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; + } + const PowOptions *builtin_options_as_PowOptions() const { + return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } CustomOptionsFormat custom_options_format() const { return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } + const flatbuffers::Vector *mutating_variable_inputs() const { + return GetPointer *>(VT_MUTATING_VARIABLE_INPUTS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX) && @@ -4840,6 +5435,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.Verify(custom_options()) && VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && + VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) && + verifier.Verify(mutating_variable_inputs()) && verifier.EndTable(); } OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4951,8 +5548,8 @@ template<> inline const TransposeOptions *Operator::builtin_options_as inline const MeanOptions *Operator::builtin_options_as() const { - return builtin_options_as_MeanOptions(); +template<> inline const ReducerOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReducerOptions(); } template<> inline const SubOptions *Operator::builtin_options_as() const { @@ -5043,6 +5640,34 @@ template<> inline const TransposeConvOptions *Operator::builtin_options_as inline const SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + +template<> inline const TileOptions *Operator::builtin_options_as() const { + return builtin_options_as_TileOptions(); +} + +template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpandDimsOptions(); +} + +template<> inline const EqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const NotEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_NotEqualOptions(); +} + +template<> inline const ShapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ShapeOptions(); +} + +template<> inline const PowOptions *Operator::builtin_options_as() const { + return builtin_options_as_PowOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5067,6 +5692,9 @@ struct OperatorBuilder { void add_custom_options_format(CustomOptionsFormat custom_options_format) { fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } + void add_mutating_variable_inputs(flatbuffers::Offset> mutating_variable_inputs) { + fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs); + } explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5087,8 +5715,10 @@ inline flatbuffers::Offset CreateOperator( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + flatbuffers::Offset> mutating_variable_inputs = 0) { OperatorBuilder builder_(_fbb); + builder_.add_mutating_variable_inputs(mutating_variable_inputs); builder_.add_custom_options(custom_options); builder_.add_builtin_options(builtin_options); builder_.add_outputs(outputs); @@ -5107,7 +5737,8 @@ inline flatbuffers::Offset CreateOperatorDirect( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + const std::vector *mutating_variable_inputs = nullptr) { return tflite::CreateOperator( _fbb, opcode_index, @@ -5116,7 +5747,8 @@ inline flatbuffers::Offset CreateOperatorDirect( builtin_options_type, builtin_options, custom_options ? _fbb.CreateVector(*custom_options) : 0, - custom_options_format); + custom_options_format, + mutating_variable_inputs ? _fbb.CreateVector(*mutating_variable_inputs) : 0); } flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5487,6 +6119,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t { auto _e = buffer(); _o->buffer = _e; }; { auto _e = name(); if (_e) _o->name = _e->str(); }; { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; + { auto _e = is_variable(); _o->is_variable = _e; }; } inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5502,13 +6135,15 @@ inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder & auto _buffer = _o->buffer; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + auto _is_variable = _o->is_variable; return tflite::CreateTensor( _fbb, _shape, _type, _buffer, _name, - _quantization); + _quantization, + _is_variable); } inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -5812,6 +6447,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl (void)_o; (void)_resolver; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = weights_format(); _o->weights_format = _e; }; } inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5823,9 +6459,11 @@ inline flatbuffers::Offset CreateFullyConnectedOptions(fl (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; + auto _weights_format = _o->weights_format; return tflite::CreateFullyConnectedOptions( _fbb, - _fused_activation_function); + _fused_activation_function, + _weights_format); } inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6008,6 +6646,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = kernel_type(); _o->kernel_type = _e; }; } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6021,11 +6660,13 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; + auto _kernel_type = _o->kernel_type; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, - _proj_clip); + _proj_clip, + _kernel_type); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6429,28 +7070,28 @@ inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferB _fbb); } -inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new MeanOptionsT(); +inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReducerOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; { auto _e = keep_dims(); _o->keep_dims = _e; }; } -inline flatbuffers::Offset MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateMeanOptions(_fbb, _o, _rehasher); +inline flatbuffers::Offset ReducerOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReducerOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReducerOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions( + return tflite::CreateReducerOptions( _fbb, _keep_dims); } @@ -6643,6 +7284,29 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions(fl _fbb); } +inline TileOptionsT *TileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TileOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TileOptions::UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset TileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTileOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTileOptions( + _fbb); +} + inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ArgMaxOptionsT(); UnPackTo(_o, _resolver); @@ -6862,6 +7526,150 @@ inline flatbuffers::Offset CreateTransposeConvOptions(flat _stride_h); } +inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ExpandDimsOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ExpandDimsOptions::UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ExpandDimsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpandDimsOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpandDimsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpandDimsOptions( + _fbb); +} + +inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SparseToDenseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = validate_indices(); _o->validate_indices = _e; }; +} + +inline flatbuffers::Offset SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparseToDenseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _validate_indices = _o->validate_indices; + return tflite::CreateSparseToDenseOptions( + _fbb, + _validate_indices); +} + +inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateEqualOptions( + _fbb); +} + +inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NotEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNotEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNotEqualOptions( + _fbb); +} + +inline ShapeOptionsT *ShapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ShapeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = out_type(); _o->out_type = _e; }; +} + +inline flatbuffers::Offset ShapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateShapeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ShapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _out_type = _o->out_type; + return tflite::CreateShapeOptions( + _fbb, + _out_type); +} + +inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PowOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PowOptions::UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset PowOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePowOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePowOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -6910,6 +7718,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; { auto _e = custom_options_format(); _o->custom_options_format = _e; }; + { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } }; } inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6927,6 +7736,7 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild auto _builtin_options = _o->builtin_options.Pack(_fbb); auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; auto _custom_options_format = _o->custom_options_format; + auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0; return tflite::CreateOperator( _fbb, _opcode_index, @@ -6935,7 +7745,8 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild _builtin_options_type, _builtin_options, _custom_options, - _custom_options_format); + _custom_options_format, + _mutating_variable_inputs); } inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -7152,8 +7963,8 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SubOptions: { @@ -7244,6 +8055,34 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7366,8 +8205,8 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SubOptions: { @@ -7458,6 +8297,34 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7568,9 +8435,9 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeOptions(_fbb, ptr, _rehasher).Union(); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); - return CreateMeanOptions(_fbb, ptr, _rehasher).Union(); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); + return CreateReducerOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SubOptions: { auto ptr = reinterpret_cast(value); @@ -7660,6 +8527,34 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + return CreateTileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + return CreatePowOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7770,8 +8665,8 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeOptionsT(*reinterpret_cast(u.value)); break; } - case BuiltinOptions_MeanOptions: { - value = new MeanOptionsT(*reinterpret_cast(u.value)); + case BuiltinOptions_ReducerOptions: { + value = new ReducerOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SubOptions: { @@ -7862,6 +8757,34 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeConvOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_SparseToDenseOptions: { + value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TileOptions: { + value = new TileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpandDimsOptions: { + value = new ExpandDimsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EqualOptions: { + value = new EqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NotEqualOptions: { + value = new NotEqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ShapeOptions: { + value = new ShapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PowOptions: { + value = new PowOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -7999,8 +8922,8 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -8114,6 +9037,41 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 2f2004f56bcad5b56f9dd6d4bc824ec14d79e795..4eaf6f1bfe76efc1e6737d03d58be9bc87bb849d 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -36,6 +36,12 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, ArenaAlloc* new_alloc) { TF_LITE_ENSURE(context, alignment < arena_alignment_); + if (size == 0) { + new_alloc->offset = 0; + new_alloc->size = 0; + return kTfLiteOk; + } + size_t current_top = 0; if (!allocs_.empty()) { @@ -75,6 +81,10 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context, const ArenaAlloc& alloc) { + if (alloc.size == 0) { + return kTfLiteOk; + } + int erased_allocs_count = 0; auto it = allocs_.begin(); while (it != allocs_.end()) { @@ -122,7 +132,11 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + if (alloc.size == 0) { + *output_ptr = nullptr; + } else { + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 5faf78b59e3755d22e4e866d433e622baa6c66c1..f738315cf2f91403f9dcb6fa9e66b40bd70495aa 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -39,7 +39,8 @@ struct ArenaAlloc { // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is -// repetitive, e.g. running NN inference in multiple iterations. +// repetitive, e.g. running NN inference in multiple iterations. Note that +// zero-sized allocations are explicitly allowed, and will resolve to null. class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment) diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc index 4444f642eb75c563c57762d095e454ac63d836c6..60d4d5e768aeda958574422e1c36a7cc2f6a1429 100644 --- a/tensorflow/contrib/lite/simple_memory_arena_test.cc +++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc @@ -43,6 +43,47 @@ TEST(SimpleMemoryArenaTest, BasicArenaOperations) { EXPECT_EQ(allocs[5].offset, 1024); } +TEST(SimpleMemoryArenaTest, BasicZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc alloc; + + // Zero-sized allocs should have a 0 offset and size. + ASSERT_EQ(arena.Allocate(&context, 32, 0, &alloc), kTfLiteOk); + EXPECT_EQ(alloc.offset, 0); + EXPECT_EQ(alloc.size, 0); + + // Deallocation of zero-sized allocs should always succeed (even redundantly). + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + + // The zero-sized alloc should resolve to null. + char* resolved_ptr = nullptr; + ASSERT_EQ(arena.Commit(&context), kTfLiteOk); + ASSERT_EQ(arena.ResolveAlloc(&context, alloc, &resolved_ptr), kTfLiteOk); + EXPECT_EQ(resolved_ptr, nullptr); +} + +TEST(SimpleMemoryArenaTest, InterleavedZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[4]; + + // Interleave some zero and non-zero-sized allocations and deallocations. + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[0]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 0, &allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 1023, &allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[3]), kTfLiteOk); + + // Deallocation of a zero-sized alloc should not impact the allocator offsets. + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 0); + EXPECT_EQ(allocs[2].offset, 2048); + EXPECT_EQ(allocs[3].offset, 2048); +} + TEST(SimpleMemoryArenaTest, TestAfterClear) { TfLiteContext context; SimpleMemoryArena arena(64); diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc index a89776b29f895fe82ee71efe00c0949c58c109df..a316a40b62d89189da43768d448acdf5bbeca129 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/contrib/lite/string_util.cc @@ -105,7 +105,7 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { dims->data[0] = offset_.size() - 1; // Store number of strings. TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, - tensor); + tensor->is_variable, tensor); } int GetStringCount(const char* raw_buffer) { diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index a722fe106beeddbf08cb7c8a6dd63ecfc2f80933..789bc695f8e9f8721edeb3b3a3f2af59b36adeed 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -6,7 +6,8 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow/contrib/lite:build_def.bzl", - "gen_zipped_test_files", + "gen_zip_test", + "generated_test_models", ) load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load( @@ -14,60 +15,56 @@ load( "tf_cc_test", ) -gen_zipped_test_files( - name = "optest", - files = [ - "add.zip", - "arg_max.zip", - "avg_pool.zip", - "batch_to_space_nd.zip", - "concat.zip", - "constant.zip", - "control_dep.zip", - "conv.zip", - "depthwiseconv.zip", - "div.zip", - "exp.zip", - "floor.zip", - "fully_connected.zip", - "fused_batch_norm.zip", - "gather.zip", - "global_batch_norm.zip", - "greater.zip", - "greater_equal.zip", - "l2_pool.zip", - "l2norm.zip", - "less.zip", - "less_equal.zip", - "local_response_norm.zip", - "log_softmax.zip", - "max_pool.zip", - "maximum.zip", - "mean.zip", - "minimum.zip", - "mul.zip", - "neg.zip", - "pad.zip", - "padv2.zip", - "relu.zip", - "relu1.zip", - "relu6.zip", - "reshape.zip", - "resize_bilinear.zip", - "sigmoid.zip", - "sin.zip", - "slice.zip", - "softmax.zip", - "space_to_batch_nd.zip", - "space_to_depth.zip", - "split.zip", - "squeeze.zip", - "strided_slice.zip", - "sub.zip", - "topk.zip", - "transpose.zip", - "transpose_conv.zip", - "where.zip", +[gen_zip_test( + name = "zip_test_%s" % test_name, + size = "large", + srcs = ["generated_examples_zip_test.cc"], + args = [ + ] + select({ + "//tensorflow:android": [], + "//conditions:default": [ + "--zip_file_path=$(location :zip_%s)" % test_name, + # TODO(angerson) We may be able to add an external unzip binary instead + # of relying on an existing one for OSS builds. + "--unzip_binary_path=/usr/bin/unzip", + ], + }), + data = [ + ":zip_%s" % test_name, + ], + shard_count = 20, + tags = [ + "gen_zip_test", + "no_oss", + "tflite_not_portable", + ], + test_name = test_name, + deps = [ + ":parse_testdata_lib", + ":tflite_driver", + ":util", + "@com_google_googletest//:gtest", + "@com_googlesource_code_re2//:re2", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ] + select({ + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + }), +) for test_name in generated_test_models()] + +test_suite( + name = "generated_zip_tests", + tags = [ + "gen_zip_test", ], ) @@ -162,6 +159,7 @@ cc_library( deps = [ ":split", ":test_runner", + "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", ], @@ -174,6 +172,7 @@ cc_test( data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ ":tflite_driver", @@ -352,42 +351,4 @@ cc_binary( ], ) -tf_cc_test( - name = "generated_examples_zip_test", - size = "large", - srcs = ["generated_examples_zip_test.cc"], - args = [ - "--zip_files_dir=tensorflow/contrib/lite/testing/optest", - # TODO(angerson) We may be able to add an external unzip binary instead - # of relying on an existing one for OSS builds. - "--unzip_binary_path=/usr/bin/unzip", - ], - data = [":optest"], - shard_count = 20, - tags = [ - "no_oss", - "tflite_not_portable", - ], - deps = [ - ":parse_testdata_lib", - ":tflite_driver", - ":util", - "@com_google_googletest//:gtest", - "@com_googlesource_code_re2//:re2", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - ] + select({ - "//conditions:default": [ - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", - "//tensorflow/core:android_tensorflow_test_lib", - ], - }), -) - tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 1008dd6fe2be8bc367c07265bdf2f0ca26391f13..50237ed79232cff0be7ae8c5b125ac1ee7fdf520 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -24,12 +24,15 @@ bazel run //tensorflow/contrib/lite/testing:generate_examples To more easily debug failures use (or override) the --save_graphdefs flag to place text proto graphdefs into the generated zip files. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse +import functools import itertools +import operator import os import random import re @@ -55,10 +58,11 @@ from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", help="Directory where the outputs will be go.") -parser.add_argument("--zip_to_output", - type=str, - help="Particular zip to output.", - required=False) +parser.add_argument( + "--zip_to_output", + type=str, + help="Particular zip to output.", + required=True) parser.add_argument("--toco", type=str, help="Path to toco tool.", @@ -90,12 +94,10 @@ KNOWN_BUGS = { r"sigmoid.*input_shape=\[\]": "67645668", # Concat doesn't work with a single input tensor r"concat.*num_tensors=1": "67378344", - # Transposition in MatMul is not supported. - r"fully_connected.*transpose_.=True": "67586970", + # Transposition in MatMul is not fully supported. + "fully_connected.*transpose_a=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", - # SpaceToDepth only supports float32. - r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", # Div will use floordiv. @@ -115,6 +117,8 @@ class ExtraTocoOptions(object): self.allow_custom_ops = False # Rnn states that are used to support rnn / lstm cells. self.rnn_states = None + # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite. + self.split_tflite_lstm_inputs = None def toco_options(data_types, @@ -133,7 +137,7 @@ def toco_options(data_types, Returns: the options in a string. """ - shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x]) inference_type = "FLOAT" # TODO(ahentz): if we get multi-input quantization to work we need this # to change @@ -143,14 +147,20 @@ def toco_options(data_types, " --inference_type=%s" % inference_type + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + " --input_arrays=%s" % ",".join(input_arrays) + - " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) + if shape_str: + s += (" --input_shapes=%s" % shape_str) if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" if extra_toco_options.allow_custom_ops: s += " --allow_custom_ops" if extra_toco_options.rnn_states: s += (" --rnn_states='" + extra_toco_options.rnn_states + "'") + if extra_toco_options.split_tflite_lstm_inputs is not None: + if extra_toco_options.split_tflite_lstm_inputs: + s += " --split_tflite_lstm_inputs=true" + else: + s += " --split_tflite_lstm_inputs=false" return s @@ -235,6 +245,19 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): return value.astype(dtype) +def create_scalar_data(dtype, min_value=-100, max_value=100): + """Build scalar tensor data range from min_value to max_value exclusively.""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value - min_value) * np.random.random() + min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.randint(min_value, max_value + 1) + return np.array(value, dtype=dtype) + + def freeze_graph(session, outputs): """Freeze the current graph. @@ -325,6 +348,11 @@ def normalize_output_name(output_name): ":0") else output_name +# How many test cases we may have in a zip file. Too many test cases will +# slow down the test data generation process. +_MAX_TESTS_PER_ZIP = 500 + + def make_zip_of_tests(zip_path, test_parameters, make_graph, @@ -354,19 +382,39 @@ def make_zip_of_tests(zip_path, Raises: RuntimeError: if there are toco errors that can't be ignored. """ + parameter_count = 0 + for parameters in test_parameters: + parameter_count += functools.reduce( + operator.mul, [len(values) for values in parameters.values()]) + + if parameter_count > _MAX_TESTS_PER_ZIP: + raise RuntimeError( + "Too many parameter combinations for generating '%s'.\n" + "There are %d combinations while the upper limit is %d.\n" + "Having too many combinations will slow down the tests.\n" + "Please consider splitting the test into multiple functions.\n" + % (zip_path, parameter_count, _MAX_TESTS_PER_ZIP)) # TODO(aselle): Make this allow multiple inputs outputs. archive = zipfile.PyZipFile(zip_path, "w") zip_manifest = [] convert_report = [] toco_errors = 0 + + processed_labels = set() for parameters in test_parameters: keys = parameters.keys() for curr in itertools.product(*parameters.values()): - label = zip_path.replace(".zip", "") + (",".join( + label = zip_path.replace(".zip", "_") + (",".join( "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) if label[0] == "/": label = label[1:] + if label in processed_labels: + # Do not populate data for the same label more than once. It will cause + # errors when unzipping. + continue + processed_labels.add(label) + param_dict = dict(zip(keys, curr)) def build_example(label, param_dict_real): @@ -419,6 +467,11 @@ def make_zip_of_tests(zip_path, sess, tf.global_variables() + inputs + outputs) if use_frozen_graph else sess.graph_def + + if "split_tflite_lstm_inputs" in param_dict_real: + extra_toco_options.split_tflite_lstm_inputs = param_dict_real[ + "split_tflite_lstm_inputs"] + tflite_model_binary, toco_log = toco_convert( graph_def.SerializeToString(), input_tensors, output_tensors, extra_toco_options) @@ -465,6 +518,7 @@ def make_zip_of_tests(zip_path, report["toco_log"]) convert_report.append((param_dict, report)) + report_io = StringIO() report_lib.make_report_table(report_io, zip_path, convert_report) archive.writestr("report.html", report_io.getvalue()) @@ -651,7 +705,7 @@ def make_constant_tests(zip_path): def make_binary_op_tests(zip_path, binary_operator): - """Make a set of tests to do add with and without broadcast.""" + """Make a set of tests to do binary ops with and without broadcast.""" # These parameters are split because we don't support broadcasting. test_parameters = [{ @@ -701,65 +755,89 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_mean_tests(zip_path): - """Make a set of tests to do mean.""" +def make_reduce_tests(reduce_op): + """Make a set of tests to do reduce operation. - test_parameters = [{ - "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape": [[3, 2, 4]], - "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], - [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], - [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }, { - "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape": [[1, 224, 224, 3]], - "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], - [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, - -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], - [2, 2, 3], [-3, -3, -4], [-3, 2, 1] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }] + Args: + reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. - def build_graph(parameters): - """Build the mean op testing graph.""" - input_tensor = tf.placeholder( - dtype=parameters["input_dtype"], - name="input", - shape=parameters["input_shape"]) + Returns: + a function representing the true generator with `reduce_op_in` curried. + """ - # Get axis as either a placeholder or constants. - if parameters["const_axis"]: - axis = parameters["axis"] - input_tensors = [input_tensor] - else: - if isinstance(parameters["axis"], list): - shape = [len(parameters["axis"])] + def f(zip_path): + """Actual function that generates examples.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape": [[3, 2, 4]], + "axis": [ + None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], + [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[1, 8, 8, 3]], + "axis": [ + None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, + -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], + [2, 2, 3], [-3, -3, -4], [-3, 2, 1] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }] + + def build_graph(parameters): + """Build the mean op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + # Get axis as either a placeholder or constants. + if parameters["const_axis"]: + axis = parameters["axis"] + input_tensors = [input_tensor] else: - shape = [0] # shape for None or integers. - axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) - input_tensors = [input_tensor, axis] + if isinstance(parameters["axis"], list): + shape = [len(parameters["axis"])] + else: + shape = [0] # shape for None or integers. + axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) + input_tensors = [input_tensor, axis] - out = tf.reduce_mean( - input_tensor, axis=axis, keepdims=parameters["keepdims"]) - return input_tensors, [out] + out = reduce_op( + input_tensor, axis=axis, keepdims=parameters["keepdims"]) + return input_tensors, [out] - def build_inputs(parameters, sess, inputs, outputs): - values = [ - create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - ] - if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) - return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], + parameters["input_shape"])] + if not parameters["const_axis"]: + if parameters["axis"]: + values.append(np.array(parameters["axis"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_mean_tests(zip_path): + """Make a set of tests to do mean.""" + + return make_reduce_tests(tf.reduce_mean)(zip_path) + + +def make_sum_tests(zip_path): + """Make a set of tests to do sum.""" + + return make_reduce_tests(tf.reduce_sum)(zip_path) def make_exp_tests(zip_path): @@ -912,6 +990,10 @@ def make_mul_tests(zip_path): make_binary_op_tests(zip_path, tf.multiply) +def make_pow_tests(zip_path): + make_binary_op_tests(zip_path, tf.pow) + + def make_gather_tests(zip_path): """Make a set of tests to do gather.""" @@ -1243,6 +1325,12 @@ def make_fully_connected_tests(zip_path): "transpose_a": [False], "transpose_b": [False], "constant_filter": [True, False], + }, { + "shape1": [[40, 37]], + "shape2": [[40, 37]], + "transpose_a": [False], + "transpose_b": [True], + "constant_filter": [True, False], }] def build_graph(parameters): @@ -1313,10 +1401,10 @@ def make_local_response_norm_tests(zip_path): # Chose a set of parameters test_parameters = [{ "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], - "depth_radius": [None, 0, 1, 3, 4, 5], - "bias": [None, 0.1, 0.3, -0.1], - "alpha": [None, 1, 2, -3], - "beta": [None, 0.5, 0.25, 2], + "depth_radius": [None, 0, 1, 3, 5], + "bias": [None, 0.3, -0.1], + "alpha": [None, 2, -3], + "beta": [None, 0.25, 2], }] def build_graph(parameters): @@ -1467,6 +1555,32 @@ def make_reshape_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_shape_tests(zip_path): + """Make a set of tests to do shape.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]], + "out_type": [tf.int32, tf.int64], + }] + + def build_graph(parameters): + """Build the topk op testing graph.""" + # Note that we intentionally leave out the shape from the input placeholder + # to prevent the Shape operation from being optimized out during conversion. + input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") + out = tf.shape(input_value, out_type=parameters["out_type"]) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_resize_bilinear_tests(zip_path): """Make a set of tests to do resize_bilinear.""" @@ -1548,7 +1662,7 @@ def make_space_to_depth_tests(zip_path): """Make a set of tests to do space_to_depth.""" test_parameters = [{ - "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64], "input_shape": [[2, 12, 24, 1]], "block_size": [2, 3, 4], }] @@ -1791,77 +1905,8 @@ def make_squeeze_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_strided_slice_tests(zip_path): - """Make a set of tests to do strided_slice.""" - - # TODO(soroosh): add test/support for uint8. - test_parameters = [ - # 4-D - { - "dtype": [tf.float32, tf.int32, tf.int64], - "index_type": [tf.int32], - "input_shape": [[12, 2, 2, 5]], - "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], - "end": [[8, 2, 2, 3], [12, 2, 2, 5]], - "strides": [None, [2, 1, 3, 1]], - "begin_mask": [None, 1, 8], - "end_mask": [None, 1, 8], - "shrink_axis_mask": [None, 1, 8, 11, 15, -1], - "constant_indices": [False, True], - }, - # Begin, end, strides dim are different from input shape - { - "dtype": [tf.float32], - "index_type": [tf.int32], - "input_shape": [[12, 2, 2, 5]], - "begin": [[0]], - "end": [[1]], - "strides": [None, [1]], - "begin_mask": [0], - "end_mask": [0], - "shrink_axis_mask": [1], - "constant_indices": [True], - }, - # 2-D - { - "dtype": [tf.float32, tf.int32, tf.int64], - "index_type": [tf.int32], - "input_shape": [[2, 3]], - "begin": [[0, 0], [1, 0]], - "end": [[2, 3], [2, 2]], - "strides": [None, [2, 2]], - "begin_mask": [None, 1, 2], - "end_mask": [None, 1, 2], - "shrink_axis_mask": [None, 1, 2, 3, -1], - "constant_indices": [False, True], - }, - # 1-D Exhaustive - { - "dtype": [tf.float32], - "index_type": [tf.int32], - "input_shape": [[4]], - "begin": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]], - "end": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]], - "strides": [-2, -1, 1, 2], - "begin_mask": [0, 1], - "end_mask": [0, 1], - "shrink_axis_mask": [0], - "constant_indices": [False], - }, - # Negative strides - { - "dtype": [tf.float32], - "index_type": [tf.int32], - "input_shape": [[2, 3]], - "begin": [[0, -1]], - "end": [[2, -3]], - "strides": [[1, -1]], - "begin_mask": [None, 1, 2], - "end_mask": [None, 1, 2], - "shrink_axis_mask": [None, 1, 2, 3, -1], - "constant_indices": [False], - }, - ] +def _make_strided_slice_tests(zip_path, test_parameters): + """Utility function to make strided_slice_tests based on parameters.""" def build_graph(parameters): """Build graph for stride_slice test.""" @@ -1923,6 +1968,100 @@ def make_strided_slice_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_strided_slice_tests(zip_path): + """Make a set of tests to do strided_slice.""" + + # TODO(soroosh): add test/support for uint8. + test_parameters = [ + # 4-D (basic cases with const/non-const indices). + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "strides": [None, [2, 1, 3, 1]], + "begin": [[0, 0, 0, 0]], + "end": [[12, 2, 2, 5]], + "begin_mask": [None], + "end_mask": [None], + "shrink_axis_mask": [None], + "constant_indices": [False, True], + }, + # 4-D with non-trivial begin & end. + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], + "end": [[8, 2, 2, 3], [12, 2, 2, 5]], + "strides": [None, [2, 1, 3, 1]], + "begin_mask": [None, 8], + "end_mask": [None, 3], + "shrink_axis_mask": [None, 15, -1], + "constant_indices": [True], + }, + # Begin, end, strides dim are different from input shape + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0]], + "end": [[1]], + "strides": [None, [1]], + "begin_mask": [0], + "end_mask": [0], + "shrink_axis_mask": [1], + "constant_indices": [True], + }, + # 2-D + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, 0]], + "end": [[2, 2]], + "strides": [None, [2, 2]], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False, True], + }, + # Negative strides + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, -1]], + "end": [[2, -3]], + "strides": [[1, -1]], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False], + }, + ] + _make_strided_slice_tests(zip_path, test_parameters) + + +def make_strided_slice_1d_exhaustive_tests(zip_path): + """Make a set of exhaustive tests for 1D strided_slice.""" + test_parameters = [ + # 1-D Exhaustive + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[3]], + "begin": [[-2], [-1], [0], [1], [2]], + "end": [[-2], [-1], [0], [1], [2]], + "strides": [[-2], [-1], [1], [2]], + "begin_mask": [0, 1], + "end_mask": [0, 1], + "shrink_axis_mask": [0], + "constant_indices": [False], + }, + ] + _make_strided_slice_tests(zip_path, test_parameters) + + def make_lstm_tests(zip_path): """Make a set of tests to do basic Lstm cell.""" @@ -1933,6 +2072,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], + "split_tflite_lstm_inputs": [True, False], }, ] @@ -2067,6 +2207,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_equal_tests(zip_path): + """Make a set of tests to do equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_not_equal_tests(zip_path): + """Make a set of tests to do not equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the not euqal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.not_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_greater_tests(zip_path): """Make a set of tests to do greater.""" @@ -2254,30 +2462,54 @@ def make_neg_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def _make_elementwise_tests(op): + """Make a set of tests to do element-wise operations.""" + + def f(zip_path): + """Actual function that generates examples.""" + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] + + def build_graph(parameters): + """Build the unary op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = op(input_value) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict={inputs[0]: input_value}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + def make_sin_tests(zip_path): """Make a set of tests to do sin.""" + return _make_elementwise_tests(tf.sin)(zip_path) - test_parameters = [{ - "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], - }] - def build_graph(parameters): - """Build the sin op testing graph.""" - input_value = tf.placeholder( - dtype=parameters["input_dtype"], - name="input1", - shape=parameters["input_shape"]) - out = tf.sin(input_value) - return [input_value], [out] +def make_log_tests(zip_path): + """Make a set of tests to do log.""" + return _make_elementwise_tests(tf.log)(zip_path) - def build_inputs(parameters, sess, inputs, outputs): - input_value = create_tensor_data(parameters["input_dtype"], - parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict={inputs[0]: input_value}) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_sqrt_tests(zip_path): + """Make a set of tests to do sqrt.""" + return _make_elementwise_tests(tf.sqrt)(zip_path) + + +def make_rsqrt_tests(zip_path): + """Make a set of tests to do 1/sqrt.""" + return _make_elementwise_tests(tf.rsqrt)(zip_path) def make_where_tests(zip_path): @@ -2431,6 +2663,134 @@ def make_transpose_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_tile_tests(zip_path): + """Make a set of tests to do tile.""" + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[3, 2, 1], [2, 2, 2]], + "multiplier_dtype": [tf.int32, tf.int64], + "multiplier_shape": [[3]] + }] + + def build_graph(parameters): + """Build the tile op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + shape=parameters["input_shape"], + name="input") + multiplier_value = tf.placeholder( + dtype=parameters["multiplier_dtype"], + shape=parameters["multiplier_shape"], + name="multiplier") + out = tf.tile(input_value, multiplier_value) + return [input_value, multiplier_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + multipliers_value = create_tensor_data(parameters["multiplier_dtype"], + parameters["multiplier_shape"]) + return [input_value, multipliers_value], sess.run( + outputs, + feed_dict={ + inputs[0]: input_value, + inputs[1]: multipliers_value + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_expand_dims_tests(zip_path): + """Make a set of tests to do expand_dims.""" + + test_parameters = [{ + "input_type": [tf.float32, tf.int32], + "input_shape": [[3, 4], [10, 10, 3]], + "axis_value": [0, 1, 2, -1, -2], + }] + + def build_graph(parameters): + """Build the where op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_type"], + name="input", + shape=parameters["input_shape"]) + axis_value = tf.placeholder(dtype=tf.int32, name="axis", shape=[1]) + out = tf.expand_dims(input_value, axis=axis_value) + return [input_value, axis_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_type"], + parameters["input_shape"]) + axis_value = np.array([parameters["axis_value"]], dtype=np.int32) + return [input_value, axis_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, axis_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_sparse_to_dense_tests(zip_path): + """Make a set of tests to do sparse to dense.""" + + test_parameters = [{ + "value_dtype": [tf.float32, tf.int32], + "index_dtype": [tf.int32, tf.int64], + "value_count": [1, 3, 6, 8], + "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], + "default_value": [0, -1], + "value_is_scalar": [True, False], + }] + + # Return a single value for 1-D dense shape, but a tuple for other shapes. + def generate_index(dense_shape): + if len(dense_shape) == 1: + return np.random.randint(dense_shape[0]) + else: + index = [] + for shape in dense_shape: + index.append(np.random.randint(shape)) + return tuple(index) + + def build_graph(parameters): + """Build the sparse_to_dense op testing graph.""" + dense_shape = parameters["dense_shape"] + + # Special handle for value_is_scalar case. + # value_count must be 1. + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + value = tf.placeholder( + name="value", dtype=parameters["value_dtype"], shape=()) + else: + value = tf.placeholder( + name="value", + dtype=parameters["value_dtype"], + shape=[parameters["value_count"]]) + indices = set() + while len(indices) < parameters["value_count"]: + indices.add(generate_index(dense_shape)) + indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) + # TODO(renjieliu): Add test for validate_indices case. + out = tf.sparse_to_dense( + indices, + dense_shape, + value, + parameters["default_value"], + validate_indices=False) + + return [value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + input_value = create_scalar_data(parameters["value_dtype"]) + else: + input_value = create_tensor_data(parameters["value_dtype"], + [parameters["value_count"]]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc index c0c861ff6da2fc144b9303dfdd48f19794cebeca..c1092e4d25567f0374e3cd5a27bde32419d3db19 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -25,7 +25,7 @@ namespace testing { template void GenerateCsv(const std::vector& shape, float min, float max, string* out) { - auto random_float = [](int min, int max) { + auto random_float = [](float min, float max) { static unsigned int seed; return min + (max - min) * static_cast(rand_r(&seed)) / RAND_MAX; }; @@ -37,16 +37,10 @@ void GenerateCsv(const std::vector& shape, float min, float max, *out = Join(data.data(), data.size(), ","); } -bool GenerateTestSpecFromTensorflowModel( - std::iostream& stream, const string& tensorflow_model_path, - const string& tflite_model_path, const std::vector& input_layer, +std::vector GenerateInputValues( + const std::vector& input_layer, const std::vector& input_layer_type, - const std::vector& input_layer_shape, - const std::vector& output_layer) { - CHECK_EQ(input_layer.size(), input_layer_type.size()); - CHECK_EQ(input_layer.size(), input_layer_shape.size()); - - // Generate inputs. + const std::vector& input_layer_shape) { std::vector input_values; input_values.resize(input_layer.size()); for (int i = 0; i < input_layer.size(); i++) { @@ -73,9 +67,22 @@ bool GenerateTestSpecFromTensorflowModel( default: fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n", type, input_layer_type[i].c_str()); - return false; + input_values.clear(); + return input_values; } } + return input_values; +} + +bool GenerateTestSpecFromTensorflowModel( + std::iostream& stream, const string& tensorflow_model_path, + const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); // Invoke tensorflow model. TfDriver runner(input_layer, input_layer_type, input_layer_shape, @@ -91,39 +98,51 @@ bool GenerateTestSpecFromTensorflowModel( return false; } - for (int i = 0; i < input_values.size(); i++) { - runner.SetInput(i, input_values[i]); - if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; - return false; - } - } - - runner.Invoke(); - if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; - return false; - } - - // Write test spec. + // Write first part of test spec, defining model and input shapes. stream << "load_model: " << tflite_model_path << "\n"; stream << "reshape {\n"; for (const auto& shape : input_layer_shape) { stream << " input: \"" << shape << "\"\n"; } stream << "}\n"; - stream << "invoke {\n"; - for (const auto& value : input_values) { - stream << " input: \"" << value << "\"\n"; - } - for (int i = 0; i < output_layer.size(); i++) { - stream << " output: \"" << runner.ReadOutput(i) << "\"\n"; + + // Generate inputs. + for (int i = 0; i < num_invocations; ++i) { + // Note that the input values are random, so each invocation will have a + // different set. + std::vector input_values = + GenerateInputValues(input_layer, input_layer_type, input_layer_shape); + if (input_values.empty()) return false; + + // Run TensorFlow. + for (int j = 0; j < input_values.size(); j++) { + runner.SetInput(j, input_values[j]); + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } + } + + runner.Invoke(); if (!runner.IsValid()) { cerr << runner.GetErrorMessage() << endl; return false; } + + // Write second part of test spec, with inputs and outputs. + stream << "invoke {\n"; + for (const auto& value : input_values) { + stream << " input: \"" << value << "\"\n"; + } + for (int j = 0; j < output_layer.size(); j++) { + stream << " output: \"" << runner.ReadOutput(j) << "\"\n"; + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } + } + stream << "}\n"; } - stream << "}\n"; return true; } diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h index 6e31a853c3f7f82a89126ff83af784ffd418741a..bfaf5e7ec89bbdd85b68a7dc45d7686e143e5d3d 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.h +++ b/tensorflow/contrib/lite/testing/generate_testspec.h @@ -30,13 +30,15 @@ namespace testing { // stream: mutable iostream that contains the contents of test spec. // tensorflow_model_path: path to TensorFlow model. // tflite_model_path: path to tflite_model_path that the test spec runs +// num_invocations: how many pairs of inputs and outputs will be generated. // against. input_layer: names of input tensors. Example: input1 // input_layer_type: datatypes of input tensors. Example: float // input_layer_shape: shapes of input tensors, separated by comma. example: // 1,3,4 output_layer: names of output tensors. Example: output bool GenerateTestSpecFromTensorflowModel( std::iostream& stream, const string& tensorflow_model_path, - const string& tflite_model_path, const std::vector& input_layer, + const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, const std::vector& input_layer_type, const std::vector& input_layer_shape, const std::vector& output_layer); diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index c085ea28ea9c95a796bf6a76f77b8d6d19aee36a..c4e20312d891be6f659845fe4fc66e085955b81b 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -35,8 +35,14 @@ namespace { bool FLAGS_ignore_known_bugs = true; // TODO(b/71769302) zip_files_dir should have a more accurate default, if // possible -string* FLAGS_zip_files_dir = new string("./"); +string* FLAGS_zip_file_path = new string("./"); +#ifndef __ANDROID__ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); +#else +string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); +#endif +bool FLAGS_use_nnapi = false; +bool FLAGS_ignore_unsupported_nnapi = false; } // namespace // TensorFlow system environment for file system called. @@ -47,9 +53,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Add only supports float32. (and "constant" tests use Add) - {R"(^\/adda.*int32)", "68808744"}, - {R"(^\/constant.*int32)", "68808744"}, {R"(^\/mul.*int32)", "68808744"}, {R"(^\/div.*int32)", "68808744"}, {R"(^\/sub.*int32)", "68808744"}, @@ -61,25 +64,25 @@ std::map kBrokenTests = { "70527055"}, // L2Norm only supports tensors with 4D or fewer. - {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, @@ -137,7 +140,10 @@ class ZipEnvironment : public ::testing::Environment { *out_dir = dir; return tensorflow::Status::OK(); } else { - return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed"); + return tensorflow::Status(tensorflow::error::UNKNOWN, + "unzip failed. " + "stdout:\n" + + out + "\nstderr:\n" + err); } } @@ -191,8 +197,7 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir, } // Get a list of tests from a zip file `zip_file_name`. -std::vector UnarchiveZipAndFindTestNames(const string& zip_file_name) { - string zip_file = *FLAGS_zip_files_dir + "/" + zip_file_name; +std::vector UnarchiveZipAndFindTestNames(const string& zip_file) { string decompress_tmp_dir; TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); std::vector stuff; @@ -210,7 +215,7 @@ TEST_P(OpsTest, RunZipTests) { std::ifstream tflite_stream(tflite_test_case); ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case; - tflite::testing::TfLiteDriver test_driver(/*use_nnapi=*/true); + tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi); test_driver.SetModelBaseDir(tflite_dir); string bug_number; @@ -221,16 +226,21 @@ TEST_P(OpsTest, RunZipTests) { } bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver); + string message = test_driver.GetErrorMessage(); if (bug_number.empty()) { - EXPECT_TRUE(result) << test_driver.GetErrorMessage(); + if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { + EXPECT_EQ(message, string("Failed to invoke interpreter")) << message; + } else { + EXPECT_TRUE(result) << message; + } } else { if (FLAGS_ignore_known_bugs) { EXPECT_FALSE(result) << "Test was expected to fail but is now passing; " "you can mark http://b/" << bug_number << " as fixed! Yay!"; } else { - EXPECT_TRUE(result) << test_driver.GetErrorMessage() - << ": Possibly due to http://b/" << bug_number; + EXPECT_TRUE(result) << message << ": Possibly due to http://b/" + << bug_number; } } } @@ -251,66 +261,10 @@ struct ZipPathParamName { } }; -// Instantiate a test. This assumes `zip_base`.zip is a declared data file -// of this test. -#define INSTANTIATE_TESTS(zip_base) \ - INSTANTIATE_TEST_CASE_P( \ - zip_base, OpsTest, \ - ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip")), \ - ZipPathParamName()); - -INSTANTIATE_TESTS(add) -INSTANTIATE_TESTS(arg_max) -INSTANTIATE_TESTS(avg_pool) -INSTANTIATE_TESTS(batch_to_space_nd) -INSTANTIATE_TESTS(concat) -INSTANTIATE_TESTS(constant) -INSTANTIATE_TESTS(control_dep) -INSTANTIATE_TESTS(conv) -INSTANTIATE_TESTS(depthwiseconv) -INSTANTIATE_TESTS(div) -INSTANTIATE_TESTS(exp) -INSTANTIATE_TESTS(floor) -INSTANTIATE_TESTS(fully_connected) -INSTANTIATE_TESTS(fused_batch_norm) -INSTANTIATE_TESTS(gather) -INSTANTIATE_TESTS(global_batch_norm) -INSTANTIATE_TESTS(greater) -INSTANTIATE_TESTS(greater_equal) -INSTANTIATE_TESTS(l2_pool) -INSTANTIATE_TESTS(l2norm) -INSTANTIATE_TESTS(less) -INSTANTIATE_TESTS(less_equal) -INSTANTIATE_TESTS(local_response_norm) -INSTANTIATE_TESTS(log_softmax) -INSTANTIATE_TESTS(max_pool) -INSTANTIATE_TESTS(maximum) -INSTANTIATE_TESTS(mean) -INSTANTIATE_TESTS(minimum) -INSTANTIATE_TESTS(mul) -INSTANTIATE_TESTS(neg) -INSTANTIATE_TESTS(pad) -INSTANTIATE_TESTS(padv2) -// INSTANTIATE_TESTS(prelu) -INSTANTIATE_TESTS(relu) -INSTANTIATE_TESTS(relu1) -INSTANTIATE_TESTS(relu6) -INSTANTIATE_TESTS(reshape) -INSTANTIATE_TESTS(resize_bilinear) -INSTANTIATE_TESTS(sigmoid) -INSTANTIATE_TESTS(sin) -INSTANTIATE_TESTS(slice) -INSTANTIATE_TESTS(softmax) -INSTANTIATE_TESTS(space_to_batch_nd) -INSTANTIATE_TESTS(space_to_depth) -INSTANTIATE_TESTS(split) -INSTANTIATE_TESTS(squeeze) -INSTANTIATE_TESTS(strided_slice) -INSTANTIATE_TESTS(sub) -INSTANTIATE_TESTS(topk) -INSTANTIATE_TESTS(transpose) -INSTANTIATE_TESTS(transpose_conv) -INSTANTIATE_TESTS(where) +INSTANTIATE_TEST_CASE_P( + tests, OpsTest, + ::testing::ValuesIn(UnarchiveZipAndFindTestNames(*FLAGS_zip_file_path)), + ZipPathParamName()); } // namespace testing } // namespace tflite @@ -323,11 +277,17 @@ int main(int argc, char** argv) { "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs, "If a particular model is affected by a known bug, the " "corresponding test should expect the outputs to not match."), - tensorflow::Flag("zip_files_dir", tflite::testing::FLAGS_zip_files_dir, - "Required: Location of the test zips."), + tensorflow::Flag("zip_file_path", tflite::testing::FLAGS_zip_file_path, + "Required: Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, - "Required: Location of a suitable unzip binary.")}; + "Required: Location of a suitable unzip binary."), + tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, + "Whether to enable the NNAPI delegate"), + tensorflow::Flag("ignore_unsupported_nnapi", + &tflite::testing::FLAGS_ignore_unsupported_nnapi, + "Don't fail tests just because delegation to NNAPI " + "is not possible")}; bool success = tensorflow::Flags::Parse(&argc, argv, flags); if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); @@ -335,6 +295,8 @@ int main(int argc, char** argv) { } ::tflite::LogToStderr(); + // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags + // parser removes them? ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc index 5afa0f800cdaa8bf70a11cb6e2ac64ace8138e79..f2c49fe389763110279b3dd1e4f13b1522de0460 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc @@ -20,12 +20,29 @@ int main(int argc, char** argv) { ::tflite::testing::DiffOptions options = ::tflite::testing::ParseTfliteDiffFlags(&argc, argv); if (options.tensorflow_model.empty()) return 1; + int failure_count = 0; - for (int i = 0; i < 100; i++) { - if (!tflite::testing::RunDiffTest(options)) { + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/1)) { ++failure_count; } } - fprintf(stderr, "Num errors: %d\n", failure_count); + int failures_in_first_pass = failure_count; + + if (failure_count == 0) { + // Let's try again with num_invocations > 1 to make sure we can do multiple + // invocations without resetting the interpreter. + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/2)) { + ++failure_count; + } + } + } + + fprintf(stderr, "Num errors in single-inference pass: %d\n", + failures_in_first_pass); + fprintf(stderr, "Num errors in multi-inference pass : %d\n", + failure_count - failures_in_first_pass); + return failure_count != 0 ? 1 : 0; } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 706108ed73bb3fd9bd784cffffe322d6981433e6..7a57e8d3fba29cd106eb038992bb5ed12bb457ae 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -30,6 +30,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string input_layer_type; string input_layer_shape; string output_layer; + int32_t num_runs_per_pass = 100; } values; std::vector flags = { @@ -49,6 +50,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { tensorflow::Flag("output_layer", &values.output_layer, "Names of output tensors, separated by comma. Example " "output_1,output_2"), + tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass, + "Number of full runs in each pass."), }; bool no_inputs = *argc == 1; @@ -63,7 +66,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split(values.input_layer, ","), Split(values.input_layer_type, ","), Split(values.input_layer_shape, ":"), - Split(values.output_layer, ",")}; + Split(values.output_layer, ","), + values.num_runs_per_pass}; } } // namespace testing diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc index f601d3752ddb5df9f2b5ac73d9bc303efaade4a5..19f34c0a51e442804bf2824adc3a1d8bde1eb4b0 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc @@ -25,13 +25,14 @@ limitations under the License. namespace tflite { namespace testing { -bool RunDiffTest(const DiffOptions& options) { +bool RunDiffTest(const DiffOptions& options, int num_invocations) { std::stringstream tflite_stream; if (!GenerateTestSpecFromTensorflowModel( tflite_stream, options.tensorflow_model, options.tflite_model, - options.input_layer, options.input_layer_type, - options.input_layer_shape, options.output_layer)) + num_invocations, options.input_layer, options.input_layer_type, + options.input_layer_shape, options.output_layer)) { return false; + } TfLiteDriver tflite_driver(/*use_nnapi=*/true); tflite_driver.LoadModel(options.tflite_model); return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h index 326fa6c3e28000dee9b6eb9cc5b3a6c5c87e28d0..4ab2f230fdcdfe4616ab1706aa41f0e806665f66 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h @@ -40,10 +40,14 @@ struct DiffOptions { // Names of output tensors. // Example output_1,output_2 std::vector output_layer; + // Number of full runs (from building interpreter to checking outputs) in + // each of the passes. The first pass has a single inference, while the + // second pass does multiple inferences back to back. + int num_runs_per_pass; }; // Run a single TensorFLow Lite diff test with a given options. -bool RunDiffTest(const DiffOptions& options); +bool RunDiffTest(const DiffOptions& options, int num_invocations); } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 75ac24719aa8fad960ae06d006eda386d44d721a..4d08fb545801521213890a4f5a9b010de57b27cd 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/testing/split.h" namespace tflite { @@ -143,6 +144,7 @@ void TfLiteDriver::AllocateTensors() { Invalidate("Failed to allocate tensors"); return; } + ResetLSTMStateTensors(); must_allocate_tensors_ = false; } } @@ -161,6 +163,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { Invalidate("Failed build interpreter"); return; } + interpreter_->UseNNAPI(use_nnapi_); must_allocate_tensors_ = true; } @@ -281,5 +284,31 @@ bool TfLiteDriver::CheckResults() { return success; } +void TfLiteDriver::ResetLSTMStateTensors() { + interpreter_->ResetVariableTensorsToZero(); + + // Below is a workaround for initializing state tensors for LSTM. + // TODO(ycling): Remove the code below after nobody is using the 18-inputs + // definition. + for (auto node_index : interpreter_->execution_plan()) { + const auto& node_and_reg = interpreter_->node_and_registration(node_index); + const auto& node = node_and_reg->first; + const auto& registration = node_and_reg->second; + + if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { + const auto* params = + reinterpret_cast(node.builtin_data); + if (params->kernel_type == kTfLiteLSTMFullKernel && + node.inputs->size == 18 && node.outputs->size >= 2) { + // The first 2 outputs of LSTM are state tensors. + for (int i = 0; i < 2; ++i) { + int node_index = node.outputs->data[i]; + ResetTensor(node_index); + } + } + } + } +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 02b7de1534e648734d7bc53154afa42f2ef256b4..5493ba3631b0423942cc9c4f98fbd6393a404060 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner { string ReadOutput(int id) override { return "no-op"; } private: + void ResetLSTMStateTensors(); + class Expectation; bool use_nnapi_ = false; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index b8acc9a8e0361a4c38fcbe2f16be172e637b95c6..209dce56cbdfbbff5884aa9961bd29e9cf98f49d 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -143,7 +143,6 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", - "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -169,41 +168,6 @@ cc_library( ], ) -cc_library( - name = "toco_saved_model", - srcs = [ - "toco_saved_model.cc", - ], - hdrs = [ - "toco_saved_model.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":model_cmdline_flags", - ":model_flags_proto_cc", - ":toco_flags_proto_cc", - ":types_proto_cc", - "//tensorflow/cc/tools:freeze_saved_model", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "toco_saved_model_test", - srcs = ["toco_saved_model_test.cc"], - deps = [ - ":model_cmdline_flags", - ":toco_cmdline_flags", - ":toco_saved_model", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "graph_transformations", srcs = [ @@ -213,6 +177,7 @@ cc_library( "graph_transformations/convert_squeeze_to_reshape.cc", "graph_transformations/convert_trivial_addn_to_add.cc", "graph_transformations/convert_trivial_stack_to_reshape.cc", + "graph_transformations/convert_trivial_tile_to_concat.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", @@ -220,10 +185,10 @@ cc_library( "graph_transformations/drop_im2col_arrays.cc", "graph_transformations/ensure_bias_vectors.cc", "graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc", - "graph_transformations/experimental_shuffle_fc_weights.cc", "graph_transformations/fuse_activation_functions.cc", "graph_transformations/fuse_binary_into_following_affine.cc", "graph_transformations/fuse_binary_into_preceding_affine.cc", + "graph_transformations/fuse_broadcast_into_following_binary.cc", "graph_transformations/graph_transformations.cc", "graph_transformations/hardcode_min_max.cc", "graph_transformations/identify_dilated_conv.cc", @@ -237,6 +202,7 @@ cc_library( "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", "graph_transformations/merge_reshape_into_preceding_transpose.cc", + "graph_transformations/move_binary_operator_before_reshape.cc", "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_default_min_max.cc", @@ -245,6 +211,7 @@ cc_library( "graph_transformations/quantization_util.cc", "graph_transformations/quantization_util.h", "graph_transformations/quantize.cc", + "graph_transformations/quantize_weights.cc", "graph_transformations/read_fake_quant_min_max.cc", "graph_transformations/remove_final_dequantize_op.cc", "graph_transformations/remove_tensorflow_assert.cc", @@ -292,8 +259,8 @@ cc_library( "graph_transformations/resolve_tensorflow_matmul.cc", "graph_transformations/resolve_tensorflow_merge.cc", "graph_transformations/resolve_tensorflow_switch.cc", - "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", + "graph_transformations/shuffle_fc_weights.cc", "graph_transformations/unfuse_activation_functions.cc", "graph_transformations/unpartition_embedding_lookup.cc", "graph_transformations/unroll_batch_matmul.cc", @@ -373,6 +340,7 @@ tf_cc_test( ":toco_tooling", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", ], @@ -410,6 +378,7 @@ tf_cc_test( deps = [ ":model", ":tooling_util", + "//tensorflow/core:lib", "@com_google_googletest//:gtest_main", ], ) @@ -427,7 +396,6 @@ tf_cc_binary( ":toco_cmdline_flags", ":toco_flags_proto_cc", ":toco_port", - ":toco_saved_model", ":toco_tooling", ":types_proto_cc", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md index 522e260ad2a14c5f8e080c0a0f538f4192b7ed2d..2db6a627ab59604a99cafe3b38df08b70092d989 100644 --- a/tensorflow/contrib/lite/toco/README.md +++ b/tensorflow/contrib/lite/toco/README.md @@ -17,11 +17,12 @@ Usage information is given in these documents: Once an application developer has a trained TensorFlow model, TOCO will accept that model and generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports -[SavedModels](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators) -and frozen graphs (models generated via -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)). -The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally -mobile devices, where the TensorFlow Lite interpreter handles them on-device. -This flow is represented in the diagram below. +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped +to client devices, generally mobile devices, where the TensorFlow Lite +interpreter handles them on-device. This flow is represented in the diagram +below. ![drawing](g3doc/toco_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 6c0311af0a926711955caaa1c7507d7c52c77069..aef35ad490656c09a7d7314aa033bc985b3af661 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -21,13 +21,13 @@ limitations under the License. #include #include #include +#include "tensorflow/contrib/lite/toco/toco_port.h" #if defined(PLATFORM_GOOGLE) #include "strings/split.h" +#include "strings/strip.h" #endif #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -#include "tensorflow/cc/saved_model/tag_constants.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" namespace toco { @@ -145,8 +145,10 @@ class Arg final { } string outer_member_copy = outer_member; absl::StripAsciiWhitespace(&outer_member); - if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; - if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; + if (!strings::TryStripPrefixString(outer_member, "{", &outer_member)) + return false; + if (!strings::TryStripSuffixString(outer_member, "}", &outer_member)) + return false; const std::vector inner_fields_vector = absl::StrSplit(outer_member, ','); @@ -223,7 +225,7 @@ struct ParsedTocoFlags { Arg output_file; Arg input_format = Arg("TENSORFLOW_GRAPHDEF"); Arg output_format = Arg("TFLITE"); - Arg savedmodel_tagset = Arg(tensorflow::kSavedModelTagServe); + Arg savedmodel_tagset; // TODO(aselle): command_line_flags doesn't support doubles Arg default_ranges_min = Arg(0.); Arg default_ranges_max = Arg(0.); @@ -234,6 +236,7 @@ struct ParsedTocoFlags { Arg drop_fake_quant = Arg(false); Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); + Arg quantize_weights = Arg(false); // Deprecated flags Arg input_type; Arg input_types; @@ -242,6 +245,7 @@ struct ParsedTocoFlags { Arg propagate_fake_quant_num_bits = Arg(false); Arg allow_nudging_weights_to_use_fast_gemm_kernel = Arg(false); Arg dedupe_array_min_size_bytes = Arg(64); + Arg split_tflite_lstm_inputs = Arg(true); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 166ead918471ee1b06d9683b8dc7baf7bcbdc427..6877fb237c0514a972589ac0301647104f5ed7ed 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include -#include -#include #include #include "absl/strings/str_replace.h" @@ -91,10 +89,7 @@ Color GetColorForArray(const Model& model, const string& array_name) { // We use gray colors for them because they are the majority // of arrays so we want to highlight other arrays instead of them. // First, we use a bolder gray for input/output arrays: - const auto& dump_options = *GraphVizDumpOptions::singleton(); - if (IsInputArray(model, array_name) || - array_name == dump_options.graphviz_first_array || - array_name == dump_options.graphviz_last_array) { + if (IsInputArray(model, array_name)) { return Color(0x9E, 0x9E, 0x9E); } if (IsOutputArray(model, array_name)) { @@ -137,6 +132,12 @@ void AppendArrayVal(string* string, Array const& array, int index) { return; } AppendF(string, "%d", data[index]); + } else if (array.buffer->type == ArrayDataType::kBool) { + const auto& data = array.GetBuffer().data; + if (index >= data.size()) { + return; + } + AppendF(string, "%d", data[index]); } } @@ -145,6 +146,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties node_properties; node_properties.color = GetColorForArray(model, array_name); node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); + node_properties.log2_buffer_size = 0.0f; // Append array shape to the label. auto& array = model.GetArray(array_name); @@ -164,9 +166,12 @@ NodeProperties GetPropertiesForArray(const Model& model, } node_properties.label += "]"; - int buffer_size = RequiredBufferSizeForShape(array.shape()); - node_properties.log2_buffer_size = - std::log2(static_cast(buffer_size)); + int buffer_size = 0; + if (IsValid(array.shape())) { + buffer_size = RequiredBufferSizeForShape(array.shape()); + node_properties.log2_buffer_size = + std::log2(static_cast(buffer_size)); + } if (array.buffer) { const auto& array = model.GetArray(array_name); @@ -199,8 +204,6 @@ NodeProperties GetPropertiesForArray(const Model& model, AppendF(&node_properties.label, "}"); } } - } else { - node_properties.log2_buffer_size = 0.0f; } if (array.minmax) { @@ -224,7 +227,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties GetPropertiesForOperator(const Operator& op) { NodeProperties node_properties; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { node_properties.label = static_cast(op).tensorflow_op; } else { @@ -287,47 +290,6 @@ NodeProperties GetPropertiesForOperator(const Operator& op) { return node_properties; } -std::vector OperatorsToDump(const Model& model) { - const auto& dump_options = *GraphVizDumpOptions::singleton(); - bool first_specified = !dump_options.graphviz_first_array.empty(); - bool last_specified = !dump_options.graphviz_last_array.empty(); - CHECK_EQ(first_specified, last_specified); - std::vector ops_to_dump; - if (last_specified) { - // Return only the part of the graph between graphviz_first_array - // and graphviz_last_array. - CHECK(model.HasArray(dump_options.graphviz_first_array)); - CHECK(model.HasArray(dump_options.graphviz_last_array)); - std::unordered_set arrays_already_produced; - std::vector arrays_to_produce; - arrays_to_produce.push_back(dump_options.graphviz_last_array); - while (!arrays_to_produce.empty()) { - const string array = arrays_to_produce.back(); - arrays_to_produce.pop_back(); - CHECK(!arrays_already_produced.count(array)); - arrays_already_produced.insert(array); - const Operator* op = GetOpWithOutput(model, array); - if (!op) { - continue; - } - ops_to_dump.push_back(op); - for (const string& input : op->inputs) { - if (arrays_already_produced.count(input) || - input == dump_options.graphviz_first_array) { - continue; - } - arrays_to_produce.push_back(input); - } - } - } else { - // Return the whole graph. - for (const auto& op : model.operators) { - ops_to_dump.push_back(op.get()); - } - } - return ops_to_dump; -} - } // namespace void DumpGraphviz(const Model& model, string* output_file_contents) { @@ -348,30 +310,30 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { constexpr char kRNNBackEdgeFormat[] = "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; - std::vector ops_to_dump = OperatorsToDump(model); - std::set already_added_arrays; - for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) { - const Operator& op = *ops_to_dump[op_index]; + for (const auto& array_kv : model.GetArrayMap()) { + // Add node for array. + const string& array_name = array_kv.first; + const auto& array_properties = GetPropertiesForArray(model, array_name); + AppendF(output_file_contents, kNodeFormat, array_name, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } + for (int op_index = 0; op_index < model.operators.size(); op_index++) { + const Operator& op = *model.operators[op_index]; // Add node for operator. auto op_properties = GetPropertiesForOperator(op); string operator_id = StringF("op%05d", op_index); AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, "box", op_properties.color.FillColorString().c_str(), op_properties.color.TextColorString().c_str()); - // Add nodes and edges for all inputs of the operator. + // Add edges for all inputs of the operator. for (const auto& input : op.inputs) { if (!model.HasArray(input)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } auto array_properties = GetPropertiesForArray(model, input); - if (!already_added_arrays.count(input)) { - AppendF(output_file_contents, kNodeFormat, input, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); - } - // Draw lines that transport more data thicker (Otherwise, where would the // data fit? right?). float line_width = @@ -387,22 +349,14 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width, weight); - already_added_arrays.insert(input); } - // Add nodes and edges for all outputs of the operator. + // Add edges for all outputs of the operator. for (const auto& output : op.outputs) { if (!model.HasArray(output)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } auto array_properties = GetPropertiesForArray(model, output); - if (!already_added_arrays.count(output)) { - AppendF(output_file_contents, kNodeFormat, output, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); - } - // See comments above regarding weight and line_width calculations. float line_width = std::max(0.5f, array_properties.log2_buffer_size / 3.0f); @@ -412,7 +366,6 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } AppendF(output_file_contents, kEdgeFormat, operator_id, output, line_width, weight); - already_added_arrays.insert(output); } } diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f5157149afca17383a8625c489f15a23ce6dd224..6be6b25f9318deb08bd427d5e3166909fae8f3ea 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -145,7 +145,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -162,7 +162,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -178,7 +178,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -199,7 +199,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -222,7 +222,7 @@ void ConvertIntTensorConst(const Model& model, const string& name, } CHECK(model.HasArray(name)); const auto& array = model.GetArray(name); - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -245,7 +245,7 @@ void CreateIntTensorConst(const string& name, const std::vector& data, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -268,7 +268,7 @@ void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -286,7 +286,7 @@ void CreateDummyConcatDimTensorConst(const string& name, int dim, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -301,7 +301,7 @@ void CreateReshapeShapeTensorConst(const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -341,7 +341,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, conv_output += "/conv"; } - auto* conv2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); conv2d_op->set_op("Conv2D"); conv2d_op->set_name(conv_output); *conv2d_op->add_input() = src_op.inputs[0]; @@ -377,7 +377,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, (*conv2d_op->mutable_attr())["padding"].set_s(padding); if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(conv_output); @@ -409,7 +409,7 @@ void ConvertDepthwiseConvOperator(const Model& model, conv_output += "/conv"; } - auto* dc2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node(); dc2d_op->set_op("DepthwiseConv2dNative"); dc2d_op->set_name(conv_output); *dc2d_op->add_input() = src_op.inputs[0]; @@ -457,7 +457,7 @@ void ConvertDepthwiseConvOperator(const Model& model, (*dc2d_op->mutable_attr())["padding"].set_s(padding); if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(conv_output); @@ -482,7 +482,7 @@ void ConvertDepthwiseConvOperator(const Model& model, void ConvertTransposeConvOperator(const Model& model, const TransposeConvOperator& src_op, GraphDef* tensorflow_graph) { - auto* conv2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); conv2d_op->set_op("Conv2DBackpropInput"); conv2d_op->set_name(src_op.outputs[0]); *conv2d_op->add_input() = src_op.inputs[0]; @@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model, const auto& weights_array = model.GetArray(weights_array_name); CHECK(weights_array.buffer->type == ArrayDataType::kFloat); ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, - AxesOrder::kHWIO, tensorflow_graph); + AxesOrder::kHWOI, tensorflow_graph); auto& strides = (*conv2d_op->mutable_attr())["strides"]; strides.mutable_list()->add_i(1); strides.mutable_list()->add_i(src_op.stride_height); @@ -514,7 +514,7 @@ void ConvertTransposeConvOperator(const Model& model, void ConvertDepthToSpaceOperator(const Model& model, const DepthToSpaceOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("DepthToSpace"); op->set_name(src_op.outputs[0]); *op->add_input() = src_op.inputs[0]; @@ -525,7 +525,7 @@ void ConvertDepthToSpaceOperator(const Model& model, void ConvertSpaceToDepthOperator(const Model& model, const SpaceToDepthOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("SpaceToDepth"); op->set_name(src_op.outputs[0]); *op->add_input() = src_op.inputs[0]; @@ -546,7 +546,7 @@ void ConvertFullyConnectedOperator(const Model& model, CHECK_EQ(fc_weights_shape.dimensions_count(), 2); CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, tensorflow_graph); - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); reshape_op->add_input(src_op.inputs[0]); @@ -568,7 +568,7 @@ void ConvertFullyConnectedOperator(const Model& model, const string transpose_perm = AvailableArrayName(model, transpose_output + "/perm"); CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph); - auto transpose_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(transpose_output); *transpose_op->add_input() = src_op.inputs[1]; @@ -577,7 +577,7 @@ void ConvertFullyConnectedOperator(const Model& model, GetTensorFlowDataType(model, src_op.inputs[1])); (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32); - auto* matmul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); matmul_op->set_name(matmul_output); *matmul_op->add_input() = reshape_output; @@ -590,7 +590,7 @@ void ConvertFullyConnectedOperator(const Model& model, // Add the bias, if it exists. if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(matmul_output); @@ -615,7 +615,7 @@ void ConvertFullyConnectedOperator(const Model& model, void ConvertAddOperator(const Model& model, const AddOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("Add"); add_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -626,7 +626,7 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op, void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("AddN"); add_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -638,7 +638,7 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, void ConvertMulOperator(const Model& model, const MulOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("Mul"); add_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -649,7 +649,7 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op, void ConvertReluOperator(const ReluOperator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -662,7 +662,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, const string min_bounds = src_op.outputs[0] + "/min_bounds"; const string max_output = src_op.outputs[0] + "/max_output"; - auto* max_bounds_const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node(); max_bounds_const_op->set_op("Const"); max_bounds_const_op->set_name(max_bounds); (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -671,7 +671,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, max_bounds_const_op_tensor->set_dtype(DT_FLOAT); max_bounds_const_op_tensor->add_float_val(-1.0f); - auto* min_bounds_const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node(); min_bounds_const_op->set_op("Const"); min_bounds_const_op->set_name(min_bounds); (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -680,14 +680,14 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, min_bounds_const_op_tensor->set_dtype(DT_FLOAT); min_bounds_const_op_tensor->add_float_val(1.0f); - auto* max_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); max_op->set_op("Maximum"); max_op->set_name(max_output); *max_op->add_input() = src_op.inputs[0]; *max_op->add_input() = max_bounds; (*max_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* min_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); min_op->set_op("Minimum"); min_op->set_name(src_op.outputs[0]); *min_op->add_input() = max_output; @@ -697,7 +697,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, void ConvertRelu6Operator(const Relu6Operator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu6"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -705,7 +705,7 @@ void ConvertRelu6Operator(const Relu6Operator& src_op, } void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("Log"); op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -715,7 +715,7 @@ void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { void ConvertLogisticOperator(const LogisticOperator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Sigmoid"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -724,7 +724,7 @@ void ConvertLogisticOperator(const LogisticOperator& src_op, void ConvertTanhOperator(const TanhOperator& src_op, GraphDef* tensorflow_graph) { - auto* tanh_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node(); tanh_op->set_op("Tanh"); tanh_op->set_name(src_op.outputs[0]); *tanh_op->add_input() = src_op.inputs[0]; @@ -735,8 +735,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -745,7 +744,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, const string softmax_size = src_op.outputs[0] + "/softmax_insert_size"; softmax_input = reshape_output; - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); *reshape_op->add_input() = src_op.inputs[0]; @@ -762,7 +761,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); } - auto* softmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node(); softmax_op->set_op("Softmax"); softmax_op->set_name(src_op.outputs[0]); *softmax_op->add_input() = softmax_input; @@ -776,8 +775,7 @@ void ConvertLogSoftmaxOperator(const Model& model, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -787,7 +785,7 @@ void ConvertLogSoftmaxOperator(const Model& model, const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size"; softmax_input = reshape_output; - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); *reshape_op->add_input() = src_op.inputs[0]; @@ -804,7 +802,7 @@ void ConvertLogSoftmaxOperator(const Model& model, CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); } - auto* log_softmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node(); log_softmax_op->set_op("LogSoftmax"); log_softmax_op->set_name(src_op.outputs[0]); *log_softmax_op->add_input() = softmax_input; @@ -819,7 +817,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, const string rsqrt_output = src_op.outputs[0] + "/rsqrt"; const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled"; - auto* sum_reduction_indices_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node(); sum_reduction_indices_op->set_op("Const"); sum_reduction_indices_op->set_name(sum_reduction_indices); (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -833,26 +831,26 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, sum_reduction_indices_tensor->add_int_val(0); sum_reduction_indices_tensor->add_int_val(1); - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(square_output); *square_op->add_input() = src_op.inputs[0]; (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* sum_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sum_op = tensorflow_graph->add_node(); sum_op->set_op("Sum"); sum_op->set_name(sum_output); *sum_op->add_input() = square_output; *sum_op->add_input() = sum_reduction_indices; (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* rsqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); rsqrt_op->set_op("Rsqrt"); rsqrt_op->set_name(rsqrt_output); *rsqrt_op->add_input() = sum_output; (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* mul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); mul_op->set_op("Mul"); mul_op->set_name(src_op.outputs[0]); *mul_op->add_input() = src_op.inputs[0]; @@ -863,7 +861,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, void ConvertLocalResponseNormalizationOperator( const LocalResponseNormalizationOperator& src_op, GraphDef* tensorflow_graph) { - auto* lrn_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node(); lrn_op->set_op("LRN"); lrn_op->set_name(src_op.outputs[0]); *lrn_op->add_input() = src_op.inputs[0]; @@ -875,7 +873,7 @@ void ConvertLocalResponseNormalizationOperator( void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, GraphDef* tensorflow_graph) { - auto* fakequant_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node(); fakequant_op->set_op("FakeQuantWithMinMaxArgs"); fakequant_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -890,7 +888,7 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, GraphDef* tensorflow_graph) { - auto* maxpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node(); maxpool_op->set_op("MaxPool"); maxpool_op->set_name(src_op.outputs[0]); *maxpool_op->add_input() = src_op.inputs[0]; @@ -918,7 +916,7 @@ void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, GraphDef* tensorflow_graph) { - auto* avgpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); avgpool_op->set_op("AvgPool"); avgpool_op->set_name(src_op.outputs[0]); *avgpool_op->add_input() = src_op.inputs[0]; @@ -947,7 +945,7 @@ void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, void ConvertConcatenationOperator(const Model& model, const ConcatenationOperator& src_op, GraphDef* tensorflow_graph) { - auto* dc_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* dc_op = tensorflow_graph->add_node(); dc_op->set_op("ConcatV2"); dc_op->set_name(src_op.outputs[0]); const string dummy_axis = src_op.outputs[0] + "/axis"; @@ -965,7 +963,7 @@ void ConvertConcatenationOperator(const Model& model, void ConvertTensorFlowReshapeOperator(const Model& model, const TensorFlowReshapeOperator& src_op, GraphDef* tensorflow_graph) { - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -987,7 +985,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, const string square_output = src_op.outputs[0] + "/square"; const string avgpool_output = src_op.outputs[0] + "/avgpool"; - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(square_output); *square_op->add_input() = src_op.inputs[0]; @@ -1002,7 +1000,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } - auto* avgpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); avgpool_op->set_op("AvgPool"); avgpool_op->set_name(avgpool_output); *avgpool_op->add_input() = square_output; @@ -1020,7 +1018,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, ksize.mutable_list()->add_i(src_op.kwidth); ksize.mutable_list()->add_i(1); - auto* sqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); sqrt_op->set_op("Sqrt"); sqrt_op->set_name(src_op.outputs[0]); *sqrt_op->add_input() = avgpool_output; @@ -1029,7 +1027,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, GraphDef* tensorflow_graph) { - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1039,7 +1037,7 @@ void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, GraphDef* tensorflow_graph) { - auto* sqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); sqrt_op->set_op("Sqrt"); sqrt_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1047,10 +1045,23 @@ void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertRsqrtOperator(const Model& model, + const TensorFlowRsqrtOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); + rsqrt_op->set_op("Rsqrt"); + rsqrt_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *rsqrt_op->add_input() = src_op.inputs[0]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*rsqrt_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertSplitOperator(const Model& model, const TensorFlowSplitOperator& src_op, GraphDef* tensorflow_graph) { - auto* split_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); split_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -1071,7 +1082,7 @@ void ConvertSplitOperator(const Model& model, void ConvertCastOperator(const Model& model, const CastOperator& src_op, GraphDef* tensorflow_graph) { - auto* cast_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* cast_op = tensorflow_graph->add_node(); cast_op->set_op("Cast"); cast_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1085,7 +1096,7 @@ void ConvertCastOperator(const Model& model, const CastOperator& src_op, void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, GraphDef* tensorflow_graph) { - auto* floor_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* floor_op = tensorflow_graph->add_node(); floor_op->set_op("Floor"); floor_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1095,7 +1106,7 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { - auto* gather_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); gather_op->set_op("Gather"); gather_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1103,13 +1114,14 @@ void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, *gather_op->add_input() = src_op.inputs[1]; (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*gather_op->mutable_attr())["Tparams"].set_type(params_type); } void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GraphDef* tensorflow_graph) { - auto* argmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node(); argmax_op->set_op("ArgMax"); argmax_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1126,7 +1138,7 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, void ConvertTransposeOperator(const Model& model, const TransposeOperator& src_op, GraphDef* tensorflow_graph) { - auto* transpose_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1141,7 +1153,7 @@ void ConvertTransposeOperator(const Model& model, void ConvertTensorFlowShapeOperator(const Model& model, const TensorFlowShapeOperator& src_op, GraphDef* tensorflow_graph) { - auto* shape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* shape_op = tensorflow_graph->add_node(); shape_op->set_op("Shape"); shape_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1154,7 +1166,7 @@ void ConvertTensorFlowShapeOperator(const Model& model, void ConvertRankOperator(const Model& model, const RankOperator& src_op, GraphDef* tensorflow_graph) { - auto* rank_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* rank_op = tensorflow_graph->add_node(); rank_op->set_op("Rank"); rank_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1165,7 +1177,7 @@ void ConvertRankOperator(const Model& model, const RankOperator& src_op, void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, GraphDef* tensorflow_graph) { - auto* range_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* range_op = tensorflow_graph->add_node(); range_op->set_op("Range"); range_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); @@ -1178,7 +1190,7 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, void ConvertStackOperator(const Model& model, const StackOperator& src_op, GraphDef* tensorflow_graph) { - auto* stack_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* stack_op = tensorflow_graph->add_node(); stack_op->set_op("Stack"); stack_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -1191,7 +1203,7 @@ void ConvertStackOperator(const Model& model, const StackOperator& src_op, void ConvertFillOperator(const Model& model, const FillOperator& src_op, GraphDef* tensorflow_graph) { - auto* fill_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* fill_op = tensorflow_graph->add_node(); fill_op->set_op("Fill"); fill_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1205,7 +1217,7 @@ void ConvertFillOperator(const Model& model, const FillOperator& src_op, void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, GraphDef* tensorflow_graph) { - auto* floor_div_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node(); floor_div_op->set_op("FloorDiv"); floor_div_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1218,7 +1230,7 @@ void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, void ConvertExpandDimsOperator(const Model& model, const ExpandDimsOperator& src_op, GraphDef* tensorflow_graph) { - auto* expand_dims_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node(); expand_dims_op->set_op("ExpandDims"); expand_dims_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1233,7 +1245,7 @@ void ConvertExpandDimsOperator(const Model& model, void ConvertResizeBilinearOperator(const Model& model, const ResizeBilinearOperator& src_op, GraphDef* tensorflow_graph) { - auto* resize_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* resize_op = tensorflow_graph->add_node(); resize_op->set_op("ResizeBilinear"); resize_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1283,7 +1295,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // works the same since the tensor has the same underlying data layout. const string axis_output = concat_output + "/axis"; CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph); - auto* concat_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* concat_op = tensorflow_graph->add_node(); concat_op->set_op("ConcatV2"); concat_op->set_name(concat_output); *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; @@ -1311,7 +1323,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Fully connected matrix multiply const string matmul_output = base + "MatMul"; - auto* matmul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); matmul_op->set_name(matmul_output); *matmul_op->add_input() = concat_output; @@ -1340,7 +1352,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Add biases string biasadd_output = base + "BiasAdd"; - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(biasadd_output); biasadd_op->add_input(matmul_output); @@ -1353,7 +1365,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // The dimension is the same as the concatenation dimension CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph); string split_output = base + "split"; - auto* split_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); split_op->set_name(split_output); *split_op->add_input() = split_dim_output; @@ -1363,21 +1375,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Activation functions and memory computations const string tanh_0_output = base + "Tanh"; - auto* tanh_0_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node(); tanh_0_op->set_op("Tanh"); tanh_0_op->set_name(tanh_0_output); *tanh_0_op->add_input() = split_output + ":1"; (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_1_output = base + "Sigmoid_1"; - auto* logistic_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node(); logistic_1_op->set_op("Sigmoid"); logistic_1_op->set_name(sigmoid_1_output); *logistic_1_op->add_input() = split_output; (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_1_output = base + "mul_1"; - auto* mul_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node(); mul_1_op->set_op("Mul"); mul_1_op->set_name(mul_1_output); *mul_1_op->add_input() = sigmoid_1_output; @@ -1385,21 +1397,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_0_output = base + "Sigmoid"; - auto* logistic_2_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node(); logistic_2_op->set_op("Sigmoid"); logistic_2_op->set_name(sigmoid_0_output); *logistic_2_op->add_input() = split_output + ":2"; (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_2_output = base + "Sigmoid_2"; - auto* logistic_3_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node(); logistic_3_op->set_op("Sigmoid"); logistic_3_op->set_name(sigmoid_2_output); *logistic_3_op->add_input() = split_output + ":3"; (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_0_output = base + "mul"; - auto* mul_0_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node(); mul_0_op->set_op("Mul"); mul_0_op->set_name(mul_0_output); *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT]; @@ -1407,7 +1419,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT); const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT]; - auto* add_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node(); add_1_op->set_op("Add"); add_1_op->set_name(add_1_output); *add_1_op->add_input() = mul_0_output; @@ -1415,14 +1427,14 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string tanh_1_output = base + "Tanh_1"; - auto* tanh_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node(); tanh_1_op->set_op("Tanh"); tanh_1_op->set_name(tanh_1_output); *tanh_1_op->add_input() = add_1_output; (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]; - auto* mul_2_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node(); mul_2_op->set_op("Mul"); mul_2_op->set_name(mul_2_output); *mul_2_op->add_input() = tanh_1_output; @@ -1433,14 +1445,15 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, void ConvertSpaceToBatchNDOperator(const Model& model, const SpaceToBatchNDOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("SpaceToBatchND"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); @@ -1449,14 +1462,15 @@ void ConvertSpaceToBatchNDOperator(const Model& model, void ConvertBatchToSpaceNDOperator(const Model& model, const BatchToSpaceNDOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("BatchToSpaceND"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32); @@ -1464,18 +1478,19 @@ void ConvertBatchToSpaceNDOperator(const Model& model, void ConvertPadOperator(const Model& model, const PadOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Pad"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1494,7 +1509,7 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op, void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("PadV2"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1502,11 +1517,12 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1525,7 +1541,7 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, void CreateSliceInput(const string& input_name, const std::vector& values, GraphDef* tensorflow_graph) { - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(input_name); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1542,7 +1558,7 @@ void CreateSliceInput(const string& input_name, const std::vector& values, void ConvertStridedSliceOperator(const Model& model, const StridedSliceOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("StridedSlice"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 4); @@ -1551,7 +1567,8 @@ void ConvertStridedSliceOperator(const Model& model, *new_op->add_input() = src_op.inputs[2]; *new_op->add_input() = src_op.inputs[3]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Index"].set_type(DT_INT32); @@ -1569,7 +1586,7 @@ void ConvertStridedSliceOperator(const Model& model, void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Slice"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); @@ -1577,7 +1594,8 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Index"].set_type(DT_INT32); @@ -1588,14 +1606,15 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Mean"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); if (src_op.keep_dims) { @@ -1603,7 +1622,7 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, } // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1619,13 +1638,14 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Squeeze"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); *new_op->add_input() = src_op.inputs[0]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); if (!src_op.squeeze_dims.empty()) { @@ -1638,58 +1658,79 @@ void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, void ConvertSubOperator(const Model& model, const SubOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Sub"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMinimumOperator(const Model& model, const TensorFlowMinimumOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Minimum"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMaximumOperator(const Model& model, const TensorFlowMaximumOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Maximum"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Select"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; *sub_op->add_input() = src_op.inputs[2]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertTileOperator(const Model& model, + const TensorFlowTileOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* tile_op = tensorflow_graph->add_node(); + tile_op->set_op("Tile"); + tile_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *tile_op->add_input() = src_op.inputs[0]; + *tile_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*tile_op->mutable_attr())["T"].set_type(data_type); + const tensorflow::DataType multiples_data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type); +} + void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { - auto* topk_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); topk_op->set_op("TOPKV2"); topk_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1702,12 +1743,13 @@ void ConvertRandomUniformOperator(const Model& model, const RandomUniformOperator& src_op, GraphDef* tensorflow_graph) { CHECK(tensorflow_graph != nullptr); - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("RandomUniform"); CHECK_EQ(src_op.inputs.size(), 1); new_op->set_name(src_op.outputs[0]); *new_op->add_input() = src_op.inputs[0]; - const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType shape_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(shape_type); (*new_op->mutable_attr())["dtype"].set_type( GetTensorFlowDataType(src_op.dtype)); @@ -1718,16 +1760,52 @@ void ConvertRandomUniformOperator(const Model& model, void ConvertComparisonOperator(const Model& model, const Operator& src_op, const char* op_name, GraphDef* tensorflow_graph) { - auto* comparison_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node(); comparison_op->set_op(op_name); comparison_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *comparison_op->add_input() = src_op.inputs[0]; *comparison_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*comparison_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSparseToDenseOperator(const Model& model, + const SparseToDenseOperator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node(); + sparse_to_dense_op->set_op(op_name); + sparse_to_dense_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (int i = 0; i < 4; ++i) { + *sparse_to_dense_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[3]); + (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); + const tensorflow::DataType index_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( + src_op.validate_indices); +} + +void ConvertPowOperator(const Model& model, const PowOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pow_op = tensorflow_graph->add_node(); + pow_op->set_op(op_name); + pow_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *pow_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*pow_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1808,20 +1886,24 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertConcatenationOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowReshape) { + } else if (src_op.type == OperatorType::kReshape) { ConvertTensorFlowReshapeOperator( model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kL2Pool) { ConvertL2PoolOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSquare) { + } else if (src_op.type == OperatorType::kSquare) { ConvertSquareOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSqrt) { + } else if (src_op.type == OperatorType::kSqrt) { ConvertSqrtOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSplit) { + } else if (src_op.type == OperatorType::kRsqrt) { + ConvertRsqrtOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSplit) { ConvertSplitOperator(model, static_cast(src_op), tensorflow_graph); @@ -1865,11 +1947,11 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kSub) { ConvertSubOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMinimum) { + } else if (src_op.type == OperatorType::kMinimum) { ConvertTensorFlowMinimumOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMaximum) { + } else if (src_op.type == OperatorType::kMaximum) { ConvertTensorFlowMaximumOperator( model, static_cast(src_op), tensorflow_graph); @@ -1888,7 +1970,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kTranspose) { ConvertTransposeOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowShape) { + } else if (src_op.type == OperatorType::kShape) { ConvertTensorFlowShapeOperator( model, static_cast(src_op), tensorflow_graph); @@ -1919,17 +2001,28 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreater) { + } else if (src_op.type == OperatorType::kEqual) { + ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); + } else if (src_op.type == OperatorType::kNotEqual) { + ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + } else if (src_op.type == OperatorType::kGreaterEqual) { ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLess) { + } else if (src_op.type == OperatorType::kLess) { ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + } else if (src_op.type == OperatorType::kLessEqual) { ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kSelect) { ConvertSelectOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTile) { + ConvertTileOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kPow) { + ConvertPowOperator(model, static_cast(src_op), "Pow", + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -1937,7 +2030,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, void AddPlaceholder(const string& name, ArrayDataType type, GraphDef* tensorflow_graph) { - auto* placeholder = tensorflow_graph->add_node(); + tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); placeholder->set_op("Placeholder"); switch (type) { case ArrayDataType::kBool: @@ -1966,7 +2059,7 @@ void AddPlaceholder(const string& name, ArrayDataType type, void AddPlaceholderForRNNState(const Model& model, const string& name, int size, GraphDef* tensorflow_graph) { - auto* placeholder = tensorflow_graph->add_node(); + tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); placeholder->set_op("Placeholder"); placeholder->set_name(name); (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 7680cdd344814bf6cbc7bbe11c915f220642d55d..18b7848db86e553ec645fa87298420012b5f753f 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -9,59 +9,56 @@ complemented by the following documents: Table of contents: -* [Convert a TensorFlow SavedModel to TensorFlow Lite](#savedmodel) -* [Convert a TensorFlow GraphDef to TensorFlow Lite for float - inference](#graphdef-float) +* [Command-line tools](#tools) + * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) +* [Basic examples](#basic) + * [Convert a TensorFlow GraphDef](#graphdef) + * [Convert a TensorFlow SavedModel](#savedmodel) + * [Convert a tf.keras model](#keras) * [Quantization](#quantization) - * [Convert a TensorFlow GraphDef to TensorFlow Lite for quantized - inference](#graphdef-quant) + * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) * [Use "dummy-quantization" to try out quantized inference on a float graph](#dummy-quant) * [Specifying input and output arrays](#specifying-input-and-output-arrays) - * [Multiple output arrays](#multiple-output-arrays) * [Multiple input arrays](#multiple-input-arrays) + * [Multiple output arrays](#multiple-output-arrays) * [Specifying subgraphs](#specifying-subgraphs) -* [Other conversions supported by TOCO](#other-conversions) - * [Optimize a TensorFlow GraphDef](#optimize-graphdef) - * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef - format](#to-graphdef) -* [Logging](#logging) - * [Standard logging](#standard-logging) - * [Verbose logging](#verbose-logging) - * [Graph "video" logging](#graph-video-logging) * [Graph visualizations](#graph-visualizations) * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot) * [Using --dump_graphviz](#using-dump-graphviz) + * [Graph "video" logging](#graph-video-logging) * [Legend for the graph visualizations](#graphviz-legend) -## Convert a TensorFlow SavedModel to TensorFlow Lite +## Command-line tools -The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite -FlatBuffer to perform floating-point inference. +There are two approaches to running TOCO via command line. -``` -bazel run --config=opt \ - third_party/tensorflow/contrib/lite/toco:toco -- \ - --savedmodel_directory=/tmp/saved_model \ - --output_file=/tmp/foo.tflite -``` +* `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool + `tflite_convert` will be installed as part of the Python package. All of the + examples below use `tflite_convert` for simplicity. + * Example: `tflite --output_file=...` +* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow + repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository) + and use `bazel`. This is the recommended approach for converting models that + utilize new features that were not supported by TOCO in TensorFlow 1.9. + * Example: `bazel run + //tensorflow/contrib/lite/python:tflite_convert -- + --output_file=...` -[SavedModel](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators) -has fewer required flags than frozen graphs (described [below](#graphdef-float)) -due to access to additional data contained within the SavedModel. The values for -`--input_arrays` and `--output_arrays` are an aggregated, alphabetized list of -the inputs and outputs in the -[SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within the -[MetaGraphDef](https://www.tensorflow.org/programmers_guide/saved_model#apis_to_build_and_load_a_savedmodel) -specified by `--savedmodel_tagset`. The value for `input_shapes` is -automatically determined from the MetaGraphDef whenever possible. The default -value for `--inference_type` for SavedModels is `FLOAT`. +### Converting models prior to TensorFlow 1.9. -There is currently no support for MetaGraphDefs without a SignatureDef or for -MetaGraphDefs that use the [`assets/` -directory](https://www.tensorflow.org/programmers_guide/saved_model#structure_of_a_savedmodel_directory). +The recommended approach for using TOCO prior to TensorFlow 1.9 is the [Python +API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the +`toco` command line tool was available in TensorFlow 1.7. Enter `toco --help` in +Terminal for additional details on the command-line flags available. There were +no command line tools in TensorFlow 1.8. + +## Basic examples -## Convert a TensorFlow GraphDef to TensorFlow Lite for float inference +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. + +### Convert a TensorFlow GraphDef The follow example converts a basic TensorFlow GraphDef (frozen by [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) @@ -71,19 +68,54 @@ graphs contain the variables stored in Checkpoint files as Const ops. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.tflite \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 +``` + +The value for `input_shapes` is automatically determined whenever possible. + +### Convert a TensorFlow SavedModel + +The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite +FlatBuffer to perform floating-point inference. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --saved_model_dir=/tmp/saved_model +``` + +[SavedModel](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators) +has fewer required flags than frozen graphs due to access to additional data +contained within the SavedModel. The values for `--input_arrays` and +`--output_arrays` are an aggregated, alphabetized list of the inputs and outputs +in the [SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within +the +[MetaGraphDef](https://www.tensorflow.org/guide/saved_model#apis_to_build_and_load_a_savedmodel) +specified by `--saved_model_tag_set`. As with the GraphDef, the value for +`input_shapes` is automatically determined whenever possible. + +There is currently no support for MetaGraphDefs without a SignatureDef or for +MetaGraphDefs that use the [`assets/` +directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory). + +### Convert a tf.Keras model + +The following example converts a `tf.keras` model into a TensorFlow Lite +Flatbuffer. The `tf.keras` file must contain both the model and the weights. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --keras_model_file=/tmp/keras_model.h5 ``` ## Quantization -### Convert a TensorFlow GraphDef to TensorFlow Lite for quantized inference +### Convert a TensorFlow GraphDef for quantized inference TOCO is compatible with fixed point quantization models described [here](https://www.tensorflow.org/performance/quantization). These are float @@ -97,18 +129,14 @@ The following command generates a quantized TensorFlow Lite FlatBuffer from a "quantized" TensorFlow GraphDef. ``` -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/some_quantized_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ + --graph_def_file=/tmp/some_quantized_graph.pb \ --inference_type=QUANTIZED_UINT8 \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --mean_value=128 \ - --std_value=127 + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ + --mean_values=128 \ + --std_dev_values=127 ``` ### Use \"dummy-quantization\" to try out quantized inference on a float graph @@ -126,45 +154,20 @@ a reasonable guess is that most activation ranges should be contained in [0, 6]. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.cc \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --inference_type=QUANTIZED_UINT8 \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ --default_ranges_min=0 \ --default_ranges_max=6 \ - --mean_value=127.5 \ - --std_value=127.5 + --mean_values=128 \ + --std_dev_values=127 ``` ## Specifying input and output arrays -### Multiple output arrays - -The flag `output_arrays` takes in a comma-separated list of output arrays as -seen in the example below. This is useful for models or subgraphs with multiple -outputs. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,224,224,3 \ - --input_array=input \ - --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu -``` - ### Multiple input arrays The flag `input_arrays` takes in a comma-separated list of input arrays as seen @@ -174,21 +177,33 @@ inputs. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ - --output_array=InceptionV1/Logits/Predictions/Reshape_1 + --output_arrays=InceptionV1/Logits/Predictions/Reshape_1 ``` Note that `input_shapes` is provided as a colon-separated list. Each input shape corresponds to the input array at the same position in the respective list. +### Multiple output arrays + +The flag `output_arrays` takes in a comma-separated list of output arrays as +seen in the example below. This is useful for models or subgraphs with multiple +outputs. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ + | tar xzv -C /tmp +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ + --output_file=/tmp/foo.tflite \ + --input_arrays=input \ + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu +``` + ### Specifying subgraphs Any array in the input file can be specified as an input or output array in @@ -203,158 +218,57 @@ GraphDef. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ --output_file=/tmp/foo.pb \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TENSORFLOW_GRAPHDEF \ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ - --output_array=InceptionV1/InceptionV1/Mixed_3b/concat_v2 + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/concat_v2 ``` -Note that the final representation of an on-device inference workload (say, in -TensorFlow Lite FlatBuffers format) tends to have coarser granularity than the -very fine granularity of the TensorFlow GraphDef representation. For example, -while a fully-connected layer is typically represented as at least four separate -ops in TensorFlow GraphDef (Reshape, MatMul, BiasAdd, Relu...), it is typically -represented as a single "fused" op (FullyConnected) in the converter's optimized -representation and in the final on-device representation (e.g. in TensorFlow -Lite FlatBuffer format). As the level of granularity gets coarser, some +Note that the final representation in TensorFlow Lite FlatBuffers tends to have +coarser granularity than the very fine granularity of the TensorFlow GraphDef +representation. For example, while a fully-connected layer is typically +represented as at least four separate ops in TensorFlow GraphDef (Reshape, +MatMul, BiasAdd, Relu...), it is typically represented as a single "fused" op +(FullyConnected) in the converter's optimized representation and in the final +on-device representation. As the level of granularity gets coarser, some intermediate arrays (say, the array between the MatMul and the BiasAdd in the -TensorFlow GraphDef) are dropped. When specifying intermediate arrays as -`--input_arrays` / `--output_arrays`, it is desirable (and often required) to -specify arrays that are meant to survive in the final form of the graph, after -fusing. These are typically the outputs of activation functions (since -everything in each layer until the activation function tends to get fused). - -## Other conversions supported by TOCO +TensorFlow GraphDef) are dropped. -The converter accepts both TENSORFLOW_GRAPHDEF and TFLITE file formats as both -`--input_format` and `--output_format`. This means that conversion to and from -any supported format is possible. - -### Optimize a TensorFlow GraphDef - -Same-format "conversions" can be used to optimize and simplify a graph or be -used to [get a subgraph](#specifying-subgraphs) of a graph. The flag -`--inference_type` is not required because TensorFlow graphs, including those -containing the -[`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization) -ops are always float graphs. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.pb \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TENSORFLOW_GRAPHDEF \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 -``` - -### Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format - -The converter supports file format conversions from TensorFlow Lite, back into -TensorFlow GraphDef format. - -``` -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/foo.tflite \ - --output_file=/tmp/foo.pb \ - --input_format=TFLITE \ - --output_format=TENSORFLOW_GRAPHDEF \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 -``` +When specifying intermediate arrays as `--input_arrays` and `--output_arrays`, +it is desirable (and often required) to specify arrays that are meant to survive +in the final form of the graph, after fusing. These are typically the outputs of +activation functions (since everything in each layer until the activation +function tends to get fused). ## Logging -### Standard logging - -The converter generates some informative log messages during processing. The -easiest way to view them is to add `--logtostderr` to command lines as seen in -the following example. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --logtostderr -``` - -After some initialization messages, we get the following informative messages: - -``` -I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized) -I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes. -I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops). -``` - -### Verbose logging - -For debugging purposes, the converter supports two levels of verbose logging, -which can be set by passing a `--v=` flag: - -* For `--v=1`, the converter generates text dumps of the graph at various - points during processing as well as log messages about every graph - transformation that took place. -* For `--v=2`, the converter additionally generates log messages about graph - transformations that were considered but not performed. - -### Graph "video" logging - -When `--dump_graphviz=` is used (see the section on [graph -visualizations](#graph-visualizations)), one may additionally pass -`--dump_graphviz_video`, which causes a graph visualization to be dumped after -each individual graph transformation. This results in thousands of files. -Typically, one would then bisect into these files to understand when a given -change was introduced in the graph. ## Graph visualizations TOCO can export a graph to the GraphViz Dot format for easy visualization via -either the `--output_format` flag or the `--dump_graphviz` flag. The subsections -below outline the use cases for each. +either the `--output_format` flag or the `--dump_graphviz_dir` flag. The +subsections below outline the use cases for each. ### Using `--output_format=GRAPHVIZ_DOT` The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into `--output_format`. This results in a plausible visualization of the graph. This -reduces the requirements that normally exist during conversion between other -input and output formats. For example, this may be useful if conversion from -TENSORFLOW_GRAPHDEF to TFLITE is failing. +reduces the requirements that exist during conversion between other input and +output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to +TFLITE is failing. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --output_file=/tmp/foo.dot \ - --input_format=TENSORFLOW_GRAPHDEF \ --output_format=GRAPHVIZ_DOT \ --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 ``` The resulting `.dot` file can be rendered into a PDF as follows: @@ -375,49 +289,35 @@ Example PDF files are viewable online in the next section. ### Using `--dump_graphviz` -The second way to get a graphviz rendering is to pass the `--dump_graphviz=` +The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir` flag, specifying a destination directory to dump GraphViz rendering to. Unlike -the previous approach, this one allows you to keep your real command-line (with -your real `--output_format` and other flags) unchanged, just appending a -`--dump_graphviz=` flag to it. This provides a visualization of the actual graph -during a specific conversion process. +the previous approach, this one retains the original output format. This +provides a visualization of the actual graph resulting from a specific +conversion process. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --dump_graphviz=/tmp + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ + --dump_graphviz_dir=/tmp ``` -This generates a few files in the destination directory, here `/tmp`. The two -most important files are: - -``` -/tmp/toco_AT_IMPORT.dot -/tmp/toco_AFTER_TRANSFORMATIONS.dot -``` - -`toco_AT_IMPORT.dot` represents the graph as it was imported from -`--input_file`, before any transformation was applied to it (besides some -transformations that are applied immediately while importing). This tends to be -a complex visualization with limited information, but is useful especially in -situations where a conversion command fails (this file is generated even if the -conversion subsequently fails). +This generates a few files in the destination directory. The two most important +files are `toco_AT_IMPORT.dot` and `/tmp/toco_AFTER_TRANSFORMATIONS.dot`. +`toco_AT_IMPORT.dot` represents the original graph containing only the +transformations done at import time. This tends to be a complex visualization +with limited information about each node. It is useful in situations where a +conversion command fails. `toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations -were applied to it, just before it was exported to the `--output_file`. -Typically, this is a much smaller graph with more information about each node. +were applied to it, just before it is exported. Typically, this is a much +smaller graph with more information about each node. -Again, these can be rendered to PDFs: +As before, these can be rendered to PDFs: ``` dot -Tpdf -O /tmp/toco_*.dot @@ -428,6 +328,14 @@ Sample output files can be seen here: * [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf) * [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf). +### Graph "video" logging + +When `--dump_graphviz_dir` is used, one may additionally pass +`--dump_graphviz_video`. This causes a graph visualization to be dumped after +each individual graph transformation, resulting in thousands of files. +Typically, one would then bisect into these files to understand when a given +change was introduced in the graph. + ### Legend for the graph visualizations * Operators are red square boxes with the following hues of red: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 9e99287f828c22aa81eb216c087f3261e378fc14..decc8a45a40ffba2a27320ce8391b1916391d744 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -1,7 +1,8 @@ # TensorFlow Lite Optimizing Converter command-line glossary -This page is complete reference of command-line flags. It is complemented by the -following other documents: +This page is complete reference of command-line flags used by TOCO's command +line starting from TensorFlow 1.9 up until the most recent build of TensorFlow. +It is complemented by the following other documents: * [README](../README.md) * [Command-line examples](cmdline_examples.md) @@ -16,116 +17,81 @@ Table of contents: ## High-level flags -The following high level flags specify the location of the input and output +The following high level flags specify the details of the input and output files. The flag `--output_file` is always required. Additionally, either -`--input_file` or `--savedmodel_directory` is required. - -* `--savedmodel_directory`. Type: string. Specifies the full path to the - directory containing the SavedModel. -* `--savedmodel_tagset`. Type: string. Default: +`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required. + +* `--output_file`. Type: string. Specifies the full path of the output file. +* `--graph_def_file`. Type: string. Specifies the full path of the input + GraphDef file frozen using + [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). +* `--saved_model_dir`. Type: string. Specifies the full path to the directory + containing the SavedModel. +* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file + containing the tf.keras model. +* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of + the output file. Allowed values: + * `TFLITE`: TensorFlow Lite FlatBuffer format. + * `GRAPHVIZ_DOT`: GraphViz `.dot` format containg a visualization of the + graph after graph transformations. + * Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss + of TFLite specific transformations. Therefore, the resulting + visualization may not reflect the final set of graph + transformations. To get a final visualization with all graph + transformations use `--dump_graphviz` instead. + +The following flags specify optional parameters when using SavedModels. + +* `--saved_model_tag_set`. Type: string. Default: [kSavedModelTagServe](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h). Specifies a comma-separated set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be specified. -* `--input_file`. Type: string. Specifies the path of the input file. This may - be either an absolute or a relative path. -* `--output_file`. Type: string. Specifies the path of the output file. - -The following high level flags specify the types of the input and output files: - -* `--input_format`. Type: string. Default: `TENSORFLOW_GRAPHDEF`. Specifies - the format of the input file. Allowed values: - * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Both - binary and text proto formats are allowed. - * `TFLITE` — The TensorFlow Lite FlatBuffers format. -* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of - the output file. Allowed values: - * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Always - produces a file in binary (not text) proto format. - * `TFLITE` — The TensorFlow Lite FlatBuffers format. - * Whether a float or quantized TensorFlow Lite file will be produced - depends on the `--inference_type` flag. - * `GRAPHVIZ_DOT` — The GraphViz `.dot` format. This asks the - converter to generate a reasonable graphical representation of the graph - after simplification by a generic set of transformation. - * A typical `dot` command line to view the resulting graph might look - like: `dot -Tpdf -O file.dot`. - * Note that since passing this `--output_format` means losing the - information of which output format you actually care about, and - since the converter's transformations depend on the specific output - format, the resulting visualization may not fully reflect what you - would get on the actual output format that you are using. To avoid - that concern, and generally to get a visualization of exactly what - you get in your actual output format as opposed to just a merely - plausible visualization of a model, consider using `--dump_graphviz` - instead and keeping your true `--output_format`. +* `--saved_model_signature_key`. Type: string. Default: + [DEFAULT_SERVING_SIGNATURE_DEF_KEY](https://www.tensorflow.org/api_docs/python/tf/saved_model/signature_constants). + Specifies the key identifying the SignatureDef containing inputs and + outputs. ## Model flags *Model flags* provide additional information about the model stored in the input file. -* `--output_array`. Type: string. Specifies a single array as the output - activations. Incompatible with `--output_arrays`. -* `--output_arrays`. Type: comma-separated list of strings. Specifies a list - of arrays as the output activations, for models with multiple outputs. - Incompatible with `--output_array`. -* `--input_array`. Type: string. Specifies a single array as the input - activations. Incompatible with `--input_arrays`. -* `--input_arrays`. Type: comma-separated list of strings. Specifies a list of - arrays as the input activations, for models with multiple inputs. - Incompatible with `--input_array`. -* `--batch_size`. Type: integer. Default: 1. Specifies the batch size for the - model. Replaces the first dimension of an input size array if undefined. Use - only with SavedModels when neither `--input_shape` nor `input_shapes` flags - are specified. Incompatible with GraphDefs. - -When `--input_array` is used, the following flags are available to provide -additional information about the single input array: - -* `--input_shape`. Type: comma-separated list of integers. Specifies the shape - of the input array, in TensorFlow convention: starting with the outer-most - dimension (the dimension corresponding to the largest offset stride in the - array layout), ending with the inner-most dimension (the dimension along - which array entries are typically laid out contiguously in memory). - * For example, a typical vision model might pass - `--input_shape=1,60,80,3`, meaning a batch size of 1 (no batching), an - input image height of 60, an input image width of 80, and an input image - depth of 3, for the typical case where the input image is a RGB bitmap - (3 channels, depth=3) stored by horizontal scanlines (so 'width' is the - next innermost dimension after 'depth'). -* `--mean_value` and `--std_value`. Type: floating-point. The decimal point - character is always the dot (`.`) regardless of the locale. These specify - the (de-)quantization parameters of the input array, when it is quantized. - * The meaning of mean_value and std_value is as follows: each quantized - value in the quantized input array will be interpreted as a mathematical - real number (i.e. as an input activation value) according to the - following formula: +* `--input_arrays`. Type: comma-separated list of strings. Specifies the list + of names of input activation tensors. +* `--output_arrays`. Type: comma-separated list of strings. Specifies the list + of names of output activation tensors. + +The following flags define properties of the input tensors. Each item in the +`--input_arrays` flag should correspond to each item in the following flags +based on index. + +* `--input_shapes`. Type: colon-separated list of comma-separated lists of + integers. Each comma-separated list of integers gives the shape of one of + the input arrays specified in [TensorFlow + convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape). + * Example: `--input_shapes=1,60,80,3` for a typical vision model means a + batch size of 1, an input image height of 60, an input image width of + 80, and an input image depth of 3 (representing RGB channels). + * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo" + has a shape of [2, 3] and "bar" has a shape of [4, 5, 6]. +* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers. + These specify the (de-)quantization parameters of the input array, when it + is quantized. + * The meaning of `mean_values` and `std_dev_values` is as follows: each + quantized value in the quantized input array will be interpreted as a + mathematical real number (i.e. as an input activation value) according + to the following formula: * `real_value = (quantized_input_value - mean_value) / std_value`. * When performing float inference (`--inference_type=FLOAT`) on a quantized input, the quantized input would be immediately dequantized by the inference code according to the above formula, before proceeding with float inference. * When performing quantized inference - (`--inference_type=QUANTIZED_UINT8`), no dequantization is ever to be - performed by the inference code; however, the quantization parameters of - all arrays, including those of the input arrays as specified by - mean_value and std_value, all participate in the determination of the - fixed-point multipliers used in the quantized inference code. - -When `--input_arrays` is used, the following flags are available to provide -additional information about the multiple input arrays: - -* `--input_shapes`. Type: colon-separated list of comma-separated lists of - integers. Each comma-separated list of integer gives the shape of one of the - input arrays specified in `--input_arrays`, in the same order. See - `--input_shape` for details. - * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means that - there are two input arrays. The first one, "foo", has shape [2,3]. The - second one, "bar", has shape [4,5,6]. -* `--mean_values`, `--std_values`. Type: comma-separated lists of - floating-point numbers. Each number gives the corresponding value for one of - the input arrays specified in `--input_arrays`, in the same order. See - `--mean_value`, `--std_value` for details. + (`--inference_type=QUANTIZED_UINT8`), no dequantization is performed by + the inference code. However, the quantization parameters of all arrays, + including those of the input arrays as specified by `mean_value` and + `std_dev_value`, determine the fixed-point multipliers used in the + quantized inference code. ## Transformation flags @@ -133,21 +99,13 @@ additional information about the multiple input arrays: the graph, i.e. they specify requested properties that the output file should have. -* `--inference_type`. Type: string. Sets the type of real-number arrays in the - output file, that is, controls the representation (quantization) of real - numbers in the output file, except for input arrays, which are controlled by - `--inference_input_type`. +* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all + real-number arrays in the output file except for input arrays (defined by + `--inference_input_type`). Must be `{FLOAT, QUANTIZED_UINT8}`. - This flag only impacts real-number arrays. By "real-number" we mean float - arrays, and quantized arrays. This excludes plain integer arrays, strings - arrays, and every other data type. - - For real-number arrays, the impact of this flag is to allow the output file - to choose a different real-numbers representation (quantization) from what - the input file used. For any other types of arrays, changing the data type - would not make sense. - - Specifically: + This flag only impacts real-number arrays including float and quantized + arrays. This excludes all other data types including plain integer arrays + and string arrays. Specifically: * If `FLOAT`, then real-numbers arrays will be of type float in the output file. If they were quantized in the input file, then they get @@ -155,72 +113,54 @@ have. * If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as uint8 in the output file. If they were float in the input file, then they get quantized. - * If not set, then all real-numbers arrays retain the same type in the - output file as they have in the input file. - -* `--inference_input_type`. Type: string. Similar to inference_type, but - allows to control specifically the quantization of input arrays, separately - from other arrays. - - If not set, then the value of `--inference_type` is implicitly used, i.e. by - default input arrays are quantized like other arrays. - - Like `--inference_type`, this only affects real-number arrays. By - "real-number" we mean float arrays, and quantized arrays. This excludes - plain integer arrays, strings arrays, and every other data type. - - The typical use for this flag is for vision models taking a bitmap as input, - typically with uint8 channels, yet still requiring floating-point inference. - For such image models, the uint8 input is quantized, i.e. the uint8 values - are interpreted as real numbers, and the quantization parameters used for - such input arrays are their `mean_value`, `std_value` parameters. - -* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. The - decimal point character is always the dot (`.`) regardless of the locale. - These flags enable what is called "dummy quantization". If defined, their - effect is to define fallback (min, max) range values for all arrays that do - not have a properly specified (min, max) range in the input file, thus - allowing to proceed with quantization of non-quantized or - incorrectly-quantized input files. This enables easy performance prototyping - ("how fast would my model run if I quantized it?") but should never be used - in production as the resulting quantized arithmetic is inaccurate. - -* `--drop_fake_quant`. Type: boolean. Default: false. Causes fake-quantization - nodes to be dropped from the graph. This may be used to recover a plain - float graph from a fake-quantized graph. - -* `--reorder_across_fake_quant`. Type: boolean. Default: false. Normally, - fake-quantization nodes must be strict boundaries for graph transformations, - in order to ensure that quantized inference has the exact same arithmetic - behavior as quantized training --- which is the whole point of quantized - training and of FakeQuant nodes in the first place. However, that entails - subtle requirements on where exactly FakeQuant nodes must be placed in the - graph. Some quantized graphs have FakeQuant nodes at unexpected locations, - that prevent graph transformations that are necessary in order to generate a - well-formed quantized representation of these graphs. Such graphs should be - fixed, but as a temporary work-around, setting this - reorder_across_fake_quant flag allows the converter to perform necessary - graph transformations on them, at the cost of no longer faithfully matching - inference and training arithmetic. + +* `--inference_input_type`. Type: string. Data type of a real-number input + array in the output file. By default the `--inference_type` is used as type + of all of the input arrays. Flag is primarily intended for generating a + float-point graph with a quantized input array. A Dequantized operator is + added immediately after the input array. Must be `{FLOAT, QUANTIZED_UINT8}`. + + The flag is typically used for vision models taking a bitmap as input but + requiring floating-point inference. For such image models, the uint8 input + is quantized and the quantization parameters used for such input arrays are + their `mean_value` and `std_dev_value` parameters. + +* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. + Default value for the (min, max) range values used for all arrays without a + specified range. Allows user to proceed with quantization of non-quantized + or incorrectly-quantized input files. These flags produce models with low + accuracy. They are intended for easy experimentation with quantization via + "dummy quantization". + +* `--drop_control_dependency`. Type: boolean. Default: True. Indicates whether + to drop control dependencies silently. This is due to TensorFlow Lite not + supporting control dependencies. + +* `--reorder_across_fake_quant`. Type: boolean. Default: False. Indicates + whether to reorder FakeQuant nodes in unexpected locations. Used when the + location of the FakeQuant nodes is preventing graph transformations + necessary to convert the graph. Results in a graph that differs from the + quantized training graph, potentially causing differing arithmetic behavior. + +* `--allow_custom_ops`. Type: string. Default: False. Indicates whether to + allow custom operations. When false, any unknown operation is an error. When + true, custom ops are created for any op that is unknown. The developer will + need to provide these to the TensorFlow Lite runtime with a custom resolver. + +* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to + store weights as quantized weights followed by dequantize operations. + Computation is still done in float, but reduces model size (at the cost of + accuracy and latency). ## Logging flags -The following are standard Google logging flags: - -* `--logtostderr` redirects Google logging to standard error, typically making - it visible in a terminal. -* `--v` sets verbose logging levels (for debugging purposes). Defined levels: - * `--v=1`: log all graph transformations that did make a change on the - graph. - * `--v=2`: log all graph transformations that did *not* make a change on - the graph. - -The following flags allow to generate graph visualizations of the actual graph -at various points during transformations: - -* `--dump_graphviz=/path` enables dumping of the graphs at various stages of - processing as GraphViz `.dot` files. Generally preferred over - `--output_format=GRAPHVIZ_DOT` as this allows you to keep your actually - relevant `--output_format`. -* `--dump_graphviz_video` enables dumping of the graph after every single - graph transformation (for debugging purposes). +The following flags generate graph visualizations of the graph as +[GraphViz](https://www.graphviz.org/) `.dot` files at various points during +graph transformations: + +* `--dump_graphviz_dir`. Type: string. Specifies the full path of the + directory to output GraphViz `.dot` files. Outputs the graph immediately + after reading in the graph and after all of the transformations have been + completed. +* `--dump_graphviz_video`. Type: boolean. Outputs GraphViz after every graph + transformation. Requires `--dump_graphviz_dir` to be specified. diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index f0fd638a618c75c75d336a746f9b1d8dccaea470..3799eac0a1181afe3b63d2f8651745c2ec61f5e0 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -1,69 +1,268 @@ -# TensorFlow Lite Optimizing Converter (TOCO) Python API reference +# TensorFlow Lite Optimizing Converter & Interpreter Python API reference -This page provides examples on how to use TOCO via the Python API. It is -complemented by the following documents: +This page provides examples on how to use TOCO and the TensorFlow Lite +interpreter via the Python API. It is complemented by the following documents: * [README](../README.md) * [Command-line examples](cmdline_examples.md) * [Command-line glossary](cmdline_reference.md) +Table of contents: + +* [High-level overview](#high-level-overview) +* [API](#api) +* [Basic examples](#basic) + * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) + * [Exporting a GraphDef from file](#basic-graphdef-file) + * [Exporting a SavedModel](#basic-savedmodel) + * [Exporting a tf.keras File](#basic-keras-file) +* [Complex examples](#complex) + * [Exporting a quantized GraphDef](#complex-quant) +* [TensorFlow Lite Python interpreter](#interpreter) + * [Using the interpreter from a model file](#interpreter-file) + * [Using the interpreter from model data](#interpreter-data) +* [Additional instructions](#additional-instructions) + * [Build from source code](#latest-package) + * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) + ## High-level overview While the TensorFlow Lite Optimizing Converter can be used from the command -line, it is often convenient to use it as part of Python model build and +line, it is often convenient to use it as part of a Python model build and training script. This is so that conversion can be part of your model development pipeline. This allows you to know early and often that you are designing a model that can be targeted to devices with mobile. ## API -In Python you can run `help(tf.contrib.lite)` to get documentation on functions. -In particular, `tf.contrib.lite.toco_convert` presents a simple API and -`tf.contrib.lite.toco_from_protos` allows more detailed control of TOCO using -the protobuf interface to TOCO. +The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9 +is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is +`tf.contrib.lite.Interpreter`. + +`TocoConverter` provides class methods based on the original format of the +model. `TocoConverter.from_session()` is available for GraphDefs. +`TocoConverter.from_saved_model()` is available for SavedModels. +`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. +Example usages for simple float-point models are shown in [Basic +Examples](#basic). Examples usages for more complex models is shown in [Complex +Examples](#complex). + +**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python +interpreter when the conversion fails. This will be remedied as soon as +possible. + +## Basic examples + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. -## Example +### Exporting a GraphDef from tf.Session -In particular, here we show creating a simple model and converting it to a -TensorFlow Lite Model. +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer from a `tf.Session` object. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + var out = tf.identity(val, name="out") + with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("test.tflite", "wb").write(tflite_model) + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) ``` -**NOTE** Currently, the TOCO command will cause a fatal error to the Python -interpreter when TOCO conversion fails. This will be remedied as soon as -possible. +### Exporting a GraphDef from file + +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and +`.pbtxt` files are accepted. + +The example uses +[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz). +The function only supports GraphDefs frozen via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). + +```python +import tensorflow as tf + +graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" +input_arrays = ["input"] +output_arrays = ["MobilenetV1/Predictions/Softmax"] + +converter = tf.contrib.lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` -## Example 2: Export with variables +### Exporting a SavedModel -If a model has variables, they need to be turned into constants. This process is -known as freezing, and it can actually be accomplished with +The following example shows how to convert a SavedModel into a TensorFlow Lite +FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +For more complex SavedModels, the optional parameters that can be passed into +`TocoConverter.from_saved_model()` are `input_arrays`, `input_shapes`, +`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are +available by running `help(tf.contrib.lite.TocoConverter)`. + +### Exporting a tf.keras File + +The following example shows how to convert a `tf.keras` model into a TensorFlow +Lite FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +The `tf.keras` file must contain both the model and the weights. A comprehensive +example including model construction can be seen below. + +```python +import numpy as np +import tensorflow as tf + +# Generate tf.keras model. +model = tf.keras.models.Sequential() +model.add(tf.keras.layers.Dense(2, input_shape=(3,))) +model.add(tf.keras.layers.RepeatVector(3)) +model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3))) +model.compile(loss=tf.keras.losses.MSE, + optimizer=tf.keras.optimizers.RMSprop(lr=0.0001), + metrics=[tf.keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + +x = np.random.random((1, 3)) +y = np.random.random((1, 3, 3)) +model.train_on_batch(x, y) +model.predict(x) + +# Save tf.keras model in HDF5 format. +keras_file = "keras_model.h5" +tf.keras.models.save_model(model, keras_file) + +# Convert to TensorFlow Lite model. +converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +## Complex examples + +For models where the default value of the attributes is not sufficient, the +attribute's values should be set before calling `convert()`. In order to call +any constants use `tf.contrib.lite.constants.` as seen below with +`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python +terminal for detailed documentation on the attributes. + +Although the examples are demonstrated on GraphDefs containing only constants. +The same logic can be applied irrespective of the input data format. + +### Exporting a quantized GraphDef + +The following example shows how to convert a quantized model into a TensorFlow +Lite FlatBuffer. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3)) -val = img + var +const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +val = img + const +out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") + +with tf.Session() as sess: + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 + input_arrays = converter.get_input_arrays() + converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) +``` + +## TensorFlow Lite Python interpreter + +### Using the interpreter from a model file + +The following example shows how to use the TensorFlow Lite Python interpreter +when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates +how to run inference on random input data. Run +`help(tf.contrib.lite.Interpreter)` in the Python terminal to get detailed +documentation on the interpreter. + +```python +import numpy as np +import tensorflow as tf -def canonical_name(x): - return x.name.split(":")[0] +# Load TFLite model and allocate tensors. +interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") +interpreter.allocate_tensors() +# Get input and output tensors. +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +# Test model on random input data. +input_shape = input_details[0]['shape'] +input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +interpreter.set_tensor(input_details[0]['index'], input_data) + +interpreter.invoke() +output_data = interpreter.get_tensor(output_details[0]['index']) +print(output_data) +``` + +### Using the interpreter from model data + +The following example shows how to use the TensorFlow Lite Python interpreter +when starting with the TensorFlow Lite Flatbuffer model previously loaded. This +example shows an end-to-end use case, starting from building the TensorFlow +model. + +```python +import numpy as np +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +val = img + const out = tf.identity(val, name="out") + with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - out_tensors = [out] - frozen_graphdef = tf.graph_util.convert_variables_to_constants( - sess, sess.graph_def, map(canonical_name, out_tensors)) - tflite_model = tf.contrib.lite.toco_convert( - frozen_graphdef, [img], out_tensors) - open("converted_model.tflite", "wb").write(tflite_model) + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + tflite_model = converter.convert() + +# Load TFLite model and allocate tensors. +interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model) +interpreter.allocate_tensors() ``` + +## Additional instructions + +### Build from source code + +In order to run the latest version of the TOCO Python API, clone the TensorFlow +repository, configure the installation, and build and install the pip package. +Detailed instructions are available +[here](https://www.tensorflow.org/install/install_sources). + +### Converting models prior to TensorFlow 1.9. + +To use TOCO in TensorFlow 1.7 and TensorFlow 1.8, use the `toco_convert` +function. Run `help(tf.contrib.lite.toco_convert)` to get details about accepted +parameters. diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg index a47c088991299159be39bc490149720dae43eb53..262e13a591b998c4f38f0a9f44a5b385f612df90 100644 --- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg +++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index 0fffab574ddd8ad75ec07ae4442f363a36ed289e..1ea83abf8eb1b49f649e81def29857094cd0c2d7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -38,6 +38,16 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { // Depthwise conv does not support dilation return false; } + auto& input_array = model->GetArray(conv_op->inputs[0]); + if (!input_array.has_shape()) { + // Shapes not propagated yet + return false; + } + if (input_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. @@ -46,11 +56,6 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { if (weights_array.data_type != ArrayDataType::kFloat) { return false; } - if (weights_array.shape().dims(3) != 1) { - // Not a pure convolution: Conv does accumulation across the depth - // dimension. - return false; - } // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. AddMessageF( "%s is purely convolutional (input/weights depth is 1), replacing it by " diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..b689be07926ecd9be4cc317735dc88eb90950e13 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { + auto tile_it = model->operators.begin() + op_index; + if (tile_it->get()->type != OperatorType::kTile) { + return false; + } + auto* tile_op = static_cast(tile_it->get()); + + const auto& input_array = model->GetArray(tile_op->inputs[0]); + const auto& multiples_array = model->GetArray(tile_op->inputs[1]); + const auto& output_array = model->GetArray(tile_op->outputs[0]); + if (!input_array.has_shape() || !multiples_array.has_shape() || + !output_array.has_shape()) { + // Yield until PropagateFixedSizes has been run on this op. + return false; + } + // Note: We can assume we have error checked inputs in PropagateFixedSizes. + + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return false; + } + std::vector const& multiples = + multiples_array.GetBuffer().data; + + // We can simplify the tile if only a single dimension is being multiplied. + // It then just becomes a concat along that dimension. + int non_one_dims = 0; + int concat_axis = 0; + for (int i = 0; i < multiples.size(); ++i) { + if (multiples[i] != 1) { + ++non_one_dims; + concat_axis = i; + } + } + if (non_one_dims != 1) { + // The tile is non-trivial. Good luck. + AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)", + LogName(*tile_op)); + return false; + } + + // The tile is like a concat. + AddMessageF("Simplifying %s to a Concat along a single axis %d", + LogName(*tile_op), concat_axis); + + auto* concat_op = new ConcatenationOperator; + + // Copy input and output. + // Note that we multiply out the input by the number of times requested. + for (int i = 0; i < multiples[concat_axis]; ++i) { + concat_op->inputs.push_back(tile_op->inputs[0]); + } + concat_op->axis = concat_axis; + concat_op->outputs = tile_op->outputs; + + // Delete multiples array if unused. + if (IsDiscardableArray(*model, tile_op->inputs[1]) && + CountOpsWithInput(*model, tile_op->inputs[1]) == 1) { + model->EraseArray(tile_op->inputs[1]); + } + + // Replace the operator in the graph. + const auto concat_it = model->operators.emplace(tile_it, concat_op); + tile_it = concat_it + 1; + CHECK_EQ(tile_it->get(), tile_op); + model->operators.erase(tile_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 076415ece8c1039caa32e947fe54ab3e101bec9e..1e68cd678bce6c27f1852a5ae0c13362d8938cdd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -25,17 +25,12 @@ limitations under the License. namespace toco { -bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { - auto conv_it = model->operators.begin() + op_index; - if (conv_it->get()->type != OperatorType::kConv) { - return false; - } - auto* conv_op = static_cast(conv_it->get()); - if (conv_op->outputs.size() == 2) { +bool ProcessConvOperator(Model* model, ConvOperator* op) { + if (op->outputs.size() == 2) { // We already have an im2col array return false; } - const auto& weights_array = model->GetArray(conv_op->inputs[1]); + const auto& weights_array = model->GetArray(op->inputs[1]); if (!weights_array.has_shape()) { // We need to yield until weights dims have been resolved, because // from the weights dims we determine whether an im2col array is @@ -45,25 +40,52 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const auto& weights_shape = weights_array.shape(); const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); - if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1) { - // 1x1 unstrided conv does not need an im2col array. + if (kwidth == 1 && kheight == 1 && op->stride_width == 1 && + op->stride_height == 1 && op->dilation_width_factor == 1 && + op->dilation_height_factor == 1) { + // 1x1 unstrided undilated conv does not need an im2col array. return false; } // Create the im2col array. - CHECK_EQ(conv_op->outputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); const string& im2col_array_name = - AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + AvailableArrayName(*model, op->inputs[0] + "_im2col"); model->GetOrCreateArray(im2col_array_name); - conv_op->outputs.push_back(im2col_array_name); - AddMessageF( - "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " - "stride_height=%d", - LogName(*conv_op), kwidth, kheight, conv_op->stride_width, - conv_op->stride_height); + op->outputs.push_back(im2col_array_name); return true; } +bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { + if (op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + + // Always create an im2col array for transpose_conv. + CHECK_EQ(op->outputs.size(), 1); + const string& im2col_array_name = AvailableArrayName( + *model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + op->outputs.push_back(im2col_array_name); + + return true; +} + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + switch (op->type) { + case OperatorType::kConv: + return ProcessConvOperator(model, static_cast(op)); + case OperatorType::kTransposeConv: + return ProcessTransposeConvOperator( + model, static_cast(op)); + default: + return false; + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 498c864bde6d656c8318e981204cb42cb3a4d03f..2c7ffe488477ef1a544dfe6f36a6e0d1ac40aa96 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -111,7 +111,7 @@ bool DequantizeArray(const string& array_name, auto* op_outputting_array = GetOpWithOutput(*model, array_name); if (op_outputting_array) { - if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + if (op_outputting_array->type == OperatorType::kReshape) { return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc index 708ecf6e0a96811ab274fbb25f748f562cd3afad..e80ed036b311cfc586c40ece410ef6a6432a0cd9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -26,17 +26,38 @@ namespace toco { namespace { +int GetOutputDepthFromWeights(const Model& model, const Operator& op) { + const string& weights_name = op.inputs[1]; + const auto& weights_shape = model.GetArray(weights_name).shape(); + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected) { + return weights_shape.dims(0); + } + if (op.type == OperatorType::kDepthwiseConv) { + return weights_shape.dims(3); + } + LOG(FATAL) << "Unhandled operator type"; + return 0; +} + bool ProcessLinearOperator(Model* model, Operator* op) { if (op->inputs.size() >= 3) { return false; } const string& output_name = op->outputs[0]; + const string& weights_name = op->inputs[1]; + if (!model->GetArray(weights_name).has_shape()) { + return false; + } + const int depth = GetOutputDepthFromWeights(*model, *op); const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); op->inputs.push_back(bias_name); DCHECK_EQ(op->inputs.size(), 3); auto& bias_array = model->GetOrCreateArray(bias_name); bias_array.data_type = ArrayDataType::kFloat; - + bias_array.mutable_shape()->mutable_dims()->push_back(depth); + auto& bias_buffer = bias_array.GetMutableBuffer(); + bias_buffer.data.resize(depth, 0.f); return true; } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index 394fa349e2663e2806344f27a96a5132a2d4a810..75642bbc37be6b3140e5b79a463ca70b5786d772 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -122,7 +122,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, case OperatorType::kFullyConnected: { weights_index = 1; const auto& fc_op = static_cast(op); - CHECK(!fc_op.experimental_shuffled_weights) + CHECK(fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) << "This graph transformation expects to run before FC weights get " "shuffled."; break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc new file mode 100644 index 0000000000000000000000000000000000000000..874d8def571fbce4219de15285c8df6fd2487a9a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Returns true if the given op is strictly a broadcasting operation. +// This is commonly seen as a Concat of the same input multiple times, and is +// often generated from Tile ops that were converted via the +// convert_trivial_tile_to_concat transformation. +bool IsBroadcastingOp(const Model& model, Operator* op) { + // Concatenation of identical inputs is usually a broadcast. + if (op->type == OperatorType::kConcatenation) { + // Verify that all inputs are the same. + for (int i = 1; i < op->inputs.size(); ++i) { + if (op->inputs[i] != op->inputs[0]) { + return false; + } + } + return true; + } + + // There are other things we could look for (Stack/etc) when needed. + return false; +} + +} // namespace + +// Finds an operation that looks like a broadcast (concat of the same sources +// along the last dimension) and drops it by relying on the ability of certain +// binary ops to perform an implicit broadcast. +bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + // NOTE: either of these ops may be nullptr if the input array is constant. + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // Check whether either input is a broadcast-like concat. + bool is_op_0_broadcast = op[0] && IsBroadcastingOp(*model, op[0]); + bool is_op_1_broadcast = op[1] && IsBroadcastingOp(*model, op[1]); + if (!is_op_0_broadcast && !is_op_1_broadcast) { + // Neither input is a broadcast-looking thing. + AddMessageF("Neither input looks broadcasty"); + return false; + } else if (is_op_0_broadcast && is_op_1_broadcast) { + AddMessageF( + "Unable to fuse broadcast into %s as both inputs (%s, %s) are " + "broadcasts", + LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)", + op[1] ? LogName(*op[1]) : "(?)"); + return false; + } + int broadcast_index = is_op_0_broadcast ? 0 : 1; + + // Just pull out the input of the broadcast op and pass it directly to the + // binary op. + AddMessageF("Fusing broadcast op %s into the following binary %s", + LogName(*op[broadcast_index]), LogName(*binary_op)); + binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0]; + + // We leave the broadcast op in; it'll get cleaned up if it's not used later. + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 8da242aa9c2ca4917a681c95c3eded894664c046..8cd1298bcacd7b9c1379ccb4532885f686484278 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -117,12 +117,14 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) @@ -133,12 +135,14 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape) DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights) DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) @@ -164,7 +168,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) -DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) @@ -190,7 +193,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather) DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) -DECLARE_GRAPH_TRANSFORMATION(ExperimentalShuffleFCWeights) +DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights) class PropagateDefaultMinMax : public GraphTransformation { public: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d63ee7c9519d169a2f44ec1afe81125217db8976..2f1bb8f0ad6374243e5a094701eef54cd086151a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -133,24 +133,20 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { } bool HardcodeMinMaxForSplit(Model* model, Operator* op) { - for (const auto& output : op->outputs) { - if (model->GetArray(output).minmax) { - LOG(WARNING) << "Skipping min-max setting for " << LogName(*op) - << " because output " << output << " already has min-max."; - return false; - } - } // Data is in second input. auto& input_array = model->GetArray(op->inputs[1]); if (!input_array.minmax) { return false; - } else { - for (const auto& output : op->outputs) { - auto& array = model->GetArray(output); + } + bool changed = false; + for (const auto& output : op->outputs) { + auto& array = model->GetArray(output); + if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) { + changed = true; array.GetOrCreateMinMax() = *input_array.minmax; } - return true; } + return changed; } // The output of average or max pooling is within the same range as its input. @@ -232,6 +228,14 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, return true; } +bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) { + const double magnitude = + std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min); + const double tolerated = 1e-6 * magnitude; + return std::abs(minmax1.min - minmax2.min) < tolerated && + std::abs(minmax1.max - minmax2.max) < tolerated; +} + // Propagates MinMax from any of the listed arrays, to all others. // If multiple of these arrays have MinMax, then these are required // to agree with each other. @@ -254,7 +258,7 @@ bool PropagateMinMaxAmongArrays(Model* model, for (const string& array_name : array_names) { auto& array = model->GetArray(array_name); if (array.minmax) { - CHECK(*array.minmax == *reference_minmax) + CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax)) << "Both the following arrays have minmax, and they disagree: " << reference_array_name << " (" << reference_minmax->min << "," << reference_minmax->max << ") and " << array_name << " (" @@ -353,7 +357,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForConcatenation(model, op); break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: changed = HardcodeMinMaxForSplit(model, op); break; @@ -362,9 +366,11 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kResizeBilinear: + case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kPad: case OperatorType::kGather: case OperatorType::kTranspose: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index ae3301f467de5714230e731b4bab87ddc1637201..d49857cfc22ecaf5feb06b39a42187f8adb61d50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -90,12 +90,13 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { } // Conv Op - ConvOperator* conv_op = dynamic_cast( - has_expand_op ? GetOpWithInput(*model, post_stb_op->outputs[0]) - : GetOpWithInput(*model, stb_op->outputs[0])); - if (!conv_op || conv_op->type != OperatorType::kConv) { + const string& input_of_conv_op = + has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; + auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); + if (conv_base_op->type != OperatorType::kConv) { return false; } + auto* conv_op = static_cast(conv_base_op); if (conv_op->inputs.size() != 2) { // The conv op must only have weights, no bias. return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index 419a0776a6b987a18df059d3c1d4bf4370cd24d8..b78efd7fc3602dc2d6e03fd28d694c344b61c17c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { const auto* div_or_mul_op = div_it->get(); OperatorType expected_op_type_producing_div_or_mul_input; if (div_or_mul_op->type == OperatorType::kDiv) { - expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt; } else if (div_or_mul_op->type == OperatorType::kMul) { - expected_op_type_producing_div_or_mul_input = - OperatorType::kTensorFlowRsqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt; } else { return false; } @@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* add_op = nullptr; Operator* op_producing_add_input = nullptr; if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || - op_producing_sqrt_or_rsqrt_input->type == - OperatorType::kTensorFlowMaximum) { + op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) { add_op = op_producing_sqrt_or_rsqrt_input; bool add_can_be_removed = false; CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); @@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* sum_op = add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; - if (sum_op->type != OperatorType::kTensorFlowSum) { + if (sum_op->type != OperatorType::kSum) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Sum op, got %s", @@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { } Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index e4d52476c649de53b3ab663f53ce7a5538dbb5ab..705e73779b7f74698149d5e9e56f69a371326ceb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -41,7 +41,7 @@ std::vector>::iterator FindOperator( bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const auto sqrt_it = model->operators.begin() + op_index; const auto* sqrt_op = sqrt_it->get(); - if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + if (sqrt_op->type != OperatorType::kSqrt) { return false; } @@ -52,6 +52,13 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const Operator* square_op; Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op == nullptr) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, but Sqrt op has no preceding op"); + return false; + } + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " @@ -65,7 +72,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); CHECK_EQ(square_op->inputs.size(), 1); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index e9842524c829b839b97b3453a36c41efe186efbb..c0b014b45eb1df25173ce3ca3fa488b0655c3c76 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -35,19 +35,24 @@ std::vector>::iterator FindOperator( return it; } -bool GetStateArrayForBackEdge(const Model& model, - const string& back_edge_source_array, - string* state_array = nullptr) { - for (const auto& rnn_state : model.flags.rnn_states()) { - if (back_edge_source_array == rnn_state.back_edge_source_array()) { - // Found LSTM cell output - if (state_array) { - *state_array = rnn_state.state_array(); - } - return true; +bool ValidateSourceOp(const Model& model, const string& array_name, + OperatorType op_type, Operator** source_op) { + if (op_type == OperatorType::kNone) { + CHECK(!source_op); + } else { + CHECK(source_op); + *source_op = GetOpWithOutput(model, array_name); + if (*source_op == nullptr) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((*source_op)->type != op_type) { + return false; } } - return false; + + return true; } // Returns true if the given operator has exactly 1 input, and is connected to @@ -62,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((op_type != OperatorType::kNone) && (x == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) { return false; } - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (connected_op) { - *connected_op = x; - } - return true; } @@ -96,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } return true; } @@ -147,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { - return false; - } - - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } // Check if third input is disconnected/connected to an operator - Operator* z = GetOpWithOutput(model, op.inputs[2]); - if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { - return false; - } - if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { - return false; - } - - // Check that third operator, if connected, is of correct type - if ((z != nullptr) && (z->type != c_op_type)) { + if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) { return false; } - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } - if (c_op != nullptr) { - *c_op = z; - } return true; } @@ -231,11 +160,6 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { &state_combine_add)) { return false; } - string prev_state; - if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], - &prev_state)) { - return false; - } // State forget & remember addition Operator *state_forget_mul, *state_remember_mul; @@ -244,9 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { &state_remember_mul)) { return false; } - if (state_forget_mul->inputs[0] != prev_state) { - return false; - } + const string prev_state = state_forget_mul->inputs[0]; // State forget gate Operator* state_forget_sig; @@ -266,26 +188,26 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { // State remember "information" activation function Operator* fc_output_split; - if (!MatchOperatorInputs(*state_info_tanh, *model, - OperatorType::kTensorFlowSplit, &fc_output_split)) { + if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit, + &fc_output_split)) { return false; } // State remember gate activation function Operator* tmp; - if (!MatchOperatorInputs(*state_remember_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // State forget gate activation function - if (!MatchOperatorInputs(*state_forget_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // Fully connected output activation function - if (!MatchOperatorInputs(*fc_output_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } @@ -306,8 +228,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { return false; } - if (static_cast(fully_connected) - ->experimental_shuffled_weights) { + if (static_cast(fully_connected)->weights_format != + FullyConnectedWeightsFormat::kDefault) { // Not yet implemented: experimental shuffled weights in fused LSTM cell. return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 3f768bfee12ebe31ebeb72855eb67ec03d5bcf8c..5b6a984ee143a6007471b165510030cd3ad3f73c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs, - // do not need to merge cell inputs. - if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) { + // Already a compact LstmCell. Do not need to merge cell inputs. + const auto* src_lstm_op = static_cast(src_op); + if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL || + src_lstm_op->inputs.size() != kExtendedLstmInputCount) { return false; } @@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LSTM cell operator (use basic 5 inputs kernel). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC; // Compact LstmCell's 5 inputs. lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index 8e66323bd769ca166d6b521c5b7b2f1cb944b0a2..46d1fce50e5d6e2a74cf5461d731e46469dde5bf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already an extended LstmCell with kExtendedLstmInputCount of inputs, - // do not need to split cell inputs. - if (curr_op->inputs.size() == kExtendedLstmInputCount) { + const auto* curr_lstm_op = static_cast(curr_op); + // Already an extended LstmCell. Do not need to split cell inputs. + if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC || + curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) { return false; } @@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL; lstm_cell_op->inputs.resize(kExtendedLstmInputCount); int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT]) .shape() @@ -72,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { lstm_cell_op->inputs[kInputTensor] = curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + // Previous states. + lstm_cell_op->inputs[kInputActivationStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + lstm_cell_op->inputs[kInputCellStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT]; + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. Array& kernel = model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); @@ -158,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Erase curr lstm op being replaced. DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], - model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], - model); model->operators.erase(FindOp(*model, curr_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index bddb563206f763a756685d196836fa41825cf045..94820a016622a12654e91967737e05fc91ed404c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -60,24 +60,22 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; const auto* op_0 = op_it->get(); - if (op_0->type != OperatorType::kTensorFlowMinimum && - op_0->type != OperatorType::kTensorFlowMaximum) { + if (op_0->type != OperatorType::kMinimum && + op_0->type != OperatorType::kMaximum) { return false; } // Get the paired op and ensure it's the counter to the first. const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); if (!op_1 || - (op_1->type != OperatorType::kTensorFlowMinimum && - op_1->type != OperatorType::kTensorFlowMaximum) || + (op_1->type != OperatorType::kMinimum && + op_1->type != OperatorType::kMaximum) || op_0->type == op_1->type) { return false; } - const auto* min_op = - op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; - const auto* max_op = - op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; + const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1; + const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1; if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h index 1c32a781698ec78003ebbf9caff28557924323e5..6d8603a1133a7478647b8bcc49ea1eceba28df31 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs { kOutputGateBiasTensor = 15, kProjectionWeightsTensor = 16, // Optional kProjectionBiasTensor = 17, // Optional - kExtendedLstmInputCount = 18 + kInputActivationStateTensor = 18, + // The op can handle 18 inputs or 20 inputs. + kInputCellStateTensor = 19, + kExtendedLstmInputCount = 20, }; enum ExtendedLstmCellOutputs { + // TODO(ycling): Make the 2 output state tensors optional. kOutputStateTensor = 0, kCellStateTensor = 1, kOutputTensor = 2, diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 5065004093434475172a39efdcfd26c10c49148b..95bc7f7d4b8b517c1cc5a73b3e85bbd985ce460f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -106,7 +106,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator( - it->get(), OperatorType::kTensorFlowReshape); + it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f44c65285bdef6ba314b16122fdd550bfa47e6a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +bool IsTailOfShape(const Shape& tail, const Shape& shape) { + // Return true if 'tail' dimensions are the same as the ending dimensions of + // 'shape'. + + int shape_end = shape.dimensions_count() - 1; + int tail_end = tail.dimensions_count() - 1; + + if (tail_end > shape_end) { + // tail cannot be longer than shape. + return false; + } + + // Walk dimensions back to front and compare + for (int i = 0; i <= tail_end; i++) { + if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) { + return false; + } + } + return true; +} + +} // namespace + +// If a binary operator is doing a broadcast operation from a constant array, +// and the constant array shape is the tail of both the other input shape, and a +// subsequent reshape op's output shape, we can swap their order. Since we +// prefer to have reshape ops after mathematic ops, this can allow for the +// collapsing of some reshapes. The WaveNet model in particular benefits from +// this transformation. +// +// Note we are testing for one particular case of a broader set of possible +// binary-reshape op transformations. This transformation could be generalized. +bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + Operator* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kFloorDiv && + binary_op->type != OperatorType::kFloorMod && + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { + return false; + } + + // BINARY OP INPUT CHECKS + CHECK_EQ(binary_op->inputs.size(), 2); + const bool input_is_const[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!input_is_const[0] && !input_is_const[1]) { + // To limit our scope, we require one constant input. Though there's no + // reason this transformation wouldn't work with all variable inputs. + return false; + } + if (input_is_const[0] && input_is_const[1]) { + // Both inputs are constants. Leave this for constants propagation. + return false; + } + const int constant_input_idx = input_is_const[0] ? 0 : 1; + const int variable_input_idx = input_is_const[0] ? 1 : 0; + CHECK(input_is_const[constant_input_idx]); + CHECK(!input_is_const[variable_input_idx]); + + const auto& variable_input_array = + model->GetArray(binary_op->inputs[variable_input_idx]); + if (!variable_input_array.has_shape()) { + AddMessageF( + "Not moving %s because it's non-constant input shape is not resolved.", + LogName(*binary_op)); + return false; + } + if (!IsTailOfShape( + model->GetArray(binary_op->inputs[constant_input_idx]).shape(), + model->GetArray(binary_op->inputs[variable_input_idx]).shape())) { + // Constant array shape must be the latter part of the variable shape. + return false; + } + + // RESHAPE OP CHECKS + auto reshape_it = + FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]); + if (reshape_it == model->operators.end()) { + AddMessageF("Not moving %s because it's variable input is not connected.", + LogName(*binary_op)); + return false; + } + Operator* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kReshape) { + AddMessageF("Not moving %s because the preceding %s is not a reshape op", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]); + if (!reshape_input_array.has_shape()) { + AddMessageF( + "Not moving %s because it's non-constant input shape is not resolved " + "yet", + LogName(*binary_op)); + return false; + } + if (!IsTailOfShape( + model->GetArray(binary_op->inputs[constant_input_idx]).shape(), + model->GetArray(reshape_op->outputs[0]).shape())) { + // Constant array shape must be the latter part of the binary op output + // shape. + return false; + } + + // EXTRA CHECKS ON CONNECTING ARRAY + for (const string& output_array : model->flags.output_arrays()) { + if (binary_op->inputs[variable_input_idx] == output_array) { + AddMessageF( + "Not moving %s because the output of reshape op %s is an output op.", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + } + int count_ops_consuming_output = + CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not moving %s because the output of reshape op %s is consumed by " + "another op", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + + // SWAP ORDER OF BINARY AND RESHAPE OPS + AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op), + LogName(*reshape_op)); + + // Swap op input and outputs + std::iter_swap(reshape_op->inputs.begin(), + binary_op->inputs.begin() + variable_input_idx); + std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin()); + + // Swap operator ordering + std::iter_swap(binary_it, reshape_it); + + // Clear binary output shape so it will be re-propagated + model->GetArray(binary_op->outputs[0]).clear_shape(); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 6342cf3e8af4d85ad869a5d60a63d62ca2b00588..00ab7cbaa90b399ca08bdfba82991fbd5d2c9f7e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -56,20 +56,22 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // These operators unconditionally produce float outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); break; - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; case OperatorType::kRank: - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: // These operators only produce int32 outputs. SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); break; - case OperatorType::kTensorFlowSplit: - case OperatorType::kTensorFlowConcat: + case OperatorType::kSplit: + case OperatorType::kConcat: case OperatorType::kFill: { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); @@ -133,7 +135,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; break; } - case OperatorType::kTensorFlowUnsupported: { + case OperatorType::kUnsupported: { auto* unsupported_op = static_cast(op); // Some output tensors from the op could be eliminated by optimization. // This can make unsupported_op->output_data_types have more elements than @@ -163,6 +165,24 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type_x); break; } + case OperatorType::kSparseToDense: { + // Select produces outputs with the same type as their 3rd input + CHECK_EQ(op->inputs.size(), 4); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + const ArrayDataType data_type_default = + model->GetArray(op->inputs[3]).data_type; + CHECK(data_type == data_type_default); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kPow: { + CHECK_EQ(op->inputs.size(), 2); + CHECK(model->GetArray(op->inputs[0]).data_type == + model->GetArray(op->inputs[1]).data_type); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc index 50b90e7c2bfddb0382a4d44ad6c90fc7f7701273..cd078ef189e922682098a0ec8dc4743060181aac 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -25,6 +25,14 @@ limitations under the License. namespace toco { +namespace { + +bool SupportsMinMax(const Array& array) { + return array.data_type == ArrayDataType::kFloat; +} + +} // namespace + // Propagates default min/max values to any operator input/output array that // is missing them. // @@ -39,14 +47,16 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { auto& input_array = model->GetArray(input); - if (!input_array.minmax && !input_array.buffer) { + if (!input_array.minmax && !input_array.buffer && + SupportsMinMax(input_array)) { did_change |= SetArrayMinMax(input, &input_array); } } for (const auto& output : op->outputs) { auto& output_array = model->GetArray(output); - if (!output_array.minmax && !output_array.buffer) { + if (!output_array.minmax && !output_array.buffer && + SupportsMinMax(output_array)) { did_change |= SetArrayMinMax(output, &output_array); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 6d51fc8c31e6c86701c3dc1fd07a9a5479114738..0f2592d05f6e01599735c5138c53ba7779ce805d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -27,11 +27,21 @@ namespace toco { namespace { -void ChangeArrayDataType(GraphTransformation* transformation, Array* array, +bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, ArrayDataType new_data_type, const MinMax* new_minmax) { + // The code below assumes kInt16, see + // GetQuantizationParamsFromMinMax + if (new_data_type != ArrayDataType::kInt16) { + return false; + } + + bool changed = false; // Ensure the array ends up in the new type (if it hasn't yet been quantized). - array->final_data_type = new_data_type; + if ((array->final_data_type != new_data_type)) { + array->final_data_type = new_data_type; + changed = true; + } if (array->minmax && array->quantization_params) { // The array is already quantized and has min/max info. @@ -70,10 +80,10 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, // Directly change the type as the array was already quantized. array->data_type = new_data_type; - } else { + changed = true; + } else if (!array->quantization_params) { // Array has not yet been quantized so we can just set the final data type // and assign the new min/max value (if provided). - CHECK(!array->quantization_params); if (!array->minmax && new_minmax) { transformation->AddMessageF("Forcing new minmax to %g,%g (%s)", @@ -82,16 +92,18 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, auto& array_minmax = array->GetOrCreateMinMax(); array_minmax.min = new_minmax->min; array_minmax.max = new_minmax->max; + changed = true; } } + return changed; } // Returns true if the op blocks our backward recursive data type propagation. bool DoesOpBlockBackwardPropagation(const Operator& op) { switch (op.type) { case OperatorType::kConcatenation: - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Concat shouldn't block propagation, but we do expect that all inputs // have the same range. return false; @@ -100,9 +112,10 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { // FakeQuant so make sure we move across them. case OperatorType::kGather: // Gathers need their parameters changed to the appropriate data type. - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: case OperatorType::kSelect: + case OperatorType::kTile: // Reshapes and transposes don't change values. return false; default: @@ -120,10 +133,13 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { // Ignore gather indices. return input_index != 0; break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: // Ignore reshape/transpose shapes/dimensions. return input_index != 0; + case OperatorType::kTile: + // Ignore tile multiples. + return input_index != 0; default: return false; } @@ -155,9 +171,8 @@ bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation, "Adjusting input final data type of array %s from %s to %s", input, ArrayDataTypeName(input_array.final_data_type), ArrayDataTypeName(new_data_type)); - did_change = true; - ChangeArrayDataType(transformation, &input_array, new_data_type, - &new_minmax); + did_change |= ChangeArrayDataType(transformation, &input_array, + new_data_type, &new_minmax); // Walk up into all ops producing the inputs to this op. for (auto& producing_op : model->operators) { @@ -208,9 +223,8 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, "Adjusting output final data type of array %s from %s to %s", output, ArrayDataTypeName(output_array.final_data_type), ArrayDataTypeName(new_data_type)); - did_change = true; - ChangeArrayDataType(transformation, &output_array, new_data_type, - nullptr); + did_change |= ChangeArrayDataType(transformation, &output_array, + new_data_type, nullptr); // Walk down into all ops consuming the output of this op. for (auto& consuming_op : model->operators) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 9d1d27f3ef01a572c2ae232b1f172a8e05374381..8eb0423283a267652e3d51361b8a0440f46d0c8b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -120,49 +120,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x, CHECK(output_array->has_shape()); } -int GetOutputDepthFromWeights(const Model& model, const Operator& op) { - const string& weights_name = op.inputs[1]; - const auto& weights_shape = model.GetArray(weights_name).shape(); - if (op.type == OperatorType::kConv || - op.type == OperatorType::kFullyConnected) { - return weights_shape.dims(0); - } else if (op.type == OperatorType::kDepthwiseConv) { - return weights_shape.dims(3); - } else { - LOG(FATAL) << "Unhandled operator type"; - } -} - -bool EnsureBiasVectorShape(Model* model, Operator* op) { - const string& weights_name = op->inputs[1]; - const auto& weights_array = model->GetArray(weights_name); - // Yield until weights shape has been resolved. - if (!weights_array.has_shape()) { - return false; - } - - if (op->inputs.size() < 3) { - return false; - } - auto& bias_array = model->GetArray(op->inputs[2]); - if (bias_array.has_shape()) { - return true; - } - - const int output_depth = GetOutputDepthFromWeights(*model, *op); - bias_array.copy_shape(Shape({output_depth})); - - auto& float_buffer = bias_array.GetMutableBuffer(); - float_buffer.data.resize(output_depth, 0); - - return true; -} - void ProcessConvOperator(Model* model, ConvOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -211,12 +169,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // might as well calculate the output shape and ensure it matches the // specified one - // Check if we have already run. - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - return; - } - // SPECIFIED OUTPUT SHAPE // The below is the specified, or prescribed output shape, _given_ to the // operator as an input. @@ -278,20 +230,26 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << "TransposeConv input shape must have 4 dimensions. Input \"" << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; - CHECK_EQ(input_shape.dims(3), weights_shape.dims(0)) + CHECK_EQ(input_shape.dims(3), weights_shape.dims(3)) << "Input shape depth and weight depth do not agree"; // Set the output shape according to the specified output shape. std::vector const& specified_output_shape = specified_output_shape_array.GetBuffer().data; + auto& output_array = model->GetArray(op->outputs[0]); *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape; -} -void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const int input_depth = weights_shape.dims(3); + auto& im2col_array = model->GetArray(op->outputs[1]); + im2col_array.copy_shape( + Shape{specified_output_shape[0], specified_output_shape[1], + specified_output_shape[2], input_depth * kheight * kwidth}); } +} +void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -321,7 +279,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { if (!op->depth_multiplier) { op->depth_multiplier = output_depth / input_depth; } - QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + CHECK_EQ(output_depth, input_depth * op->depth_multiplier) << "input/output depths and depth_multiplier don't match"; const int kheight = weights_shape.dims(1); @@ -406,10 +364,6 @@ void ProcessOpWithShapeInput(Model* model, Operator* op) { } void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -568,11 +522,11 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kTensorFlowMin: + case OperatorType::kMin: // Reduction Min return static_cast(op).keep_dims; - case OperatorType::kTensorFlowMax: + case OperatorType::kMax: // Reduction Max return static_cast(op).keep_dims; - case OperatorType::kTensorFlowSum: + case OperatorType::kSum: return static_cast(op).keep_dims; case OperatorType::kMean: return static_cast(op).keep_dims; @@ -1085,9 +1039,6 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { QCHECK_GE(input_shape.dimensions_count(), 1); op->input_rank = input_shape.dimensions_count(); - // We only support 1-D indices. - QCHECK_EQ(indices_shape.dimensions_count(), 1); - // Copy the input dimensions to the output except for dimension 0, // where the dimension of indices_shape is used. // TODO(mgubin): if axis != 0 this is not true, change when it's supported. @@ -1337,8 +1288,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { op->begin_mask, op->start_indices, op->strides, input_array.shape().dims().data(), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis); + op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, + input_array.shape().dims().data(), axis, start_index); int dim_size = ceil(static_cast(stop_index - start_index) / op->strides[axis]); @@ -1477,6 +1428,76 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { *output_array.mutable_shape()->mutable_dims() = output_dims; } +void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + + const Array& output_shape_array = model->GetArray(op->inputs[1]); + if (!output_shape_array.has_shape()) return; + CHECK_EQ(output_shape_array.shape().dimensions_count(), 1); + + // Output should not go over four dimensions. + CHECK_LE(output_shape_array.shape().dims(0), 4); + + const string& output_name = op->outputs[0]; + Array& output_array = model->GetArray(output_name); + if (output_array.has_shape()) return; + + CHECK(output_shape_array.data_type == ArrayDataType::kInt32 || + output_shape_array.data_type == ArrayDataType::kInt64); + if (output_shape_array.data_type == ArrayDataType::kInt32) { + *output_array.mutable_shape()->mutable_dims() = + output_shape_array.GetBuffer().data; + } else { + const std::vector& output_shape_data = + output_shape_array.GetBuffer().data; + std::copy( + output_shape_data.begin(), output_shape_data.end(), + std::back_inserter(*output_array.mutable_shape()->mutable_dims())); + } +} + +void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& multiples_array = model->GetArray(op->inputs[1]); + if (!multiples_array.has_shape()) { + // Yield until multiples shape been resolved. + return; + } + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return; + } + CHECK(multiples_array.data_type == ArrayDataType::kInt32) + << "Tile multiples input must be int32"; + + std::vector const& multiples = + multiples_array.GetBuffer().data; + CHECK_EQ(multiples.size(), input_shape.dimensions_count()) + << "Tile multiples input " << op->inputs[1] + << " must be same length as input dimensions"; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(multiples.size()); + for (int i = 0; i < mutable_dims->size(); ++i) { + (*mutable_dims)[i] = input_shape.dims(i) * multiples[i]; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1503,14 +1524,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLogistic: case OperatorType::kTanh: case OperatorType::kLocalResponseNormalization: - case OperatorType::kTensorFlowIdentity: + case OperatorType::kIdentity: case OperatorType::kFakeQuant: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowAll: - case OperatorType::kTensorFlowAssert: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kAll: + case OperatorType::kAssert: case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: @@ -1529,12 +1550,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kDiv: case OperatorType::kFloorDiv: case OperatorType::kFloorMod: - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowMaximum: - case OperatorType::kTensorFlowMinimum: - case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kMaximum: // Element-wise Maximum + case OperatorType::kMinimum: // Element-wise Minimum + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: + case OperatorType::kPow: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: @@ -1567,7 +1591,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessFullyConnectedOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: ProcessTensorFlowReshapeOperator( model, static_cast(op)); break; @@ -1580,9 +1604,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowSum: + case OperatorType::kMin: // Reduction Min + case OperatorType::kMax: // Reduction Max + case OperatorType::kSum: case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; @@ -1593,34 +1617,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSliceOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowTile: - // We don't currently implement the propagation of fixed sizes through - // a TensorFlow Tile. - // - // Fortunately, we don't need to: so far, we have only dealt with Tile - // or Slice ops in subgraphs that are identified as L2Normalization. - // See IdentifyL2Normalization. - break; - case OperatorType::kTensorFlowSwitch: + case OperatorType::kSwitch: // We can't know the sizes of the outputs until we have resolved the // predicate, and once we have resolved the predicate, the whole // Switch node will get resolved away. // See ResolveTensorFlowSwitch. break; - case OperatorType::kTensorFlowMerge: + case OperatorType::kMerge: // No need to bother resolving TensorFlow Merge ops: other graph // transformations will remove them anyway. // See ResolveTensorFlowMerge. break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: ProcessTensorFlowSplitOperator(model, static_cast(op)); break; case OperatorType::kSqueeze: ProcessSqueezeOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Unimplemented, hopefully another graph transformation will // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat // will resolve this node to a DepthConcatenation, or else we have @@ -1636,7 +1652,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRank: ProcessRankOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: ProcessShapeOperator(model, static_cast(op)); break; case OperatorType::kStack: @@ -1657,7 +1673,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessLstmCellOperator(model, static_cast(op)); break; case OperatorType::kBatchMatMul: - case OperatorType::kTensorFlowMatMul: + case OperatorType::kMatMul: // MatMul operators are converted to FullyConnected, after which their // shapes are propagated. break; @@ -1682,7 +1698,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kArgMax: ProcessArgMaxOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowUnsupported: + case OperatorType::kUnsupported: break; case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast(op)); @@ -1700,6 +1716,13 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 1); ProcessOpWithShapeInput(model, op); break; + case OperatorType::kSparseToDense: + ProcessSparseToDenseOperator(model, + static_cast(op)); + break; + case OperatorType::kTile: + ProcessTileOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 142841fcc460e8a5e9e4f2333496f4ece2557275..58885b4950733bfc9d394127e597a08232cd5663 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -33,7 +33,7 @@ namespace { bool SupportsQuantization(const Operator& op) { auto type = op.type; - if (type == OperatorType::kTensorFlowUnsupported) { + if (type == OperatorType::kUnsupported) { auto* unsupported = static_cast(&op); return unsupported->quantized; } @@ -42,25 +42,25 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kConcatenation || type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || - type == OperatorType::kTensorFlowMinimum || - type == OperatorType::kTensorFlowMaximum || + type == OperatorType::kMinimum || type == OperatorType::kMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || - type == OperatorType::kLogSoftmax || - type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || + type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || + type == OperatorType::kResizeBilinear || + type == OperatorType::kSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || - type == OperatorType::kPadV2 || - type == OperatorType::kTensorFlowReshape || + type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell || type == OperatorType::kGather || type == OperatorType::kTranspose || type == OperatorType::kMean || - type == OperatorType::kTensorFlowGreater || - type == OperatorType::kTensorFlowGreaterEqual || - type == OperatorType::kTensorFlowLess || - type == OperatorType::kTensorFlowLessEqual || - type == OperatorType::kSelect; + type == OperatorType::kGreater || + type == OperatorType::kGreaterEqual || type == OperatorType::kLess || + type == OperatorType::kLessEqual || type == OperatorType::kSelect || + type == OperatorType::kArgMax || type == OperatorType::kRelu || + type == OperatorType::kRelu1 || type == OperatorType::kRelu6; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { @@ -326,14 +326,15 @@ bool ChooseQuantizationForOperatorOutput( output, OperatorTypeName(op.type)); return true; } - if ((op.type == OperatorType::kDepthToSpace) || - (op.type == OperatorType::kSpaceToDepth) || - (op.type == OperatorType::kTensorFlowReshape) || - (op.type == OperatorType::kTensorFlowSplit) || - (op.type == OperatorType::kConcatenation && - model->flags.change_concat_input_ranges())) { + if ((op.type == OperatorType::kConcatenation && + model->flags.change_concat_input_ranges()) || + op.type == OperatorType::kDepthToSpace || + op.type == OperatorType::kSpaceToDepth || + op.type == OperatorType::kReshape || op.type == OperatorType::kSplit || + op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 || + op.type == OperatorType::kRelu6) { int data_input_index = 0; - if (op.type == OperatorType::kTensorFlowSplit) { + if (op.type == OperatorType::kSplit) { data_input_index = 1; } // Copying and rearrangement ops should preserve the quantization parameters @@ -506,36 +507,47 @@ bool Quantize::Run(Model* model, std::size_t op_index) { // Check if the output of that Dequantize op was not used by any // other operator. We will then erase that Dequantize op. if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { - // If any of the model's output_arrays was pointing to the - // Dequantize op's output, let it point to the Dequantize op's - // input instead. - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { - // TODO(b/78013785): never rename output arrays. - if (IsInputArray(*model, dequantize_op->inputs[0])) { - // The op input is an input array and the output is an output - // array and we can't have an array be both. Insert a copy - // op to ensure the two arrays stay separate. - AddMessageF( - "Tried to rename output array %d while removing dequant " - "op %s but array is also an input; inserting copy %s " - "-> %s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantize_op->inputs[0]); - InsertCopyOperator(model, dequantize_op->inputs[0], - dequantize_op->outputs[0]); - } else { - // Op output is strictly used as an output array, so we can - // just rename the array and directly bypass the op. - AddMessageF( - "Renaming output array %d after removing dequant op %s: " - "%s -> %s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantize_op->inputs[0]); - model->flags.set_output_arrays(i, dequantize_op->inputs[0]); - model->EraseArray(dequantize_op->outputs[0]); + if (IsDiscardableArray(*model, dequantize_op->outputs[0])) { + // Usual case: we can just discard the dequantize output. + model->EraseArray(dequantize_op->outputs[0]); + } else { + // The dequantize output is not discardable. Special care needed. + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == + dequantize_op->outputs[0]) { + // TODO(b/78013785): never rename output arrays. + if (IsInputArray(*model, dequantize_op->inputs[0])) { + // The op input is an input array and the output is an + // output array and we can't have an array be both. Insert a + // copy op to ensure the two arrays stay separate. + AddMessageF( + "Tried to rename output array %d while removing " + "dequant " + "op %s but array is also an input; inserting copy %s " + "-> %s", + i, LogName(*dequantize_op), + model->flags.output_arrays(i), + dequantize_op->inputs[0]); + InsertCopyOperator(model, dequantize_op->inputs[0], + dequantize_op->outputs[0]); + } else { + // Op output is strictly used as an output array, so we can + // just rename the array and directly bypass the op. + AddMessageF( + "Renaming output array %d after removing dequant op " + "%s: " + "%s -> %s", + i, LogName(*dequantize_op), + model->flags.output_arrays(i), + dequantize_op->inputs[0]); + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + model->EraseArray(dequantize_op->outputs[0]); + } + break; } - break; } } model->operators.erase(dequantize_it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc new file mode 100644 index 0000000000000000000000000000000000000000..88ea0945e7dd15ba325d34ea3fdbf34ff7d91381 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +// The minimum number of elements a weights array must have to be quantized +// by this transformation. +// TODO(suharshs): Make this minimum size configurable. +const int kWeightsMinSize = 1024; + +// Gets the quantization params from the float array. +void GetQuantizationParamsFromArray(const Array& array, + QuantizationParams* params) { + const std::vector& float_vals = + array.GetBuffer().data; + auto minmax = std::minmax_element(float_vals.begin(), float_vals.end()); + MinMax toco_minmax; + toco_minmax.min = *minmax.first; + toco_minmax.max = *minmax.second; + GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params); +} + +} // namespace + +bool QuantizeWeights::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + Operator* op = op_it->get(); + + // Get the weights tensor, if the current operator has one. + int weights_index; + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + weights_index = 1; + } else if (op->type == OperatorType::kLstmCell) { + weights_index = LstmCellOperator::WEIGHTS_INPUT; + } else { + return false; + } + + // Return early if the array isn't a constant param, this can happen in early + // transformation passes until transpose operations following the weight array + // are resolved. + const string weights = op->inputs[weights_index]; + if (!IsConstantParameterArray(*model, weights)) { + return false; + } + + // Return early if the weight tensor is not type float. + Array& weights_array = model->GetArray(weights); + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + + // Return early if the tensor is too small. Small tensors don't take up too + // much space and can result in bad quantization results. + if (weights_array.GetBuffer().data.size() < + kWeightsMinSize) { + return false; + } + + // Quantize the weight tensor to type kUint8. + QuantizationParams params; + GetQuantizationParamsFromArray(weights_array, ¶ms); + QuantizeArray(this, model, weights, ArrayDataType::kUint8, params); + + // Insert a Dequantize operation after the quantized weights tensor. + auto* dequantize_op = new DequantizeOperator; + model->operators.emplace(op_it, dequantize_op); + + // Create a new intermediate tensor to connect the Dequantize op to the + // original op. + const string dequantized_output = + AvailableArrayName(*model, weights + "_dequantized"); + Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + + // Connect up the new Dequantize op with the weights and original op. + op->inputs[weights_index] = dequantized_output; + dequantize_op->inputs = {weights}; + dequantize_op->outputs = {dequantized_output}; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 35a0c465327f352863350e7a8af714d16b7be393..73ad326299bbd929afbb8dda2c41b97a126afbe1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -26,7 +26,7 @@ namespace toco { bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { const auto assert_it = model->operators.begin() + op_index; const auto* assert_op = assert_it->get(); - if (assert_op->type != OperatorType::kTensorFlowAssert) { + if (assert_op->type != OperatorType::kAssert) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 404269bbfd9312bbbab32489783d9e4217ecbd89..7ec7752f25dad1c24b821733c0e6dafbd1cd8bf2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -28,7 +28,7 @@ namespace toco { bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { const auto passthru_it = model->operators.begin() + op_index; const auto* passthru_op = passthru_it->get(); - if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + if (passthru_op->type != OperatorType::kIdentity) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index a950fe6442bc656b725a1f0687f4c024f4fb0f84..9f5d8b94507ec11957c3ae55ffca510eeb81ac89 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -97,7 +97,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, "Cannot remove %s, neither its main input nor its output may be " "discarded", LogName(*passthru_op)); - if (passthru_op->type != OperatorType::kTensorFlowReshape && + if (passthru_op->type != OperatorType::kReshape && model->GetArray(main_input_name).has_shape()) { // We can't remove either array but we can remove the op. Converting it to // a reshape gives us some hope of later on fixing that (either in the diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index eaee1c662b7cedb2baec7be47e12e348c3e7b25c..142c876b154755ac9c6b93e560f22ec8d6ec6563 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -47,11 +47,11 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, double clamp_min; double clamp_max; switch (op_type) { - case OperatorType::kTensorFlowMinimum: + case OperatorType::kMinimum: // Element-wise Minimum clamp_min = -std::numeric_limits::infinity(); clamp_max = clamp_value; break; - case OperatorType::kTensorFlowMaximum: + case OperatorType::kMaximum: // Element-wise Maximum clamp_min = clamp_value; clamp_max = std::numeric_limits::infinity(); break; @@ -72,8 +72,8 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; auto* op = it->get(); - if ((op->type != OperatorType::kTensorFlowMinimum && - op->type != OperatorType::kTensorFlowMaximum) || + if ((op->type != OperatorType::kMinimum && + op->type != OperatorType::kMaximum) || op->inputs.size() != 2) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index e28d8cf01eafee64e08ac2cc4b43ea7c227456c2..404f27e067402474484d3ee8e23595fb9f93a6c9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -30,7 +30,7 @@ namespace { bool IsReshapeTrivial(const Model& model, const Operator& op, RemoveTrivialReshape* transformation) { - CHECK(op.type == OperatorType::kTensorFlowReshape); + CHECK(op.type == OperatorType::kReshape); // One way in which a reshape can be trivial is if its // output shape is == its input shape @@ -58,7 +58,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, // is only consumed by another reshape. if (CountOpsWithInput(model, op.outputs[0]) == 1) { const auto* next_op = GetOpWithInput(model, op.outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { transformation->AddMessageF( "%s is trivial because its output is only consumed by another " "Reshape op %s", @@ -75,7 +75,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index 1956ab2d2021cda84a0d715534923d6174c30dd1..dde91234a8240f4518cd105c2cc4e79102735980 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -48,7 +48,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { for (const auto& rnn_state : model->flags.rnn_states()) { if (output == rnn_state.state_array()) { CHECK(op->type == OperatorType::kFill || - op->type == OperatorType::kTensorFlowIdentity); + op->type == OperatorType::kIdentity); found_output_as_rnn_state_array = true; break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 9f5b7920cb937b021eb23fc1d5fdc3c1ff18a72d..550de83018f25a7aa4da82707fedb86434615fb0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -37,8 +37,8 @@ bool IsElementwiseOperator(OperatorType optype) { case OperatorType::kRelu1: case OperatorType::kRelu6: case OperatorType::kTanh: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: + case OperatorType::kSqrt: + case OperatorType::kSquare: return true; default: return false; @@ -51,7 +51,7 @@ bool IsMoveOperator(OperatorType optype) { case OperatorType::kExpandDims: case OperatorType::kSpaceToDepth: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: return true; default: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc index 9e7fe1b1ccd851dd998e59e75ff798f52f7c6e5a..c907a597cb719b68dbf36868a75e49a7c5181423 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -123,8 +123,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { } TensorFlowReshapeOperator* reshape_op = - ConvertOperator( - reshape_it->get(), OperatorType::kTensorFlowReshape); + ConvertOperator(reshape_it->get(), + OperatorType::kReshape); if (reshape_op == nullptr) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index a06919e228dc2084f8943a714a0ca111d013c159..b8b35161d77e5b6dd8c30e03959dba3c60d1d56c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -50,7 +50,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { // will delete this op. return false; } - std::vector crops_buffer = + const std::vector& crops_buffer = crops_array.GetBuffer().data; for (int i = 0; i < crops_dims[0]; ++i) { op->before_crops.push_back(crops_buffer[i * 2]); @@ -62,7 +62,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); - std::vector block_shape_buffer = + const std::vector& block_shape_buffer = block_shape_array.GetBuffer().data; for (int i = 0; i < block_shape_dims[0]; ++i) { op->block_shape.push_back(block_shape_buffer[i]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index 6e78653fad238085da5ba66166884093ea9b0214..f7e5aa6609bd4f7eb2a95750125e30a7803b36e1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -145,17 +145,17 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, outval = floor(val0 / val1); } else if (binary_op->type == OperatorType::kFloorMod) { outval = val0 - (floor(val0 / val1) * val1); - } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + } else if (binary_op->type == OperatorType::kMinimum) { outval = std::min(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + } else if (binary_op->type == OperatorType::kMaximum) { outval = std::max(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowLess) { + } else if (binary_op->type == OperatorType::kLess) { outval = val0 < val1; - } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + } else if (binary_op->type == OperatorType::kLessEqual) { outval = val0 <= val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + } else if (binary_op->type == OperatorType::kGreater) { outval = val0 > val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + } else if (binary_op->type == OperatorType::kGreaterEqual) { outval = val0 >= val1; } else { LOG(FATAL) << "should not get here"; @@ -198,12 +198,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kDiv && binary_op->type != OperatorType::kFloorDiv && binary_op->type != OperatorType::kFloorMod && - binary_op->type != OperatorType::kTensorFlowMinimum && - binary_op->type != OperatorType::kTensorFlowMaximum && - binary_op->type != OperatorType::kTensorFlowLess && - binary_op->type != OperatorType::kTensorFlowLessEqual && - binary_op->type != OperatorType::kTensorFlowGreater && - binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { return false; } CHECK_EQ(binary_op->inputs.size(), 2); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index 7e7ad383e7789891f5396845241e70143dc8b76f..41562ab393694d76c5cb6c5df5f7df2a71f893f5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -25,7 +25,7 @@ namespace toco { bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); - if (base_op->type != OperatorType::kTensorFlowReshape) { + if (base_op->type != OperatorType::kReshape) { return false; } const auto* op = static_cast(base_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 9ea01acd05364224ce219bed533c999793a2a2f1..8a0e3e8995839a737b5671701a97b514b0fc7bf1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -22,8 +22,7 @@ namespace toco { bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; const auto* op = it->get(); - if (!(op->type == OperatorType::kTensorFlowShape || - op->type == OperatorType::kRank)) { + if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { return false; } @@ -48,7 +47,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { // Compute the output CHECK(!output_array.buffer); auto& output_buffer = output_array.GetMutableBuffer(); - if (op->type == OperatorType::kTensorFlowShape) { + if (op->type == OperatorType::kShape) { // Copy the input shape into the output buffer. output_buffer.data = input_array.shape().dims(); } else if (op->type == OperatorType::kRank) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc index 69db1942cd52af810acf38a818997c71122d8500..a4d5f1923a1dffdff1ef51eb5317fa5794a8bc27 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc @@ -41,7 +41,7 @@ void Stack(Model* model, StackOperator const& op) { const auto& input_array = model->GetArray(op.inputs[i]); int input_size = RequiredBufferSizeForShape(input_array.shape()); memcpy(&output_data[dst_offset], &input_array.GetBuffer().data[0], - input_size * sizeof(Type)); + input_size * ElementSize(Type)); dst_offset += input_size; } CHECK_EQ(dst_offset, output_data.size()); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 1dd52e906900e997f282740404a81b9fcd21e867..9d8bd4fc39344a4ea1fa4942a2a99ec535b5bee8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -38,6 +38,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, CHECK_EQ(op.new_axis_mask, 0); int num_input_axes = op.start_indices.size(); + CHECK_EQ(num_input_axes, op.start_indices.size()); CHECK_EQ(num_input_axes, op.stop_indices.size()); CHECK_EQ(num_input_axes, op.strides.size()); @@ -49,11 +50,16 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, // Initialize source coordinate Shape const& input_shape = input_array.shape(); Buffer const& input_buffer = input_array.GetBuffer(); - std::vector src_coord(op.start_indices.size()); + std::vector src_coord(num_input_axes); + std::vector stop_for_axis(num_input_axes); for (int axis = 0; axis < num_input_axes; axis++) { - src_coord[axis] = tflite::strided_slice::StartForAxis( + int start = tflite::strided_slice::StartForAxis( op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), axis); + src_coord[axis] = start; + stop_for_axis[axis] = tflite::strided_slice::StopForAxis( + op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, + input_shape.dims().data(), axis, start); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -76,9 +82,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, } // Check if we've overflowed. - int stop = tflite::strided_slice::StopForAxis( - op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(), - axis); + int stop = stop_for_axis[axis]; if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( @@ -155,14 +159,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { break; } - // Erase input array if no longer used - if (IsDiscardableArray(*model, op->inputs[0]) && - CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->EraseArray(op->inputs[0]); - } - - // Erase the operator - model->operators.erase(it); + DeleteOpAndArraysIfUnused(model, it->get()); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index f6c8f79d8d3311dc2294e3ec406a184b2a16a6b5..f89ef85fdb63ca4906c7f016e86bb1f9d8a7099a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -53,13 +53,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kLog: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowSum: - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowReshape: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kSum: + case OperatorType::kMin: // Reduction Min + case OperatorType::kMax: // Reduction Max + case OperatorType::kReshape: case OperatorType::kRelu6: case OperatorType::kRelu1: case OperatorType::kRelu: @@ -103,7 +103,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { // The min-max is only copied for ops that copy data without arithmetic. // In future trivial transpose, etc, can be handled here. - if (unary_op->type == OperatorType::kTensorFlowReshape) { + if (unary_op->type == OperatorType::kReshape) { CopyMinMaxFromFirstInput(*unary_op, model); } @@ -164,10 +164,10 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = outval; } - } else if (unary_op->type == OperatorType::kTensorFlowReshape) { + } else if (unary_op->type == OperatorType::kReshape) { CHECK(input_buffer_size == output_buffer_size); output_float_data = *input_float_data; - } else if (unary_op->type == OperatorType::kTensorFlowSum) { + } else if (unary_op->type == OperatorType::kSum) { CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); @@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = sum; } - } else if (unary_op->type == OperatorType::kTensorFlowMin) { + } else if (unary_op->type == OperatorType::kMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { min = std::min(min, (*input_float_data)[i]); } output_float_data[0] = min; - } else if (unary_op->type == OperatorType::kTensorFlowMax) { + } else if (unary_op->type == OperatorType::kMax) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -220,9 +220,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { output_float_data[0] = max; } else if (unary_op->type == OperatorType::kNeg || unary_op->type == OperatorType::kLog || - unary_op->type == OperatorType::kTensorFlowRsqrt || - unary_op->type == OperatorType::kTensorFlowSqrt || - unary_op->type == OperatorType::kTensorFlowSquare) { + unary_op->type == OperatorType::kRsqrt || + unary_op->type == OperatorType::kSqrt || + unary_op->type == OperatorType::kSquare) { // Element-wise ops. Should have perfectly matching sizes here. for (int i = 0; i < output_dims_count; i++) { CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); @@ -235,11 +235,11 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { outval = -val; } else if (unary_op->type == OperatorType::kLog) { outval = std::log(val); - } else if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + } else if (unary_op->type == OperatorType::kRsqrt) { outval = 1.0f / std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + } else if (unary_op->type == OperatorType::kSqrt) { outval = std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + } else if (unary_op->type == OperatorType::kSquare) { outval = val * val; } else { LOG(FATAL) << "should not get here."; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index bc70db0bd8c26319fa140616de96452260a01058..8266e2c205b65e9d8a969643f102bb852be9125b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -51,11 +51,12 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order, } bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { - auto reorder_it = model->operators.begin() + op_index; - auto* reorder_op = static_cast(reorder_it->get()); - if (reorder_op->type != OperatorType::kReorderAxes) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->type != OperatorType::kReorderAxes) { return false; } + auto* reorder_op = static_cast(op); const auto& input_array_name = reorder_op->inputs[0]; const auto& output_array_name = reorder_op->outputs[0]; auto& input_array = model->GetArray(input_array_name); @@ -95,7 +96,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { // Remove the op and output array. model->EraseArray(output_array_name); - model->operators.erase(reorder_it); + model->operators.erase(it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index 2e063e35548aa5e51c3bcc94a2dfc7992180d014..b615c9a545695e5d14fa5809e0c38a770f23ea24 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -28,7 +28,7 @@ namespace toco { bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index dad6aceccfd201b3db07c29c99a8c6ef75bb89a1..fab50bec1fc5ec50cecba53845457931ed59c0b8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -53,7 +53,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { // will delete this op. return false; } - std::vector paddings_buffer = + const std::vector& paddings_buffer = paddings_array.GetBuffer().data; for (int i = 0; i < paddings_dims[0]; ++i) { op->before_paddings.push_back(paddings_buffer[i * 2]); @@ -66,7 +66,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); - std::vector block_shape_buffer = + const std::vector& block_shape_buffer = block_shape_array.GetBuffer().data; for (int i = 0; i < block_shape_dims[0]; ++i) { op->block_shape.push_back(block_shape_buffer[i]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index dd3e73635ae0215510f0a8d1aee487da5af35700..e8bb85704e1c750300079681b5a12f6a488b6b48 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -36,7 +36,7 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { // If the output is consumed by a reshape op, it's a trivial squeeze. if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { AddMessageF( "%s is trivial because its output is only consumed by a " "Reshape op", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index 5c0c1e3478fa0d94104d1b76bab176b98b314c50..fa5ee899334bdf2d39a6861b0e0c4548142e9d2a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -28,8 +28,8 @@ namespace toco { bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { auto concat_it = model->operators.begin() + op_index; const auto* tf_concat_op = concat_it->get(); - if (tf_concat_op->type != OperatorType::kTensorFlowConcat && - tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type != OperatorType::kConcat && + tf_concat_op->type != OperatorType::kConcatV2) { return false; } @@ -38,7 +38,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { // of inputs: in Concat,the axis is the first input, while in // ConcatV2, it is the last input. std::size_t axis_pos = 0; - if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type == OperatorType::kConcatV2) { axis_pos = tf_concat_op->inputs.size() - 1; } const string axis_name = tf_concat_op->inputs[axis_pos]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 2a236d3f98784e8244942f94d5a250b5bc00a8ad..fcf30bd34725fc59bb819e75deda0dadf330f372 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -26,27 +26,40 @@ namespace toco { bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { auto matmul_it = model->operators.begin() + op_index; - if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + if (matmul_it->get()->type != OperatorType::kMatMul) { return false; } const auto* matmul_op = static_cast(matmul_it->get()); + // Handling transposition of the first input here isn't very simple because + // we need to know the actual shape in order to produce a proper + // TransposeOperator. However, the second input is supposed to be 2D, so we + // can actually handle transposition of that matrix, which happens to be more + // common anyway. + CHECK(!matmul_op->transpose_a); + // Reorder the axes on the second input. TensorFlow uses row-major ordering // on both inputs, however this is inefficient for the FullyConnected // operator. We'll transpose the second input to be in column-major order now // and let constant propagation optimize things (if possible). - auto* transpose_op = new TransposeOperator; - transpose_op->inputs = { - matmul_op->inputs[1], - CreateInt32Array( - model, - AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"), - {1, 0})}; - transpose_op->outputs = { - AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; - model->GetOrCreateArray(transpose_op->outputs[0]); - model->operators.emplace(matmul_it, transpose_op); + string input_lhs = matmul_op->inputs[0]; + string input_rhs = matmul_op->inputs[1]; + if (!matmul_op->transpose_b) { + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = { + matmul_op->inputs[1], + CreateInt32Array(model, + AvailableArrayName( + *model, matmul_op->inputs[1] + "/transpose/perm"), + {1, 0})}; + transpose_op->outputs = { + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; + model->GetOrCreateArray(transpose_op->outputs[0]); + model->operators.emplace(matmul_it, transpose_op); + + input_rhs = transpose_op->outputs[0]; + } // Refresh iterator. matmul_it = model->operators.begin(); @@ -57,9 +70,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } DCHECK_EQ(matmul_it->get(), matmul_op); - string input_lhs = matmul_op->inputs[0]; - string input_rhs = transpose_op->outputs[0]; - // Construct the new FullyConnectedOperator. auto* fc_op = new FullyConnectedOperator; fc_op->outputs = matmul_op->outputs; @@ -97,7 +107,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { // MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if // the input doesn't need reshaping, so we can't just match (Reshape, MatMul) // pairs. - if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + if (previous_op && previous_op->type == OperatorType::kReshape) { AddMessageF("Combining %s and %s into %s", LogName(*previous_op), LogName(*matmul_op), LogName(*fc_op)); const auto& previous_op_output = previous_op->outputs[0]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 38e0005890ac10410df4ddb5290be8fcc948c349..4edffe3d48fd880c0261b34fc407b8e2ac66ccb9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { const auto merge_it = model->operators.begin() + op_index; const auto* merge_op = merge_it->get(); - if (merge_op->type != OperatorType::kTensorFlowMerge) { + if (merge_op->type != OperatorType::kMerge) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index a418073441f1241a5acb1164b36f332828ea2e99..da8e7a2d1c06cf89b9708b404da7667565245f8f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { const auto switch_it = model->operators.begin() + op_index; const auto* switch_op = switch_it->get(); - if (switch_op->type != OperatorType::kTensorFlowSwitch) { + if (switch_op->type != OperatorType::kSwitch) { return false; } @@ -92,7 +92,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { if (*input_it == switch_op->outputs[nonselected_output_index]) { // Let us guard our assumption that only Merge nodes consume the outputs // of Switch nodes: - CHECK(other_op->type == OperatorType::kTensorFlowMerge); + CHECK(other_op->type == OperatorType::kMerge); input_it = other_op->inputs.erase(input_it); } else { ++input_it; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc deleted file mode 100644 index 1ddf54c778cd1fae7a8fce0ecb97209274e71ac0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -namespace { - -void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, - int operand_index) { - CHECK(tile_op->type == OperatorType::kTensorFlowTile); - CHECK_EQ(binary_op->inputs.size(), 2); - CHECK_EQ(tile_op->inputs.size(), 2); - const string tile_multiplier_array = tile_op->inputs[1]; - const string tile_output_array = tile_op->outputs[0]; - binary_op->inputs[operand_index] = tile_op->inputs[0]; - auto tile_it = model->operators.begin(); - for (; tile_it != model->operators.end(); ++tile_it) { - if (tile_it->get() == tile_op) { - break; - } - } - CHECK(tile_it != model->operators.end()); - CHECK(tile_it->get() == tile_op); - model->operators.erase(tile_it); - if (!CountOpsWithInput(*model, tile_multiplier_array) && - !GetOpWithOutput(*model, tile_multiplier_array)) { - model->EraseArray(tile_multiplier_array); - } - if (!CountOpsWithInput(*model, tile_output_array)) { - model->EraseArray(tile_output_array); - } -} -} // namespace - -bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { - const auto binary_it = model->operators.begin() + op_index; - auto* binary_op = binary_it->get(); - // Test for binary ops of types that we know how to resolve - if (binary_op->inputs.size() != 2) { - return false; - } - if (binary_op->type != OperatorType::kAdd && - binary_op->type != OperatorType::kMul && - binary_op->type != OperatorType::kSub && - binary_op->type != OperatorType::kDiv) { - return false; - } - - Operator* const op[2] = { - GetOpWithOutput(*model, binary_op->inputs[0]), - GetOpWithOutput(*model, binary_op->inputs[1]), - }; - - // In the unlikely case where both operands are Tile, we can't infer the - // output - // size without the Tile nodes, so we have to bail out. - if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && - op[1]->type == OperatorType::kTensorFlowTile) { - return false; - } - - for (int i = 0; i < 2; i++) { - if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { - // We can only remove a Tile operator is no other op than the present - // binary op was consuming its tiled output. - if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { - AddMessageF("Removing %s", LogName(*op[i])); - RemoveTileOperator(model, op[i], binary_op, i); - return true; - } - } - } - return false; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc rename to tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc index c00cdcb944b085dda41033b95c96537cc2e047c3..22c258cec5fde4144c4b048d5ec60a8604362cbb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -24,14 +24,14 @@ limitations under the License. namespace toco { -bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) { +bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { Operator* op = model->operators[op_index].get(); if (op->type != OperatorType::kFullyConnected) { return false; } FullyConnectedOperator* fc_op = static_cast(op); // Exit if this FC op already has shuffled weights - if (fc_op->experimental_shuffled_weights) { + if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) { return false; } const Array& input_array = model->GetArray(fc_op->inputs[0]); @@ -135,7 +135,7 @@ bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) { CHECK_EQ(shuffled_data_ptr, shuffled_data.data() + rows * cols); // Switch this FC op to using the shuffled weights. weights_data = std::move(shuffled_data); - fc_op->experimental_shuffled_weights = true; + fc_op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; AddMessageF("Applied experimental shuffling to the weights of %s", LogName(*op)); // Add a second output array to this FC op, serving as a workspace to perform diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 8dcd4adc90b188c745cadb9815c3c46383705833..95e8433be2a332cfce5175f4f65ea0b83d5638c5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -8,8 +8,8 @@ load( ) tf_cc_test( - name = "resolve_constant_concatenation_test", - srcs = ["resolve_constant_concatenation_test.cc"], + name = "lstm_utils_test", + srcs = ["lstm_utils_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -19,8 +19,20 @@ tf_cc_test( ) tf_cc_test( - name = "lstm_utils_test", - srcs = ["lstm_utils_test.cc"], + name = "quantize_weights_test", + srcs = ["quantize_weights_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c05eb0929fd775d315fa735b4c9842a7fc024fa8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +class QuantizeWeightsTest : public ::testing::Test { + protected: + QuantizeWeightsTest() {} + + // The name of the weights input array. + const string kWeightsName = "weights"; + // The zero_point of the values in the input array. + const int kZeroPoint = 128; + + // Prepare a hypothetical TOCO model of a quantizable fully connected float + // layer. + void PrepareModel(Model* model, int elements_per_dim) { + std::vector fc_input_names = {"inputs", kWeightsName}; + + const int kDim = 4; + const int buf_size = std::pow(elements_per_dim, static_cast(kDim)); + auto in_buf = absl::make_unique(buf_size); + // Initialize the array with values from -128.0 to 127.0, since these values + // should be exactly representable by quantization. + for (int i = 0; i < buf_size; i++) { + in_buf[i] = static_cast(i % 256 - kZeroPoint); + } + + for (const string& fc_input_name : fc_input_names) { + Array& in_array = model->GetOrCreateArray(fc_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); + in_array_shape_dim->resize(kDim, elements_per_dim); + auto& in_array_buffer = + in_array.GetMutableBuffer(); + in_array_buffer.data.resize(buf_size); + float* buf_ptr = + in_array.GetMutableBuffer().data.data(); + std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr); + } + + auto* fc_op = new FullyConnectedOperator; + fc_op->inputs = fc_input_names; + fc_op->outputs = {"fc_op_outputs"}; + Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim, elements_per_dim); + model->operators.push_back(std::unique_ptr(fc_op)); + } +}; + +TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) { + // Test that weight arrays that are large enough are quantized. + Model model; + // 6 elements per dim gives us 1296 elements, which is sufficient to be + // quantized. + PrepareModel(&model, 6); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (const auto& element : float_array_map) { + EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat); + } + const std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& quantized_array_map = model.GetArrayMap(); + EXPECT_EQ(quantized_array_map.size(), 4); + // After the transformation, three arrays should be type float and one array + // should be uint8. + int num_float = 0; + int num_uint8 = 0; + for (const auto& element : quantized_array_map) { + if (element.second->data_type == ArrayDataType::kFloat) { + num_float++; + } else if (element.second->data_type == ArrayDataType::kUint8) { + num_uint8++; + } else { + FAIL() << "Unexpected array type."; + } + } + EXPECT_EQ(num_float, 3); + EXPECT_EQ(num_uint8, 1); + // Ensure that the values were quantized correctly. + const std::vector& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint); + } + + // Ensure that a Dequantize operator has been inserted before the + // FullyConnectedLayer. + EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize); +} + +TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) { + // Test that weight arrays that are too small are left untouched. + Model model; + // 5 elements per dim gives us 625 elements, which is NOT sufficient to be + // quantized. + PrepareModel(&model, 5); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& post_array_map = model.GetArrayMap(); + EXPECT_EQ(post_array_map.size(), 3); + for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + // Ensure that the values remain unchanged. + std::vector const& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 3a1d175b9823f085c9b8730caba8bedd7eb87d52..66cfed4ac26969729d1881f11ba6ae74d9817fb5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include @@ -126,7 +124,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test { Array& in_array = model->GetOrCreateArray(concat_input_name); in_array.data_type = ArrayDataType::kFloat; - // Initialize shape for the input array. + // Initialize shape for the input array. Shape* in_array_shape = in_array.mutable_shape(); std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); for (int i = 0; i < kDim; i++) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index af84c667a7e89e0fc2f90818e1fef3618f4328de..5c32a39035f3c5396b09621bacaa58a7baa3ae9b 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -63,8 +63,6 @@ using tensorflow::TensorShapeProto; namespace toco { -using port::Status; - namespace { bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; @@ -130,6 +128,42 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } +tensorflow::Status CheckOptionalAttr(const NodeDef& node, + const string& attr_name, + const string& expected_value) { + if (HasAttr(node, attr_name)) { + const string& value = GetStringAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + expected_value + "'"); + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status CheckOptionalAttr( + const NodeDef& node, const string& attr_name, + const tensorflow::DataType& expected_value) { + if (HasAttr(node, attr_name)) { + const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + tensorflow::DataType_Name(expected_value) + "'"); + } + } + return tensorflow::Status::OK(); +} + +template +tensorflow::Status ExpectValue(const T1& v1, const T2& v2, + const string& description) { + if (v1 == v2) return tensorflow::Status::OK(); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Unexpected ", description, ": got ", v1, ", expected ", v2)); +} + ArrayDataType ConvertDataType(tensorflow::DataType dtype) { if (dtype == DT_UINT8) return ArrayDataType::kUint8; @@ -148,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< - tensorflow::TensorShapeProto_Dim>& input_dims, - int* input_flat_size, Shape* shape) { +tensorflow::Status ImportShape( + const TFLITE_PROTO_NS::RepeatedPtrField& + input_dims, + int* input_flat_size, Shape* shape) { std::vector input_dims_only_sizes; for (auto& d : input_dims) { if (d.size() == 0) { @@ -160,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< // For now, tweaking this to record a 0-D shape instead. shape->mutable_dims()->clear(); if (input_flat_size != nullptr) *input_flat_size = 0; - return Status::OK(); + return tensorflow::Status::OK(); } // TensorFlow's shapes use int64s, while TOCO uses ints. if (d.size() > std::numeric_limits::max()) { - return Status(false, "Shape element overflows"); + return tensorflow::errors::InvalidArgument("Shape element overflows"); } input_dims_only_sizes.push_back(d.size()); } *shape->mutable_dims() = input_dims_only_sizes; - if (input_flat_size == nullptr) return Status::OK(); + if (input_flat_size == nullptr) return tensorflow::Status::OK(); return NumElements(input_dims_only_sizes, input_flat_size); } -Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -203,14 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_float_data.data())); } else { - return Status(false, - "Neither input_content nor float_val have the right " - "dimensions for this float tensor"); + return tensorflow::errors::InvalidArgument( + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(float), + ") nor float_val (", input_tensor.float_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this float tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -223,7 +263,11 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -232,14 +276,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status(false, - "Neither input_content nor int_val have the right dimensions " - "for this uint8 tensor"); + return tensorflow::errors::InvalidArgument( + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(uint8_t), + ") nor int_val (", input_tensor.int_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this uint8 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -252,7 +300,11 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -261,14 +313,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status(false, - "Neither input_content nor int_val have the right dimensions " - "for this int32 tensor"); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (", + input_tensor.int_val_size(), ") have the right dimensions (", + input_flat_size, ") for this int32 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -281,8 +336,12 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int64_val_size()) { - for (int i = 0; i < input_tensor.int64_val_size(); i++) { + if (input_tensor.int64_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int64_val(0); + } + } else if (input_tensor.int64_val_size() == input_flat_size) { + for (int i = 0; i < input_tensor.float_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); } } else if (input_tensor.tensor_content().size() == @@ -290,14 +349,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status(false, - "Neither input_content nor int64_val have the right " - "dimensions for this int64 tensor"); + return tensorflow::errors::InvalidArgument( + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int64), + ") nor int64_val (", input_tensor.int64_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this int64 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -311,7 +374,11 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), false); CHECK_GE(output_bool_data.size(), input_flat_size); - if (input_tensor.bool_val_size()) { + if (input_tensor.bool_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_bool_data[i] = input_tensor.bool_val(0); + } + } else if (input_tensor.bool_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.bool_val_size(); i++) { output_bool_data[i] = input_tensor.bool_val(i); } @@ -327,16 +394,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. if (output_bool_data.size() != 1) { - return Status(false, - "Neither input_content nor bool_val have the right " - "dimensions for this bool tensor"); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", input_tensor.tensor_content().size(), + ") nor bool_val (", input_tensor.bool_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this bool tensor")); } output_bool_data[0] = false; } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -346,9 +416,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { if (!status.ok()) return status; if (input_flat_size != input_tensor.string_val_size()) { - return Status(false, - "Input_content string_val doesn't have the right dimensions " - "for this string tensor"); + return tensorflow::errors::InvalidArgument( + "Input_content string_val doesn't have the right dimensions " + "for this string tensor"); } auto& output_string_data = @@ -358,7 +428,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } - return Status::OK(); + return tensorflow::Status::OK(); } // Count the number of inputs of a given node. If @@ -372,18 +442,19 @@ int GetInputsCount(const NodeDef& node, return i; } } - return node.input_size(); - } else { - return node.input_size(); } + return node.input_size(); } -void CheckInputsCount(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - int expected_input_count) { - QCHECK_EQ(GetInputsCount(node, tf_import_flags), expected_input_count) - << node.op() << " node expects " << expected_input_count - << " input(s) other than control dependencies: " << node.DebugString(); +tensorflow::Status CheckInputsCount( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + int expected_input_count) { + if (GetInputsCount(node, tf_import_flags) != expected_input_count) { + return tensorflow::errors::FailedPrecondition( + node.op(), " node expects ", expected_input_count, + " input(s) other than control dependencies: ", node.DebugString()); + } + return tensorflow::Status::OK(); } template @@ -398,14 +469,14 @@ string CreateConstArray(Model* model, string const& name, return array_name; } -Status ConvertConstOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConstOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - Status status = Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -441,24 +512,21 @@ Status ConvertConstOperator(const NodeDef& node, array.GetMutableBuffer(); break; } - if (!status.ok()) { - status.AppendMessage(" (while processing node '" + node.name() + "')"); - } - return status; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + status, " (while processing node '" + node.name() + "')"); + return tensorflow::Status::OK(); } -void ConvertConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2D"); - CheckInputsCount(node, tf_import_flags, 2); + TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (HasAttr(node, "data_format")) { - CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); - } - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); const auto& input_name = node.input(0); const auto& weights_name = node.input(1); @@ -483,27 +551,26 @@ void ConvertConvOperator(const NodeDef& node, auto* conv = new ConvOperator; conv->inputs = {input_name, reordered_weights_name}; conv->outputs = {node.name()}; + if (!HasAttr(node, "strides")) { + return tensorflow::errors::InvalidArgument("Missing attribute 'strides'"); + } const auto& strides = GetListAttr(node, "strides"); - CHECK_EQ(strides.i_size(), 4); - CHECK_EQ(strides.i(0), 1); - CHECK_EQ(strides.i(3), 1); + TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); if (HasAttr(node, "dilations")) { const auto& dilations = GetListAttr(node, "dilations"); - CHECK_EQ(dilations.i_size(), 4); - CHECK_EQ(dilations.i(0), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; - CHECK_EQ(dilations.i(3), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; + TF_RETURN_IF_ERROR( + ExpectValue(dilations.i_size(), 4, "number of dilations")); + if (dilations.i(0) != 1 || dilations.i(3) != 1) { + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "].")); + } conv->dilation_height_factor = dilations.i(1); conv->dilation_width_factor = dilations.i(2); } else { @@ -516,16 +583,19 @@ void ConvertConvOperator(const NodeDef& node, } else if (padding == "VALID") { conv->padding.type = PaddingType::kValid; } else { - LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + return tensorflow::errors::InvalidArgument( + "Bad padding (only SAME and VALID are supported)"); } model->operators.emplace_back(conv); + + return tensorflow::Status::OK(); } -void ConvertDepthwiseConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthwiseConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -572,13 +642,14 @@ void ConvertDepthwiseConvOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(conv); + return tensorflow::Status::OK(); } -void ConvertDepthToSpaceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthToSpaceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); auto* op = new DepthToSpaceOperator; @@ -587,28 +658,37 @@ void ConvertDepthToSpaceOperator(const NodeDef& node, op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToDepthOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToDepthOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); + if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && + dtype != DT_INT64) { + const auto* enum_descriptor = tensorflow::DataType_descriptor(); + LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:" + << enum_descriptor->FindValueByNumber(dtype)->name() << ". " + << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}."; + } auto* op = new SpaceToDepthOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBiasAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBiasAddOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BiasAdd"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto& input_name = node.input(0); const auto& bias_name = node.input(1); @@ -618,13 +698,14 @@ void ConvertBiasAddOperator(const NodeDef& node, biasadd->inputs.push_back(bias_name); biasadd->outputs.push_back(node.name()); model->operators.emplace_back(biasadd); + return tensorflow::Status::OK(); } -void ConvertRandomUniform(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRandomUniform( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "RandomUniform"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32); auto op = absl::make_unique(); @@ -635,86 +716,12 @@ void ConvertRandomUniform(const NodeDef& node, op->seed2 = GetIntAttr(node, "seed2"); CHECK(model != nullptr); model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); } -void ConvertReluOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* relu = new ReluOperator; - relu->inputs.push_back(input_name); - relu->outputs.push_back(node.name()); - model->operators.emplace_back(relu); -} - -void ConvertRelu6Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu6"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new Relu6Operator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLogOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Log"); - CheckInputsCount(node, tf_import_flags, 1); - - auto op = absl::make_unique(); - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(std::move(op)); -} - -void ConvertLogisticOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sigmoid"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new LogisticOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertTanhOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tanh"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new TanhOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK(node.op() == "Div" || node.op() == "RealDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new DivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertIdentityOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertIdentityOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient"); auto* op = new TensorFlowIdentityOperator; @@ -731,13 +738,14 @@ void ConvertIdentityOperator(const NodeDef& node, op->inputs.push_back(input_name); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxArgs( +tensorflow::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new FakeQuantOperator; op->inputs.push_back(node.input(0)); op->minmax.reset(new MinMax); @@ -748,9 +756,10 @@ void ConvertFakeQuantWithMinMaxArgs( // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxVars( +tensorflow::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); @@ -766,46 +775,14 @@ void ConvertFakeQuantWithMinMaxVars( op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertNegOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Neg"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new NegOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertRsqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rsqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowRsqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSqueezeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSqueezeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Squeeze"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new SqueezeOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); @@ -819,73 +796,14 @@ void ConvertSqueezeOperator(const NodeDef& node, } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSquareOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Square"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSquareOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Add"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new AddOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "AddN"); - const int num_inputs = GetInputsCount(node, tf_import_flags); - auto* op = new AddNOperator; - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Mul"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new MulOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSubOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sub"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new SubOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSumOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Sum"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSumOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -894,74 +812,14 @@ void ConvertSumOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertTileOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tile"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowTileOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Slice"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new SliceOperator; - for (int i = 0; i < 3; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Pad"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new PadOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "PadV2"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new PadV2Operator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->inputs.push_back(node.input(2)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertShapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Shape"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowShapeOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSplitOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSplitOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Split"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -972,25 +830,14 @@ void ConvertSplitOperator(const NodeDef& node, } op->num_split = num_split; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertMergeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Merge"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMergeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSwitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSwitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Switch"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -998,13 +845,14 @@ void ConvertSwitchOperator(const NodeDef& node, // Switch operators have two outputs: "name" and "name:1". op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSoftmaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Softmax"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* softmax = new SoftmaxOperator; softmax->inputs.push_back(input_name); @@ -1013,25 +861,14 @@ void ConvertSoftmaxOperator(const NodeDef& node, CHECK(!node.attr().count("beta")); // Stab in the dark, just in case. softmax->beta = 1.f; model->operators.emplace_back(softmax); + return tensorflow::Status::OK(); } -void ConvertLogSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LogSoftmax"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* log_softmax = new LogSoftmaxOperator; - log_softmax->inputs.push_back(input_name); - log_softmax->outputs.push_back(node.name()); - model->operators.emplace_back(log_softmax); -} - -void ConvertLRNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertLRNOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "LRN"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* lrn = new LocalResponseNormalizationOperator; lrn->inputs.push_back(input_name); @@ -1041,13 +878,14 @@ void ConvertLRNOperator(const NodeDef& node, lrn->alpha = GetFloatAttr(node, "alpha"); lrn->beta = GetFloatAttr(node, "beta"); model->operators.emplace_back(lrn); + return tensorflow::Status::OK(); } -void ConvertMaxPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "MaxPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -1083,13 +921,14 @@ void ConvertMaxPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(maxpool); + return tensorflow::Status::OK(); } -void ConvertAvgPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertAvgPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "AvgPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -1121,24 +960,13 @@ void ConvertAvgPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(avgpool); + return tensorflow::Status::OK(); } -void ConvertReshapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Reshape"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowReshapeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertBatchMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertBatchMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); @@ -1148,33 +976,36 @@ void ConvertBatchMatMulOperator(const NodeDef& node, batch_matmul->inputs = {node.input(0), node.input(1)}; batch_matmul->outputs = {node.name()}; model->operators.emplace_back(batch_matmul); + return tensorflow::Status::OK(); } -void ConvertMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"), - false); - CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"), - false); CHECK(!HasAttr(node, "adjoint_a") || (GetBoolAttr(node, "adjoint_a") == false)); CHECK(!HasAttr(node, "adjoint_b") || (GetBoolAttr(node, "adjoint_b") == false)); auto* matmul = new TensorFlowMatMulOperator; + if (HasAttr(node, "transpose_a")) { + matmul->transpose_a = GetBoolAttr(node, "transpose_a"); + } + if (HasAttr(node, "transpose_b")) { + matmul->transpose_b = GetBoolAttr(node, "transpose_b"); + } + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); + return tensorflow::Status::OK(); } -void ConvertConcatOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConcatOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -1194,104 +1025,38 @@ void ConvertConcatOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertAllOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "All"); - auto* op = new TensorFlowAllOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAssertOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Assert"); - auto* op = new TensorFlowAssertOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLessOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Less"); - auto* op = new TensorFlowLessOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLessEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LessEqual"); - auto* op = new TensorFlowLessEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sin"); - auto* op = new SinOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Greater"); - auto* op = new TensorFlowGreaterOperator; +// This method supports simple operators without additional attributes. +template +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertGreaterEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "GreaterEqual"); - auto* op = new TensorFlowGreaterEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); +// This method supports simple operators without additional attributes. +template +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); + return ConvertSimpleOperator(node, tf_import_flags, model); } -void ConvertMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Max"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowMaxOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1300,13 +1065,14 @@ void ConvertMaxOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertMinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Min"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowMinOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1315,35 +1081,12 @@ void ConvertMinOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertMaximumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Maximum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMaximumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMinimumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Minimum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMinimumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertUnsupportedOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertUnsupportedOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1362,29 +1105,20 @@ void ConvertUnsupportedOperator(const NodeDef& node, for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } + } else if (HasAttr(node, "Tout")) { + const auto& output_type = GetDataTypeAttr(node, "Tout"); + op->output_data_types.push_back(ConvertDataType(output_type)); } + return tensorflow::Status::OK(); } -void ConvertSelectOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 3); - - auto* op = new SelectOperator; - for (const auto& input : node.input()) { - op->inputs.push_back(input); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertStridedSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStridedSliceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "StridedSlice"); // TODO(soroosh): The 4th input (strides) should be e optional, to be // consistent with TF. - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new StridedSliceOperator; for (const auto& input : node.input()) { @@ -1404,14 +1138,15 @@ void ConvertStridedSliceOperator(const NodeDef& node, : 0; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertPlaceholderOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertPlaceholderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); if (node.op() == "Placeholder") { - CheckInputsCount(node, tf_import_flags, 0); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0)); } auto& array = model->GetOrCreateArray(node.name()); if (node.attr().count("dtype")) { @@ -1436,17 +1171,20 @@ void ConvertPlaceholderOperator(const NodeDef& node, } } } + return tensorflow::Status::OK(); } -void ConvertNoOpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) {} +tensorflow::Status ConvertNoOpOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + return tensorflow::Status::OK(); +} -void ConvertCastOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertCastOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Cast"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT"); auto* op = new CastOperator; @@ -1455,27 +1193,31 @@ void ConvertCastOperator(const NodeDef& node, op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFloorOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFloorOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Floor"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); CHECK(data_type == DT_FLOAT); auto* op = new FloorOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertGatherOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertGatherOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); - if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2); - if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3); + if (node.op() == "Gather") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + if (node.op() == "GatherV2") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); auto* op = new GatherOperator; @@ -1485,13 +1227,14 @@ void ConvertGatherOperator(const NodeDef& node, // should read it an pass it on to the TF Lite Interpreter. op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertArgMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMax"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; const auto output_type = HasAttr(node, "output_type") @@ -1505,13 +1248,14 @@ void ConvertArgMaxOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertResizeBilinearOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertResizeBilinearOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new ResizeBilinearOperator; op->align_corners = false; @@ -1523,13 +1267,14 @@ void ConvertResizeBilinearOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchNormWithGlobalNormalizationOperator( +tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // TODO(ahentz): to really match tensorflow we need to add variance_epsilon // to the input, before feeding it into TensorFlowRsqrtOperator. @@ -1572,13 +1317,14 @@ void ConvertBatchNormWithGlobalNormalizationOperator( op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFusedBatchNormOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFusedBatchNormOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "FusedBatchNorm"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // Declare shortcuts for the inputs. const string& gamma_input = node.input(1); @@ -1624,13 +1370,14 @@ void ConvertFusedBatchNormOperator(const NodeDef& node, op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToBatchNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToBatchNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32); auto* op = new SpaceToBatchNDOperator; @@ -1639,13 +1386,14 @@ void ConvertSpaceToBatchNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchToSpaceNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBatchToSpaceNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32); auto* op = new BatchToSpaceNDOperator; @@ -1654,24 +1402,14 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertExpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Exp"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new ExpOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMeanOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMeanOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Mean"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new MeanOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1682,11 +1420,12 @@ void ConvertMeanOperator(const NodeDef& node, } else if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertSvdfOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSvdfOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 3 || input_size == 4) @@ -1709,14 +1448,15 @@ void ConvertSvdfOperator(const NodeDef& node, } op->rank = node.attr().at("Rank").i(); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This is just bare bones support to get the shapes to propagate. -void ConvertTransposeConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTransposeConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TransposeConvOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1757,11 +1497,13 @@ void ConvertTransposeConvOperator(const NodeDef& node, if (existing_transpose) { CHECK(existing_transpose->type == OperatorType::kTranspose); } else { - // Transpose weights from HWIO order to OHWI order, which is more efficient - // for computation + // Transpose weights from HWOI order to OHWI order, which is more efficient + // for computation. (Note that TensorFlow considers the order as HWIO + // because they consider this a backward conv, inverting the sense of + // input/output.) TransposeOperator* transpose = new TransposeOperator; string perm_array = CreateConstArray( - model, node.name() + "_transpose_perm", {3, 0, 1, 2}); + model, node.name() + "_transpose_perm", {2, 0, 1, 3}); transpose->inputs = {weights_name, perm_array}; transpose->outputs = {transposed_weights_name}; model->operators.emplace_back(transpose); @@ -1778,61 +1520,14 @@ void ConvertTransposeConvOperator(const NodeDef& node, "Conv2DBackpropInput nodes."; } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertExpandDimsOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "ExpandDims"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new ExpandDimsOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFillOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Fill"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FillOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorDivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorModOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorMod"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorModOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertRangeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRangeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Range"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; if (HasAttr(node, "Tidx")) { const auto dtype = toco::GetDataTypeAttr(node, "Tidx"); @@ -1845,22 +1540,12 @@ void ConvertRangeOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertRankOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rank"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new RankOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertStackOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStackOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "Stack") || (node.op() == "Pack")); auto* op = new StackOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1876,18 +1561,7 @@ void ConvertStackOperator(const NodeDef& node, op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; op->outputs.push_back(node.name()); model->operators.emplace_back(op); -} - -void ConvertTransposeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Transpose"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TransposeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // Some TensorFlow ops only occur in graph cycles, representing @@ -1900,7 +1574,7 @@ void ConvertTransposeOperator(const NodeDef& node, // such ops as RNN back-edges, which is technically incorrect (does not // allow representing the op's semantics) but good enough to get a // graph visualization. -void ConvertOperatorSpecialCasedAsRNNBackEdge( +tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is @@ -1913,6 +1587,23 @@ void ConvertOperatorSpecialCasedAsRNNBackEdge( rnn_state->set_discardable(true); rnn_state->set_state_array(node.name()); rnn_state->set_back_edge_source_array(node.input(0)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertShapeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Shape"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto out_type = + HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32; + CHECK(out_type == DT_INT64 || out_type == DT_INT32); + auto op = absl::make_unique(); + op->output_data_type = ConvertDataType(out_type); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); } void StripCaretFromArrayNames(Model* model) { @@ -2055,9 +1746,9 @@ bool InlineAllFunctions(GraphDef* graphdef) { return graph_modified; } -void ConvertTopKV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTopKV2Operator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = absl::make_unique(); op->inputs.push_back(node.input(0)); @@ -2067,22 +1758,23 @@ void ConvertTopKV2Operator(const NodeDef& node, model, node.name() + "k", {static_cast(GetIntAttr(node, "k"))}); op->inputs.push_back(k_array); } else { - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(1)); } // The op has two outputs. op->outputs.push_back(node.name()); op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicPartitionOperator( +tensorflow::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { auto op = absl::make_unique(); CHECK(HasAttr(node, "num_partitions")); op->num_partitions = GetIntAttr(node, "num_partitions"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); CHECK_GT(op->num_partitions, 1); @@ -2091,11 +1783,12 @@ void ConvertDynamicPartitionOperator( op->outputs.push_back(node.name() + ":" + std::to_string(i)); } model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicStitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDynamicStitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { // The parallel and non-parallel variants are the same besides whether they // have a parallel loop; there are no behavioral differences. CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); @@ -2103,199 +1796,158 @@ void ConvertDynamicStitchOperator(const NodeDef& node, CHECK(HasAttr(node, "N")); op->num_partitions = GetIntAttr(node, "N"); // Expect all ID partitions + all value partitions. - CheckInputsCount(node, tf_import_flags, op->num_partitions * 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2)); for (int i = 0; i < op->num_partitions * 2; ++i) { op->inputs.push_back(node.input(i)); } op->outputs.push_back(node.name()); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSparseToDenseOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "SparseToDense"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); + + auto* op = new SparseToDenseOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->validate_indices = HasAttr(node, "validate_indices") + ? GetBoolAttr(node, "validate_indices") + : true; + model->operators.emplace_back(op); + return tensorflow::Status::OK(); } } // namespace namespace internal { -Status ImportTensorFlowNode(const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - // TODO(ahentz): Historically these functions all CHECK-fail on error. We've - // been slowly converting them to return Status. - if (node.op() == "Const") { - return ConvertConstOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2D") { - ConvertConvOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2DBackpropInput") { - ConvertTransposeConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthwiseConv2dNative") { - ConvertDepthwiseConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthToSpace") { - ConvertDepthToSpaceOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToDepth") { - ConvertSpaceToDepthOperator(node, tf_import_flags, model); - } else if (node.op() == "BiasAdd") { - ConvertBiasAddOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu") { - ConvertReluOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu6") { - ConvertRelu6Operator(node, tf_import_flags, model); - } else if (node.op() == "Sigmoid") { - ConvertLogisticOperator(node, tf_import_flags, model); - } else if (node.op() == "Tanh") { - ConvertTanhOperator(node, tf_import_flags, model); - } else if (node.op() == "MaxPool") { - ConvertMaxPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "AvgPool") { - ConvertAvgPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "Reshape") { - ConvertReshapeOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchMatMul") { - ConvertBatchMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul") { - ConvertMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertDivOperator(node, tf_import_flags, model); - } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || - node.op() == "StopGradient") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxVars") { - ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxArgs") { - ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); - } else if (node.op() == "Neg") { - ConvertNegOperator(node, tf_import_flags, model); - } else if (node.op() == "Rsqrt") { - ConvertRsqrtOperator(node, tf_import_flags, model); - } else if (node.op() == "Squeeze") { - ConvertSqueezeOperator(node, tf_import_flags, model); - } else if (node.op() == "Sqrt") { - ConvertSqrtOperator(node, tf_import_flags, model); - } else if (node.op() == "Square") { - ConvertSquareOperator(node, tf_import_flags, model); - } else if (node.op() == "Add") { - ConvertAddOperator(node, tf_import_flags, model); - } else if (node.op() == "AddN") { - ConvertAddNOperator(node, tf_import_flags, model); - } else if (node.op() == "Mul") { - ConvertMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Sub") { - ConvertSubOperator(node, tf_import_flags, model); - } else if (node.op() == "Sum") { - ConvertSumOperator(node, tf_import_flags, model); - } else if (node.op() == "Tile") { - ConvertTileOperator(node, tf_import_flags, model); - } else if (node.op() == "Concat" || node.op() == "ConcatV2") { - ConvertConcatOperator(node, tf_import_flags, model); - } else if (node.op() == "LRN") { - ConvertLRNOperator(node, tf_import_flags, model); - } else if (node.op() == "Softmax") { - ConvertSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertLogOperator(node, tf_import_flags, model); - } else if (node.op() == "LogSoftmax") { - ConvertLogSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "All") { - ConvertAllOperator(node, tf_import_flags, model); - } else if (node.op() == "Assert") { - ConvertAssertOperator(node, tf_import_flags, model); - } else if (node.op() == "Less") { - ConvertLessOperator(node, tf_import_flags, model); - } else if (node.op() == "LessEqual") { - ConvertLessEqualOperator(node, tf_import_flags, model); - } else if (node.op() == "Greater") { - ConvertGreaterOperator(node, tf_import_flags, model); - } else if (node.op() == "GreaterEqual") { - ConvertGreaterEqualOperator(node, tf_import_flags, model); - } else if (node.op() == "Max") { - ConvertMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Min") { - ConvertMinOperator(node, tf_import_flags, model); - } else if (node.op() == "Maximum") { - ConvertMaximumOperator(node, tf_import_flags, model); - } else if (node.op() == "Minimum") { - ConvertMinimumOperator(node, tf_import_flags, model); - } else if (node.op() == "Merge") { - ConvertMergeOperator(node, tf_import_flags, model); - } else if (node.op() == "Pad") { - ConvertPadOperator(node, tf_import_flags, model); - } else if (node.op() == "PadV2") { - ConvertPadV2Operator(node, tf_import_flags, model); - } else if (node.op() == "StridedSlice") { - ConvertStridedSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Shape") { - ConvertShapeOperator(node, tf_import_flags, model); - } else if (node.op() == "Slice") { - ConvertSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Split") { - ConvertSplitOperator(node, tf_import_flags, model); - } else if (node.op() == "Switch") { - ConvertSwitchOperator(node, tf_import_flags, model); - } else if (node.op() == "Placeholder") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "PlaceholderWithDefault") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "LegacyFedInput") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "NoOp") { - ConvertNoOpOperator(node, tf_import_flags, model); - } else if (node.op() == "Cast") { - ConvertCastOperator(node, tf_import_flags, model); - } else if (node.op() == "Floor") { - ConvertFloorOperator(node, tf_import_flags, model); - } else if (node.op() == "Gather" || node.op() == "GatherV2") { - ConvertGatherOperator(node, tf_import_flags, model); - } else if (node.op() == "ResizeBilinear") { - ConvertResizeBilinearOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchNormWithGlobalNormalization") { - ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags, - model); - } else if (node.op() == "FusedBatchNorm") { - ConvertFusedBatchNormOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToBatchND") { - ConvertSpaceToBatchNDOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchToSpaceND") { - ConvertBatchToSpaceNDOperator(node, tf_import_flags, model); - } else if (node.op() == "Mean") { - ConvertMeanOperator(node, tf_import_flags, model); - } else if (node.op() == "Svdf") { - ConvertSvdfOperator(node, tf_import_flags, model); - } else if (node.op() == "NextIteration") { - ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); - } else if (node.op() == "ExpandDims") { - ConvertExpandDimsOperator(node, tf_import_flags, model); - } else if (node.op() == "Fill") { - ConvertFillOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorDiv") { - ConvertFloorDivOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorMod") { - ConvertFloorModOperator(node, tf_import_flags, model); - } else if (node.op() == "Range") { - ConvertRangeOperator(node, tf_import_flags, model); - } else if (node.op() == "Rank") { - ConvertRankOperator(node, tf_import_flags, model); - } else if (node.op() == "Stack" || node.op() == "Pack") { - ConvertStackOperator(node, tf_import_flags, model); - } else if (node.op() == "Transpose") { - ConvertTransposeOperator(node, tf_import_flags, model); - } else if (node.op() == "ArgMax") { - ConvertArgMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Exp") { - ConvertExpOperator(node, tf_import_flags, model); - } else if (node.op() == "TopK" || node.op() == "TopKV2") { - ConvertTopKV2Operator(node, tf_import_flags, model); - } else if (node.op() == "DynamicPartition") { - ConvertDynamicPartitionOperator(node, tf_import_flags, model); - } else if (node.op() == "DynamicStitch" || - node.op() == "ParallelDynamicStitch") { - ConvertDynamicStitchOperator(node, tf_import_flags, model); - } else if (node.op() == "RandomUniform") { - ConvertRandomUniform(node, tf_import_flags, model); - } else if (node.op() == "Sin") { - ConvertSinOperator(node, tf_import_flags, model); - } else if (node.op() == "Select") { - ConvertSelectOperator(node, tf_import_flags, model); + +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap() { + return std::unordered_map({ + {"Add", ConvertSimpleOperator}, + {"AddN", ConvertSimpleOperator}, + {"All", ConvertSimpleOperator}, + {"ArgMax", ConvertArgMaxOperator}, + {"Assert", ConvertSimpleOperator}, + {"AvgPool", ConvertAvgPoolOperator}, + {"BatchMatMul", ConvertBatchMatMulOperator}, + {"BatchNormWithGlobalNormalization", + ConvertBatchNormWithGlobalNormalizationOperator}, + {"BatchToSpaceND", ConvertBatchToSpaceNDOperator}, + {"BiasAdd", ConvertBiasAddOperator}, + {"Cast", ConvertCastOperator}, + {"CheckNumerics", ConvertIdentityOperator}, + {"Concat", ConvertConcatOperator}, + {"ConcatV2", ConvertConcatOperator}, + {"Const", ConvertConstOperator}, + {"Conv2D", ConvertConvOperator}, + {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"DepthToSpace", ConvertDepthToSpaceOperator}, + {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, + {"Div", ConvertSimpleOperator}, + {"DynamicPartition", ConvertDynamicPartitionOperator}, + {"DynamicStitch", ConvertDynamicStitchOperator}, + {"Equal", ConvertSimpleOperator}, + {"Exp", ConvertSimpleOperator}, + {"ExpandDims", ConvertSimpleOperator}, + {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs}, + {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars}, + {"Fill", ConvertSimpleOperator}, + {"Floor", ConvertFloorOperator}, + {"FloorDiv", ConvertSimpleOperator}, + {"FloorMod", ConvertSimpleOperator}, + {"FusedBatchNorm", ConvertFusedBatchNormOperator}, + {"Gather", ConvertGatherOperator}, + {"GatherV2", ConvertGatherOperator}, + {"Greater", ConvertSimpleOperator}, + {"GreaterEqual", + ConvertSimpleOperator}, + {"Identity", ConvertIdentityOperator}, + {"LRN", ConvertLRNOperator}, + {"LegacyFedInput", ConvertPlaceholderOperator}, + {"Less", ConvertSimpleOperator}, + {"LessEqual", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"LogSoftmax", ConvertSimpleOperator}, + {"MatMul", ConvertMatMulOperator}, + {"Max", ConvertMaxOperator}, + {"MaxPool", ConvertMaxPoolOperator}, + {"Maximum", ConvertSimpleOperator}, + {"Mean", ConvertMeanOperator}, + {"Merge", ConvertSimpleOperator}, + {"Min", ConvertMinOperator}, + {"Minimum", ConvertSimpleOperator}, + {"Mul", ConvertSimpleOperator}, + {"Neg", ConvertSimpleOperator}, + {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, + {"NoOp", ConvertNoOpOperator}, + {"NotEqual", ConvertSimpleOperator}, + {"Pack", ConvertStackOperator}, + {"Pad", ConvertSimpleOperator}, + {"PadV2", ConvertSimpleOperator}, + {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, + {"Placeholder", ConvertPlaceholderOperator}, + {"PlaceholderWithDefault", ConvertIdentityOperator}, + {"Pow", ConvertSimpleOperator}, + {"RandomUniform", ConvertRandomUniform}, + {"Range", ConvertRangeOperator}, + {"Rank", ConvertSimpleOperator}, + {"RealDiv", ConvertSimpleOperator}, + {"Relu", ConvertSimpleOperator}, + {"Relu6", ConvertSimpleOperator}, + {"Reshape", ConvertSimpleOperator}, + {"ResizeBilinear", ConvertResizeBilinearOperator}, + {"Rsqrt", ConvertSimpleOperator}, + {"Select", ConvertSimpleOperator}, + {"Shape", ConvertShapeOperator}, + {"Sigmoid", ConvertSimpleOperator}, + {"Sin", ConvertSimpleOperator}, + {"Slice", ConvertSimpleOperator}, + {"Softmax", ConvertSoftmaxOperator}, + {"SpaceToBatchND", ConvertSpaceToBatchNDOperator}, + {"SpaceToDepth", ConvertSpaceToDepthOperator}, + {"SparseToDense", ConvertSparseToDenseOperator}, + {"Split", ConvertSplitOperator}, + {"Sqrt", ConvertSimpleOperator}, + {"Square", ConvertSimpleOperator}, + {"Squeeze", ConvertSqueezeOperator}, + {"Stack", ConvertStackOperator}, + {"StopGradient", ConvertIdentityOperator}, + {"StridedSlice", ConvertStridedSliceOperator}, + {"Sub", ConvertSimpleOperator}, + {"Sum", ConvertSumOperator}, + {"Svdf", ConvertSvdfOperator}, + {"Switch", ConvertSwitchOperator}, + {"Tanh", ConvertSimpleOperator}, + {"Tile", ConvertSimpleOperator}, + {"TopK", ConvertTopKV2Operator}, + {"TopKV2", ConvertTopKV2Operator}, + {"Transpose", ConvertSimpleOperator}, + }); +} + +tensorflow::Status ImportTensorFlowNode( + const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, Model* model, + const ConverterMapType& converter_map) { + auto converter = converter_map.find(node.op()); + if (converter == converter_map.end()) { + return ConvertUnsupportedOperator(node, tf_import_flags, model); } else { - ConvertUnsupportedOperator(node, tf_import_flags, model); + return converter->second(node, tf_import_flags, model); } - return Status::OK(); } } // namespace internal @@ -2321,10 +1973,13 @@ std::unique_ptr ImportTensorFlowGraphDef( } Model* model = new Model; + const internal::ConverterMapType& converter_map = + internal::GetTensorFlowNodeConverterMap(); for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); - auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model); + auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model, + converter_map); CHECK(status.ok()) << status.error_message(); } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 5dc78f73ad2e2ab6f1fcb1ee430513488ce47027..90e6f698efee6a6a32da18a658e72c3e8b6550c0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { -using port::Status; using tensorflow::AttrValue; using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; @@ -33,10 +33,17 @@ using tensorflow::DT_INT64; using tensorflow::DT_QUINT8; using tensorflow::DT_STRING; using tensorflow::NodeDef; +using tensorflow::Status; namespace internal { +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap(); Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, - Model*); + Model*, const ConverterMapType&); } // namespace internal namespace { @@ -104,8 +111,9 @@ class ShapeImportTest : public ::testing::TestWithParam { Status ImportNode(const NodeDef& node) { Model model; - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), - &model); + const auto converter = internal::GetTensorFlowNodeConverterMap(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model, + converter); } }; @@ -117,9 +125,10 @@ TEST_P(ShapeImportTest, ShapeElementIsNegative) { NodeDef node; BuildConstNode({1, -2, 10}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_EQ(status.error_message(), - "Tensor shape should not include negative values (while processing " - "node 'Node1')"); + EXPECT_EQ( + status.error_message(), + "Tensor shape should not include negative values\n\t (while processing " + "node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -129,7 +138,7 @@ TEST_P(ShapeImportTest, ShapeElementTooLarge) { BuildConstNode({3000000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Shape element overflows (while processing node 'Node1')"); + "Shape element overflows\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -139,7 +148,7 @@ TEST_P(ShapeImportTest, ShapeTooLarge) { BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Tensor shape is too large (while processing node 'Node1')"); + "Tensor shape is too large\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -150,8 +159,9 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { auto status = ImportNode(node); EXPECT_THAT(status.error_message(), ::testing::MatchesRegex( - "Neither input_content nor .*_val have the right dimensions " - "for this .* tensor .while processing node 'Node1'.")); + "Neither input_content .0. nor .*_val .0. have the right " + "dimensions .8. for this .* tensor\n\t .while processing " + "node 'Node1'.")); } INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d878ac54e4d819efc1b0951acbbab23b3387eac5..3a1d243f87b20651aafe3b31cb14804e94dee72b 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#include #include #include #include @@ -32,7 +33,7 @@ namespace toco { using tflite::QuantizationParams; -enum class OperatorType { +enum class OperatorType : uint8 { kNone, // General-purpose neural network operators. kAdd, @@ -96,38 +97,38 @@ enum class OperatorType { // Special operators used for importing TensorFlow nodes. // The general intent is to have some graph transformation either // drop them or rewrite them as general-purpose operators. - kTensorFlowAll, - kTensorFlowAssert, - kTensorFlowConcat, - kTensorFlowConcatV2, - kTensorFlowGreater, - kTensorFlowGreaterEqual, - kTensorFlowIdentity, - kTensorFlowLess, - kTensorFlowLessEqual, - kTensorFlowMax, - kTensorFlowMaximum, - kTensorFlowMin, - kTensorFlowMinimum, - kTensorFlowMatMul, - kTensorFlowMerge, + kAll, + kAssert, + kConcat, + kConcatV2, + kGreater, + kGreaterEqual, + kIdentity, + kLess, + kLessEqual, + kMax, // Reduction Max + kMaximum, // Element-wise Maximum + kMin, // Reduction Min + kMinimum, // Element-wise Minimum + kMatMul, + kMerge, kNeg, - kTensorFlowReshape, - kTensorFlowRsqrt, - kTensorFlowShape, - kTensorFlowSplit, - kTensorFlowSqrt, - kTensorFlowSquare, - kTensorFlowSum, - kTensorFlowSwitch, - kTensorFlowTile, + kReshape, + kRsqrt, + kShape, + kSplit, + kSqrt, + kSquare, + kSum, + kSwitch, + kTile, kTranspose, kTopK_V2, kDynamicPartition, kDynamicStitch, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. - kTensorFlowUnsupported, + kUnsupported, // Finally, TensorFlow uses different conventions for axes ordering, // see AxesOrder, and this cannot always be resolved at the time of importing // nodes, as TensorFlow parameters may be constant-expression subgraphs @@ -135,6 +136,10 @@ enum class OperatorType { // special nodes in the graph to shuffle axes. kReorderAxes, kSelect, + kSparseToDense, + kEqual, + kNotEqual, + kPow, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -152,25 +157,27 @@ enum class AxesOrder { k1HWO, // Our standard for DepthwiseConv weights kHWIM, // TensorFlow DepthwiseConv weights kNHWC, // TensorFlow activations + kHWOI, // TensorFlow back-prop conv weights }; // The type of the scalars in an array. // Note that the type does not by itself tell whether the values in the array -// are real (are literally interpreted as real numbers) or quantized (only -// acquire a meaning as real numbers in conjunction with QuantizationParams). +// are non-quantized (can be accessed directly) or quantized (must be +// interpreted in conjunction with QuantizationParams). // // In practice though: -// float values are always real +// float values are never quantized // uint8 values are always quantized -// int32 values are either real or quantized (depending on whether +// int32 values are sometimes quantized (depending on whether // QuantizationParams are present). -// other types are unused at the moment. +// complex values are never quantized +// other types are never quantized at the moment. // // kNone means that we don't know the data type yet, or that we don't care // because we'll be dropping the array anyway (e.g. some exotic array types // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). -enum class ArrayDataType { +enum class ArrayDataType : uint8 { kNone, // 0 kBool, kFloat, @@ -182,7 +189,8 @@ enum class ArrayDataType { kUint32, kInt64, kUint64, // 10 - kString + kString, + kComplex64, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type @@ -236,6 +244,10 @@ template <> struct DataTypeImpl { typedef string Type; }; +template <> +struct DataTypeImpl { + typedef std::complex Type; +}; template using DataType = typename DataTypeImpl::Type; @@ -429,7 +441,8 @@ struct SpaceToDepthOperator : Operator { // input activations as a matrix, followed by a MatMul node. struct FullyConnectedOperator : Operator { FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {} - bool experimental_shuffled_weights = false; + FullyConnectedWeightsFormat weights_format = + FullyConnectedWeightsFormat::kDefault; }; // Dequantization operator, converting a quantized array of integers with @@ -526,7 +539,15 @@ struct LstmCellOperator : Operator { ACTIV_TEMP = 3, NUM_OUTPUTS = 4 }; - LstmCellOperator() : Operator(OperatorType::kLstmCell) {} + enum KernelType { + KERNEL_BASIC = 0, + KERNEL_FULL = 1, + }; + + LstmCellOperator() + : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {} + + KernelType kernel_type; }; // Element-wise multiplication operator. @@ -789,7 +810,7 @@ struct DivOperator : Operator { // // TensorFlow equivalent: Identity struct TensorFlowIdentityOperator : Operator { - TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} + TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {} }; // Batch matrix multiplication operator. This comes from the (deprecated) @@ -815,7 +836,9 @@ struct BatchMatMulOperator : Operator { // // TensorFlow equivalent: MatMul struct TensorFlowMatMulOperator : Operator { - TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {} + TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {} + bool transpose_a = false; + bool transpose_b = false; }; // Padding operator. Pads a tensor with zeros. @@ -949,7 +972,7 @@ struct StridedSliceOperator : Operator { // TensorFlow equivalent: Reshape --- except that we only support a special case // here, where the output shape is a matrix (2D) shape. struct TensorFlowReshapeOperator : Operator { - TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {} + TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {} std::vector shape; }; @@ -1119,7 +1142,7 @@ struct SelectOperator : Operator { // // TensorFlow equivalent: Rsqrt struct TensorFlowRsqrtOperator : Operator { - TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {} + TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {} }; // Stacks a list of rank-R tensors into one rank-(R+1) tensor. @@ -1145,10 +1168,10 @@ struct StackOperator : Operator { // This operation outputs a 1-D integer tensor representing the shape of // the input. // -// TensorFlow equivalent: Shape. We currently assume that the output is int32 -// and not int64. The output type could be stored herein. +// TensorFlow equivalent: Shape. struct TensorFlowShapeOperator : Operator { - TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {} + TensorFlowShapeOperator() : Operator(OperatorType::kShape) {} + ArrayDataType output_data_type = ArrayDataType::kInt32; }; // Element-wise square-root (x^0.5) operator. @@ -1158,7 +1181,7 @@ struct TensorFlowShapeOperator : Operator { // // TensorFlow equivalent: Sqrt struct TensorFlowSqrtOperator : Operator { - TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {} + TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {} }; // Element-wise square (x*x) operator. @@ -1168,7 +1191,7 @@ struct TensorFlowSqrtOperator : Operator { // // TensorFlow equivalent: Square struct TensorFlowSquareOperator : Operator { - TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {} + TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {} }; // Transposes a tensor. @@ -1196,24 +1219,24 @@ struct SubOperator : Operator { SubOperator() : Operator(OperatorType::kSub) {} }; -// Global sum reduction: computes the sum of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Sum reduction: computes the sum of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Sum --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Sum struct TensorFlowSumOperator : Operator { - TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {} + TensorFlowSumOperator() : Operator(OperatorType::kSum) {} bool keep_dims = false; }; // TensorFlow Tile equivalent. Refer to TensorFlow documentation for details. -// Not fully supported, just a placeholder to handle TensorFlow graphs and -// support graph transformations to other operator types by matching sub-graphs. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: int array with length of rank(input[0]) struct TensorFlowTileOperator : Operator { - TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} + TensorFlowTileOperator() : Operator(OperatorType::kTile) {} }; // TensorFlow Slice equivalent. Refer to TensorFlow documentation for details. @@ -1228,7 +1251,7 @@ struct SliceOperator : Operator { // Not fully supported, just a placeholder to handle TensorFlow graphs and // support graph transformations to other operator types by matching sub-graphs. struct TensorFlowSplitOperator : Operator { - TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {} + TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {} int num_split = 0; }; @@ -1239,7 +1262,7 @@ struct TensorFlowSplitOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatOperator : Operator { - TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {} + TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {} }; // TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for @@ -1250,7 +1273,7 @@ struct TensorFlowConcatOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatV2Operator : Operator { - TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {} + TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {} }; // TensorFlow Merge equivalent. Refer to TensorFlow documentation for details. @@ -1266,7 +1289,7 @@ struct TensorFlowConcatV2Operator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowMergeOperator : Operator { - TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {} + TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {} }; // TensorFlow Switch equivalent. Refer to TensorFlow documentation for details. @@ -1289,7 +1312,7 @@ struct TensorFlowMergeOperator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowSwitchOperator : Operator { - TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {} + TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {} }; // TensorFlow All equivalent. Refer to TensorFlow documentation for details. @@ -1298,7 +1321,7 @@ struct TensorFlowSwitchOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowAllOperator : Operator { - TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {} + TensorFlowAllOperator() : Operator(OperatorType::kAll) {} }; // TensorFlow Assert equivalent. Refer to TensorFlow documentation for details. @@ -1306,7 +1329,7 @@ struct TensorFlowAllOperator : Operator { // support graph transformations to other operator types by matching sub-graphs. // Typically, we just drop Assert nodes. struct TensorFlowAssertOperator : Operator { - TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {} + TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1315,7 +1338,7 @@ struct TensorFlowAssertOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessOperator : Operator { - TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {} + TensorFlowLessOperator() : Operator(OperatorType::kLess) {} }; // TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for @@ -1325,8 +1348,7 @@ struct TensorFlowLessOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessEqualOperator : Operator { - TensorFlowLessEqualOperator() - : Operator(OperatorType::kTensorFlowLessEqual) {} + TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1335,7 +1357,7 @@ struct TensorFlowLessEqualOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterOperator : Operator { - TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {} + TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {} }; // TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for @@ -1345,8 +1367,23 @@ struct TensorFlowGreaterOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterEqualOperator : Operator { - TensorFlowGreaterEqualOperator() - : Operator(OperatorType::kTensorFlowGreaterEqual) {} + TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {} +}; + +// TensorFlow Equal equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowEqualOperator : Operator { + TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {} +}; + +// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for +// details. +struct TensorFlowNotEqualOperator : Operator { + TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {} }; // Global max reduction: computes the max of all of entries in the input array. @@ -1358,7 +1395,7 @@ struct TensorFlowGreaterEqualOperator : Operator { // TensorFlow equivalent: Max --- except that we only support the special case // of global reduction across all dimensions. struct TensorFlowMaxOperator : Operator { - TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {} + TensorFlowMaxOperator() : Operator(OperatorType::kMax) {} bool keep_dims = false; }; @@ -1371,7 +1408,7 @@ struct TensorFlowMaxOperator : Operator { // TensorFlow equivalent: Min --- except that we only support the special case // of global reduction across all dimensions. struct TensorFlowMinOperator : Operator { - TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {} + TensorFlowMinOperator() : Operator(OperatorType::kMin) {} bool keep_dims = false; }; @@ -1384,7 +1421,7 @@ struct TensorFlowMinOperator : Operator { // // TensorFlow equivalent: Maximum struct TensorFlowMaximumOperator : Operator { - TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {} + TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {} }; // Element-wise minimum operator. Currently it only supports scalar as @@ -1396,14 +1433,13 @@ struct TensorFlowMaximumOperator : Operator { // // TensorFlow equivalent: Minimum struct TensorFlowMinimumOperator : Operator { - TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {} + TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {} }; // General TF operation, unsupported by tf.mini. Expected to be dropped by // graph transformations. struct TensorFlowUnsupportedOperator : Operator { - TensorFlowUnsupportedOperator() - : Operator(OperatorType::kTensorFlowUnsupported) {} + TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {} // The original TF operation type. Used for diagnostic purposes. string tensorflow_op; @@ -1598,13 +1634,37 @@ struct DynamicStitchOperator : Operator { int num_partitions; }; +// SparseToDense operator: +// +// Inputs: +// Inputs[0]: required: sparse_indices. +// Inputs[1]: required: output_shape. +// Inputs[2]: required: sparse_values. +// +// TensorFlow equivalent: SparseToDense. +struct SparseToDenseOperator : Operator { + SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {} + bool validate_indices; +}; + +// Pow operator: +// +// Inputs: +// Inputs[0]: required: A tensor. +// Inputs[1]: required: A tensor. +// +// TensorFlow equivalent: Pow. +struct PowOperator : Operator { + PowOperator() : Operator(OperatorType::kPow) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are // offsets from the start of the workspace buffer, expressed in bytes. struct Alloc { - int start = 0; - int end = 0; + int64 start = 0; + int64 end = 0; }; inline bool operator<(const Alloc& a, const Alloc& b) { diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 7bbeab7c9d1e42d28f221f1a1134d9d05fe6ab51..06072d1fcb0612ed8193b3a0be1317923fe95bcc 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags( "that information from the input file."), Flag("input_arrays", parsed_flags.input_arrays.bind(), parsed_flags.input_arrays.default_value(), - "Names of the output arrays, comma-separated. If not specified, " + "Names of the input arrays, comma-separated. If not specified, " "will try to read that information from the input file."), Flag("output_array", parsed_flags.output_array.bind(), parsed_flags.output_array.default_value(), @@ -74,16 +74,16 @@ bool ParseModelFlagsFromCommandLineFlags( "height, input array width, input array depth."), Flag("batch_size", parsed_flags.batch_size.bind(), parsed_flags.batch_size.default_value(), - "Batch size for the model. Replaces the first dimension of an " - "input size array if undefined. Use only with SavedModels when " - "--input_shapes flag is not specified. Always use --input_shapes " - "flag with frozen graphs."), + "Deprecated. Batch size for the model. Replaces the first dimension " + "of an input size array if undefined. Use only with SavedModels " + "when --input_shapes flag is not specified. Always use " + "--input_shapes flag with frozen graphs."), Flag("input_data_type", parsed_flags.input_data_type.bind(), parsed_flags.input_data_type.default_value(), "Deprecated: use --input_data_types instead. Input array type, if " "not already provided in the graph. " "Typically needs to be specified when passing arbitrary arrays " - "to --input_array."), + "to --input_arrays."), Flag("input_data_types", parsed_flags.input_data_types.bind(), parsed_flags.input_data_types.default_value(), "Input arrays types, comma-separated, if not already provided in " @@ -124,14 +124,6 @@ bool ParseModelFlagsFromCommandLineFlags( parsed_flags.model_checks.default_value(), "A list of model checks to be applied to verify the form of the " "model. Applied after the graph transformations after import."), - Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(), - parsed_flags.graphviz_first_array.default_value(), - "If set, defines the start of the sub-graph to be dumped to " - "GraphViz."), - Flag( - "graphviz_last_array", parsed_flags.graphviz_last_array.bind(), - parsed_flags.graphviz_last_array.default_value(), - "If set, defines the end of the sub-graph to be dumped to GraphViz."), Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(), parsed_flags.dump_graphviz.default_value(), "Dump graphviz during LogDump call. If string is non-empty then " @@ -180,8 +172,6 @@ bool ParseModelFlagsFromCommandLineFlags( if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; } auto& dump_options = *GraphVizDumpOptions::singleton(); - dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value(); - dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value(); dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index 6c4f8e12cdd5b3222997c4a2b0ac243cc74324e0..93fe756a55d378fa205ff88be5e18aff586e5dca 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -12,10 +12,11 @@ cc_library( deps = [ "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options", "//tensorflow/contrib/lite/toco:toco_port", "//tensorflow/contrib/lite/toco:toco_tooling", "//tensorflow/core:lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -26,7 +27,7 @@ tf_py_wrap_cc( ":toco_python_api", "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", - "//util/python:python_headers", + "//third_party/python_runtime:headers", "@com_google_absl//absl/strings", ], ) @@ -41,12 +42,6 @@ py_binary( ], ) -py_binary( - name = "toco_wrapper", - srcs = ["toco_wrapper.py"], - srcs_version = "PY2AND3", -) - tf_py_test( name = "toco_from_protos_test", srcs = ["toco_from_protos_test.py"], diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc index 5b1db852b4f8e89c1a591cfe18a0ab0aa2db04c9..d93e104038741e6e59608f04115854d611f1f9ae 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/python/toco_python_api.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" @@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); if (error) return nullptr; - // Use toco to produce new outputs + // Use TOCO to produce new outputs. toco::ModelFlags model_flags; if (!model_flags.ParseFromString(model_flags_proto_txt)) { LOG(FATAL) << "Model proto failed to parse." << std::endl; @@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { LOG(FATAL) << "Toco proto failed to parse." << std::endl; } + + auto& dump_options = *GraphVizDumpOptions::singleton(); + if (toco_flags.has_dump_graphviz_dir()) { + dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); + } + if (toco_flags.has_dump_graphviz_include_video()) { + dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video(); + } + + // Convert model. std::unique_ptr model = toco::Import(toco_flags, model_flags, input_contents_txt); toco::Transform(toco_flags, model.get()); diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h index 9af38e937c29804f950ea53ee86b70e3ccb02360..7e8ad9c1dafa68dd91e4a0eb3bfb742207878c59 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.h +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ #define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ -#include #include +#include namespace toco { diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py deleted file mode 100644 index 6d6b500d7eccd353f566a4bad76df35e0e849d95..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/python/toco_wrapper.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wrapper for runninmg toco binary embedded in pip site-package. - -NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. -It can only install Python "console-scripts." This will work as a console -script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - - -def main(): - # Pip installs the binary in aux-bin off of main site-package install. - # Just find it and exec, passing all arguments in the process. - # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. - print("""TOCO from pip install is currently not working on command line. -Please use the python TOCO API or use -bazel run tensorflow/contrib/lite:toco -- from a TensorFlow source dir. -""") - sys.exit(1) - # TODO(aselle): Replace this when we find a way to run toco without - # blowing up executable size. - # binary = os.path.join(tf.__path__[0], 'aux-bin/toco') - # os.execvp(binary, sys.argv) diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h index f5de5a5781a5304634642680e6a3cef60e7b844b..207f2c1706ef4cc12572e381c38f61a504ece232 100644 --- a/tensorflow/contrib/lite/toco/runtime/types.h +++ b/tensorflow/contrib/lite/toco/runtime/types.h @@ -24,6 +24,7 @@ namespace toco { // TODO(ahentz): These are just stopgaps for now, untils we move all // the code over to tflite. using tflite::Dims; +using tflite::FullyConnectedWeightsFormat; using tflite::FusedActivationFunctionType; using tflite::RequiredBufferSizeForDims; diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index e1025c66642d2860c5916bf7625f1c0403c9901c..a02f90988b2863900b6a735fd69aa1975a762338 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ ":types", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 335b496dccdbdb7e342515868e1d7195c98f0351..19722468079a32b76f6952db6ca818da470a03ac 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -45,14 +45,20 @@ using ::tflite::Tensor; namespace { -details::OperatorKey GetOperatorKey(const ::toco::Operator& op) { +details::OperatorKey GetOperatorKey( + const ::toco::Operator& op, + const std::map>& ops_by_type) { string custom_code; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); custom_code = unsupported_op.tensorflow_op; } - return details::OperatorKey(op.type, custom_code); + int version = 1; + if (ops_by_type.count(op.type) != 0) { + version = ops_by_type.at(op.type)->GetVersion(op); + } + return details::OperatorKey(op.type, custom_code, version); } } // Anonymous namespace. @@ -74,11 +80,13 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { } } -void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) { +void LoadOperatorsMap( + const Model& model, OperatorsMap* operators_map, + const std::map>& ops_by_type) { // First find a list of unique operator types. std::set keys; for (const auto& op : model.operators) { - keys.insert(GetOperatorKey(*op)); + keys.insert(GetOperatorKey(*op, ops_by_type)); } // Now assign indices to them and fill in the map. int index = 0; @@ -91,7 +99,8 @@ void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) { Offset>> ExportTensors( const Model& model, const details::TensorsMap& tensors_map, - FlatBufferBuilder* builder, std::vector* buffers_to_write) { + FlatBufferBuilder* builder, std::vector* buffers_to_write, + const std::set& variable_tensor_indices) { // In the end we will need to produce a vector sorted by the indices of the // tensors in the tensors_map. std::map> ordered_tensors; @@ -131,9 +140,11 @@ Offset>> ExportTensors( scale, zero_point); int index = tensors_map.at(tensor_name); + bool is_variable = + variable_tensor_indices.find(index) != variable_tensor_indices.end(); ordered_tensors[index] = CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index, - builder->CreateString(tensor_name), q_param); + builder->CreateString(tensor_name), q_param, is_variable); } std::vector> tensor_vector; @@ -185,8 +196,9 @@ Offset>> ExportOperatorCodes( std::map> ordered_opcodes; for (const auto& op : model.operators) { - const details::OperatorKey operator_key = GetOperatorKey(*op); + const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type); int op_index = operators_map.at(operator_key); + int op_version = operator_key.version; string name = HelpfulOperatorTypeName(*op); bool is_builtin = false; @@ -197,9 +209,9 @@ Offset>> ExportOperatorCodes( if (is_builtin) { ordered_opcodes[op_index] = - CreateOperatorCode(*builder, builtin_ops[name], 0); + CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); } else { - // This could be a kTensorFlowUnsupported, in which case we should be + // This could be a kUnsupported, in which case we should be // able to retrieve the original Tensorflow name from the OperatorKey, or // this could be a proper TOCO operator that is completely unknown to TF // Lite. @@ -211,8 +223,9 @@ Offset>> ExportOperatorCodes( if (error_summary) { error_summary->insert(name); } - ordered_opcodes[op_index] = CreateOperatorCode( - *builder, BuiltinOperator_CUSTOM, builder->CreateString(name)); + ordered_opcodes[op_index] = + CreateOperatorCode(*builder, BuiltinOperator_CUSTOM, + builder->CreateString(name), op_version); } } @@ -229,7 +242,10 @@ Offset>> ExportOperators( const Model& model, const std::map>& ops_by_type, const details::OperatorsMap& operators_map, - const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) { + const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, + std::set* variable_tensor_indices) { + variable_tensor_indices->clear(); + // The operators are in execution order, so we just follow tf.mini order. std::vector> op_vector; for (const auto& op : model.operators) { @@ -244,20 +260,38 @@ Offset>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at(GetOperatorKey(*op)); + int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); - // This is a custom op unless we can find it in ops_by_type, and even then - // it could be a custom op (such as kTensorFlowUnsupported). + auto tflite_op_it = ops_by_type.find(op->type); + BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() + ? nullptr + : tflite_op_it->second.get(); + // This is a custom op unless we can find it in ops_by_type, and even then + // it could be a custom op (such as kUnsupported). auto options = Options::Custom(0); - if (ops_by_type.count(op->type) != 0) { - options = ops_by_type.at(op->type)->Serialize(*op, builder); + + std::vector mutating_input_variables; + if (tflite_op) { + options = tflite_op->Serialize(*op, builder); + mutating_input_variables = tflite_op->GetMutatingInputVariables(*op); + + if (!mutating_input_variables.empty()) { + for (int i = 0; i < op->inputs.size(); ++i) { + if (!mutating_input_variables[i]) { + continue; + } + int32_t variable_tensor_index = tensors_map.at(op->inputs[i]); + variable_tensor_indices->insert(variable_tensor_index); + } + } } // The only supported CustomOptionFormat is FLEXBUFFERS now. op_vector.push_back(CreateOperator( *builder, op_index, builder->CreateVector(inputs), builder->CreateVector(outputs), options.type, options.builtin, - options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS, + builder->CreateVector(mutating_input_variables))); } return builder->CreateVector(op_vector); @@ -279,28 +313,31 @@ Offset>> ExportBuffers( void Export(const Model& model, bool allow_custom_ops, string* output_file_contents) { - flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); - const auto ops_by_type = BuildOperatorByTypeMap(); + Export(model, allow_custom_ops, output_file_contents, ops_by_type); +} + +void Export( + const Model& model, bool allow_custom_ops, string* output_file_contents, + const std::map>& ops_by_type) { + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); details::TensorsMap tensors_map; details::LoadTensorsMap(model, &tensors_map); details::OperatorsMap operators_map; - details::LoadOperatorsMap(model, &operators_map); + details::LoadOperatorsMap(model, &operators_map, ops_by_type); std::vector buffers_to_write; Array empty_array; buffers_to_write.push_back(&empty_array); - auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write); - auto inputs = ExportInputTensors(model, tensors_map, &builder); - auto outputs = ExportOutputTensors(model, tensors_map, &builder); - std::set error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); + const string fake_quant_operation_name = "FAKE_QUANT"; + if (error_summary.count(fake_quant_operation_name) != 0) { LOG(ERROR) << fake_quant_operation_name @@ -312,19 +349,43 @@ void Export(const Model& model, bool allow_custom_ops, error_summary.erase(fake_quant_operation_name); } if (!allow_custom_ops && !error_summary.empty()) { - LOG(QFATAL) << "Some of the operators in the model are not supported by " - "the standard TensorFlow Lite runtime. If you have a custom " - "implementation for them you can disable this error with " - "--allow_custom_ops. Here is a list of operators for which " - "you will need custom implementations: " - << absl::StrJoin(error_summary, ", ") << "."; + // Remove ExpandDims and ReorderAxes from unimplemented list unless they + // compose the list. Both ops are removed during graph transformations. + // However, if an op is unimplemented earlier in the model, the graph + // transformation is unable to run because the output shape is not defined. + // This causes unnecessary confusion during model conversion time. + std::set error_summary_final; + for (const auto& op_type : error_summary) { + if (op_type != "ReorderAxes" && op_type != "ExpandDims") { + error_summary_final.insert(op_type); + } + } + if (error_summary_final.empty()) { + error_summary_final = error_summary; + } + + LOG(QFATAL) + << "Some of the operators in the model are not supported by " + "the standard TensorFlow Lite runtime. If you have a custom " + "implementation for them you can disable this error with " + "--allow_custom_ops, or by setting allow_custom_ops=True " + "when calling tf.contrib.lite.toco_convert(). Here is a list " + "of operators for which you will need custom implementations: " + << absl::StrJoin(error_summary_final, ", ") << "."; } - auto ops = - ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder); + std::set variable_tensor_indices; + auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, + &builder, &variable_tensor_indices); + + auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, + variable_tensor_indices); + auto inputs = ExportInputTensors(model, tensors_map, &builder); + auto outputs = ExportOutputTensors(model, tensors_map, &builder); // TODO(aselle): add support to toco for multiple subgraphs. - auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops); + auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops, + /* name */ 0); std::vector> subgraphs = {subgraph}; auto buffers = ExportBuffers(model, buffers_to_write, &builder); diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 8c79cb820015e16847ce48c171e8f6e41f60c319..58ea5c725c378827aac79f2a5a2cdca59ccc0162 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ #include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/util.h" namespace toco { @@ -25,39 +27,56 @@ namespace tflite { // result in the given string. void Export(const Model& model, bool allow_custom_ops, string* output_file_contents); + // This if backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. inline void Export(const Model& model, string* output_file_contents) { Export(model, true, output_file_contents); } +// Export API with custom TFLite operator mapping. +void Export( + const Model& model, bool allow_custom_ops, string* output_file_contents, + const std::map>& ops_by_type); + namespace details { // A maps from tensor name to its final position in the TF Lite buffer. using TensorsMap = std::unordered_map; // A key to identify an operator. -// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to +// Only when `type` is `kUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { - OperatorKey(OperatorType type, const std::string& custom_code) - : type(type), custom_code(custom_code) {} + OperatorKey(OperatorType type, const std::string& custom_code, int version) + : type(type), custom_code(custom_code), version(version) {} const OperatorType type; const std::string custom_code; + const int version; bool operator<(const OperatorKey& other) const { if (type < other.type) return true; - if (type > other.type) return false; - return custom_code < other.custom_code; + else if (type > other.type) + return false; + else if (custom_code < other.custom_code) + return true; + else if (custom_code > other.custom_code) + return false; + else + return version < other.version; } bool operator==(const OperatorKey& other) const { - return type == other.type && custom_code == other.custom_code; + return type == other.type && custom_code == other.custom_code && + version == other.version; } struct Hash { - std::size_t operator()(const OperatorKey& key) const { - return std::hash()(static_cast(key.type)) ^ - std::hash()(key.custom_code); + size_t operator()(const OperatorKey& key) const { + return ::tflite::CombineHashes( + {std::hash()(static_cast(key.type)), + std::hash()(key.custom_code), + std::hash()(key.version)}); } }; }; @@ -66,11 +85,12 @@ struct OperatorKey { using OperatorsMap = std::unordered_map; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); -void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map); +void LoadOperatorsMap( + const Model& model, OperatorsMap* operators_map, + const std::map>& ops_by_type); } // namespace details } // namespace tflite - } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 6754372330797ae30230af26a3b478c24ad44005..d1fdbcb8e9131e1d65fa32ca0395bbc17b2014e7 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" namespace toco { namespace tflite { @@ -65,12 +68,13 @@ TEST_F(ExportTest, LoadOperatorsMap) { BuildTestModel(); details::OperatorsMap operators; - details::LoadOperatorsMap(input_model_, &operators); - EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]); - EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]); - EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "")]); - EXPECT_EQ(3, operators[details::OperatorKey( - OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]); + const auto ops_by_type = BuildOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); + EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); + EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); + EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported, + "MyCrazyOp", 1)]); } TEST_F(ExportTest, Export) { @@ -104,6 +108,160 @@ TEST_F(ExportTest, Export) { EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2)); } +// This test is based on a hypothetical scenario that dilation is supported +// only in Conv version 2. So Toco populates version=1 when dialation +// parameters are all 1, and version=2 otehrwise. +class FakeConvolutionOperator + : public BuiltinOperator { + public: + FakeConvolutionOperator() + : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D, + OperatorType::kConv) {} + + // Returning the op version according to the op parameters. + int GetVersion(const Operator& op) const override { + const TocoOperator& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + // Version 2 if dilation is used. + return 2; + } + return 1; + } + + // Note: The read / write code doesn't need to be changed if we stick with + // the restrictions: + // * Only adding parameters at the bottom of the Flatbuffer tables. + // * When the default value of parameters are used, the op works consistently + // with the previous version. + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, + op.stride_height, activation_function, + op.dilation_width_factor, + op.dilation_height_factor); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->dilation_width_factor = options.dilation_w_factor(); + op->dilation_height_factor = options.dilation_h_factor(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class VersionedOpExportTest : public ::testing::Test { + protected: + void SetUp() override { + input_model_.GetOrCreateArray("input"); + input_model_.GetOrCreateArray("filter"); + input_model_.GetOrCreateArray("output"); + } + void AddConvOp(bool use_dialation) { + { + auto* op = new ConvOperator; + op->inputs.push_back("input"); + op->inputs.push_back("filter"); + op->inputs.push_back("output"); + + op->padding.type = PaddingType::kSame; + op->stride_width = 1; + op->stride_height = 1; + if (use_dialation) { + op->dilation_width_factor = 2; + op->dilation_height_factor = 2; + } else { + op->dilation_width_factor = 1; + op->dilation_height_factor = 1; + } + input_model_.operators.emplace_back(op); + } + } + + std::map> + BuildFakeOperatorByTypeMap() { + std::map> result; + result[OperatorType::kConv] = + std::unique_ptr(new FakeConvolutionOperator); + return result; + } + + Model input_model_; +}; + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { + AddConvOp(false); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(1, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); +} + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { + AddConvOp(true); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(1, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); +} + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { + AddConvOp(false); + AddConvOp(true); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(2, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); + EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); +} + +TEST_F(VersionedOpExportTest, Export) { + AddConvOp(false); + AddConvOp(true); + + string result; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + Export(input_model_, true, &result, ops_by_type); + + auto* model = ::tflite::GetModel(result.data()); + auto operator_codes = model->operator_codes(); + + // Verify that 2 operator codes are populdated. Both are CONV_2D but with + // different versions. + EXPECT_EQ(2, operator_codes->size()); + EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, + (*operator_codes)[0]->builtin_code()); + EXPECT_EQ(1, (*operator_codes)[0]->version()); + EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, + (*operator_codes)[1]->builtin_code()); + EXPECT_EQ(2, (*operator_codes)[1]->version()); + + // Verify that the 2 operators points to the correct indices of the operation + // codes. + auto operators = (*model->subgraphs())[0]->operators(); + EXPECT_EQ(2, operators->size()); + EXPECT_EQ(0, (*operators)[0]->opcode_index()); + EXPECT_EQ(1, (*operators)[1]->opcode_index()); +} + // TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index c0e7ab2ef57ed8edf1b7cda08c64f6ae66172af3..1dd4915b31413e5afb04b45ee7c4893a2eded66d 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -113,15 +113,35 @@ void ImportOperators( << operators_table.size(); } string opname = operators_table.at(index); + + // Find and use the appropriate operator deserialization factory. + std::unique_ptr new_op = nullptr; if (ops_by_name.count(opname) == 0) { - LOG(FATAL) << "Op '" << opname << "' not supported"; + string effective_opname = "TENSORFLOW_UNSUPPORTED"; + if (ops_by_name.count(effective_opname) == 0) { + LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found."; + } + new_op = ops_by_name.at(effective_opname) + ->Deserialize(input_op->builtin_options(), + input_op->custom_options()); + if (new_op->type == OperatorType::kUnsupported) { + auto* unsupported_op = + static_cast(new_op.get()); + unsupported_op->tensorflow_op = opname; + // TODO(b/109932940): Remove this when quantized is removed. + // For now, we assume all ops are quantized. + unsupported_op->quantized = true; + } else { + LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator"; + } + } else { + new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(), + input_op->custom_options()); } - - auto new_op = ops_by_name.at(opname)->Deserialize( - input_op->builtin_options(), input_op->custom_options()); model->operators.emplace_back(new_op.release()); auto* op = model->operators.back().get(); + // Make sure all the inputs and outputs are hooked up. auto inputs = input_op->inputs(); for (int i = 0; i < inputs->Length(); i++) { auto input_index = inputs->Get(i); @@ -201,6 +221,8 @@ std::unique_ptr Import(const ModelFlags& model_flags, model.get()); ImportIOTensors(*input_model, tensors_table, model.get()); + UndoWeightsShuffling(model.get()); + return model; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 2cd97002be2da5dca23a0bacc13a1fd92ae67b37..7e55ae92bd57447cc821b21b40ba289cb484a9ed 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/operator.h" +// TODO(ycling): Consider refactoring to extract the LSTM definition out of +// graph_transformation module. +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" @@ -53,6 +56,8 @@ class AveragePool op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Convolution @@ -83,6 +88,8 @@ class Convolution op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class DepthwiseConvolution @@ -112,6 +119,8 @@ class DepthwiseConvolution op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Add : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class SpaceToBatchND @@ -149,6 +160,8 @@ class SpaceToBatchND void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Sub : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Div : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class BatchToSpaceND @@ -206,6 +223,8 @@ class BatchToSpaceND void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Cast : public BuiltinOperatorsrc_data_type = DataType::Deserialize(options.in_data_type()); op->dst_data_type = DataType::Deserialize(options.out_data_type()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Concatenation @@ -243,6 +264,8 @@ class Concatenation TocoOperator* op) const override { op->axis = options.axis(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class DepthToSpace : public CustomOperator { @@ -255,6 +278,8 @@ class DepthToSpace : public CustomOperator { void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { op->block_size = m["block_size"].AsInt64(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class FakeQuant : public CustomOperator { @@ -274,6 +299,8 @@ class FakeQuant : public CustomOperator { const auto& num_bits = m["num_bits"]; op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8; } + + int GetVersion(const Operator& op) const override { return 1; } }; class FullyConnected @@ -287,13 +314,46 @@ class FullyConnected flatbuffers::FlatBufferBuilder* builder) const override { auto activation_function = ActivationFunction::Serialize(op.fused_activation_function); - return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); + ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format; + switch (op.weights_format) { + case FullyConnectedWeightsFormat::kDefault: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + break; + case FullyConnectedWeightsFormat::kShuffled4x16Int8: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + } + return ::tflite::CreateFullyConnectedOptions(*builder, activation_function, + tflite_weights_format); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); + switch (options.weights_format()) { + case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT: + op->weights_format = FullyConnectedWeightsFormat::kDefault; + break; + case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + op->weights_format = FullyConnectedWeightsFormat::kDefault; + } + } + + int GetVersion(const Operator& op) const override { + const auto& fc_op = static_cast(op); + return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1 + : 2; } }; @@ -311,6 +371,8 @@ class Gather : public BuiltinOperatoraxis = options.axis(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Svdf : public BuiltinOperatorrank = options.rank(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class L2Normalization @@ -351,6 +415,8 @@ class L2Normalization op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class L2Pool : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class LocalResponseNormalization @@ -401,6 +469,8 @@ class LocalResponseNormalization op->alpha = options.alpha(); op->beta = options.beta(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class MaxPool : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Mul : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Pad : public BuiltinOperator { + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateTileOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + int GetVersion(const Operator& op) const override { return 1; } }; class PadV2 : public BuiltinOperatorshape.insert(op->shape.end(), options.new_shape()->begin(), options.new_shape()->end()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Softmax @@ -516,6 +612,8 @@ class Softmax TocoOperator* op) const override { op->beta = options.beta(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class SpaceToDepth @@ -534,6 +632,8 @@ class SpaceToDepth TocoOperator* op) const override { op->block_size = options.block_size(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Transpose @@ -549,6 +649,8 @@ class Transpose void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Lstm : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { + ::tflite::LSTMKernelType kernel_type; + switch (op.kernel_type) { + case LstmCellOperator::KERNEL_BASIC: + kernel_type = ::tflite::LSTMKernelType_BASIC; + break; + case LstmCellOperator::KERNEL_FULL: + kernel_type = ::tflite::LSTMKernelType_FULL; + break; + } + // Current toco converter only supports tanh, no clip. return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ ::tflite::ActivationFunctionType_TANH, /*cell_clip=*/0.0, - /*proj_clip=*/0.0); + /*proj_clip=*/0.0, kernel_type); } void ReadOptions(const TfLiteOptions& options, @@ -570,23 +682,83 @@ class Lstm : public BuiltinOperatorkernel_type = LstmCellOperator::KERNEL_BASIC; + break; + case ::tflite::LSTMKernelType_FULL: + op->kernel_type = LstmCellOperator::KERNEL_FULL; + break; + } + } + + int GetVersion(const Operator& op) const override { + const auto& lstm_op = static_cast(op); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + return 1; + case LstmCellOperator::KERNEL_BASIC: + return 2; + } + } + + std::vector GetMutatingInputVariables( + const Operator& op) const override { + const auto& lstm_op = static_cast(op); + + std::vector mutating_input_variables(op.inputs.size(), false); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: { + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + break; + } + case LstmCellOperator::KERNEL_BASIC: { + mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; + mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; + break; + } + } + return mutating_input_variables; + } +}; + +class Mean : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); } + + int GetVersion(const Operator& op) const override { return 1; } }; -class Mean : public BuiltinOperator { +class Sum + : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateMeanOptions(*builder, op.keep_dims); + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { op->keep_dims = options.keep_dims(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class ResizeBilinear @@ -605,6 +777,8 @@ class ResizeBilinear TocoOperator* op) const override { op->align_corners = options.align_corners(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Squeeze @@ -626,6 +800,8 @@ class Squeeze options.squeeze_dims()->begin(), options.squeeze_dims()->end()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Split @@ -644,6 +820,8 @@ class Split TocoOperator* op) const override { op->num_split = options.num_splits(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class StridedSlice @@ -668,6 +846,8 @@ class StridedSlice op->new_axis_mask = options.new_axis_mask(); op->shrink_axis_mask = options.shrink_axis_mask(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class TopK_V2 : public BuiltinOperatoroutput_data_type = DataType::Deserialize(options.output_type()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class TransposeConv @@ -722,6 +906,67 @@ class TransposeConv op->stride_width = options.stride_w(); op->stride_height = options.stride_h(); } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class SparseToDense + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->validate_indices = options.validate_indices(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ExpandDims + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateExpandDimsOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class Shape + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateShapeOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.out_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } }; class TensorFlowUnsupported : public BaseOperator { @@ -784,6 +1029,20 @@ class TensorFlowUnsupported : public BaseOperator { fbb->Bool(key, attr.b()); has_valid_attr = true; break; + case tensorflow::AttrValue::kList: + if (attr.list().i_size() > 0) { + auto start = fbb->StartVector(key); + for (const int64_t v : attr.list().i()) { + fbb->Add(v); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + has_valid_attr = true; + } else { + LOG(WARNING) + << "Ignoring unsupported type in list attribute with key '" + << key << "'"; + } + break; default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -820,6 +1079,14 @@ class TensorFlowUnsupported : public BaseOperator { case flexbuffers::TYPE_BOOL: (*attr)[key].set_b(value.AsBool()); break; + case flexbuffers::TYPE_VECTOR_INT: { + auto* list = (*attr)[key].mutable_list(); + const auto& vector = value.AsTypedVector(); + for (size_t i = 0; i < vector.size(); i++) { + list->add_i(vector[i].AsInt64()); + } + break; + } default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -828,6 +1095,12 @@ class TensorFlowUnsupported : public BaseOperator { } node_def.SerializeToString(&op->tensorflow_node_def); } + + int GetVersion(const Operator& op) const override { + // TODO(ycling): Deisng and implement a way to plumb the version of + // custom ops. + return 1; + } }; namespace { @@ -872,8 +1145,8 @@ std::vector> BuildOperatorList() { ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); ops.emplace_back( new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); - ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, - OperatorType::kTensorFlowReshape)); + ops.emplace_back( + new Reshape(::tflite::BuiltinOperator_RESHAPE, OperatorType::kReshape)); ops.emplace_back( new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, @@ -884,12 +1157,13 @@ std::vector> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); - ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, - OperatorType::kTensorFlowSplit)); + ops.emplace_back( + new Split(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); ops.emplace_back( @@ -900,22 +1174,28 @@ std::vector> BuildOperatorList() { new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); + ops.emplace_back( + new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); + ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, + OperatorType::kExpandDims)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); + ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, + OperatorType::kSparseToDense)); + ops.emplace_back( + new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); - ops.emplace_back(new TensorFlowUnsupported( - "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); + ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", + OperatorType::kUnsupported)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. ops.emplace_back( new SimpleOperator("ADDN", OperatorType::kAddN)); - ops.emplace_back(new SimpleOperator( - "RSQRT", OperatorType::kTensorFlowRsqrt)); // Simple Operators. ops.emplace_back(new SimpleOperator( "DEQUANTIZE", OperatorType::kDequantize)); @@ -937,23 +1217,34 @@ std::vector> BuildOperatorList() { ops.emplace_back(new SimpleOperator( "LOG_SOFTMAX", OperatorType::kLogSoftmax)); ops.emplace_back(new SimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum)); + "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum ops.emplace_back(new SimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum)); + "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum ops.emplace_back(new SimpleOperator( - "GREATER", OperatorType::kTensorFlowGreater)); + "GREATER", OperatorType::kGreater)); ops.emplace_back(new SimpleOperator( - "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual)); - ops.emplace_back(new SimpleOperator( - "LESS", OperatorType::kTensorFlowLess)); + "GREATER_EQUAL", OperatorType::kGreaterEqual)); + ops.emplace_back( + new SimpleOperator("LESS", OperatorType::kLess)); ops.emplace_back(new SimpleOperator( - "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); + "LESS_EQUAL", OperatorType::kLessEqual)); + ops.emplace_back(new SimpleOperator( + "EQUAL", OperatorType::kEqual)); + ops.emplace_back(new SimpleOperator( + "NOT_EQUAL", OperatorType::kNotEqual)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back( new SimpleOperator("SELECT", OperatorType::kSelect)); ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); + ops.emplace_back(new SimpleOperator("POW", OperatorType::kPow)); + // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); + ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); + ops.emplace_back( + new SimpleOperator("SQRT", OperatorType::kSqrt)); + ops.emplace_back(new SimpleOperator( + "RSQRT", OperatorType::kRsqrt)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 85f7bdafe04979abc14f826ef667b3fa1aeec65c..d9ea23edf2b08146773ca58762623397e0f6257c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -77,6 +77,27 @@ class BaseOperator { const BuiltinOptions* builtin_options, const CustomOptions* custom_options) const = 0; + // Get the op version by op parameters. + // The function need to be overridden to return the op version based on the + // parameters. Note: + // * The first version for each op should be 1 (to be consistent with the + // default value in Flatbuffer. `return 1;` is okay for newly implemented + // ops. + // * When multiple versions are defined for an op, this function need to be + // overridden. (See example in `operator_test.cc`) + virtual int GetVersion(const Operator& op) const = 0; + + // Given a Toco `Operator`, return a list of booleans indicating the op + // mutates which input variables. + // * If the op mutates any input variables, it should return a list of bool + // with the same length as inputs. + // * Otherwise, it will return an empty list. + virtual std::vector GetMutatingInputVariables( + const Operator& op) const { + // Most ops don't have variable tensors. This function can be overridden. + return std::vector(); + } + private: string name_; OperatorType type_; diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fe594c6da9826ab904d162c9e28e1455b1bf69f6..8b6808d3c78d8c51c1b33d09eb4082326100b028 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -74,8 +74,10 @@ class OperatorTest : public ::testing::Test { auto new_toco_op = op.Deserialize(output_options->builtin_options(), output_options->custom_options()); - CHECK(dynamic_cast(new_toco_op.get())) - << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " + CHECK(new_toco_op->type == toco_op.type) + << "The type of the serialized and deserialized" + << HelpfulOperatorTypeName(*new_toco_op) + << " does not match the type of the original " << HelpfulOperatorTypeName(toco_op); return std::unique_ptr(dynamic_cast(new_toco_op.release())); @@ -110,15 +112,21 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax); CheckSimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum); + "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum CheckSimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum); - CheckSimpleOperator("LESS", - OperatorType::kTensorFlowLess); + "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum + CheckSimpleOperator("LESS", OperatorType::kLess); CheckSimpleOperator("NEG", OperatorType::kNeg); CheckSimpleOperator("SELECT", OperatorType::kSelect); CheckSimpleOperator("SLICE", OperatorType::kSlice); CheckSimpleOperator("SIN", OperatorType::kSin); + CheckSimpleOperator("EQUAL", OperatorType::kEqual); + CheckSimpleOperator("NOT_EQUAL", + OperatorType::kNotEqual); + CheckSimpleOperator("LOG", OperatorType::kLog); + CheckSimpleOperator("SQRT", OperatorType::kSqrt); + CheckSimpleOperator("RSQRT", OperatorType::kRsqrt); + CheckSimpleOperator("POW", OperatorType::kPow); } TEST_F(OperatorTest, BuiltinAdd) { @@ -247,7 +255,7 @@ TEST_F(OperatorTest, BuiltinReshape) { TensorFlowReshapeOperator op; op.shape = {1, 2, 4, 5, 8}; auto output_toco_op = SerializeAndDeserialize( - GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); + GetOperator("RESHAPE", OperatorType::kReshape), op); EXPECT_EQ(op.shape, output_toco_op->shape); } @@ -270,8 +278,8 @@ TEST_F(OperatorTest, BuiltinSpaceToDepth) { TEST_F(OperatorTest, CustomSplit) { TensorFlowSplitOperator op; op.num_split = 123; - auto output_toco_op = SerializeAndDeserialize( - GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op); EXPECT_EQ(op.num_split, output_toco_op->num_split); } @@ -420,6 +428,23 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinShape) { + TensorFlowShapeOperator op; + op.output_data_type = ArrayDataType::kInt64; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + +TEST_F(OperatorTest, BuiltinSparseToDense) { + SparseToDenseOperator op; + op.validate_indices = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op); + EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; @@ -430,12 +455,17 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { (*attr)["str_attr"].set_s("Hello World"); (*attr)["int_attr"].set_i(17); (*attr)["bool_attr"].set_b(true); + { + auto* list = (*attr)["list_int_attr"].mutable_list(); + list->add_i(1); + list->add_i(20); + list->add_i(1LL << 40); + list->add_i(-(1LL << 40)); + } node_def.SerializeToString(&op.tensorflow_node_def); - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); @@ -444,15 +474,22 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); EXPECT_EQ(17, output_attr.at("int_attr").i()); EXPECT_EQ(true, output_attr.at("bool_attr").b()); + + { + const auto& list = output_attr.at("list_int_attr").list(); + ASSERT_EQ(4, list.i_size()); + EXPECT_EQ(1, list.i(0)); + EXPECT_EQ(20, list.i(1)); + EXPECT_EQ(1LL << 40, list.i(2)); + EXPECT_EQ(-(1LL << 40), list.i(3)); + } } TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h index 72678c82a22a7168f858747b0b1c6a2b515b6578..a7f7e886f61d3bbf221c0ab7a24d6c3e629ec274 100644 --- a/tensorflow/contrib/lite/toco/tflite/simple_operator.h +++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h @@ -41,6 +41,8 @@ class SimpleOperator : public BaseOperator { const CustomOptions* custom_options) const override { return std::unique_ptr(new T); } + + int GetVersion(const Operator& op) const override { return 1; } }; } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 4867c3a62e68406428644cd05bddf212008c2656..754f0b4b8c661355c99d9e5a86f2d7844414a303 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -88,6 +88,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { switch (array_data_type) { case ArrayDataType::kFloat: return ::tflite::TensorType_FLOAT32; + case ArrayDataType::kInt16: + return ::tflite::TensorType_INT16; case ArrayDataType::kInt32: return ::tflite::TensorType_INT32; case ArrayDataType::kInt64: @@ -98,6 +100,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_STRING; case ArrayDataType::kBool: return ::tflite::TensorType_BOOL; + case ArrayDataType::kComplex64: + return ::tflite::TensorType_COMPLEX64; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -109,6 +113,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { switch (::tflite::TensorType(tensor_type)) { case ::tflite::TensorType_FLOAT32: return ArrayDataType::kFloat; + case ::tflite::TensorType_INT16: + return ArrayDataType::kInt16; case ::tflite::TensorType_INT32: return ArrayDataType::kInt32; case ::tflite::TensorType_INT64: @@ -119,6 +125,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kUint8; case ::tflite::TensorType_BOOL: return ArrayDataType::kBool; + case ::tflite::TensorType_COMPLEX64: + return ArrayDataType::kComplex64; default: LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; } @@ -131,6 +139,8 @@ flatbuffers::Offset> DataBuffer::Serialize( switch (array.data_type) { case ArrayDataType::kFloat: return CopyBuffer(array, builder); + case ArrayDataType::kInt16: + return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); case ArrayDataType::kInt64: @@ -141,6 +151,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyBuffer(array, builder); case ArrayDataType::kBool: return CopyBoolToBuffer(array, builder); + case ArrayDataType::kComplex64: + return CopyBuffer(array, builder); default: LOG(FATAL) << "Unhandled array data type."; } @@ -154,6 +166,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, switch (tensor.type()) { case ::tflite::TensorType_FLOAT32: return CopyBuffer(buffer, array); + case ::tflite::TensorType_INT16: + return CopyBuffer(buffer, array); case ::tflite::TensorType_INT32: return CopyBuffer(buffer, array); case ::tflite::TensorType_INT64: @@ -164,6 +178,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyBuffer(buffer, array); case ::tflite::TensorType_BOOL: return CopyBuffer(buffer, array); + case ::tflite::TensorType_COMPLEX64: + return CopyBuffer(buffer, array); default: LOG(FATAL) << "Unhandled tensor type."; } diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 564f303b9bb41a777633ecabd666aa93ec3faefe..8e9f30ba3a6e6b98fa9c4237567b0797a5a797aa 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/types.h" +#include + #include #include @@ -71,7 +73,8 @@ TEST(DataType, SupportedTypes) { {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, - {ArrayDataType::kBool, ::tflite::TensorType_BOOL}}; + {ArrayDataType::kBool, ::tflite::TensorType_BOOL}, + {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}}; for (auto x : testdata) { EXPECT_EQ(x.second, DataType::Serialize(x.first)); EXPECT_EQ(x.first, DataType::Deserialize(x.second)); @@ -151,6 +154,12 @@ TEST(DataBuffer, Int32) { ::testing::ElementsAre(1, 1 << 30)); } +TEST(DataBuffer, Int16) { + Array recovered = ToFlatBufferAndBack({1, 1 << 14}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1 << 14)); +} + TEST(DataBuffer, String) { Array recovered = ToFlatBufferAndBack( {"AA", "BBB", "Best. String. Ever."}); @@ -165,6 +174,14 @@ TEST(DataBuffer, Bool) { ::testing::ElementsAre(true, false, true)); } +TEST(DataBuffer, Complex64) { + Array recovered = ToFlatBufferAndBack( + {std::complex(1.0f, 2.0f), std::complex(3.0f, 4.0f)}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(std::complex(1.0f, 2.0f), + std::complex(3.0f, 4.0f))); +} + TEST(Padding, All) { EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc index 8041aa9e7fbfdaf44134395fee4b2bb01633893a..0b460bd178a49cafefd3438b7ae1c38a07b2ab7c 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" @@ -49,17 +48,6 @@ void CheckFrozenModelPermissions(const Arg& input_file) { << input_file.value() << ".\n"; } -// Checks the permissions of the SavedModel directory. -void CheckSavedModelPermissions(const Arg& savedmodel_directory) { - QCHECK(savedmodel_directory.specified()) - << "Missing required flag --savedmodel_directory.\n"; - QCHECK( - port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) - .ok()) - << "Specified savedmodel_directory does not exist: " - << savedmodel_directory.value() << ".\n"; -} - // Reads the contents of the GraphDef from either the frozen graph file or the // SavedModel directory. If it reads the SavedModel directory, it updates the // ModelFlags and TocoFlags accordingly. @@ -69,24 +57,16 @@ void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, string* graph_def_contents) { port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); - bool has_input_file = parsed_toco_flags.input_file.specified(); - bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); - - // Ensure either input_file or savedmodel_directory flag has been set. - QCHECK_NE(has_input_file, has_savedmodel_dir) - << "Specify either input_file or savedmodel_directory flag.\n"; + // Ensure savedmodel_directory is not set. + QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) + << "Use `tensorflow/contrib/lite/python/tflite_convert` script with " + << "SavedModel directories.\n"; // Checks the input file permissions and reads the contents. - if (has_input_file) { - CheckFrozenModelPermissions(parsed_toco_flags.input_file); - CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), - graph_def_contents, port::file::Defaults()) - .ok()); - } else { - CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); - GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, - model_flags, graph_def_contents); - } + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); } void ToolMain(const ParsedTocoFlags& parsed_toco_flags, diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 7786a4ada335abc9a01a0a6e423125f2d67957c2..c6d0a03452f7477841d7e68665baf32dff45f41c 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -41,7 +41,7 @@ bool ParseTocoFlagsFromCommandLineFlags( "extension."), Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(), parsed_flags.savedmodel_directory.default_value(), - "Full path to the directory containing the SavedModel."), + "Deprecated. Full path to the directory containing the SavedModel."), Flag("output_file", parsed_flags.output_file.bind(), parsed_flags.output_file.default_value(), "Output file. " @@ -55,9 +55,9 @@ bool ParseTocoFlagsFromCommandLineFlags( "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(), parsed_flags.savedmodel_tagset.default_value(), - "Comma-separated set of tags identifying the MetaGraphDef within " - "the SavedModel to analyze. All tags in the tag set must be " - "specified."), + "Deprecated. Comma-separated set of tags identifying the " + "MetaGraphDef within the SavedModel to analyze. All tags in the tag " + "set must be specified."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " @@ -153,6 +153,16 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.dedupe_array_min_size_bytes.default_value(), "Minimum size of constant arrays to deduplicate; arrays smaller " "will not be deduplicated."), + Flag("split_tflite_lstm_inputs", + parsed_flags.split_tflite_lstm_inputs.bind(), + parsed_flags.split_tflite_lstm_inputs.default_value(), + "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. " + "Ignored if the output format is not TFLite."), + Flag("quantize_weights", parsed_flags.quantize_weights.bind(), + parsed_flags.quantize_weights.default_value(), + "Store weights as quantized weights followed by dequantize " + "operations. Computation is still done in float, but reduces model " + "size (at the cost of accuracy and latency)."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -245,6 +255,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel, FlagRequirement::kNone); READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); + READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); + READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { @@ -278,6 +290,11 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, QCHECK(toco::IODataType_Parse(input_types[0], &input_type)); toco_flags->set_inference_input_type(input_type); } + if (parsed_toco_flags.quantize_weights.value()) { + QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8) + << "quantize_weights is not supported with inference_type " + "QUANTIZED_UINT8."; + } #undef READ_TOCO_FLAG #undef PARSE_TOCO_FLAG diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 8589ca361dae2561207f9fa0c57b3240240c08d6..b4a9870d5834d1d5689d15ebc131ac0ead3e9850 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 19. +// Next ID to use: 26. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -165,4 +165,22 @@ message TocoFlags { // Minimum size of constant arrays to deduplicate; arrays smaller will not be // deduplicated. optional int64 dedupe_array_min_size_bytes = 18 [default = 64]; + + // Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. + // Ignored if the output format is not TFLite. + optional bool split_tflite_lstm_inputs = 19 [default = true]; + + // Store weights as quantized weights followed by dequantize operations. + // Computation is still done in float, but reduces model size (at the cost of + // accuracy and latency). + optional bool quantize_weights = 20 [default = false]; + + // Full filepath of folder to dump the graphs at various stages of processing + // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order + // to keep the requirements of the output file. + optional string dump_graphviz_dir = 24; + + // Boolean indicating whether to dump the graph after every graph + // transformation. + optional bool dump_graphviz_include_video = 25; } diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h index d6c3ba6543378b3e15b5fb7816f52376fe05123d..7cdd55e5422589aa000000b82d09b9d8397d7a88 100644 --- a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h @@ -21,8 +21,6 @@ namespace toco { // Global data for determining whether to output graph viz format from toco. struct GraphVizDumpOptions { - std::string graphviz_first_array; - std::string graphviz_last_array; std::string dump_graphviz; bool dump_graphviz_video = false; diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index a1c8696cd06a30bfe8661bb70aa4f2d6d175aac3..de76fd4032d24eff8a6c2fd0c16a911b9c00186b 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -16,8 +16,16 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#if defined(__ANDROID__) && defined(__ARM_ARCH_7A__) +namespace std { +double round(double x) { return ::round(x); } +} // namespace std +#endif + namespace toco { namespace port { void CopyToBuffer(const string& src, char* dest) { @@ -55,8 +63,12 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { // Conversion to our wrapper Status. -Status ToStatus(const ::util::Status& uts) { - return Status(uts.ok(), uts.error_message()); +tensorflow::Status ToStatus(const ::util::Status& uts) { + if (!uts.ok()) { + return tensorflow::Status(tensorflow::errors::Code(uts.error_code()), + uts.error_message()); + } + return tensorflow::Status::OK(); } // Conversion to our wrapper Options. @@ -65,7 +77,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) { return Options(); } -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { File* f = nullptr; const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); if (f) { @@ -74,22 +86,24 @@ Status Writable(const string& filename) { return ToStatus(status); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { return ToStatus(::file::Readable(filename, ::file::Defaults())); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { auto status = ::file::Exists(filename, ::file::Defaults()); return ToStatus(status); } -Status GetContents(const string& filename, string* contents, - const file::Options& options) { +tensorflow::Status GetContents(const string& filename, string* contents, + const file::Options& options) { return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); } @@ -133,37 +147,42 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { FILE* f = fopen(filename.c_str(), "w"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not writable"); + return tensorflow::errors::NotFound("not writable"); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { FILE* f = fopen(filename.c_str(), "r"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not readable"); + return tensorflow::errors::NotFound("not readable"); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { struct stat statbuf; int ret = stat(filename.c_str(), &statbuf); - return Status(ret != -1, ""); + if (ret == -1) { + return tensorflow::errors::NotFound("file doesn't exist"); + } + return tensorflow::Status::OK(); } -Status GetContents(const string& path, string* output, - const file::Options& options) { +tensorflow::Status GetContents(const string& path, string* output, + const file::Options& options) { output->clear(); int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { - return Status(false, "can't open() for read"); + return tensorflow::errors::NotFound("can't open() for read"); } // Direct read, for speed. @@ -174,25 +193,25 @@ Status GetContents(const string& path, string* output, if (size == 0) { // Done. close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } else if (size == -1) { // Error. close(fd); - return Status(false, "error during read()"); + return tensorflow::errors::Internal("error during read()"); } else { output->append(buffer, size); } } CHECK(0); - return Status(false, "internal error"); + return tensorflow::errors::Internal("internal error"); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); if (fd == -1) { - return Status(false, "can't open() for write"); + return tensorflow::errors::Internal("can't open() for write"); } size_t i = 0; @@ -201,13 +220,13 @@ Status SetContents(const string& filename, const string& contents, ssize_t written = write(fd, &contents[i], to_write); if (written == -1) { close(fd); - return Status(false, "write() error"); + return tensorflow::errors::Internal("write() error"); } i += written; } close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } string JoinPath(const string& base, const string& filename) { diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index 906792ef569e5b8dd2a40f6cf683fa8a35946012..17f82b9dd7dcc633aa204038b6d965f4eb6967bb 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "google/protobuf/text_format.h" #include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" #if defined(PLATFORM_GOOGLE) @@ -33,28 +34,26 @@ limitations under the License. #define TFLITE_PROTO_NS google::protobuf #endif -namespace toco { -namespace port { - -class Status { - public: - static Status OK() { return Status(true, ""); } - - // Create a failed status with no message. - Status() {} - - Status(bool ok, const string& message) : ok_(ok), message_(message) {} - - void AppendMessage(const string& message) { message_ += message; } +#ifdef __ANDROID__ +#include +namespace std { - bool ok() const { return ok_; } +template +std::string to_string(T value) +{ + std::ostringstream os ; + os << value ; + return os.str() ; +} - const string error_message() const { return message_; } +#ifdef __ARM_ARCH_7A__ +double round(double x); +#endif +} +#endif - private: - bool ok_ = false; - string message_; -}; +namespace toco { +namespace port { void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags); void CheckInitGoogleIsDone(const char* message); @@ -65,14 +64,14 @@ inline Options Defaults() { Options o; return o; } -Status GetContents(const string& filename, string* contents, - const Options& options); -Status SetContents(const string& filename, const string& contents, - const Options& options); +tensorflow::Status GetContents(const string& filename, string* contents, + const Options& options); +tensorflow::Status SetContents(const string& filename, const string& contents, + const Options& options); string JoinPath(const string& base, const string& filename); -Status Writable(const string& filename); -Status Readable(const string& filename, const Options& options); -Status Exists(const string& filename, const Options& options); +tensorflow::Status Writable(const string& filename); +tensorflow::Status Readable(const string& filename, const Options& options); +tensorflow::Status Exists(const string& filename, const Options& options); } // namespace file // Copy `src` string to `dest`. User must ensure `dest` has enough space. diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc deleted file mode 100644 index 26f55a66c729894a990258080e397bb42ea98a13..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/strings/numbers.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace toco { -namespace { - -// Loads a SavedModel from the directory specified in parsed_toco_flags. -// Returns a SavedModelBundle with the requested MetaGraphDef. -const tensorflow::SavedModelBundle* LoadSavedModel( - const ParsedTocoFlags& parsed_toco_flags) { - const string model_path = parsed_toco_flags.savedmodel_directory.value(); - QCHECK(tensorflow::MaybeSavedModelDirectory(model_path)) - << "Model is not saved in the supported SavedModel format.\n"; - - // Gets the tags identifying the MetaGraphDef from the command line arguments. - string tags_str; - if (parsed_toco_flags.savedmodel_tagset.specified()) { - tags_str = parsed_toco_flags.savedmodel_tagset.value(); - } else { - tags_str = parsed_toco_flags.savedmodel_tagset.default_value(); - } - auto tags = absl::StrSplit(tags_str, ','); - - // Loads MetaGraphDef. - auto* bundle = new tensorflow::SavedModelBundle; - TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(), - tensorflow::RunOptions(), model_path, - tags, bundle)) - << "Failed to load exported model from " << model_path - << ". Ensure the model contains the required tags '" << tags_str - << "'.\n"; - return bundle; -} - -// Returns the array name without the postfix. -// -// e.g. reduces "input:0" to "input". -string GetArrayName(const string& name) { - const std::vector& names = absl::StrSplit(name, ':'); - return names[0]; -} - -// Returns the list of array names without the postfix sorted alphabetically. -std::set GetSortedNames(const std::unordered_set& names) { - std::vector final_names; - final_names.reserve(names.size()); - for (const auto& name : names) { - final_names.push_back(GetArrayName(name)); - } - return std::set(final_names.begin(), final_names.end()); -} - -// Gets the final shape after replacing the first dimension with batch size, if -// it is undefined (containing the value -1). Returns whether the shape is -// valid. -bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape, - int batch_size, - tensorflow::TensorShapeProto* final_shape) { - for (int idx = 0; idx < shape.dim().size(); ++idx) { - int64 final_dim = shape.dim()[idx].size(); - if (final_dim == -1) { - if (idx > 0) return false; - final_dim = batch_size; - } - final_shape->add_dim()->set_size(final_dim); - } - return true; -} - -// Updates the input arrays in ModelFlags to contain the shape of the array. -void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size, - ModelFlags* model_flags) { - // Build map of input array names to input arrays. - std::unordered_map input_data_map; - for (auto& input : *model_flags->mutable_input_arrays()) { - input_data_map[input.name()] = &input; - } - - // Adds shapes to the input arrays if the shape is valid. - for (const tensorflow::NodeDef& node_def : graph_def.node()) { - if (input_data_map.find(node_def.name()) != input_data_map.end()) { - const auto shape_it = node_def.attr().find("shape"); - if (shape_it != node_def.attr().end()) { - tensorflow::TensorShapeProto final_shape; - bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(), - batch_size, &final_shape); - - if (is_valid) { - auto* shape = input_data_map.at(node_def.name())->mutable_shape(); - QCHECK_EQ(shape->dims_size(), 0) - << "The shape for the input '" << node_def.name() - << "' was previously defined. For clarity please define inputs " - << "via --input_arrays and input_shapes flags.\n"; - for (const auto& dim : final_shape.dim()) { - shape->add_dims(dim.size()); - } - } - } - } - } - - // Checks all input arrays have a shape. - for (auto const& input : model_flags->input_arrays()) { - QCHECK(input.shape().dims_size() > 0) - << "A valid input shape was not found for input '" << input.name() - << "'. Please define via --input_arrays and --input_shapes flags.\n"; - } -} - -} // namespace - -void ParseMetaData(const tensorflow::GraphDef& graph_def, - const std::unordered_set& inputs, - const std::unordered_set& outputs, - const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags) { - if (!parsed_model_flags.input_arrays.specified()) { - const std::set sorted_inputs = GetSortedNames(inputs); - for (const auto& input_name : sorted_inputs) { - model_flags->add_input_arrays()->set_name(input_name); - } - } - - if (!parsed_model_flags.output_arrays.specified()) { - const std::set sorted_outputs = GetSortedNames(outputs); - for (const auto& output_name : sorted_outputs) { - model_flags->add_output_arrays(GetArrayName(output_name)); - } - } - - if (!parsed_model_flags.input_shapes.specified()) { - int batch_size = parsed_model_flags.batch_size.value(); - ProcessInputShapes(graph_def, batch_size, model_flags); - } - - if (!parsed_toco_flags.inference_type.specified()) { - toco_flags->set_inference_type(IODataType::FLOAT); - } -} - -// TODO(nupurgarg): Add top level tests. -void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents) { - // Loads the MetaGraphDef within a SavedModelBundle. - auto bundle = LoadSavedModel(parsed_toco_flags); - - // Converts the MetaGraphDef to frozen GraphDef. - tensorflow::GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs, - &outputs)); - - // Reads the frozen GraphDef into a string. - QCHECK(frozen_graph_def.SerializeToString(graph_def_contents)) - << "Unable to generate serialized GraphDef.\n"; - - // Process inputs and outputs and metadata within GraphDef. - const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def(); - ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags, - parsed_model_flags, toco_flags, model_flags); -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h deleted file mode 100644 index 7a0fabd82d90131a3b2d28c757c08dcb0f9e3988..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ - -#include -#include - -#include "tensorflow/cc/tools/freeze_saved_model.h" -#include "tensorflow/contrib/lite/toco/args.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/types.pb.h" - -namespace toco { - -// Parses metadata into `toco_flags` and `model_flags`. -// -// Stores `inputs` as input_arrays and `outputs` as output_arrays in -// `model_flags`. Infers input_shapes from the GraphDef and stores it in -// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT -// and stores it in `toco_flags`. -void ParseMetaData(const tensorflow::GraphDef& graph_def, - const std::unordered_set& inputs, - const std::unordered_set& outputs, - const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags); - -// Generates a frozen graph from the SavedModel in the directory specified in -// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses -// metadata relating to the GraphDef into `toco_flags` and `model_flags`. -void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents); - -} // namespace toco - -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc deleted file mode 100644 index 5e122afe65dc29abc85f142f4019aae5058ace51..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc +++ /dev/null @@ -1,274 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" -#include "absl/strings/str_join.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -#include -#include - -namespace toco { -namespace { - -using tensorflow::ops::Add; -using tensorflow::ops::Const; -using tensorflow::ops::FakeQuantWithMinMaxArgs; -using tensorflow::ops::Placeholder; - -class TocoSavedModelTest : public ::testing::Test { - protected: - // Calls functions to process cmdline arguments and calls ParseMetaData. - // ParseMetaData parses input_arrays, output_arrays, and gets metadata from - // SavedModel it is not defined in the cmdline arguments. - void ProcessGraphDefMetadata(const std::unordered_set& inputs, - const std::unordered_set& outputs, - const tensorflow::GraphDef& graph_def) { - ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_); - ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_); - ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_, - parsed_model_flags_, &toco_flags_, &model_flags_); - } - - // Gets the GraphDef from the SavedModelBundle and processes metadata. - void ProcessSavedModelMetadata(const std::unordered_set& inputs, - const std::unordered_set& outputs) { - const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def(); - ProcessGraphDefMetadata(inputs, outputs, graph_def); - } - - // Returns a GraphDef representing a simple float model with a single input. - tensorflow::GraphDef GetFloatGraphDef(const std::vector& shape) { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output input = - Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::PartialTensorShape(shape))); - tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); - tensorflow::Output add = Add(scope.WithOpName("add"), input, zero); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Returns a GraphDef representing a simple float model with two inputs. - tensorflow::GraphDef GetComplexFloatGraphDef() { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output inputA = - Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output inputB = - Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Returns a GraphDef representing a simple quantized model. - tensorflow::GraphDef GetQuantizedGraphDef() { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output input = - Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); - tensorflow::Output fake_quant = - FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero); - tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Gets the values in the input_arrays flag. - std::vector GetInputArrays() { - std::vector actual; - for (const auto& input : model_flags_.input_arrays()) { - actual.push_back(input.name()); - } - return actual; - } - - // Gets the values in the output_arrays flag. - std::vector GetOutputArrays() { - std::vector actual(model_flags_.output_arrays().begin(), - model_flags_.output_arrays().end()); - return actual; - } - - // Gets the shape of the given input array. - string GetInputShape(const string& input_array) { - for (const auto& input : model_flags_.input_arrays()) { - if (input.name() == input_array) { - std::vector dims; - for (int idx = 0; idx < input.shape().dims_size(); ++idx) { - dims.push_back(std::to_string(input.shape().dims(idx))); - } - return absl::StrJoin(dims, ","); - } - } - return ""; - } - - tensorflow::SavedModelBundle bundle_; - ParsedTocoFlags parsed_toco_flags_; - ParsedModelFlags parsed_model_flags_; - TocoFlags toco_flags_; - ModelFlags model_flags_; -}; - -// Tests if input_arrays, output_arrays, inference_type, and output_arrays are -// added to ModelFlags if they are not specified in cmdline arguments. -// Tests if the default batch size replaces a -1 in the first dimension. -TEST_F(TocoSavedModelTest, NoCmdLine) { - tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the order of input_arrays and output_arrays is deterministic when -// they are taken from the SavedModel. -TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) { - tensorflow::GraphDef graph_def = GetComplexFloatGraphDef(); - - // Note: The model does not have two outputs. However, the function does not - // need an accurate output_array list. This is only meant to test order. - ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add", "invalid"})); - EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1"); - EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if input_shapes is inferred when input_arrays is passed in via cmdline -// arguments. -TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) { - parsed_model_flags_.input_arrays.bind()("input"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1}); - - ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "2,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Ensures a failure occurs when input_shapes is defined without input_arrays. -TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) { - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); - - EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), - "failed: input_shapes.size\\(\\) == " - "model_flags->input_arrays_size\\(\\)"); -} - -// Tests if the cmdline values of input_arrays, input_shapes are used when -// specified with an empty GraphDef. -TEST_F(TocoSavedModelTest, InputArraysCmdLine) { - parsed_model_flags_.input_arrays.bind()("inputA,inputB"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - - ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"output0", "output1"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(GetInputShape("inputB"), "9,12"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the cmdline values of input_arrays, input_shapes are used when -// specified even if values exist within the GraphDef. -TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) { - parsed_model_flags_.input_arrays.bind()("inputA"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); - - ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the cmdline values of input_arrays, input_shapes, inference_type, -// and output_arrays are used when specified with an empty GraphDef. -TEST_F(TocoSavedModelTest, AllParamsCmdLine) { - parsed_model_flags_.input_arrays.bind()("inputA,inputB"); - parsed_model_flags_.output_arrays.bind()("outputA,outputB"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - parsed_toco_flags_.inference_type.bind()("FLOAT"); - - ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"outputA", "outputB"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(GetInputShape("inputB"), "9,12"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if a quantized graph gives the correct values assuming type is passed -// in via command line. -TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) { - parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8"); - tensorflow::GraphDef graph_def = GetQuantizedGraphDef(); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8); -} - -// Tests if the provided batch size replaces a -1 in the first dimension of -// input shape. -TEST_F(TocoSavedModelTest, MissingShapeParameterValid) { - parsed_model_flags_.batch_size.bind()(3); - tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "3,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Ensures a failure occurs if there is a -1 in a dimension aside from the first -// position of input shape. -TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) { - parsed_model_flags_.batch_size.bind()(3); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1}); - - EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), - "A valid input shape was not found for input 'input'."); -} - -} // namespace -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index b5531ca2f4785e0c95703f95977be93a0ba2a8e2..fc1636831b266b6aa426c564a0c1c7ca99bc0ff1 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -34,11 +34,11 @@ limitations under the License. namespace toco { namespace { -// CHECK-fails if the model contains a kTensorFlowUnsupported operation. +// CHECK-fails if the model contains a kUnsupported operation. void CheckUnsupportedOperations(const Model& model) { std::set unsupported_ops; for (auto& op : model.operators) { - if (op->type == OperatorType::kTensorFlowUnsupported) { + if (op->type == OperatorType::kUnsupported) { unsupported_ops.insert( static_cast(op.get()) ->tensorflow_op); @@ -56,6 +56,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialStackToReshape); + transformations->Add(new ConvertTrivialTileToConcat); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); @@ -76,7 +77,9 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new FuseBroadcastIntoFollowingBinary); transformations->Add(new MergeReshapeIntoPrecedingTranspose); + transformations->Add(new MoveBinaryOperatorBeforeReshape); transformations->Add(new ReorderElementwiseUnary); transformations->Add(new ReorderReshapeTranspose); transformations->Add(new ResolveBatchNormalization); @@ -94,7 +97,6 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveSqueezeAttributes); transformations->Add(new ResolveTensorFlowSwitch); - transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); transformations->Add(new ResolveMultiplyByZero); transformations->Add(new IdentifyDilatedConv); @@ -133,6 +135,8 @@ bool SupportsPreallocatedWorkspace(FileFormat format) { return (format == TFLITE); } +bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; } + bool IsRealValued(toco::ArrayDataType type) { // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used // for quantized real-number values, and no other integer type is ever used @@ -263,12 +267,15 @@ void Transform(const TocoFlags& toco_flags, Model* model) { if (!toco_flags.debug_disable_recurrent_cell_fusion()) { transformations.Add(new IdentifyLstmCell); } - if (output_format == TFLITE) { + if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) { transformations.Add(new toco::SplitLstmCellInputs); } else { transformations.Add(new toco::MergeLstmCellInputs); } } + if (toco_flags.quantize_weights()) { + transformations.Add(new QuantizeWeights); + } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); @@ -331,6 +338,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) { new RemoveFinalDequantizeOp, ensure_safe_for_int8_kernels, }); + if (SupportsShuffledFCWeights(output_format)) { + RunGraphTransformations(model, "shuffling of FC weights", + {new ShuffleFCWeights}); + } } else { GraphTransformationsSet dequantization_transformations{new Dequantize}; // Dequantize creates FakeQuant nodes. We may want to discard diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 1e6314f2dc78297c8bdacb19cf89292603695e3f..01113506d0ebbf25c057ab0a50730a45eeef64a5 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -338,23 +338,23 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(Sin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(All) + HANDLE_OPERATORTYPENAME_CASE(Assert) HANDLE_OPERATORTYPENAME_CASE(ExpandDims) HANDLE_OPERATORTYPENAME_CASE(Fill) HANDLE_OPERATORTYPENAME_CASE(FloorMod) HANDLE_OPERATORTYPENAME_CASE(FloorDiv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Greater) + HANDLE_OPERATORTYPENAME_CASE(GreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(Identity) + HANDLE_OPERATORTYPENAME_CASE(Less) + HANDLE_OPERATORTYPENAME_CASE(LessEqual) + HANDLE_OPERATORTYPENAME_CASE(MatMul) + HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(Merge) + HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) @@ -362,22 +362,22 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Reshape) HANDLE_OPERATORTYPENAME_CASE(Squeeze) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Rsqrt) + HANDLE_OPERATORTYPENAME_CASE(Shape) HANDLE_OPERATORTYPENAME_CASE(Slice) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Split) + HANDLE_OPERATORTYPENAME_CASE(Sqrt) + HANDLE_OPERATORTYPENAME_CASE(Square) + HANDLE_OPERATORTYPENAME_CASE(Switch) HANDLE_OPERATORTYPENAME_CASE(Sub) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(Sum) + HANDLE_OPERATORTYPENAME_CASE(Tile) HANDLE_OPERATORTYPENAME_CASE(Transpose) HANDLE_OPERATORTYPENAME_CASE(TransposeConv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Concat) + HANDLE_OPERATORTYPENAME_CASE(ConcatV2) HANDLE_OPERATORTYPENAME_CASE(Cast) HANDLE_OPERATORTYPENAME_CASE(Floor) HANDLE_OPERATORTYPENAME_CASE(Gather) @@ -388,11 +388,15 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) + HANDLE_OPERATORTYPENAME_CASE(SparseToDense) + HANDLE_OPERATORTYPENAME_CASE(Equal) + HANDLE_OPERATORTYPENAME_CASE(NotEqual) + HANDLE_OPERATORTYPENAME_CASE(Pow) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -400,7 +404,7 @@ const char* OperatorTypeName(OperatorType type) { } string HelpfulOperatorTypeName(const Operator& op) { - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { return toco::port::StringF( "(Unsupported TensorFlow op: %s)", static_cast(op).tensorflow_op); @@ -410,16 +414,20 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { - case OperatorType::kConcatenation: - case OperatorType::kFakeQuant: - case OperatorType::kGather: - case OperatorType::kSlice: - case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: - case OperatorType::kTensorFlowSplit: - return false; - default: + case OperatorType::kAdd: + case OperatorType::kAveragePool: + case OperatorType::kBatchNormalization: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: + case OperatorType::kDiv: + case OperatorType::kFullyConnected: + case OperatorType::kL2Pool: + case OperatorType::kMaxPool: + case OperatorType::kMul: + case OperatorType::kSub: return true; + default: + return false; } } @@ -439,8 +447,12 @@ void LogSummary(int log_level, const Model& model) { } void LogArray(int log_level, const Model& model, const string& name) { - const auto& array = model.GetArray(name); VLOG(log_level) << "Array: " << name; + if (!model.HasArray(name)) { + VLOG(log_level) << " DOES NOT EXIST"; + return; + } + const auto& array = model.GetArray(name); VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); VLOG(log_level) << " Final type: " << ArrayDataTypeName(array.final_data_type); @@ -582,6 +594,13 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } +bool IsValid(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + if (shape.dims(i) < 1) return false; + } + return true; +} + void CheckShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i @@ -1862,18 +1881,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, output_axes_order == AxesOrder::kHWIO) { // 3210 <- 3210 // HWIO <- OHWI - (*shuffle)[0] = 1; - (*shuffle)[1] = 2; - (*shuffle)[2] = 3; - (*shuffle)[3] = 0; + *shuffle = {1, 2, 3, 0}; } else if (input_axes_order == AxesOrder::kHWIO && output_axes_order == AxesOrder::kOHWI) { // 3210 <- 3210 // OHWI <- HWIO - (*shuffle)[0] = 3; - (*shuffle)[1] = 0; - (*shuffle)[2] = 1; - (*shuffle)[3] = 2; + *shuffle = {3, 0, 1, 2}; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWOI) { + *shuffle = {1, 2, 0, 3}; } else { LOG(FATAL) << "Bad shuffle"; } @@ -2019,6 +2035,8 @@ int AxesCount(AxesOrder axes_order) { return 4; case AxesOrder::kNHWC: return 4; + case AxesOrder::kHWOI: + return 4; default: LOG(FATAL) << "Bad AxesOrder"; return 0; @@ -2187,4 +2205,51 @@ void UseArraysExtraInfo(Model* model, bool quantize_output) { } } +void UndoWeightsShuffling(Model* model) { + for (const auto& op : model->operators) { + if (op->type != toco::OperatorType::kFullyConnected) { + continue; + } + const auto& fc_op = static_cast(*op); + if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { + continue; + } + const string& weights_name = fc_op.inputs[1]; + QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1); + auto& weights_array = model->GetArray(weights_name); + QCHECK(weights_array.data_type == ArrayDataType::kUint8); + auto& weights_data = + weights_array.GetMutableBuffer().data; + const auto& weights_shape = weights_array.shape(); + QCHECK_EQ(weights_shape.dimensions_count(), 2); + const int rows = weights_shape.dims(0); + const int cols = weights_shape.dims(1); + QCHECK_EQ(rows % 4, 0); + QCHECK_EQ(cols % 16, 0); + CHECK_EQ(rows * cols, weights_data.size()); + // Compute the de-shuffled weights + std::vector deshuffled_data(weights_data.size()); + uint8* shuffled_data_ptr = weights_data.data(); + for (int r = 0; r < rows; r += 4) { + for (int c = 0; c < cols; c += 16) { + for (int i = 0; i < 4; i++) { + uint8* deshuffled_data_ptr = + deshuffled_data.data() + (r + i) * cols + c; + for (int j = 0; j < 16; j++) { + uint8 shuffled_val = *shuffled_data_ptr++; + // Deshuffling isn't only about deshuffling the storage layout, + // it's also about undoing the flipping of the sign bit, which is + // performed on the shuffled weights. + uint8 deshuffled_val = shuffled_val ^ 0x80; + *deshuffled_data_ptr++ = deshuffled_val; + } + } + } + } + CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols); + // Switch this FC op to using the deshuffled weights. + weights_data = std::move(deshuffled_data); + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 1f596ca8e5a28f17e816c33eea03725d16f7ce12..5dbfa54fa0369676dce638aec171b409a468da9f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -26,14 +26,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" #if TOCO_SUPPORT_PORTABLE_PROTOS -#include "third_party/protobuf/src/google/protobuf/text_format.h" +#include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" // TODO(aselle): Replace with using a container specific hash override instead. namespace std { @@ -100,6 +101,8 @@ std::vector>::iterator FindOp(Model& model, const char* OperatorTypeName(OperatorType type); string HelpfulOperatorTypeName(const Operator& op); +// Whether the operator can be fused with an activation function. Note that this +// will return false by default for new operators; fusing support is opt-in. bool OperatorSupportsFusedActivation(OperatorType type); void DumpGraphvizVideoFrame(const Model& model); @@ -112,7 +115,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks (using CHECK) that all dimensions of 'shape' are at least 1. +// Checks that all dimensions of 'shape' are at least 1. +bool IsValid(const Shape& shape); +// Same as above, but reports error using CHECK. void CheckShapeDimensions(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension @@ -315,7 +320,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output); // doesn't have enough range to represent the sum of elements, an error is // returned. template -port::Status NumElements(const std::vector& shape, U* num_elements) { +tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { static_assert( std::numeric_limits::max() <= std::numeric_limits::max(), "vector type exceed capabilities of NumElements"); @@ -326,19 +331,24 @@ port::Status NumElements(const std::vector& shape, U* num_elements) { // TensorFlow's shapes sometimes include -1 to represent an "unknown" // size but TOCO isn't able to create arrays of unknown sizes and will // crash in RequiredBufferSizeForShape(). - return port::Status(false, - "Tensor shape should not include negative values"); + return tensorflow::errors::InvalidArgument( + "Tensor shape should not include negative values"); } if (static_cast(dim) > std::numeric_limits::max() / *num_elements) { *num_elements = 0; - return port::Status(false, "Tensor shape is too large"); + return tensorflow::errors::InvalidArgument("Tensor shape is too large"); } *num_elements *= dim; } - return port::Status::OK(); + return tensorflow::Status::OK(); } +// A model file may have shuffled FC weights. +// When that happens, we want to de-shuffle them immediately on import, +// so that the rest of toco doesn't need to know about shuffled weights. +void UndoWeightsShuffling(Model* model); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index 87fd30db2cf54824a3c34ed875291d898f1a9e38..8609e5beddd200be4e5ebfe1fb2a79048e0e60ab 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { @@ -99,7 +100,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large"; TEST(NumElementsTest, Int) { int count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -114,7 +115,7 @@ TEST(NumElementsTest, Int) { TEST(NumElementsTest, Int32) { int32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -129,7 +130,7 @@ TEST(NumElementsTest, Int32) { TEST(NumElementsTest, Int64) { int64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 32767}, &count); EXPECT_TRUE(status.ok()); @@ -144,7 +145,7 @@ TEST(NumElementsTest, Int64) { TEST(NumElementsTest, UnsignedInt32) { uint32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 2048, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -159,7 +160,7 @@ TEST(NumElementsTest, UnsignedInt32) { TEST(NumElementsTest, UnsignedInt64) { uint64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 65535}, &count); @@ -174,4 +175,10 @@ TEST(NumElementsTest, UnsignedInt64) { EXPECT_EQ(status.error_message(), kLargeTensorMessage); } +TEST(FusedActivationTest, DefaultsToUnfused) { + EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd)); + EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone)); + EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast(255))); +} + } // namespace toco diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 824a164651073bac846a514505726a8ee85cc41d..a3df37358fac4d688ce7c513ed951cdd7e6bca1a 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -7,6 +7,8 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +common_copts = ["-Wall"] + py_binary( name = "visualize", srcs = ["visualize.py"], @@ -28,34 +30,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "benchmark_model", - srcs = ["benchmark_model.cc"], - linkopts = select({ - "//tensorflow:android": [ - "-pie", - "-landroid", - "-lm", - "-z defs", - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export - ], - "//conditions:default": [], - }), - deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", - ], - "//conditions:default": [ - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], - }), -) - cc_library( name = "gen_op_registration", srcs = ["gen_op_registration.cc"], @@ -79,6 +53,7 @@ cc_test( ], tags = [ "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ ":gen_op_registration", diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..183a545295f690decec47f1c31aa473667408a3d --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -0,0 +1,100 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +common_copts = ["-Wall"] + tflite_copts() + +cc_binary( + name = "benchmark_model", + srcs = [ + "benchmark_main.cc", + "logging.h", + ], + copts = common_copts, + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + ":benchmark_tflite_model_lib", + ], +) + +cc_library( + name = "command_line_flags", + srcs = ["command_line_flags.cc"], + hdrs = ["command_line_flags.h"], + copts = common_copts, +) + +cc_test( + name = "command_line_flags_test", + srcs = ["command_line_flags_test.cc"], + copts = common_copts, + visibility = ["//visibility:private"], + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "benchmark_tflite_model_lib", + srcs = [ + "benchmark_tflite_model.cc", + "logging.h", + ], + hdrs = ["benchmark_tflite_model.h"], + copts = common_copts, + deps = [ + ":benchmark_model_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + ], +) + +cc_library( + name = "benchmark_params", + srcs = [ + "benchmark_params.cc", + "logging.h", + ], + hdrs = ["benchmark_params.h"], + copts = common_copts, +) + +cc_library( + name = "benchmark_model_lib", + srcs = [ + "benchmark_model.cc", + "logging.h", + ], + hdrs = ["benchmark_model.h"], + copts = common_copts, + deps = [ + ":benchmark_params", + ":command_line_flags", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + "//tensorflow/contrib/lite/profiling:time", + "//tensorflow/core:stats_calculator_portable", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..93769305bde210b58f3b2cb668a9d8c1ad0ce396 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -0,0 +1,209 @@ +# TFLite Model Benchmark Tool + +## Description + +A simple C++ binary to benchmark a TFLite model and its individual operators, +both on desktop machines and on Android. The binary takes a TFLite model, +generates random inputs and then repeatedly runs the model for specified number +of runs. Aggregrate latency statistics are reported after running the benchmark. + +The instructions below are for running the binary on Desktop and Android, +for iOS please use the +[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). + +## Parameters + +The binary takes the following required parameters: + +* `graph`: `string` \ + The path to the TFLite model file. +* `input_layer`: `string` \ + The name of the input layer, this is typically the first layer of the model. +* `input_layer_shape`: `string` \ + The shape of the input layer. This is a comma separated string of the shape + of tensor of input layer. + +and the following optional parameters: + +* `num_threads`: `int` (default=1) \ + The number of threads to use for running TFLite interpreter. +* `warmup_runs`: `int` (default=1) \ + The number of warmup runs to do before starting the benchmark. +* `run_delay`: `float` (default=-1.0) \ + The delay in seconds between subsequent benchmark runs. Non-positive values + mean use no delay. +* `use_nnapi`: `bool` (default=false) \ + Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/). + This API is available on recent Android devices. + +## To build/install/run + +### On Android: + +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK. + +(1) Build for your specific platform, e.g.: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Connect your phone. Push the binary to your phone with adb push + (make the directory if required): + +``` +adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp +``` + +(3) Make the binary executable. + +``` +adb shell chmod +x /data/local/tmp/benchmark_model +``` + +(4) Push the compute graph that you need to test. For example: + +``` +adb push mobilenet_quant_v1_224.tflite /data/local/tmp +``` + +(5) Run the benchmark. For example: + +``` +adb shell /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="input" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=4 +``` + +### On desktop: +(1) build the binary + +``` +bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Run on your compute graph, similar to the Android case but without the need of adb shell. +For example: + +``` +bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ + --graph=mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=4 +``` + +The MobileNet graph used as an example here may be downloaded from +https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + + +## Reducing variance between runs on Android. + +Most modern Android phones use [ARM big.LITTLE](https://en.wikipedia.org/wiki/ARM_big.LITTLE) +architecture where some cores are more power hungry but faster than other cores. +When running benchmarks on these phones there can be significant variance +between different runs of the benchmark. One way to reduce variance between runs +is to set the [CPU affinity](https://en.wikipedia.org/wiki/Processor_affinity) +before running the benchmark. On Android this can be done using the `taskset` +command. +E.g. for running the benchmark on big cores on Pixel 2 with a single thread one +can use the following command: + +``` +adb shell tasket f0 /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="input" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=1 +``` + +where `f0` is the affinity mask for big cores on Pixel 2. +Note: The affinity mask varies with the device. + +## Profiling model operators +The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this, +compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED** +to compile benchmark with profiling support. +For example, to compile with profiling support on Android, add this flag to the previous command: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + --copt=-DTFLITE_PROFILING_ENABLED \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` +This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below: + +``` + +============================== Run Order ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] + CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] + DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] + CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] + DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] + CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] + DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] + CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] + RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape] + SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax] + +============================== Top by Computation Time ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + +Number of nodes executed: 31 +============================== Summary by node type ============================== + [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called] + CONV_2D 15 1.406 89.270% 89.270% 0.000 0 + DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0 + SOFTMAX 1 0.000 0.000% 100.000% 0.000 0 + RESHAPE 1 0.000 0.000% 100.000% 0.000 0 + AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0 + +Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929 +Memory (bytes): count=0 +31 nodes observed + + +Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9 +``` + + diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..372d31e838e5666df492ee3156022249a2d97691 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +int Main(int argc, char** argv) { +#ifdef TFLITE_CUSTOM_OPS_HEADER + TFLITE_LOG(INFO) << "STARTING with custom ops!"; +#else + TFLITE_LOG(INFO) << "STARTING!"; +#endif + BenchmarkTfLiteModel benchmark; + BenchmarkLoggingListener listener; + benchmark.AddListener(&listener); + benchmark.Run(argc, argv); + return 0; +} +} // namespace benchmark +} // namespace tflite + +int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..08648bcfe26365d180d984fde8f8e04b22eb45dd --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" + +#include + +#include +#include + +#include "tensorflow/contrib/lite/profiling/time.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +void SleepForSeconds(double sleep_seconds) { + if (sleep_seconds <= 0.0) { + return; + } + // Convert the run_delay string into a timespec. + timespec req; + req.tv_sec = static_cast(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else + nanosleep(&req, nullptr); +#endif +} + +} // namespace + +namespace tflite { +namespace benchmark { +using tensorflow::Stat; + +BenchmarkParams BenchmarkModel::DefaultParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create(50)); + params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create("")); + params.AddParam("output_prefix", BenchmarkParam::Create("")); + params.AddParam("warmup_runs", BenchmarkParam::Create(1)); + return params; +} + +BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {} + +void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { + auto inference_us = results.inference_time_us(); + auto init_us = results.startup_latency_us(); + auto warmup_us = results.warmup_time_us(); + TFLITE_LOG(INFO) << "Average inference timings in us: " + << "Warmup: " << warmup_us.avg() << ", " + << "Init: " << init_us << ", " + << "no stats: " << inference_us.avg(); +} + +std::vector BenchmarkModel::GetFlags() { + return { + CreateFlag("num_runs", ¶ms_, "number of runs"), + CreateFlag("run_delay", ¶ms_, "delay between runs in seconds"), + CreateFlag("num_threads", ¶ms_, "number of threads"), + CreateFlag("benchmark_name", ¶ms_, "benchmark name"), + CreateFlag("output_prefix", ¶ms_, + "benchmark output prefix"), + CreateFlag("warmup_runs", ¶ms_, + "how many runs to initialize model"), + }; +} + +void BenchmarkModel::LogFlags() { + TFLITE_LOG(INFO) << "Num runs: [" << params_.Get("num_runs") << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" + << params_.Get("run_delay") << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.Get("num_threads") + << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" + << params_.Get("benchmark_name") << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" + << params_.Get("output_prefix") << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get("warmup_runs") + << "]"; +} + +Stat BenchmarkModel::Run(int num_times, RunType run_type) { + Stat run_stats; + TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; + for (int run = 0; run < num_times; run++) { + listeners_.OnSingleRunStart(run_type); + int64_t start_us = profiling::time::NowMicros(); + RunImpl(); + int64_t end_us = profiling::time::NowMicros(); + listeners_.OnSingleRunEnd(); + + run_stats.UpdateStat(end_us - start_us); + SleepForSeconds(params_.Get("run_delay")); + } + + std::stringstream stream; + run_stats.OutputToStream(&stream); + TFLITE_LOG(INFO) << stream.str() << std::endl; + + return run_stats; +} + +void BenchmarkModel::Run(int argc, char **argv) { + if (!ParseFlags(argc, argv)) { + return; + } + + LogFlags(); + + listeners_.OnBenchmarkStart(params_); + int64_t initialization_start_us = profiling::time::NowMicros(); + Init(); + int64_t initialization_end_us = profiling::time::NowMicros(); + int64_t startup_latency_us = initialization_end_us - initialization_start_us; + TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3 + << "ms"; + + uint64_t input_bytes = ComputeInputBytes(); + Stat warmup_time_us = + Run(params_.Get("warmup_runs"), WARMUP); + Stat inference_time_us = + Run(params_.Get("num_runs"), REGULAR); + listeners_.OnBenchmarkEnd( + {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); +} + +bool BenchmarkModel::ParseFlags(int argc, char **argv) { + auto flag_list = GetFlags(); + const bool parse_result = + Flags::Parse(&argc, const_cast(argv), flag_list); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flag_list); + TFLITE_LOG(ERROR) << usage; + return false; + } + return ValidateFlags(); +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h new file mode 100644 index 0000000000000000000000000000000000000000..942e21f67a7f864f16b7b1b85b2599d5c872b5c7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace benchmark { + +enum RunType { + WARMUP, + REGULAR, +}; + +class BenchmarkResults { + public: + BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes, + tensorflow::Stat warmup_time_us, + tensorflow::Stat inference_time_us) + : startup_latency_us_(startup_latency_us), + input_bytes_(input_bytes), + warmup_time_us_(warmup_time_us), + inference_time_us_(inference_time_us) {} + + tensorflow::Stat inference_time_us() const { + return inference_time_us_; + } + tensorflow::Stat warmup_time_us() const { return warmup_time_us_; } + int64_t startup_latency_us() const { return startup_latency_us_; } + uint64_t input_bytes() const { return input_bytes_; } + double throughput_MB_per_second() const { + double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) / + inference_time_us_.sum(); + return bytes_per_sec / (1024.0 * 1024.0); + } + + private: + int64_t startup_latency_us_; + uint64_t input_bytes_; + tensorflow::Stat warmup_time_us_; + tensorflow::Stat inference_time_us_; +}; + +class BenchmarkListener { + public: + virtual void OnBenchmarkStart(const BenchmarkParams& params) {} + virtual void OnSingleRunStart(RunType runType) {} + virtual void OnSingleRunEnd() {} + virtual void OnBenchmarkEnd(const BenchmarkResults& results) {} + virtual ~BenchmarkListener() {} +}; + +// A listener that forwards its method calls to a collection of listeners. +class BenchmarkListeners : public BenchmarkListener { + public: + // Added a listener to the listener collection. + // |listener| is not owned by the instance of |BenchmarkListeners|. + // |listener| should not be null and should outlast the instance of + // |BenchmarkListeners|. + void AddListener(BenchmarkListener* listener) { + listeners_.push_back(listener); + } + + void OnBenchmarkStart(const BenchmarkParams& params) override { + for (auto listener : listeners_) { + listener->OnBenchmarkStart(params); + } + } + + void OnSingleRunStart(RunType runType) override { + for (auto listener : listeners_) { + listener->OnSingleRunStart(runType); + } + } + + void OnSingleRunEnd() override { + for (auto listener : listeners_) { + listener->OnSingleRunEnd(); + } + } + + void OnBenchmarkEnd(const BenchmarkResults& results) override { + for (auto listener : listeners_) { + listener->OnBenchmarkEnd(results); + } + } + + ~BenchmarkListeners() {} + + private: + // Use vector so listeners are invoked in the order they are added. + std::vector listeners_; +}; + +// Benchmark listener that just logs the results of benchmark run. +class BenchmarkLoggingListener : public BenchmarkListener { + void OnBenchmarkEnd(const BenchmarkResults& results) override; +}; + +template +Flag CreateFlag(const char* name, BenchmarkParams* params, + const std::string& usage) { + return Flag(name, [params, name](const T& val) { params->Set(name, val); }, + params->Get(name), usage); +} + +// Benchmarks a model. +// +// Subclasses need to implement initialization and running of the model. +// The results can be collected by adding BenchmarkListener(s). +class BenchmarkModel { + public: + static BenchmarkParams DefaultParams(); + BenchmarkModel(); + BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {} + virtual ~BenchmarkModel() {} + bool ParseFlags(int argc, char** argv); + virtual void Init() = 0; + void Run(int argc, char** argv); + void AddListener(BenchmarkListener* listener) { + listeners_.AddListener(listener); + } + + protected: + virtual void LogFlags(); + virtual bool ValidateFlags() { return true; } + virtual std::vector GetFlags(); + virtual uint64_t ComputeInputBytes() = 0; + virtual tensorflow::Stat Run(int num_times, RunType run_type); + virtual void RunImpl() = 0; + BenchmarkParams params_; + BenchmarkListeners listeners_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dcf580a9d4995e6cb3706d3562bc8a2f4670082 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" + +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a, + BenchmarkParam::ParamType b) { + TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter."; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_INT32; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_BOOL; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_FLOAT; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_STRING; +} + +void BenchmarkParams::AssertParamExists(const std::string& name) const { + TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found."; +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h new file mode 100644 index 0000000000000000000000000000000000000000..33448dd1623577fdfda6316c588cc60ccbaa1994 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +template +class TypedBenchmarkParam; + +class BenchmarkParam { + protected: + enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; + + public: + template + static std::unique_ptr Create(const T& default_value) { + return std::unique_ptr( + new TypedBenchmarkParam(default_value)); + } + + template + TypedBenchmarkParam* AsTyped() { + AssertHasSameType(GetValueType(), type_); + return static_cast*>(this); + } + virtual ~BenchmarkParam() {} + BenchmarkParam(ParamType type) : type_(type) {} + + private: + static void AssertHasSameType(ParamType a, ParamType b); + template + static ParamType GetValueType(); + + const ParamType type_; +}; + +template +class TypedBenchmarkParam : public BenchmarkParam { + public: + TypedBenchmarkParam(const T& value) + : BenchmarkParam(GetValueType()), value_(value) {} + void Set(const T& value) { value_ = value; } + + T Get() { return value_; } + + private: + T value_; +}; + +class BenchmarkParams { + public: + void AddParam(const std::string& name, + std::unique_ptr value) { + params_[name] = std::move(value); + } + + bool HasParam(const std::string& name) const { + return params_.find(name) != params_.end(); + } + + template + void Set(const std::string& name, const T& value) { + AssertParamExists(name); + params_.at(name)->AsTyped()->Set(value); + } + + template + T Get(const std::string& name) const { + AssertParamExists(name); + return params_.at(name)->AsTyped()->Get(); + } + + private: + void AssertParamExists(const std::string& name) const; + std::unordered_map> params_; +}; + +} // namespace benchmark +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..73affc26b034f415ae2a2101e0b558cdb94d8d5b --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -0,0 +1,334 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +#ifdef TFLITE_CUSTOM_OPS_HEADER +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); +#endif + +namespace tflite { +namespace benchmark { + +void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) { + TFLITE_BENCHMARK_CHECK(interpreter); + interpreter_ = interpreter; + interpreter_->SetProfiler(&profiler_); +} + +void ProfilingListener::OnSingleRunStart(RunType run_type) { + if (run_type == REGULAR) { + profiler_.Reset(); + profiler_.StartProfiling(); + } +} + +void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) { + if (has_profiles_) { + TFLITE_LOG(INFO) << summarizer_.GetOutputString(); + } +} + +void ProfilingListener::OnSingleRunEnd() { + profiler_.StopProfiling(); + auto profile_events = profiler_.GetProfileEvents(); + has_profiles_ = !profile_events.empty(); + summarizer_.ProcessProfiles(profile_events, *interpreter_); +} + +namespace { + +std::vector Split(const std::string& str, const char delim) { + std::istringstream input(str); + std::vector results; + std::string item; + while (std::getline(input, item, delim)) { + results.push_back(item); + } + return results; +} + +template +bool SplitAndParse(const std::string& str, char delim, std::vector* values) { + std::istringstream input(str); + bool first = true; + while (!input.eof()) { + if (!first) { + char c; + input >> c; + if (c != delim) { + return false; + } + } else { + first = false; + } + T val; + input >> val; + if (!input.eof() && !input.good()) { + return false; + } + values->push_back(val); + } + return true; +} + +template +void FillRandomValue(T* ptr, const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + *ptr++ = random_func(); + } +} + +void FillRandomString(tflite::DynamicBuffer* buffer, + const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} + +bool PopulateInputLayerInfo( + const string& names_string, const string& shapes_string, + std::vector* info) { + std::vector names = Split(names_string, ','); + std::vector shapes = Split(shapes_string, ':'); + + if (names.size() != shapes.size()) { + TFLITE_LOG(ERROR) << "The number of items in" + << " --input_layer_shape (" << shapes_string << ", with " + << shapes.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_shape=1,224,224,4:1,20"; + return false; + } + + for (int i = 0; i < names.size(); ++i) { + info->push_back(BenchmarkTfLiteModel::InputLayerInfo()); + BenchmarkTfLiteModel::InputLayerInfo& input = info->back(); + + input.name = names[i]; + + TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape)) + << "Incorrect size string specified: " << shapes[i]; + for (int dim : input.shape) { + if (dim == -1) { + TFLITE_LOG(ERROR) + << "Any unknown sizes in the shapes (-1's) must be replaced" + << " with the size you want to benchmark with."; + return false; + } + } + } + + return true; +} + +BenchmarkParams GetDefaultParams() { + BenchmarkParams default_params = BenchmarkModel::DefaultParams(); + default_params.AddParam("graph", BenchmarkParam::Create("")); + default_params.AddParam("input_layer", + BenchmarkParam::Create("")); + default_params.AddParam("input_layer_shape", + BenchmarkParam::Create("")); + default_params.AddParam("use_nnapi", BenchmarkParam::Create(false)); + return default_params; +} + +} // namespace + +BenchmarkTfLiteModel::BenchmarkTfLiteModel() + : BenchmarkModel(GetDefaultParams()) { + AddListener(&profiling_listener_); +} + +BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) + : BenchmarkModel(std::move(params)) { + AddListener(&profiling_listener_); +} + +std::vector BenchmarkTfLiteModel::GetFlags() { + std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); + std::vector specific_flags = { + CreateFlag("graph", ¶ms_, "graph file name"), + CreateFlag("input_layer", ¶ms_, "input layer names"), + CreateFlag("input_layer_shape", ¶ms_, + "input layer shape"), + CreateFlag("use_nnapi", ¶ms_, "use nnapi api")}; + + flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); + return flags; +} + +void BenchmarkTfLiteModel::LogFlags() { + BenchmarkModel::LogFlags(); + TFLITE_LOG(INFO) << "Graph: [" << params_.Get("graph") << "]"; + TFLITE_LOG(INFO) << "Input layers: [" + << params_.Get("input_layer") << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" + << params_.Get("input_layer_shape") << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get("use_nnapi") << "]"; +} + +bool BenchmarkTfLiteModel::ValidateFlags() { + if (params_.Get("graph").empty()) { + TFLITE_LOG(ERROR) + << "Please specify the name of your TF Lite input file with --graph"; + return false; + } + return PopulateInputLayerInfo(params_.Get("input_layer"), + params_.Get("input_layer_shape"), + &inputs); +} + +uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { + TFLITE_BENCHMARK_CHECK(interpreter); + uint64_t total_input_bytes = 0; + for (int input : interpreter->inputs()) { + auto* t = interpreter->tensor(input); + total_input_bytes += t->bytes; + } + return total_input_bytes; +} + +void BenchmarkTfLiteModel::Init() { + std::string graph = params_.Get("graph"); + model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); + if (!model) { + TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; + } + TFLITE_LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + TFLITE_LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + TFLITE_LOG(FATAL) << "Failed to construct interpreter"; + } + profiling_listener_.SetInterpreter(interpreter.get()); + + const int32_t num_threads = params_.Get("num_threads"); + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); + } + + bool use_nnapi = params_.Get("use_nnapi"); + + interpreter->UseNNAPI(use_nnapi); + auto interpreter_inputs = interpreter->inputs(); + + if (!inputs.empty()) { + TFLITE_BENCHMARK_CHECK_EQ(inputs.size(), interpreter_inputs.size()) + << "Inputs mismatch: Model inputs #:" << interpreter_inputs.size() + << " expected: " << inputs.size(); + } + + // TFLITE_BENCHMARK_CHECK that all names and types match + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name) + << "Tensor # " << i << " is named " << t->name << " but flags call it " + << input.name; + } + + // Resize all non-string tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + if (t->type != kTfLiteString) { + interpreter->ResizeInputTensor(i, input.shape); + } + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Set the values of the input tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + std::vector sizes = input.shape; + + // TODO(ahentz): below we ignore the O-th dimension (number of batches). + if (t->type == kTfLiteFloat32) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); + } else if (t->type == kTfLiteUInt8) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) % 255; }); + } else if (t->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, sizes, []() { + return "we're have some friends over saturday to hang out in the yard"; + }); + buffer.WriteToTensor(interpreter->tensor(i)); + } else { + TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name + << " of type " << t->type; + } + } +} + +void BenchmarkTfLiteModel::RunImpl() { + if (interpreter->Invoke() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to invoke!"; + } +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h new file mode 100644 index 0000000000000000000000000000000000000000..50cc3f24b3bd2f31555eac69ff208fa2480449b9 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" + +namespace tflite { +namespace benchmark { + +// Dumps profiling events if profiling is enabled +class ProfilingListener : public BenchmarkListener { + public: + explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {} + + void SetInterpreter(Interpreter* interpreter); + + void OnSingleRunStart(RunType run_type) override; + + void OnSingleRunEnd() override; + + void OnBenchmarkEnd(const BenchmarkResults& results) override; + + private: + Interpreter* interpreter_; + profiling::Profiler profiler_; + profiling::ProfileSummarizer summarizer_; + bool has_profiles_; +}; + +// Benchmarks a TFLite model by running tflite interpreter. +class BenchmarkTfLiteModel : public BenchmarkModel { + public: + BenchmarkTfLiteModel(); + BenchmarkTfLiteModel(BenchmarkParams params); + + std::vector GetFlags() override; + void LogFlags() override; + bool ValidateFlags() override; + uint64_t ComputeInputBytes() override; + void Init() override; + void RunImpl() override; + virtual ~BenchmarkTfLiteModel() {} + + struct InputLayerInfo { + std::string name; + std::vector shape; + }; + + private: + std::unique_ptr model; + std::unique_ptr interpreter; + std::vector inputs; + ProfilingListener profiling_listener_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff818b9dcb5ee0b58b95c3dceae74083dbd4f0da --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -0,0 +1,198 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" + +#include +#include +#include +#include +#include + +namespace tflite { +namespace { + +template +std::string ToString(T val) { + std::ostringstream stream; + stream << val; + return stream.str(); +} + +bool ParseFlag(const std::string& arg, const std::string& flag, + const std::function& parse_func, + bool* value_parsing_ok) { + *value_parsing_ok = true; + std::string flag_prefix = "--" + flag + "="; + if (arg.find(flag_prefix) != 0) { + return false; + } + bool has_value = arg.size() >= flag_prefix.size(); + *value_parsing_ok = has_value; + if (has_value) { + *value_parsing_ok = parse_func(arg.substr(flag_prefix.size())); + } + return true; +} + +template +bool ParseFlag(const std::string& flag_value, + const std::function& hook) { + std::istringstream stream(flag_value); + T read_value; + stream >> read_value; + if (!stream.eof() && !stream.good()) { + return false; + } + hook(read_value); + return true; +} + +bool ParseBoolFlag(const std::string& flag_value, + const std::function& hook) { + if (flag_value != "true" && flag_value != "false") { + return false; + } + + hook(flag_value == "true"); + return true; +} +} // namespace + +Flag::Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text) + : name_(name), + type_(TYPE_INT32), + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); + }), + default_for_display_(ToString(default_value)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text) + : name_(name), + type_(TYPE_INT64), + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); + }), + default_for_display_(ToString(default_value)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); + }), + default_for_display_(ToString(default_value)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text) + : name_(name), + type_(TYPE_BOOL), + value_hook_([hook](const std::string& flag_value) { + return ParseBoolFlag(flag_value, hook); + }), + default_for_display_(default_value ? "true" : "false"), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, + const std::function& hook, + const std::string& default_value, const std::string& usage_text) + : name_(name), + type_(TYPE_STRING), + value_hook_([hook](const std::string& flag_value) { + hook(flag_value); + return true; + }), + default_for_display_(default_value), + usage_text_(usage_text) {} + +bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { + return ParseFlag(arg, name_, value_hook_, value_parsing_ok); +} + +std::string Flag::GetTypeName() const { + switch (type_) { + case TYPE_INT32: + return "int32"; + case TYPE_INT64: + return "int64"; + case TYPE_FLOAT: + return "float"; + case TYPE_BOOL: + return "bool"; + case TYPE_STRING: + return "string"; + } + + return "unknown"; +} + +/*static*/ bool Flags::Parse(int* argc, const char** argv, + const std::vector& flag_list) { + bool result = true; + std::vector unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (std::string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag& flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + int dst = 1; // Skip argv[0] + for (auto f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0); +} + +/*static*/ std::string Flags::Usage(const std::string& cmdline, + const std::vector& flag_list) { + std::ostringstream usage_text; + usage_text << "usage: " << cmdline << "\n"; + if (!flag_list.empty()) { + usage_text << "Flags:\n"; + } + + for (const Flag& flag : flag_list) { + auto type_name = flag.GetTypeName(); + usage_text << "\t"; + usage_text << "--" << flag.name_ << "=" << flag.default_for_display_; + usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n"; + } + return usage_text.str(); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..2e514ae3ead3b602b8217998ec09177b1e6a2376 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ + +#include +#include +#include + +namespace tflite { +// A simple command-line argument parsing module. +// Dependency free simplified port of core/util/command_line_flags. +// This class is written for benchmarks and uses inefficient string +// concatenation. This was written to avoid dependency on tensorflow/core/util +// which transitively brings in a lot of other dependencies that are not +// necessary for tflite benchmarking code. +// The recommended way of using it is with local variables and an initializer +// list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// std::string some_name = "something"; +// +// std::vector flag_list = { +// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"), +// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"), +// Flag::CreateFlag("some_name", &some_name, "a string that affects Z") +// }; +// // Get usage message before ParseFlags() to capture default values. +// std::string usage = Flag::Usage(argv[0], flag_list); +// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); +// +// tensorflow::port::InitMain(usage.c_str(), &argc, &argv); +// if (argc != 1 || !parsed_values_ok) { +// ...output usage and error message... +// } +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +// +// NOTE: Unlike gflags-style libraries, this library is intended to be +// used in the `main()` function of your binary. It does not handle +// flag definitions that are scattered around the source code. + +// A description of a single command line flag, holding its name, type, usage +// text, and a pointer to the corresponding variable. +class Flag { + public: + template + static Flag CreateFlag(const char* name, T* val, const char* usage) { + return Flag(name, [val](const T& v) { *val = v; }, *val, usage); + } + + Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + const std::string& default_value, const std::string& usage_text); + + private: + friend class Flags; + + bool Parse(const std::string& arg, bool* value_parsing_ok) const; + + std::string name_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::string GetTypeName() const; + + std::function value_hook_; + std::string default_for_display_; + + std::string usage_text_; +}; + +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + static bool Parse(int* argc, const char** argv, + const std::vector& flag_list); + + // Return a usage message with command line cmdline, and the + // usage_text strings in flag_list[]. + static std::string Usage(const std::string& cmdline, + const std::vector& flag_list); +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..03da8051099899241fa5241374d754adb1aa93c6 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +TEST(CommandLineFlagsTest, BasicUsage) { + int some_int32 = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something_a"; + float some_float = -23.23f; + const char* argv_strings[] = {"program_name", + "--some_int32=20", + "--some_int64=214748364700", + "--some_switch=true", + "--some_name=somethingelse", + "--some_float=42.0"}; + int argc = 6; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + { + Flag::CreateFlag("some_int32", &some_int32, "some int32"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name"), + Flag::CreateFlag("some_float", &some_float, "some float"), + }); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(20, some_int32); + EXPECT_EQ(214748364700, some_int64); + EXPECT_EQ(true, some_switch); + EXPECT_EQ("somethingelse", some_name); + EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, EmptyStringFlag) { + int argc = 2; + std::string some_string = "invalid"; + const char* argv_strings[] = {"program_name", "--some_string="}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_string", &some_string, "some string")}); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(some_string, ""); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadIntValue) { + int some_int = 10; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(10, some_int); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadBoolValue) { + bool some_switch = false; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_switch", &some_switch, "some switch")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(false, some_switch); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadFloatValue) { + float some_float = -23.23f; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_float", &some_float, "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_NEAR(-23.23f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +// Return whether str==pat, but allowing any whitespace in pat +// to match zero or more whitespace characters in str. +static bool MatchWithAnyWhitespace(const std::string& str, + const std::string& pat) { + bool matching = true; + int pat_i = 0; + for (int str_i = 0; str_i != str.size() && matching; str_i++) { + if (isspace(str[str_i])) { + matching = (pat_i != pat.size() && isspace(pat[pat_i])); + } else { + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]); + } + } + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + return (matching && pat_i == pat.size()); +} + +TEST(CommandLineFlagsTest, UsageString) { + int some_int = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something"; + // Don't test float in this case, because precision is hard to predict and + // match against, and we don't want a flakey test. + const std::string tool_name = "some_tool_name"; + std::string usage = Flags::Usage( + tool_name + " ", + {Flag::CreateFlag("some_int", &some_int, "some int"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name")}); + // Match the usage message, being sloppy about whitespace. + const char* expected_usage = + " usage: some_tool_name \n" + "Flags:\n" + "--some_int=10\tint32\tsome int\n" + "--some_int64=21474836470\tint64\tsome int64\n" + "--some_switch=false\tbool\tsome switch\n" + "--some_name=something\tstring\tsome name\n"; + ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage; + + // Again but with no flags. + usage = Flags::Usage(tool_name, {}); + ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true) + << usage; +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c8d3307e29efaebdc5c309dc7e4262b54d64943f --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md @@ -0,0 +1,43 @@ +# TFLite iOS benchmark app. + +## Description + +An iOS app to benchmark TFLite models. + +The app reads benchmark parameters from a JSON file named `benchmark_params.json` +in its `benchmark_data` directory. Any downloaded models for benchmarking should +also be placed in `benchmark_data` directory. + +The JSON file specifies the name of the model file and other benchmarking +parameters like inputs to the model, type of inputs, number of iterations, +number of threads. The default values in the JSON file are for the +Mobilenet_1.0_224 model +([paper](https://arxiv.org/pdf/1704.04861.pdf), +[tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)) + +## To build/install/run + +- Follow instructions at [iOS build for TFLite] +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md) +to build TFLite. + +Running + +```bash +tensorflow/contrib/lite/build_ios_universal_lib.sh +``` +will also build `tensorflow/contrib/lite/gen/lib/benchmark-lib.a` . + +- Now copy the downloaded model file to `benchmark_data` directory. + +- Modify `benchmark_params.json` change the `input_layer`, `input_layer_shape` +and other benchmark parameters. + +- Change `Build Phases -> Copy Bundle Resources` and add the model file to the +resources that need to be copied. + +- Ensure that `Build Phases -> Link Binary With Library` contains the +`Accelerate framework` and `tensorflow/contrib/lite/gen/lib/benchmark-lib.a`. + +- Now try running the app. The app has a single button that runs the benchmark + on the model and displays results in a text view below. diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..b908f733d49b56a6b41ebea4185f1fe8c11edc60 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj @@ -0,0 +1,381 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */ = {isa = PBXBuildFile; fileRef = 6FE7579920D59CE500F01636 /* benchmark_params.json */; }; + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */; }; + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579E20D5A6A700F01636 /* Accelerate.framework */; }; + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */; }; + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */; }; + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */; }; + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400120D592D8008C9FE4 /* Main.storyboard */; }; + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400420D592DA008C9FE4 /* Assets.xcassets */; }; + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE9400A20D592DA008C9FE4 /* main.m */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 6FE7579920D59CE500F01636 /* benchmark_params.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = benchmark_params.json; sourceTree = ""; }; + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib/benchmark-lib.a"; sourceTree = ""; }; + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TFLiteBenchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; path = BenchmarkViewController.h; sourceTree = ""; }; + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; + 6FE9400220D592D8008C9FE4 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 6FE9400920D592DA008C9FE4 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 6FE9400A20D592DA008C9FE4 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 6FE93FF520D592D8008C9FE4 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */, + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 6FE7579820D59C8B00F01636 /* benchmark_data */ = { + isa = PBXGroup; + children = ( + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */, + 6FE7579920D59CE500F01636 /* benchmark_params.json */, + ); + path = benchmark_data; + sourceTree = ""; + }; + 6FE7579B20D5A5E000F01636 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */, + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 6FE93FEF20D592D8008C9FE4 = { + isa = PBXGroup; + children = ( + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */, + 6FE93FF920D592D8008C9FE4 /* Products */, + 6FE7579B20D5A5E000F01636 /* Frameworks */, + ); + sourceTree = ""; + }; + 6FE93FF920D592D8008C9FE4 /* Products */ = { + isa = PBXGroup; + children = ( + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */, + ); + name = Products; + sourceTree = ""; + }; + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXGroup; + children = ( + 6FE7579820D59C8B00F01636 /* benchmark_data */, + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */, + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */, + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */, + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */, + 6FE9400120D592D8008C9FE4 /* Main.storyboard */, + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */, + 6FE9400920D592DA008C9FE4 /* Info.plist */, + 6FE9400A20D592DA008C9FE4 /* main.m */, + ); + path = TFLiteBenchmark; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXNativeTarget; + buildConfigurationList = 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */; + buildPhases = ( + 6FE93FF420D592D8008C9FE4 /* Sources */, + 6FE93FF520D592D8008C9FE4 /* Frameworks */, + 6FE93FF620D592D8008C9FE4 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = TFLiteBenchmark; + productName = TFLiteBenchmark; + productReference = 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 6FE93FF020D592D8008C9FE4 /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 1000; + ORGANIZATIONNAME = Example; + TargetAttributes = { + 6FE93FF720D592D8008C9FE4 = { + CreatedOnToolsVersion = 10.0; + }; + }; + }; + buildConfigurationList = 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 6FE93FEF20D592D8008C9FE4; + productRefGroup = 6FE93FF920D592D8008C9FE4 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 6FE93FF620D592D8008C9FE4 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */, + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */, + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */, + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 6FE93FF420D592D8008C9FE4 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */, + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */, + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 6FE9400120D592D8008C9FE4 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 6FE9400220D592D8008C9FE4 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 6FE9400C20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + ONLY_ACTIVE_ARCH = YES; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + }; + name = Debug; + }; + 6FE9400D20D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = NO; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 6FE9400F20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + "USER_HEADER_SEARCH_PATHS[arch=*]" = ""; + }; + name = Debug; + }; + 6FE9401020D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400C20D592DA008C9FE4 /* Debug */, + 6FE9400D20D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400F20D592DA008C9FE4 /* Debug */, + 6FE9401020D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 6FE93FF020D592D8008C9FE4 /* Project object */; +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..a55c03e00b5065e3b149c65f820f11d13c064d87 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h @@ -0,0 +1,22 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..b1165940e9a29ac693d473a1c852b7b0681392fc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m @@ -0,0 +1,27 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + return YES; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..d8db8d65fd79fd541b2b7eba75c7378af3448f9c --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,98 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..bfa36129419f8bd7ad73581cb9f07b8c6eec3fcf --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..adcfe1ef4e708ea6f87c77f4a740b58e5027d3e5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..ec6dea0546060881682c44ad451f4812a2f3d7ea --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h @@ -0,0 +1,21 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface BenchmarkViewController : UIViewController +@property(weak, nonatomic) IBOutlet UITextView *resultsView; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..356d5b0e17abc715de9b8f7a20ec7459f3468da1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm @@ -0,0 +1,125 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "BenchmarkViewController.h" +#import +#import +#import +#import +#import "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#import "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +NSString* FilePathForResourceName(NSString* filename) { + NSString* name = [filename stringByDeletingPathExtension]; + NSString* extension = [filename pathExtension]; + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + TFLITE_LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +NSDictionary* ParseJson() { + NSString* params_json_path = FilePathForResourceName(@"benchmark_params.json"); + NSData* data = [NSData dataWithContentsOfFile:params_json_path]; + return [NSJSONSerialization JSONObjectWithData:data options:kNilOptions error:nil]; +} + +std::string FormatCommandLineParam(NSString* key, NSString* value) { + std::ostringstream stream; + stream << "--" << [key UTF8String] << "=" << [value UTF8String]; + return stream.str(); +} + +// Reads the |benchmark_params.json| to read command line parameters and returns them as a vector of +// strings. +void ReadCommandLineParameters(std::vector* params) { + NSDictionary* param_dict = ParseJson(); + for (NSString* key in param_dict) { + NSString* value = param_dict[key]; + if ([key isEqualToString:@"graph"]) { + value = FilePathForResourceName(value); + } + params->push_back(FormatCommandLineParam(key, value)); + } +} +std::vector StringVecToCharPtrVec(const std::vector& str_vec) { + std::vector charptr_vec; + std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(charptr_vec), + [](const std::string& s) -> char* { return const_cast(s.c_str()); }); + return charptr_vec; +} + +class ResultsListener : public tflite::benchmark::BenchmarkListener { + public: + void OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) override; + std::string Results() { return results_; } + + private: + std::string results_; +}; + +void OutputMicrosecondsStatToStream(const tensorflow::Stat& time_us, + const std::string& prefix, std::ostringstream* stream) { + *stream << prefix << "Num runs: " << time_us.count() << "\n"; + + *stream << prefix << "Average: " << time_us.avg() / 1e3 << " ms\n"; + *stream << prefix << "Min: " << time_us.min() / 1e3 << " ms \n"; + *stream << prefix << "Max: " << time_us.max() / 1e3 << " ms \n"; + *stream << prefix << "Std deviation: " << time_us.std_deviation() / 1e3 << " ms\n"; +} + +void ResultsListener::OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) { + std::ostringstream stream; + const std::string prefix = " - "; + stream << "Startup latency: "; + stream << results.startup_latency_us() / 1e3 << " ms\n"; + stream << "\nInference:\n"; + OutputMicrosecondsStatToStream(results.inference_time_us(), prefix, &stream); + stream << "\nWarmup:\n"; + OutputMicrosecondsStatToStream(results.warmup_time_us(), prefix, &stream); + + results_ = stream.str(); +} + +std::string RunBenchmark() { + ResultsListener listener; + tflite::benchmark::BenchmarkTfLiteModel benchmark; + benchmark.AddListener(&listener); + // TODO(shashishekhar): Passing arguments like this is brittle, refactor the BenchmarkParams + // so that it contains arguments for BenchmarkTfLiteModel and set parameters using BenchmarkParams + std::vector command_line_params; + // Benchmark model expects first arg to be program name. + // push a string for name of program. + command_line_params.push_back("benchmark_tflite_model"); + ReadCommandLineParameters(&command_line_params); + std::vector argv = StringVecToCharPtrVec(command_line_params); + int argc = static_cast(argv.size()); + benchmark.Run(argc, argv.data()); + return listener.Results(); +} +} // namespace + +@interface BenchmarkViewController () +@end + +@implementation BenchmarkViewController +- (IBAction)onBenchmarkModel:(UIButton*)sender { + std::string results = RunBenchmark(); + [_resultsView setText:[NSString stringWithUTF8String:results.c_str()]]; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..96051cf08ff54b51f458eca6f0126dd99dfc51dc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist @@ -0,0 +1,43 @@ + + + + + UILaunchStoryboardName + Main + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json new file mode 100644 index 0000000000000000000000000000000000000000..d344a7a5efaef53500bc0f88d29ca7aecf59290a --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json @@ -0,0 +1,10 @@ +{ + "benchmark_name" : "mobile_net_benchmark", + "num_threads" : "4", + "num_runs" : "20", + "warmup_runs" : "1", + "graph" : "mobilenet_v1_1.0_224.tflite", + "input_layer" : "input", + "input_layer_shape" : "1,224,224,3", + "run_delay" : "-1" +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m new file mode 100644 index 0000000000000000000000000000000000000000..1e70b9cd1d82f320ec048642520dbc54dc0f7934 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m @@ -0,0 +1,23 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char* argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/tensorflow/contrib/lite/tools/benchmark/logging.h b/tensorflow/contrib/lite/tools/benchmark/logging.h new file mode 100644 index 0000000000000000000000000000000000000000..9e9292e2feacf0eff0751534f02cdacd21c9b0dd --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/logging.h @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ + +// LOG and CHECK macros for benchmarks. + +#include +#include +#include + +namespace tflite { +namespace logging { +// A wrapper that logs to stderr. +// +// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros. +class LoggingWrapper { + public: + enum class LogSeverity : int { + INFO = 0, + WARN = 1, + ERROR = 2, + FATAL = 3, + }; + LoggingWrapper(LogSeverity severity) + : severity_(severity), should_log_(true) {} + LoggingWrapper(LogSeverity severity, bool log) + : severity_(severity), should_log_(log) {} + std::stringstream& Stream() { return stream_; } + ~LoggingWrapper() { + if (should_log_) { + std::cerr << stream_.str() << std::endl; + if (severity_ == LogSeverity::FATAL) { + std::flush(std::cerr); + std::abort(); + } + } + } + + private: + std::stringstream stream_; + LogSeverity severity_; + bool should_log_; +}; + +} // namespace logging + +} // namespace tflite + +#define TFLITE_LOG(severity) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::severity) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK(condition) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::FATAL, \ + (condition) ? false : true) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc deleted file mode 100644 index 869c531b3e3db37f634761e7b25d4ffa1e8304a7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ /dev/null @@ -1,475 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -#ifdef TFLITE_CUSTOM_OPS_HEADER -void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); -#endif - -namespace tflite { - -using ::tensorflow::Env; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsFloats; -using ::tensorflow::str_util::SplitAndParseAsInts; - -struct InputLayerInfo { - string name; - TfLiteType data_type; - std::vector shape; - // Note that initialization_values is currently unused. - std::vector initialization_values; -}; - -template -void FillRandomValue(T* ptr, const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - *ptr++ = random_func(); - } -} - -void FillRandomString(tflite::DynamicBuffer* buffer, - const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - auto str = random_func(); - buffer->AddString(str.data(), str.length()); - } -} - -TfLiteType TfLiteTypeFromString(const string& input_layer_type) { - if (input_layer_type == "string") - return kTfLiteString; - else if (input_layer_type == "float") - return kTfLiteFloat32; - else if (input_layer_type == "uint8") - return kTfLiteUInt8; - else if (input_layer_type == "int32") - return kTfLiteInt32; - else if (input_layer_type == "int64") - return kTfLiteInt64; - else - return kTfLiteNoType; -} - -std::vector ShapeFromTfLiteTensor(TfLiteTensor* t) { - std::vector result; - result.reserve(t->dims->size); - for (int i = 0; i < t->dims->size; ++i) { - result.push_back(t->dims->data[i]); - } - CHECK(!result.empty()) << "Found no shapes in model"; - return result; -} - -bool CreateInterpreter(const string& graph, - std::unique_ptr* model, - std::unique_ptr* interpreter) { - *model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); - if (!model) { - std::cerr << "Failed to load model " << graph << std::endl; - return false; - } - -#ifdef TFLITE_CUSTOM_OPS_HEADER - tflite::MutableOpResolver resolver; - RegisterSelectedOps(&resolver); -#else - tflite::ops::builtin::BuiltinOpResolver resolver; -#endif - - tflite::InterpreterBuilder(*(model->get()), resolver)(interpreter); - if (!(*interpreter)) { - std::cerr << "Failed to construct interpreter" << std::endl; - return false; - } - - return true; -} - -bool PrepareInterpreter(const std::vector inputs, - int num_threads, bool use_nnapi, - Interpreter* interpreter) { - if (num_threads != -1) { - interpreter->SetNumThreads(num_threads); - } - - interpreter->UseNNAPI(use_nnapi); - - // Check that all names and types match - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - CHECK_EQ(t->name, input.name) - << "Tensor # " << i << " is named " << t->name - << " but flags call it " << input.name; - CHECK_EQ(t->type, input.data_type) - << "Could not match the type of input tensor " << t->name; - } - } - - // Resize all non-string tensors. - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - if (t->type != kTfLiteString) { - interpreter->ResizeInputTensor(i, input.shape); - } - } - } - - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate tensors!" << std::endl; - return false; - } - - // Set the values of the input tensors. - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - std::vector sizes = ShapeFromTfLiteTensor(t); - - // TODO(ahentz): below we ignore the O-th dimension (number of batches). - if (t->type == kTfLiteFloat32) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); - } else if (t->type == kTfLiteUInt8) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) % 255; }); - } else if (t->type == kTfLiteString) { - tflite::DynamicBuffer buffer; - FillRandomString(&buffer, sizes, []() { - return "we're have some friends over saturday to hang out in the yard"; - }); - buffer.WriteToTensor(interpreter->tensor(i)); - } else { - std::cerr << "Don't know how to populate tensor " << t->name - << " of type " << t->type << std::endl; - return false; - } - } - return true; -} - -bool PopulateInputLayerInfo(const string& names_string, - const string& shapes_string, - const string& types_string, - const string& values_string, - std::vector* info) { - std::vector names = Split(names_string, ','); - std::vector shapes = Split(shapes_string, ':'); - std::vector types = Split(types_string, ','); - std::vector values = Split(values_string, ':'); - - if (names.size() != shapes.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_shape (" << shapes_string << ", with " - << shapes.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_shape=1,224,224,4:1,20"; - return false; - } - if (names.size() != types.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_type (" << types_string << ", with " - << types.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_type=float,int"; - return false; - } - - for (int i = 0; i < names.size(); ++i) { - info->push_back(InputLayerInfo()); - InputLayerInfo& input = info->back(); - - input.name = names[i]; - - input.data_type = TfLiteTypeFromString(types[i]); - CHECK(input.data_type != kTfLiteNoType) - << types[i] << " was an invalid type"; - - CHECK(SplitAndParseAsInts(shapes[i], ',', &input.shape)) - << "Incorrect size string specified: " << shapes[i]; - for (int dim : input.shape) { - if (dim == -1) { - LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" - << " with the size you want to benchmark with."; - return false; - } - } - - if (i < values.size()) { - CHECK(SplitAndParseAsFloats(values[i], ',', &input.initialization_values)) - << "Incorrect initialization values string specified: " << values[i]; - } - } - - return true; -} - -bool RunBenchmark(Interpreter* interpreter, int64_t* inference_time_us) { - const int64_t start_time = Env::Default()->NowMicros(); - - if (interpreter->Invoke() != kTfLiteOk) { - std::cerr << "Failed to invoke!"; - return false; - } - - const int64_t end_time = Env::Default()->NowMicros(); - *inference_time_us = end_time - start_time; - return true; -} - -class Latencies { - public: - void AddMeasurement(int64_t time_us) { - max_ = std::max(time_us, max_); - min_ = std::min(time_us, min_); - ++count_; - sum_ += time_us; - squared_sum_ += static_cast(time_us) * time_us; - } - - double avg() const { - if (count_ == 0) return std::numeric_limits::quiet_NaN(); - return static_cast(sum_) / count_; - } - - int64_t std_deviation() const { - if (count_ == 0 || min_ == max_) return 0; - return sqrt(squared_sum_ / count_ - avg() * avg()); - } - - void OutputToStream(std::ostream* stream) const { - *stream << "count=" << count_; - if (count_ == 0) return; - *stream << " min=" << min_ << " max=" << max_; - *stream << " avg=" << avg() << " std=" << std_deviation(); - } - - private: - int64_t count_ = 0; - int64_t min_ = std::numeric_limits::max(); - int64_t max_ = std::numeric_limits::min(); - int64_t sum_ = 0; - double squared_sum_ = 0; -}; - -bool TimeMultipleRuns(Interpreter* interpreter, double sleep_seconds, - int num_runs, int64* total_time_us) { - // Convert the run_delay string into a timespec. - timespec req; - req.tv_sec = static_cast(sleep_seconds); - req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; - - *total_time_us = 0; - - std::cout << "Running benchmark for " << num_runs - << " iterations: " << std::endl; - - Latencies latencies; - for (int i = 0; i < num_runs; ++i) { - int64_t time_us; - bool run_status = RunBenchmark(interpreter, &time_us); - latencies.AddMeasurement(time_us); - *total_time_us += time_us; - if (!run_status) { - std::cout << "Failed on run " << i << std::endl; - return false; - } - - // If requested, sleep between runs for an arbitrary amount of time. - // This can be helpful to determine the effect of mobile processor - // scaling and thermal throttling. - if (sleep_seconds > 0.0) { -#ifdef PLATFORM_WINDOWS - Sleep(sleep_seconds * 1000); -#else - nanosleep(&req, nullptr); -#endif - } - } - latencies.OutputToStream(&std::cout); - std::cout << std::endl; - - return true; -} - -int Main(int argc, char** argv) { - using tensorflow::Flag; - using tensorflow::Flags; - - string graph; // e.g.: /data/local/tmp/tfl_inception-v1_model.fb - string input_layer_string; // e.g.: input - string input_layer_shape_string; // e.g.: 1,224,224,3 - string input_layer_type_string; // e.g.: float - string input_layer_values_string; - string output_layer_string; // e.g.: output - int num_runs = 50; - string run_delay = "-1.0"; - int num_threads = 1; - string benchmark_name = ""; - string output_prefix = ""; - int warmup_runs = 1; - bool use_nnapi = false; - - std::vector flag_list = { - Flag("graph", &graph, "graph file name"), - // All the following flags are optional, but can be used in order - // to benchmark different input shapes. - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("input_layer_type", &input_layer_type_string, "input layer type"), - Flag("input_layer_values", &input_layer_values_string, - "values to initialize the inputs with"), - Flag("output_layer", &output_layer_string, "output layer name"), - Flag("num_runs", &num_runs, "number of runs"), - Flag("run_delay", &run_delay, "delay between runs in seconds"), - Flag("num_threads", &num_threads, "number of threads"), - Flag("benchmark_name", &benchmark_name, "benchmark name"), - Flag("output_prefix", &output_prefix, "benchmark output prefix"), - Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), - Flag("use_nnapi", &use_nnapi, "use nnapi api"), - }; - string usage = Flags::Usage(argv[0], flag_list); - const bool parse_result = Flags::Parse(&argc, argv, flag_list); - tensorflow::port::InitMain(argv[0], &argc, &argv); - - if (!parse_result) { - std::cerr << usage << std::endl; - return -1; - } - - std::cout << "Graph: [" << graph << "]" << std::endl; - if (!input_layer_string.empty()) { - std::cout << "Input layers: [" << input_layer_string << "]" << std::endl; - std::cout << "Input shapes: [" << input_layer_shape_string << "]" - << std::endl; - std::cout << "Input types: [" << input_layer_type_string << "]" - << std::endl; - } - if (!output_layer_string.empty()) { - std::cout << "Output layers: [" << output_layer_string << "]" << std::endl; - } - std::cout << "Num runs: [" << num_runs << "]" << std::endl; - std::cout << "Inter-run delay (seconds): [" << run_delay << "]" << std::endl; - std::cout << "Num threads: [" << num_threads << "]" << std::endl; - if (!benchmark_name.empty()) { - std::cout << "Benchmark name: [" << benchmark_name << "]" << std::endl; - std::cout << "Output prefix: [" << output_prefix << "]" << std::endl; - } - std::cout << "Warmup runs: [" << warmup_runs << "]" << std::endl; - std::cout << "Use nnapi : [" << use_nnapi << "]" << std::endl; - - if (graph.empty()) { - std::cout - << "Please specify the name of your TF Lite input file with --graph" - << std::endl; - return -1; - } - - std::vector inputs; - if (!PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, - input_layer_type_string, - input_layer_values_string, &inputs)) { - return -1; - } - - int64 initialization_start_us = Env::Default()->NowMicros(); - - std::unique_ptr model; - std::unique_ptr interpreter; - if (!CreateInterpreter(graph, &model, &interpreter)) { - return -1; - } - if (!PrepareInterpreter(inputs, num_threads, use_nnapi, interpreter.get())) { - return -1; - } - - int64 initialization_end_us = Env::Default()->NowMicros(); - - const double initialization_time_s = - (initialization_end_us - initialization_start_us) / 1000000.0f; - std::cout << "Initialized session in " << initialization_time_s << "s" - << std::endl; - - const double sleep_seconds = std::strtod(run_delay.c_str(), nullptr); - - // If requested, run through the graph first to preinitialize everything - // before the benchmarking runs. - int64 warmup_time_us = 0; - if (warmup_runs > 0) { - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, warmup_runs, - &warmup_time_us)) { - std::cerr << "Warmup failed" << std::endl; - return -1; - } - } - - // Capture overall inference time without stat logging overhead. This is the - // timing data that can be compared to other libaries. - int64 no_stat_time_us = 0; - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, num_runs, - &no_stat_time_us)) { - std::cerr << "Timing failed." << std::endl; - return -1; - } - - std::cout << "Average inference timings in us: " << no_stat_time_us / num_runs - << " , Warmup: " - << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", " - << std::endl; - - return 0; -} - -} // namespace tflite - -int main(int argc, char** argv) { return ::tflite::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index ce8a7857d2dd66b12e9ea970911ef1dd01e4550e..ad7d59ecb41a0c81a6a4d8edae5fa6b4b5a7bede 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -41,7 +41,7 @@ class TfLiteFlatbufferModelBuilder { } TfLiteFlatbufferModelBuilder(const std::vector& builtin_ops, - const std::vector& custom_ops) { + const std::vector& custom_ops) { buffers_.push_back( CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); @@ -194,8 +194,8 @@ TEST(VerifyModel, TensorBufferIsNotValid) { /*operators=*/0, builder.CreateString("Main"))}); auto buffers = builder.CreateVector(std::vector>{ - CreateBuffer(builder, - builder.CreateVector(std::vector{1, 2, 3, 4, 5, 6})), + CreateBuffer(builder, builder.CreateVector( + std::vector{1, 2, 3, 4, 5, 6})), }); auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0, diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc index fb4af07d060cac3a6a4e01c7d625b6db5241f10d..8ccb65c24fd64f05d7e2c888f7932e586c1e11ec 100644 --- a/tensorflow/contrib/lite/util.cc +++ b/tensorflow/contrib/lite/util.cc @@ -38,4 +38,14 @@ bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, return true; } +size_t CombineHashes(std::initializer_list hashes) { + size_t result = 0; + // Hash combiner used by TensorFlow core. + for (size_t hash : hashes) { + result = result ^ + (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4)); + } + return result; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index a34db35823104414cce028b9119397da085d05b1..89d9b4f5cffa99e708f391fd8fe19208009b5e79 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -35,6 +35,8 @@ TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, const int* b); +size_t CombineHashes(std::initializer_list hashes); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_ diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5d4682ec9f4b8c5864383bd1d2f4c0b41a11baad..889accdd5aafae2931048ffdd26408cccb3c874e 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -1396,15 +1397,22 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): + @test_util.run_in_graph_and_eager_modes def test_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=("brain", "salad", "surgery"), num_oov_buckets=1) + + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(table.lookup( + constant_op.constant(("salad", "surgery", "tarkus")))) + else: + # Reinitializing a table in eager should work. table = lookup.index_table_from_tensor( mapping=("brain", "salad", "surgery"), num_oov_buckets=1) - ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) - - self.assertRaises(errors_impl.OpError, ids.eval) - lookup_ops.tables_initializer().run() - self.assertAllEqual((1, 2, 3), ids.eval()) + self.evaluate(lookup_ops.tables_initializer()) + ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): with self.test_session(): @@ -1662,7 +1670,7 @@ class InitializeTableFromFileOpTest(test.TestCase): f.write("\n".join(values) + "\n") return vocabulary_file - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeStringTable(self): vocabulary_file = self._createVocabFile("one_column_1.txt") default_value = -1 diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index bdad34a665e47a4e060fcaddfffecfdc876a8fb0..651de4e2f446b2da39b000cde2541872116cbdba 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -482,9 +482,12 @@ def hinge_loss(logits, labels=None, scope=None): """Method that returns the loss tensor for hinge loss. Args: - logits: The logits, a float tensor. + logits: The logits, a float tensor. Note that logits are assumed to be + unbounded and 0-centered. A value > 0 (resp. < 0) is considered a positive + (resp. negative) binary prediction. labels: The ground truth output tensor. Its shape should match the shape of - logits. The values of the tensor are expected to be 0.0 or 1.0. + logits. The values of the tensor are expected to be 0.0 or 1.0. Internally + the {0,1} labels are converted to {-1,1} when calculating the hinge loss. scope: The scope for the operations performed in computing the loss. Returns: diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index fc88f59e0948e1d3ed7cce9b809bf30ba280af12..fb9e77ae1bcfc3404f1fdf90ab2697a4e79a9836 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -30,6 +30,14 @@ arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)" exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + if [[ -z "${NDK_ROOT}" ]]; then echo "NDK_ROOT should be set as an environment variable" 1>&2 exit 1 diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index 0a458a27b3ac9b1a24b0f42de2f0166d515e8cd9..1d4677ef4bd1e8811998d1464e63902544153a49 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -31,6 +31,14 @@ usage() { exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + DEFAULT_ARCH="i386 x86_64 armv7 armv7s arm64" while getopts "a:g:T" opt_name; do case "$opt_name" in diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index e8c6edd7ba9aa6a45d956d1d5655b2809d8d2309..a28fc3a87f9503074806d780a11878a9274efc6f 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -270,7 +270,7 @@ for arch in $archs; do PLATFORM_LDFLAGS=-pthread MKDEP=${CC} -M -std=c++11 PLATFORM_C=../../platform/c++11/src/nsync_semaphore_mutex.cc \ - ../../platform/c++11/src/per_thread_waiter.cc \ + ../../platform/posix/src/per_thread_waiter.c \ ../../platform/c++11/src/yield.cc \ ../../platform/c++11/src/time_rep_timespec.cc \ ../../platform/c++11/src/nsync_panic.cc diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index eff9081e35c285027c764c5bdbaf14f78bc5f512..48953e2e3843ff92744514d28bd725cc0d72f3a8 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -27,9 +27,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index d4c3f2eda8be0c70e961afe582983b9f73769c77..6e7423f85e3b66e2f40b25c0b83d0fcaa54817a9 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/queue_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc @@ -300,7 +301,6 @@ tensorflow/core/kernels/spacetobatch_op.cc tensorflow/core/kernels/batchtospace_op.cc tensorflow/core/kernels/warn_about_ints.cc tensorflow/core/kernels/segment_reduction_ops.cc -tensorflow/core/kernels/batch_util.cc tensorflow/core/ops/audio_ops.cc tensorflow/core/kernels/decode_proto_op.cc tensorflow/core/kernels/encode_proto_op.cc diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 4f2c82ca23011667662c74507fcbd99bcde4c7c0..66cb493e5c5bb9b8645e87dc7f5b274d916f64fc 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -77,7 +77,31 @@ py_test( py_test( name = "metric_ops_test", srcs = ["python/ops/metric_ops_test.py"], - shard_count = 16, + shard_count = 30, + srcs_version = "PY2AND3", + tags = ["noasan"], # times out b/63678675 + deps = [ + ":metrics_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "metric_ops_large_test", + size = "large", + srcs = ["python/ops/metric_ops_large_test.py"], srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 00a933e5e0c537033573b225d43581f74557b240..b14202ff9ec38016f926ee37c8acbd2bbb4c6ef5 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1064,7 +1064,7 @@ def streaming_auc(predictions, name=name) -def _compute_dynamic_auc(labels, predictions, curve='ROC'): +def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. Computes the area under the ROC or PR curve using each prediction as a @@ -1077,13 +1077,22 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): predictions: A 1-D `Tensor` of predictions whose values are `float64`. curve: The name of the curve to be computed, 'ROC' for the Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. + weights: A 1-D `Tensor` of weights whose values are `float64`. Returns: A scalar `Tensor` containing the area-under-curve value for the input. """ - # Count the total number of positive and negative labels in the input. + # Compute the total weight and the total positive weight. size = array_ops.size(predictions) - total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32) + if weights is None: + weights = array_ops.ones_like(labels, dtype=dtypes.float64) + labels, predictions, weights = metrics_impl._remove_squeezable_dimensions( + labels, predictions, weights) + total_weight = math_ops.reduce_sum(weights) + total_positive = math_ops.reduce_sum( + array_ops.where( + math_ops.greater(labels, 0), weights, + array_ops.zeros_like(labels, dtype=dtypes.float64))) def continue_computing_dynamic_auc(): """Continues dynamic auc computation, entered if labels are not all equal. @@ -1091,9 +1100,11 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): Returns: A scalar `Tensor` containing the area-under-curve value. """ - # Sort the predictions descending, and the corresponding labels as well. + # Sort the predictions descending, keeping the same order for the + # corresponding labels and weights. ordered_predictions, indices = nn.top_k(predictions, k=size) ordered_labels = array_ops.gather(labels, indices) + ordered_weights = array_ops.gather(weights, indices) # Get the counts of the unique ordered predictions. _, _, counts = array_ops.unique_with_counts(ordered_predictions) @@ -1103,23 +1114,39 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) # Count the positives to the left of the split indices. - positives = math_ops.cast( - array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]), - dtypes.int32) - true_positives = array_ops.gather(positives, splits) + true_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.greater(ordered_labels, 0), ordered_weights, + array_ops.zeros_like(ordered_labels, + dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) if curve == 'ROC': - # Count the negatives to the left of every split point and the total - # number of negatives for computing the FPR. - false_positives = math_ops.subtract(splits, true_positives) - total_negative = size - total_positive + # Compute the weight of the negatives to the left of every split point and + # the total weight of the negatives number of negatives for computing the + # FPR. + false_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.less(ordered_labels, 1), ordered_weights, + array_ops.zeros_like( + ordered_labels, dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) + total_negative = total_weight - total_positive x_axis_values = math_ops.truediv(false_positives, total_negative) y_axis_values = math_ops.truediv(true_positives, total_positive) elif curve == 'PR': x_axis_values = math_ops.truediv(true_positives, total_positive) # For conformance, set precision to 1 when the number of positive # classifications is 0. + positives = array_ops.gather( + array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]), + splits) y_axis_values = array_ops.where( - math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits), + math_ops.greater(splits, 0), + math_ops.truediv(true_positives, positives), array_ops.ones_like(true_positives, dtype=dtypes.float64)) # Calculate trapezoid areas. @@ -1133,7 +1160,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): return control_flow_ops.cond( math_ops.logical_or( math_ops.equal(total_positive, 0), math_ops.equal( - total_positive, size)), + total_positive, total_weight)), true_fn=lambda: array_ops.constant(0, dtypes.float64), false_fn=continue_computing_dynamic_auc) @@ -1143,7 +1170,8 @@ def streaming_dynamic_auc(labels, curve='ROC', metrics_collections=(), updates_collections=(), - name=None): + name=None, + weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. USAGE NOTE: this approach requires storing all of the predictions and labels @@ -1168,6 +1196,8 @@ def streaming_dynamic_auc(labels, should be added to. name: An optional name for the variable_scope that contains the metric variables. + weights: A 'Tensor' of non-negative weights whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. Returns: auc: A scalar `Tensor` containing the current area-under-curve value. @@ -1195,14 +1225,24 @@ def streaming_dynamic_auc(labels, check_ops.assert_less_equal( labels, array_ops.ones_like(labels, dtypes.int64), - message='labels must be 0 or 1, at least one is >1') + message='labels must be 0 or 1, at least one is >1'), ]): preds_accum, update_preds = streaming_concat( predictions, name='concat_preds') labels_accum, update_labels = streaming_concat( labels, name='concat_labels') - update_op = control_flow_ops.group(update_labels, update_preds) - auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) + if weights is not None: + weights = array_ops.reshape( + math_ops.cast(weights, dtypes.float64), [-1]) + weights_accum, update_weights = streaming_concat( + weights, name='concat_weights') + update_op = control_flow_ops.group(update_labels, update_preds, + update_weights) + else: + weights_accum = None + update_op = control_flow_ops.group(update_labels, update_preds) + auc = _compute_dynamic_auc( + labels_accum, preds_accum, curve=curve, weights=weights_accum) if updates_collections: ops.add_to_collections(updates_collections, update_op) if metrics_collections: @@ -1544,7 +1584,7 @@ def precision_recall_at_equal_thresholds(labels, result: A named tuple (See PrecisionRecallData within the implementation of this function) with properties that are variables of shape `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, - precision, recall, thresholds. + precision, recall, thresholds. Types are same as that of predictions. update_op: An op that accumulates values. Raises: @@ -1570,7 +1610,6 @@ def precision_recall_at_equal_thresholds(labels, check_ops.assert_type(labels, dtypes.bool) - dtype = predictions.dtype with variable_scope.variable_scope(name, 'precision_recall_at_equal_thresholds', (labels, predictions, weights)): @@ -1592,11 +1631,16 @@ def precision_recall_at_equal_thresholds(labels, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - # We cast to float to ensure we have 0.0 or 1.0. - f_labels = math_ops.cast(labels, dtype) + # It's important we aggregate using float64 since we're accumulating a lot + # of 1.0's for the true/false labels, and accumulating to float32 will + # be quite inaccurate even with just a modest amount of values (~20M). + # We use float64 instead of integer primarily since GPU scatter kernel + # only support floats. + agg_dtype = dtypes.float64 - # Get weighted true/false labels. - true_labels = f_labels * weights + f_labels = math_ops.cast(labels, agg_dtype) + weights = math_ops.cast(weights, agg_dtype) + true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Flatten predictions and labels. @@ -1638,9 +1682,9 @@ def precision_recall_at_equal_thresholds(labels, with ops.name_scope('variables'): tp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='tp_buckets') + [num_thresholds], agg_dtype, name='tp_buckets') fp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='fp_buckets') + [num_thresholds], agg_dtype, name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -1660,18 +1704,21 @@ def precision_recall_at_equal_thresholds(labels, fn = tp[0] - tp # We use a minimum to prevent division by 0. - epsilon = 1e-7 + epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype) precision = tp / math_ops.maximum(epsilon, tp + fp) recall = tp / math_ops.maximum(epsilon, tp + fn) + # Convert all tensors back to predictions' dtype (as per function contract). + out_dtype = predictions.dtype + _convert = lambda tensor: math_ops.cast(tensor, out_dtype) result = PrecisionRecallData( - tp=tp, - fp=fp, - tn=tn, - fn=fn, - precision=precision, - recall=recall, - thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + tp=_convert(tp), + fp=_convert(fp), + tn=_convert(tn), + fn=_convert(fn), + precision=_convert(precision), + recall=_convert(recall), + thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds))) update_op = control_flow_ops.group(update_tp, update_fp) return result, update_op @@ -2496,7 +2543,7 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name): name: An optional variable_scope name. Returns: - The recall at a the given `precision`. + The recall at a given `precision`. """ precisions = math_ops.div(tp, tp + fp + _EPSILON) tf_index = math_ops.argmin( diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7acfc383eb9a659a600752cf57b4978daa8a07bc --- /dev/null +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Large tests for metric_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testLargeCase(self): + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 76420db8bda39435bcc2be2fd3d8c3467d6753e2..a09fc4abd461323d67e914c70932688816fed764 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2127,6 +2127,44 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) + def testWithWeights(self): + batch_size = 10 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + weights = np.array([]) + tf_labels = variables.Variable( + array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + tf_weights = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.streaming_dynamic_auc(tf_labels, + tf_predictions, + weights=tf_weights) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.uniform(-0.2, 0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + new_weights = np.random.uniform(0.0, 3.0, size=batch_size) + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + weights = np.concatenate([weights, new_weights]) + sess.run([tf_labels.assign(new_labels), + tf_predictions.assign(new_predictions), + tf_weights.assign(new_weights)]) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels, weights) + self.assertAlmostEqual(expected_auc, auc.eval()) + class AucWithConfidenceIntervalsTest(test.TestCase): @@ -2333,47 +2371,24 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): np.random.seed(1) ops.reset_default_graph() - def _testResultsEqual(self, expected_dict, gotten_result): + def _testResultsEqual(self, expected_dict, gotten_result, eps=None): """Tests that 2 results (dicts) represent the same data. Args: expected_dict: A dictionary with keys that are the names of properties of PrecisionRecallData and whose values are lists of floats. gotten_result: A PrecisionRecallData object. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. """ gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys())) for key, expected_values in expected_dict.items(): - self.assertAllClose(expected_values, gotten_dict[key]) - - def _testCase(self, predictions, labels, expected_result, weights=None): - """Performs a test given a certain scenario of labels, predictions, weights. - - Args: - predictions: The predictions tensor. Of type float32. - labels: The labels tensor. Of type bool. - expected_result: The expected result (dict) that maps to tensors. - weights: Optional weights tensor. - """ - with self.test_session() as sess: - predictions_tensor = constant_op.constant( - predictions, dtype=dtypes_lib.float32) - labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) - weights_tensor = None - if weights: - weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) - gotten_result, update_op = ( - metric_ops.precision_recall_at_equal_thresholds( - labels=labels_tensor, - predictions=predictions_tensor, - weights=weights_tensor, - num_thresholds=3)) - - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - - self._testResultsEqual(expected_result, gotten_result) + if eps is not None: + self.assertAllClose(expected_values, gotten_dict[key], atol=eps) + else: + self.assertAllClose(expected_values, gotten_dict[key]) def testVars(self): metric_ops.precision_recall_at_equal_thresholds( @@ -2414,6 +2429,50 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) + def _testCase(self, + predictions, + labels, + expected_result, + dtype=dtypes_lib.float32, + eps=None, + weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type dtype. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + dtype: Data type to use for predictions and weights tensor. Default + is float32. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant(predictions, dtype=dtype) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtype) + gotten_result, update_op = ( + metric_ops.precision_recall_at_equal_thresholds( + labels=labels_tensor, + predictions=predictions_tensor, + weights=weights_tensor, + num_thresholds=3)) + self.assertEqual(gotten_result.tp.dtype, dtype) + self.assertEqual(gotten_result.fp.dtype, dtype) + self.assertEqual(gotten_result.tn.dtype, dtype) + self.assertEqual(gotten_result.fn.dtype, dtype) + self.assertEqual(gotten_result.precision.dtype, dtype) + self.assertEqual(gotten_result.recall.dtype, dtype) + self.assertEqual(gotten_result.thresholds.dtype, dtype) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result, eps=eps) + def testAllTruePositives(self): self._testCase( [[1]], [[True]], { @@ -2489,6 +2548,35 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): }, weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) + def testFloat64(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float64) + + def testFloat16(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float16, + eps=1e-3) + class StreamingSpecificityAtSensitivityTest(test.TestCase): @@ -4649,199 +4737,204 @@ class StreamingSparseRecallTest(test.TestCase): self._test_sparse_recall_at_top_k( labels, top_k_predictions, expected=1.0 / 2) - def test_one_label_at_k1_weighted(self): + def _test_one_label_at_k1_weighted(self, labels): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], - [0, 0, 1, 0]]) - dense_labels = np.array([[3], [2]], dtype=np.int64) - for labels in (sparse_labels, dense_labels): - # Class 3: 1 label, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) + # Class 3: 1 label, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) - # All classes: 2 labels, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + # All classes: 2 labels, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + + def test_one_label_at_k1_weighted_sparse_labels(self): + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) + self._test_one_label_at_k1_weighted(sparse_labels) + + def test_one_label_at_k1_weighted_dense_labels(self): + dense_labels = np.array([[3], [2]], dtype=np.int64) + self._test_one_label_at_k1_weighted(dense_labels) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], @@ -7101,6 +7194,14 @@ class CohenKappaTest(test.TestCase): with self.assertRaises(ValueError): metrics.cohen_kappa(labels, invalid_predictions, 3) + def testConditionalPackingOptimization(self): + placeholder = array_ops.placeholder(dtypes_lib.float32, [None]) + values, update_op = metric_ops.streaming_concat(placeholder) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for feed in range(10): + sess.run(update_op, feed_dict={placeholder: [feed]}) + print(sess.run(values)) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 480f5f6eaf493c5c87c27cc9f8e510ea9c085a72..1b0383d24c0c472b4875d15c3650e37dfd2439e1 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -34,7 +34,7 @@ def _GetExampleIter(inputs): class FixedLossScaleManagerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_basic(self): itr = _GetExampleIter([True] * 10 + [False] * 10) @@ -84,13 +84,13 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): actual_outputs.append(self.evaluate(lsm.get_loss_scale())) self.assertEqual(actual_outputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increase_every_n_steps(self): inputs = [True] * 6 expected_outputs = [1, 2, 2, 4, 4, 8] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_keep_increasing_until_capped(self): init_loss_scale = np.finfo(np.float32).max / 4 + 10 max_float = np.finfo(np.float32).max @@ -104,7 +104,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decrease_every_n_steps(self): inputs = [False] * 6 init_loss_scale = 1024 @@ -112,7 +112,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_keep_decreasing_until_one(self): inputs = [False] * 10 init_loss_scale = 16 @@ -120,19 +120,19 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_incr_bad_step_clear_good_step(self): inputs = [True, True, True, False, True] expected_outputs = [1, 2, 2, 2, 2] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_incr_good_step_does_not_clear_bad_step(self): inputs = [True, True, True, False, True, False] expected_outputs = [1, 2, 2, 2, 2, 1] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trigger_loss_scale_update_each_step(self): """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1.""" init_loss_scale = 1 @@ -145,7 +145,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_alternating_good_and_bad_gradients_trigger_each_step(self): init_loss_scale = 1 incr_every_n_step = 1 @@ -156,7 +156,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self): init_loss_scale = 32 incr_every_n_step = 2 @@ -167,7 +167,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_random_mix_good_and_bad_gradients(self): init_loss_scale = 4 inputs = [ diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index e4e5ccc33472ad5a12bd8111fb1ff6ebbd6f45f9..ef34f7bf7bf3eba047b50ce8abf883b0ed741a63 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -26,26 +26,32 @@ from tensorflow.python.training import optimizer class LossScaleOptimizer(optimizer.Optimizer): + # TODO(jamesqin): move mixed precision training explanation to __init__ + # docstring. """An optimizer that applies loss scaling in backprop. - This class is useful for mixed precision training on GPUs (or other potential - accelerators), which is an approach to improve compute throughput without loss - of model quality. - - The commmon configuration of mixed precision models is the following: - * variables are kept in high precision (e.g. float32). - * computations are done in lower precision (e.g. float16). variables are - casted to lower precision before they're used. - * (in training), final gradients are casted back to variable precision and get - applied. - - Because computations happen in lower precision, gradients in the backprop pass - might underflow in the smaller dynamic range, causing a model to converge at a - suboptimal level. This optimizer multiplies the loss by a factor before - backprop starts to prevent underflow. Before gradients are applied, they are - casted to higher precision and down-scaled by the same factor, so - mathematically the variable updates are no different from regular - same-precision training. + This class is useful for "mixed precision training" on GPUs (or other + potential accelerators), an approach to improve compute throughput without + compromising model quality. + + The canonical way to perform mixed precision training is the following: + * Model variables are kept in high precision (e.g. float32). + * Computations are done in lower precision (e.g. float16), which enjoys + performance speedup by virtue of hardware support. Variables are casted to + lower precision before they're used. + * Final gradients are casted back to high precision dtype, then used to update + variables. + + The side-effect of performing computation in lower precision, is that it comes + with smaller numerical range. During backproping, small gradients might + underflow in the reduced numerical range, causing a model to converge at + suboptimal level. + + To prevent underflow, this optimizer multiplies the loss by a factor before + backprop starts. Consequently, the gradients are linearly scaled up by the + same factor, thus not falling into the underflow zone. After that, to perserve + the correctness of backprop, the gradients are down-scaled by the same factor, + casted to the (higher) variable precision, then applied on the variables. See [Nvidia's manual on mixed precision training]( https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index dded61ccd58eb79b338d7264e8a057c9456c8695..9009df0eefec13146090ba5fc2096e71ba6eb89d 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -54,7 +54,7 @@ class LossScaleOptimizerTest(test.TestCase): opt = loss_scale_opt_fn(opt) return x, loss, opt - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_float16_underflow_without_loss_scale(self): lr = 1 init_val = 1. @@ -73,7 +73,7 @@ class LossScaleOptimizerTest(test.TestCase): rtol=0, atol=min(symbolic_update, 1e-6)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_float16_with_loss_scale(self): lr = 1. init_val = 1. @@ -95,7 +95,7 @@ class LossScaleOptimizerTest(test.TestCase): rtol=0, atol=min(expected_update, 1e-6)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_compute_gradients_with_loss_scale(self): lr = 1 init_val = 1. @@ -115,7 +115,7 @@ class LossScaleOptimizerTest(test.TestCase): # Gradients aren't applied. self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_compute_gradients_without_loss_scale(self): lr = 1 init_val = 1. @@ -127,7 +127,7 @@ class LossScaleOptimizerTest(test.TestCase): g_v = self.evaluate(grads_and_vars[0][0]) self.assertAllClose(g_v, 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_apply_gradients(self): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) @@ -155,7 +155,7 @@ class LossScaleOptimizerTest(test.TestCase): actual_output.append(self.evaluate(x)) self.assertAllClose(expected_output, actual_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_apply_gradients_loss_scale_is_updated(self): class SimpleLossScaleManager(lsm_lib.LossScaleManager): diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h index 45dc93493456e2a34be370034a73ff92d6c0aabd..4091925fc0d7ab49954bc2e0e91cfc6da2a685a9 100644 --- a/tensorflow/contrib/mpi/mpi_utils.h +++ b/tensorflow/contrib/mpi/mpi_utils.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/logging.h" // Skip MPI C++ bindings support, this matches the usage in other places #define OMPI_SKIP_MPICXX diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index a7be92a35e0d62a61f7923ac61bb2c1267d039c6..ecac06354d2ce796f2a6021cdf2370d7c30ccab7 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,6 +52,7 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", + "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index ed22ee667f1d73b3f86f77e09bad9bfec7e46391..e4b0c2c6541836243347d2950686c60ef06d2bfc 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -73,7 +73,7 @@ limitations under the License. */ template -using StatusOr = se::port::StatusOr; +using StatusOr = stream_executor::port::StatusOr; using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.h b/tensorflow/contrib/mpi_collectives/kernels/ring.h index 1d56d588bc49eda542303ae6ebb19602352ae01d..c001615d3ffbdf04194cf8fd1fd242542bf8f89d 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.h +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.h @@ -129,7 +129,7 @@ cudaStream_t CudaStreamForMPI(); * has the fully accumulated Segment 1; and so on. The scatter-reduce is * complete. * - * Next, the allgather distributes these fully accumululated chunks across all + * Next, the allgather distributes these fully accumulated chunks across all * nodes. Communication proceeds in the same ring, once again in N-1 steps. At * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). * For example, at the first iteration, the following transfers will occur: diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 334e70318dd88185cecd93ebeb2587861b7999b9..62996d1fd83f46145e9a1b773b1be57e27903127 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -19,17 +19,18 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") tf_custom_op_library( name = "python/ops/_nccl_ops.so", srcs = [ "ops/nccl_ops.cc", ], - gpu_srcs = [ + gpu_srcs = if_not_windows_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", - ], + ]), deps = if_cuda([ "@local_config_nccl//:nccl", "//tensorflow/core:gpu_headers_lib", @@ -97,18 +98,19 @@ tf_gen_op_wrapper_py( deps = [":nccl_ops_op_lib"], ) +# Test only nccl ops lib without dso to test behavior when NCCL lib is not +# installed. See nccl_dependency_test for more details. +# +# Users should use the public nccl_py lib that also adds the dso. tf_custom_op_py_library( - name = "nccl_py", + name = "nccl_ops_lib_without_dso", srcs = [ "__init__.py", "python/ops/nccl_ops.py", ], - dso = [":python/ops/_nccl_ops.so"], kernels = if_cuda([":nccl_kernels"]) + [ ":nccl_ops_op_lib", ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ ":nccl_ops", "//tensorflow/contrib/util:util_py", @@ -120,6 +122,15 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "nccl_py", + dso = [":python/ops/_nccl_ops.so"], + visibility = ["//visibility:public"], + deps = [ + ":nccl_ops_lib_without_dso", + ], +) + cuda_py_test( name = "nccl_ops_test", size = "small", @@ -141,3 +152,25 @@ cuda_py_test( "notap", ], ) + +cuda_py_test( + name = "nccl_dependency_test", + size = "small", + srcs = ["python/ops/nccl_dependency_test.py"], + additional_deps = [ + ":nccl_ops_lib_without_dso", + "//tensorflow/python:constant_op", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], + # Disable this test internally as static linking is used internally and only + # run for OSS to verify that NCCL is an optional dynamic dependency. + tags = [ + "manual", + "noguitar", + "notap", + ], +) diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index 4d8d922cb42d2974dab32cf4562bee3993bef098..5144f7c38c8650ebfced1dfcc9378263ebaad8c0 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -171,8 +171,7 @@ class NcclManagerTest : public ::testing::Test { private: static Allocator* GpuAllocator(BaseGPUDevice* device) { - return device->GetStepAllocator(AllocatorAttributes(), - nullptr /* step_resource_manager */); + return device->GetAllocator(AllocatorAttributes()); } static se::DeviceMemory AsDeviceMemory(const Scalar* cuda_memory) { diff --git a/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c766080dbee7c9a6f4383ef6fa8cade7bba158af --- /dev/null +++ b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dependency test for nccl to test behavior when NCCL is not installed.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import nccl +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +class NcclDependencyTest(test.TestCase): + """Verifies that importing nccl ops lib does not fail even if NCCL is not + installed but nccl ops throws an exception on use if NCCL is not installed. + """ + + def test_nccl_ops(self): + """Tests behavior of nccl ops when NCCL is not installed.""" + + public_methods = [ + m[0] + for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction) + if not m[0].startswith('_') + ] + for method_name in public_methods: + with ops.device('/device:CPU:0'): + tensor = constant_op.constant(1) + + if method_name == 'broadcast': + arg = tensor + else: + arg = [tensor] + + nccl_op = getattr(nccl, method_name) + with ops.device('/device:CPU:0'): + with self.assertRaisesRegexp(errors_impl.NotFoundError, + r'cannot open shared object file'): + nccl_op(arg) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 794372a1f4b0dcc41bcf0da611f5bc2ec9301973..fa597cf3efcf915311047f3a483772c45cc314fd 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -26,8 +26,10 @@ from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader -_nccl_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile('_nccl_ops.so')) + +_nccl_ops_so = None +_module_lock = threading.Lock() +_shared_name_counter = 0 def all_sum(tensors): @@ -61,12 +63,12 @@ def _all_sum_grad(op, grad): Raises: LookupError: If `reduction` is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclAllReduce except sum.') _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') - shared_name = op.get_attr('shared_name') + '_grad' + shared_name = op.get_attr('shared_name') + b'_grad' with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( @@ -160,7 +162,7 @@ def _reduce_sum_grad(op, grad): Raises: LookupError: If the reduction attribute of op is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclReduce except sum.') _check_device(grad, expected=op.device) @@ -180,7 +182,7 @@ def broadcast(tensor): A tensor with the value of `src_tensor`, which can be used as input to ops on other GPU devices. """ - _check_graph_mode() + _validate_and_load_nccl_so() _check_device(tensor) with ops.device(tensor.device): @@ -212,7 +214,7 @@ def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() shared_name = _get_shared_name() res = [] @@ -234,7 +236,7 @@ def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() for t in tensors: _check_device(t) @@ -246,14 +248,10 @@ def _apply_reduce(reduction, tensors): return result -_lock = threading.Lock() -_shared_name_counter = 0 - - def _get_shared_name(): global _shared_name_counter - with _lock: + with _module_lock: val = _shared_name_counter _shared_name_counter += 1 return 'c%s' % val @@ -266,6 +264,25 @@ def _check_device(tensor, expected=None): raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) -def _check_graph_mode(): +def _maybe_load_nccl_ops_so(): + """Loads nccl ops so if it hasn't been loaded already.""" + + with _module_lock: + global _nccl_ops_so + if not _nccl_ops_so: + _nccl_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile('_nccl_ops.so')) + + +def _validate_and_load_nccl_so(): + """Validates calling context and loads nccl ops so file. + + Raises: + ValueError: Ops are not supported. + errors_impl.NotFoundError: nccl library is not installed. + """ + if context.executing_eagerly(): raise ValueError('Nccl ops are not supported in eager mode') + + _maybe_load_nccl_ops_so() diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 13aa1d7e7a11877373a848c1ba865aa418790cd0..bbdf962d0480e52045d31f65b3d137ed3f11f2f1 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -19,6 +19,7 @@ py_library( "python/training/drop_stale_gradient_optimizer.py", "python/training/elastic_average_optimizer.py", "python/training/external_optimizer.py", + "python/training/ggt.py", "python/training/lazy_adam_optimizer.py", "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", @@ -28,15 +29,19 @@ py_library( "python/training/reg_adagrad_optimizer.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", + "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", @@ -194,6 +199,25 @@ py_test( ], ) +py_test( + name = "weight_decay_optimizers_test", + srcs = ["python/training/weight_decay_optimizers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], @@ -302,3 +326,21 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "ggt_test", + srcs = ["python/training/ggt_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 4c13c8e247185213b798eb733ddcf65a07a8f64d..3e63e99030c46c254625ca8fdccce614cd60e8b0 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,15 +22,18 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -46,6 +49,10 @@ _allowed_symbols = [ 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'MomentumWOptimizer', + 'AdamWOptimizer', + 'DecoupledWeightDecayExtension', + 'extend_with_decoupled_weight_decay', 'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'MultitaskOptimizerWrapper', @@ -53,7 +60,8 @@ _allowed_symbols = [ 'ElasticAverageOptimizer', 'ElasticAverageCustomGetter', 'ModelAverageOptimizer', - 'ModelAverageCustomGetter' + 'ModelAverageCustomGetter', + 'GGTOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index bc92a7006f1a0a56adafc486a75afa94e965cb2c..915e6504e1e59ff247a2715820bc31a4d4cc1944 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -198,11 +198,11 @@ class AdaMaxOptimizerTest(test.TestCase): self.assertTrue(beta1_power is not None) self.assertIn(beta1_power, opt_variables) - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -224,8 +224,10 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), + rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), + rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py new file mode 100644 index 0000000000000000000000000000000000000000..928c453517f825ed2d305ec498d07ac29c065f1a --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -0,0 +1,312 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GGT for Tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +class GGTOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the GGT algorithm. + + GGT has an advantage over sgd and adam on large models with poor conditioning, + for example language models and CNNs, + see [ABCHSZZ 2018]([pdf](https://arxiv.org/pdf/1806.02958.pdf)). + """ + + def __init__(self, + learning_rate=0.001, + beta1=0.9, + use_locking=False, + name="GGT", + window=10, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Construct a new GGT optimizer. + + Initialization: + + ``` + t <- 0 (Initialize timestep) + grad_buffer <- 0 (Initialize buffer for keeping past gradients) + flat_grad <- 0 (Initialize flattened gradient that contains gradients of all + variables) + m_0 <- 0 (Initialize 1st moment vector) + ``` + + Suppose all variables and their gradients are concatenated into vectors + `flat_vars` and `flat_grad`. The update rule for `flat_vars` + uses an optimization described at the beginning of section 2 of the paper: + + ``` + t <- t + 1 + + m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad + grad_buffer[(t-1) % window, :] <- m_t + + M <- grad_buffer^T / sqrt(min(t, window)) + U, sigma, _ <- SVD(M^TM + I * svd_eps) + + sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3) + sigma_sqrt_min <- min(sqrt(sigma)) + + if sigma_sqrt_min > eps: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min + else: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + flat_vars <- flat_vars - learning_rate * new_step + ``` + + GGT provides the power of full-matrix adaptive regularization at a cost not + much larger than SGD. As a result it is suited for large models where the + gradient covariance matrix has a poor condition number that slows down first + order methods. + GGT uses the preconditioner from full-matrix AdaGrad, with gradient history + attenuated exponentially as in Adam, and truncated to a window parameter. + It has provable guarantees even for non-convex optimization that is never + significantly worse than SGD and in some cases better. + + Args: + learning_rate: A float hyperparameter. The learning rate. + beta1: A float hyperparameter. The exponential decay rate for the 1st + moment estimates. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "GGT". + window: An integer hyperparameter. The number of first moments to keep in + computing the adaptive preconditioner. + eps: A float hyperparameter. Used to truncate small eigenvalues of the + gradient covariance matrix. + svd_eps: A float hyperparameter. Used to stabilize SVD. + sigma_eps: A float hyperparameter. Used to regularize matrix inversion. + """ + super(GGTOptimizer, self).__init__(use_locking, name) + self._set_hyper("lr", learning_rate) + self._set_hyper("beta1", beta1) + self._set_hyper("window", window) + self._set_hyper("eps", eps) + self._set_hyper("svd_eps", svd_eps) + self._set_hyper("sigma_eps", sigma_eps) + + self.index_dict = {} + self.shape_dict = {} + + def _create_vars(self, var_list, state): + # Construct ordered dictionary for variable dimensions, sorted by name. + shape_dict = {} + for v in var_list: + shape_dict[v.name] = np.prod(v.get_shape()).value + self.shape_dict = collections.OrderedDict( + sorted(shape_dict.items(), key=lambda t: t[0])) + + # Assign each variable its location in flat_grad. The locations are based on + # the order of sorted names. + idx = 0 + for v_name, v_dim in self.shape_dict.items(): + self.index_dict[v_name] = idx + idx += v_dim + + state.create_non_slot( + initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype), + name="global_step") + + # Buffer for keeping past gradients. + window = state.get_hyper("window") + grad_buffer_init = array_ops.zeros( + [window, idx], dtype=var_list[0].dtype.base_dtype) + state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer") + + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="moment1") + + # Flattened gradient that contains gradients for all variables in the model. + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="flat_grad") + + def _get_global_step(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("global_step") + + def _get_moment1(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("moment1") + + def _get_grad_buffer(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("grad_buffer") + + def _get_flat_grad(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("flat_grad") + + def _apply_sparse(self, grad, var): + raise NotImplementedError("Sparse gradient updates are not supported.") + + def _prepare(self, state): + self._variables = [] + + def _apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _resource_apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _finish(self, state): + var_dtype = self._variables[0].dtype.base_dtype + # Update global step. + global_step = self._get_global_step(state) + update_global_step = state_ops.assign_add(global_step, 1.) + + # Update the first moment estimate. + beta1 = state.get_hyper("beta1", dtype=var_dtype) + moment1 = self._get_moment1(state) + flat_grad = self._get_flat_grad(state) + # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t + update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) + + # Update the gradient buffer. + window = state.get_hyper("window") + grad_buffer = self._get_grad_buffer(state) + next_grad_index = math_ops.floormod( + math_ops.to_int32(update_global_step - 1.), window) + # grad_buffer[(t-1) % window] := moment1_t + update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, + update_moment1) + + # Compute the update step. + eps = state.get_hyper("eps", dtype=var_dtype) + svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) + sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) + lr = state.get_hyper("lr", dtype=var_dtype) + denom = math_ops.sqrt( + math_ops.minimum( + ops.convert_to_tensor(update_global_step), + ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) + moment1_2d = array_ops.expand_dims(update_moment1, -1) + + # m = grad_buffer^T / sqrt(min(t, window)) + # m has shape [model dimension, window], where model dimension is the sum + # of the dimensions of the flattened variables. + m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) + + # sigma, u, _ = SVD(m^Tm + I * svd_eps) + mm = math_ops.matmul(m, m, transpose_a=True) + damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps + sigma, u, _ = linalg_ops.svd(mm + damping) + sigma_sqrt = math_ops.sqrt(sigma) + sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) + + # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 + # We add sigma_eps to alleviate numerical instability. + # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. + sigma_sqrt_inv = math_ops.divide( + math_ops.cast(1.0, dtype=var_dtype), + math_ops.pow(sigma_sqrt + sigma_eps, 3)) + + # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the + # inversion of a model dimension by model dimension matrix is needed. To + # speed up this computation we calculate the following instead: + # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. + new_step = array_ops.expand_dims( + array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) + head = math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag(sigma_sqrt_inv), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + + # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for + # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using + # Woodbury's identity. + # For full derivation please see paper at + # https://arxiv.org/pdf/1806.02958.pdf + tail = moment1_2d - math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag( + math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), + sigma)), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + scaled_tail = math_ops.divide(tail, sigma_sqrt_min) + + update_new_step = control_flow_ops.cond( + sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), + lambda: math_ops.add(new_step, head)) + + # Update each variable. + update_step = [] + for var in self._variables: + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + var_update_correct_shape = array_ops.reshape( + update_new_step[start_index:end_index], var.get_shape()) + var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) + update_step.append(var_updated) + + return control_flow_ops.group(update_step) diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py new file mode 100644 index 0000000000000000000000000000000000000000..42162960b049cd90c663989fb4fc9d7f179a84ff --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GGTOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def ggt_update_numpy(param, + g_t, + lr, + grad_buffer, + m, + window, + t, + beta1=0.9, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Tests the correctness of one step of GGT.""" + m_t = m * beta1 + (1 - beta1) * g_t + grad_buffer[((t - 1) % window), :] = m_t + m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window))) + mm = np.dot(np.transpose(m_matrix), m_matrix) + damping = np.eye(window) * svd_eps + u, sigma, _ = np.linalg.svd(mm + damping) + + sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3) + new_step = np.linalg.multi_dot([ + m_matrix, u, + np.diag(sigma_sqrt_inv), + np.transpose(u), + np.transpose(m_matrix), m_t + ]) + + sigma_sqrt_min = np.sqrt(sigma).min() + + if sigma_sqrt_min > eps: + new_step += (m_t - np.linalg.multi_dot([ + m_matrix, u, + np.diag(1.0 / sigma), + np.transpose(u), + np.transpose(m_matrix), m_t + ])) * (1.0 / sigma_sqrt_min) + + param_t = param - lr * new_step + return param_t, m_t, grad_buffer + + +class GGTOptimizerTest(test.TestCase): + + def doTestBasic(self, use_resource=False): + # SVD does not support float16 + for i, dtype in enumerate([dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0 = 0.0 + window = 3 + grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype) + lr = 0.001 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np, name="var0") + var1 = variables.Variable(var1_np, name="var1") + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = GGTOptimizer(learning_rate=lr, window=window) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + self.assertTrue(m_t is not None) + self.assertTrue(grad_buffer_t is not None) + self.assertTrue(g_t is not None) + self.assertIn(m_t, opt_variables) + self.assertIn(grad_buffer_t, opt_variables) + self.assertIn(g_t, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + + # Run 3 steps of GGT + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if t == 1: + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.], + [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t)) + elif t == 2: + self.assertAllCloseAccordingToType( + np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], + [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]), + self.evaluate(grad_buffer_t)) + else: + self.assertAllCloseAccordingToType( + np.array([0.0271, 0.0271, 0.00271, 0.00271]), + self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, + 0.001], [0.019, 0.019, 0.0019, 0.0019], + [0.0271, 0.0271, 0.00271, 0.00271]]), + self.evaluate(grad_buffer_t)) + + self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01], + self.evaluate(g_t)) + + var_np = np.append(var0_np, var1_np) + grads_np = np.append(grads0_np, grads1_np) + var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr, + grad_buffer, m0, window, t) + + var0_np = var_np[:2] + var1_np = var_np[2:] + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index a7c97a1da2baf29914337094c6153447c997af08..b6b10e500b6af80ab61cbf74077ea8e70800662f 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -62,7 +62,7 @@ class ModelAverageCustomGetter(object): """ def __init__(self, worker_device): - """Create a new `ElasticAverageCustomGetter`. + """Create a new `ModelAverageCustomGetter`. Args: worker_device: String. Name of the `worker` job. diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cf40eb7b2d11c98b93c51213145ca4e2670318 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -0,0 +1,362 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import adam +from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer +from tensorflow.python.util.tf_export import tf_export + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + self._decay_var_list = None # is set in minimize or apply_gradients + self._weight_decay = weight_decay + # The tensors are initialized in call to _prepare + self._weight_decay_tensor = None + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, colocate_gradients_with_ops=False, + name=None, grad_loss=None, decay_var_list=None): + """Add operations to minimize `loss` by updating `var_list` with decay. + + This function is the same as Optimizer.minimize except that it allows to + specify the variables that should be decayed using decay_var_list. + If decay_var_list is None, all variables in var_list are decayed. + + For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, global_step=global_step, var_list=var_list, + gate_gradients=gate_gradients, aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + grad_loss=grad_loss) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None, + decay_var_list=None): + """Apply gradients to variables and decay the variables. + + This function is the same as Optimizer.apply_gradients except that it + allows to specify the variables that should be decayed using + decay_var_list. If decay_var_list is None, all variables in var_list + are decayed. + + For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, global_step=global_step, name=name) + + def _prepare(self): + weight_decay = self._weight_decay + if callable(weight_decay): + weight_decay = weight_decay() + self._weight_decay_tensor = ops.convert_to_tensor( + weight_decay, name="weight_decay") + # Call the optimizers _prepare function. + super(DecoupledWeightDecayExtension, self)._prepare() + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub(self._weight_decay * var, self._use_locking) + return control_flow_ops.no_op() + + def _decay_weights_sparse_op(self, var, indices, scatter_add): + if not self._decay_var_list or var in self._decay_var_list: + return scatter_add(var, indices, -self._weight_decay * var, + self._use_locking) + return control_flow_ops.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + def _apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) + + def _resource_apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( + grad, var) + + def _apply_sparse(self, grad, var): + scatter_add = state_ops.scatter_add + decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._apply_sparse( + grad, var) + + def _resource_scatter_add(self, x, i, v, _=None): + # last argument allows for one overflow argument, to have the same function + # signature as state_ops.scatter_add + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + scatter_add = self._resource_scatter_add + decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( + grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@tf_export("contrib.opt.MomentumWOptimizer") +class MomentumWOptimizer(DecoupledWeightDecayExtension, + momentum_opt.MomentumOptimizer): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the Momentum Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.MomentumOptimizer, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate, momentum, + use_locking=False, name="MomentumW", use_nesterov=False): + """Construct a new MomentumW optimizer. + + For further information see the documentation of the Momentum Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A `Tensor` or a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate, weight_decay and momentum + can each be a callable that takes no arguments and returns the actual value + to use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility + """ + super(MomentumWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, momentum=momentum, + use_locking=use_locking, name=name, use_nesterov=use_nesterov) + + +@tf_export("contrib.opt.AdamWOptimizer") +class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="AdamW"): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..76d8a5697acb79e7748175c4a81dfdd85807dd49 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -0,0 +1,188 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import weight_decay_optimizers +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, + beta2=0.999, epsilon=1e-8): + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - + (param * WEIGHT_DECAY)) + return param_t, m_t, v_t + + +def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): + # v, t are not needed for momentum optimizer + m = momentum * m + g_t + param_t = param - lr * m - param * WEIGHT_DECAY + return param_t, m, None + + +class WeightDecayOptimizerTest(test.TestCase): + + def doTest(self, optimizer, update_fn, optimizer_name, slot_name, + use_resource=False, do_sparse=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), + constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), + constant_op.constant([2])) + else: + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = optimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of the optimizer + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) + var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), + opt.get_slot(var=var0, name=slot_name).name) + + +class AdamWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) + + def testSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True) + + +class MomentumWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) + + def testSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True) + + +class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adam.AdamOptimizer) + return adamw(WEIGHT_DECAY) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index d538ad0fb02699ed8514f512208914f629a47436..631d4f44dfb646541244bfe1d15136dd29f02703 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): # Non-slot variables end up on the same device(s). - state.create_non_slot(initial_value=state.get_hyper("beta1"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"), name="beta1_power") - state.create_non_slot(initial_value=state.get_hyper("beta2"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"), name="beta2_power") # Create slots for the first and second moments. diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 548b494bf1df63ec28629f2a6c8a4be22d0e423e..06ab58188a2fffa0e3a810d451875ca951a077b9 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -33,8 +33,8 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops @@ -43,15 +43,15 @@ from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -88,29 +88,6 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): - """A Checkpointable object which returns a more complex SaveableObject.""" - - def __init__(self): - self.non_dep_variable = variable_scope.get_variable( - name="non_dep_variable", initializer=6., use_resource=True) - self.mirrored = variable_scope.get_variable( - name="mirrored", initializer=15., use_resource=True) - - def _gather_saveables_for_checkpoint(self): - def _saveable_factory(name=self.non_dep_variable.name): - return _MirroringSaveable( - primary_variable=self.non_dep_variable, - mirrored_variable=self.mirrored, - name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - # The Saver sorts by name before parsing, so we need a name property. - @property - def name(self): - return self.non_dep_variable.name - - class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -122,7 +99,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -137,11 +114,11 @@ class CheckpointingTests(test.TestCase): optimizer.minimize( other_model(input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph, _ = ( - checkpointable_utils._serialize_object_graph( + util._serialize_object_graph( root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. @@ -226,11 +203,11 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -240,7 +217,7 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") @@ -266,7 +243,7 @@ class CheckpointingTests(test.TestCase): # Preserve beta1_power and beta2_power when appying gradients so we can # test that they've been restored correctly. beta1=1.0, beta2=1.0) - on_create_root = checkpointable_utils.Checkpoint( + on_create_root = util.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) @@ -298,7 +275,7 @@ class CheckpointingTests(test.TestCase): for training_continuation in range(3): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(core_saver.latest_checkpoint(checkpoint_directory)) @@ -322,7 +299,7 @@ class CheckpointingTests(test.TestCase): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) @@ -347,7 +324,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, session.run(root.save_counter)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. @@ -359,7 +336,7 @@ class CheckpointingTests(test.TestCase): graph=ops.get_default_graph()), test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) @@ -381,7 +358,7 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: disable=cell-var-from-loop - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() @@ -392,7 +369,7 @@ class CheckpointingTests(test.TestCase): model = MyModel() # Don't actually train so we can test variable values optimizer = adam.AdamOptimizer(0.) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) @@ -442,7 +419,7 @@ class CheckpointingTests(test.TestCase): optimizer = adam.AdamOptimizer(learning_rate=0.05) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = util.Checkpoint( model=model, optimizer=optimizer) for _ in range(2): checkpoint.save(checkpoint_prefix) @@ -453,12 +430,12 @@ class CheckpointingTests(test.TestCase): optimizer.apply_gradients( [(g, v) for g, v in zip(grad, model.vars)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() - root.var = checkpointable_utils.add_variable( + root = tracking.Checkpointable() + root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): @@ -468,28 +445,28 @@ class CheckpointingTests(test.TestCase): # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. - self.evaluate(checkpointable_utils.gather_initializers( - checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) + self.evaluate(util.gather_initializers( + util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( + no_slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = checkpointable_utils.CheckpointableSaver(root).save( + slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = checkpointable_utils.CheckpointableSaver( + slot_status = util.CheckpointableSaver( new_root).restore(slots_path) - no_slot_status = checkpointable_utils.CheckpointableSaver( + no_slot_status = util.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() - new_root.var = checkpointable_utils.add_variable( + new_root.var = util.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() @@ -525,12 +502,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) @@ -543,12 +520,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) save_path = saver.save(checkpoint_prefix) saver.restore(save_path) before_ops = graph.get_operations() @@ -565,10 +542,10 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = checkpointable_utils.Checkpoint( + first_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( first_root_checkpointable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) @@ -581,7 +558,7 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = checkpointable_utils.Checkpoint( + second_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_checkpointable.restore(None).initialize_or_restore() @@ -616,7 +593,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore(self): def _templated(): @@ -631,7 +608,7 @@ class TemplateTests(test.TestCase): save_template = template.make_template("s1", _templated) v1_save, _, v2_save = save_template() optimizer = adam.AdamOptimizer(0.0) - save_root = checkpointable_utils.Checkpoint( + save_root = util.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in optimizer.variables()]) @@ -643,7 +620,7 @@ class TemplateTests(test.TestCase): load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) - load_root = checkpointable_utils.Checkpoint( + load_root = util.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2 = load_template() @@ -664,12 +641,12 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable @@ -712,7 +689,7 @@ class CheckpointCompatibilityTests(test.TestCase): sess=session, save_path=checkpoint_prefix, global_step=root.optimizer_step) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): @@ -721,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = checkpointable_utils.CheckpointableSaver(root) + object_saver = util.CheckpointableSaver(root) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py index 26724f66c2a1db1d01577b31b739af18f51d3976..24cdab462665adc6297b0e0821455a545c3880af 100644 --- a/tensorflow/contrib/optimizer_v2/momentum_test.py +++ b/tensorflow/contrib/optimizer_v2/momentum_test.py @@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase): with context.eager_mode(): self.doTestBasic(use_resource=True, use_callable_params=True) - @test_util.run_in_graph_and_eager_modes(reset_test=True) def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): @@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var0 + var1) - else: - loss = math_ops.reduce_sum(var0 + var1) + loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") @@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var2 + var3) - else: - loss = math_ops.reduce_sum(var2 + var3) + loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f537318b32986c941b6c41eb363929e906027dd7..8c11d8bcfdf76bc12e13ffb58f917978e966476e 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -162,12 +162,12 @@ def _get_processor(v): def _var_key_v2(var): """Key for representing a primary variable, for looking up slots.""" # pylint: disable=protected-access - if hasattr(var, "_mirrored_container"): - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + if hasattr(var, "_distributed_container"): + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - return mirrored_container._unique_id - return mirrored_container._shared_name + return distributed_container._unique_id + return distributed_container._shared_name if context.executing_eagerly(): return var._unique_id return var.op.name @@ -211,8 +211,9 @@ class _OptimizerV2State(object): # This dict starts with a single item with key "None" with the hyper # parameter value converted to a Tensor. Other items have dtype keys # with that Tensor cast to that dtype. - self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} - for name, (dynamic, value) in hyper.items() if not dynamic} + with ops.init_scope(): + self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} + for name, (dynamic, value) in hyper.items() if not dynamic} self._slots = {} self._non_slot_dict = {} # Extra state to help Optimizers implement Checkpointable. Holds information @@ -765,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # *after* loss() is evaluated, so we know what loss reduction it uses. if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: @@ -783,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of towers (non-callable-loss case). if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: @@ -895,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index 8599af32f6f4cc5529cd812e83c02ef3812cb71e..ec033c4a0163ba9ed39e55fa9e92dfdadc9a1b2f 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test class OptimizerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -113,7 +113,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoVariables(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: # pylint: disable=cell-var-from-loop @@ -128,7 +128,7 @@ class OptimizerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'No.*variables'): sgd_op.minimize(loss) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -146,7 +146,7 @@ class OptimizerTest(test.TestCase): # var1 has no gradient sgd_op.minimize(loss, var_list=[var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -162,7 +162,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.minimize(loss, var_list=[var0, var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -176,7 +176,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.apply_gradients([(None, var0), (None, var1)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -216,7 +216,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): x = ops.convert_to_tensor(1.0) def f(): diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 6ca7fe8b6e59b0dc24be76262d4f54f387e53e48..f2171efc959362c1e4392fefbd5842f0883571d7 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -6,12 +6,13 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", + "tf_cc_test", "tf_gen_op_libs", "tf_custom_op_library", "tf_custom_op_py_library", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "py_test") cc_library( name = "all_ops", @@ -84,6 +85,22 @@ py_test( ":init_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradient_checker", + ], +) + +tf_cc_test( + name = "periodic_resample_op_cc_test", + size = "small", + srcs = [ + "ops/array_ops_test.cc", + ], + deps = [ + ":all_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc index e18923c8aae74c66ce78f98eb5e615e99463af74..514689cf4543cd08632bd0321a78fa933c456467 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -22,4 +22,9 @@ namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU), PeriodicResampleOp); + +REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad") + .Device(DEVICE_CPU), + PeriodicResampleOpGrad); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index 3ab588c45881c8f93b4c1bcdf7ccde39086a1ed7..42fba81a5cb9490c093062048f269704a110756a 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -25,92 +25,202 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/work_sharder.h" namespace { -template -IndexT compute_input_index( - IndexVecT* target_dimensions, const IndexT& output_index, - const IndexVecT& original_dimensions, const int& adjustable_dimension, - const std::vector& dimension_ceiling, - const std::vector& cumulative_dimensions, IndexT* result, - std::vector* output_indices, const int& rank) { - *result = 0; - output_indices->clear(); +// Computes input tensor index for given output index during forward +// propagation through periodic_resample operation. +class InputIndexer { + public: + InputIndexer(const std::vector& output_dimensions, + const tensorflow::TensorShape& input_shape, + int adjustable_dimension) + : output_dimensions_(output_dimensions), + adjustable_dimension_(adjustable_dimension), + rank_(input_shape.dims()), + linear_output_index_(0), + linear_input_index_(0), + adjustable_dimension_carriage_sum_(0) { + auto input_dimensions = TensorShapeToVector(input_shape); + // factors by which input_dimensions increases/decreases w.r.t. + // output_dimensions + dimension_ceiling_ = + ComputeDimensionCeiling(output_dimensions, input_dimensions); + cumulative_dimensions_ = ComputeCumulativeDimensions(); + + output_indices_.resize(output_dimensions_.size()); + input_indices_.resize(output_dimensions_.size()); + + // Compute index_factors + index_factors_.resize(rank_); + tensorflow::int64 last_index_factor = 1; + for (auto r = rank_ - 1; r >= 0; --r) { + index_factors_[r] = last_index_factor; + last_index_factor *= input_dimensions[r]; + } + } + + tensorflow::int64 linear_input_index() const { return linear_input_index_; } + + void MoveToOutputIndex(tensorflow::int64 output_index); + void IncrementOutputIndex(); + + private: + void RecomputeInputAdjustableDimensionIndex() { + tensorflow::int64 index = adjustable_dimension_carriage_sum_; + index *= output_dimensions_[adjustable_dimension_]; + index += output_indices_[adjustable_dimension_]; + input_indices_[adjustable_dimension_] = index; + } + + std::vector TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape); + + std::vector ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions); + + std::vector ComputeCumulativeDimensions(); + + const std::vector output_dimensions_; + std::vector dimension_ceiling_; + std::vector index_factors_; + std::vector cumulative_dimensions_; + std::vector output_indices_; + std::vector input_indices_; + + const int adjustable_dimension_; + const int rank_; + tensorflow::int64 linear_output_index_; + tensorflow::int64 linear_input_index_; + tensorflow::int64 adjustable_dimension_carriage_sum_; +}; + +void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) { + linear_output_index_ = output_index; + linear_input_index_ = 0; // un-rasterize the output index auto last_reduced_i = output_index; - for (auto r = rank - 1; r >= 0; --r) { - (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + output_indices_[r] = last_reduced_i % output_dimensions_[r]; last_reduced_i = - (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r]; + (last_reduced_i - output_indices_[r]) / output_dimensions_[r]; } + tensorflow::int64 carriage_sum = 0; + for (int qi = 0; qi < rank_; ++qi) { + if (qi == adjustable_dimension_) continue; + carriage_sum += cumulative_dimensions_[qi] * + (output_indices_[qi] % dimension_ceiling_[qi]); + } + adjustable_dimension_carriage_sum_ = carriage_sum; + // rasterize the input index - IndexT last_index_factor = 1; - for (auto r = rank - 1; r >= 0; --r) { - IndexT index = 0; - if (r != adjustable_dimension) - index = (*output_indices)[r] / dimension_ceiling[r]; - else { - for (int qi = 0; qi < rank; ++qi) { - if (qi == adjustable_dimension) continue; - index += cumulative_dimensions[qi] * - ((*output_indices)[qi] % dimension_ceiling[qi]); - } - index *= (*target_dimensions)[adjustable_dimension]; - index += (*output_indices)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + if (r != adjustable_dimension_) { + input_indices_[r] = output_indices_[r] / dimension_ceiling_[r]; + } else { + RecomputeInputAdjustableDimensionIndex(); } - *result += last_index_factor * index; - last_index_factor *= original_dimensions[r]; } + for (auto r = rank_ - 1; r >= 0; --r) { + linear_input_index_ += index_factors_[r] * input_indices_[r]; + } +} + +void InputIndexer::IncrementOutputIndex() { + linear_output_index_++; + for (auto r = rank_ - 1; r >= 0; --r) { + auto old_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); + output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r]; + if (r != adjustable_dimension_) { + auto new_input_index = output_indices_[r] / dimension_ceiling_[r]; + linear_input_index_ += + (new_input_index - input_indices_[r]) * index_factors_[r]; + + input_indices_[r] = new_input_index; + + auto new_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); - return *result; + adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ - + old_carriage_sum_increment + + new_carriage_sum_increment; + } + + if (output_indices_[r] != 0) { + // No more carries to higher indices. + break; + } + } + auto old_adjustable_dimension_input_index = + input_indices_[adjustable_dimension_]; + RecomputeInputAdjustableDimensionIndex(); + linear_input_index_ += (input_indices_[adjustable_dimension_] - + old_adjustable_dimension_input_index) * + index_factors_[adjustable_dimension_]; } -template // both types are needed here b/c IndexVecT and - // InputDataT are not related - void - fill_periodic_tensor( - tensorflow::OpKernelContext* context, - const IndexVecT& desired_shape, - const tensorflow::Tensor& input_tensor) { - // input is a strided array (last index is fastest, C-ordered) - auto input = input_tensor.flat(); - const int rank = input_tensor.dims(); - // original and target dimensions - std::vector original_dimensions(rank), - target_dimensions(rank); - tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1); - // factors by which original_dimensions increases/decreases w.r.t. - // target_dimensions - std::vector dimension_ceiling(rank), - cumulative_dimensions(rank); - // index of adjustable dimension - int adjustable_dimension; - tensorflow::TensorShape output_shape; +std::vector InputIndexer::TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape) { + std::vector result(tensor_shape.dims()); + int count = 0; + for (const auto dim_info : tensor_shape) { + result[count] = dim_info.size; + ++count; + } + return result; +} - // requires that the rank of the input tensor and length of the desired shape - // are equal - OP_REQUIRES(context, rank == desired_shape.size(), - tensorflow::errors::InvalidArgument( - "periodic_resample expects the rank of the input tensor, ", - rank, ", to be the same as the length of the desired shape, ", - desired_shape.size(), ".")); +std::vector InputIndexer::ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions) { + std::vector dimension_ceiling(input_dimensions.size()); + for (size_t i = 0; i < input_dimensions.size(); ++i) { + dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) / + input_dimensions[i]; + } + return dimension_ceiling; +} - bool found = false; - const auto& input_tensor_shape = input_tensor.shape(); +std::vector InputIndexer::ComputeCumulativeDimensions() { + std::vector cumulative_dimensions(rank_); + int count = 0; + for (int i = 0; i < rank_; ++i) { + if (count == 0) { + cumulative_dimensions[count] = 1; + } else { + cumulative_dimensions[count] = + cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1]; + } + ++count; + } + return cumulative_dimensions; +} +template +void process_desired_shape(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& input_tensor_shape, + const IndexVecT& desired_shape, + int* adjustable_dimension, + std::vector* target_dimensions, + tensorflow::int64* output_size) { + tensorflow::int64 new_sliced_size = 1; + bool found = false; + const int rank = input_tensor_shape.dims(); for (int i = 0; i < rank; ++i) { - // if (desired_shape(i) < 1) { if (desired_shape[i] < 1) { // only one index can be adjustable OP_REQUIRES(context, !found, tensorflow::errors::InvalidArgument( "periodic_resample expects only " "one index to be marked as adjustable.")); - adjustable_dimension = i; + *adjustable_dimension = i; found = true; } else { OP_REQUIRES( @@ -122,9 +232,8 @@ template +void +do_periodic_resample_op(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape, + const tensorflow::Tensor& source_tensor) { + const int rank = source_tensor.dims(); + + // requires that the rank of the input tensor and length of the desired shape + // are equal + OP_REQUIRES(context, rank == desired_shape.dims(), + tensorflow::errors::InvalidArgument( + "periodic_resample expects the rank of the input tensor, ", + rank, ", to be the same as the length of the desired shape, ", + desired_shape.dims(), ".")); + + std::vector target_dimensions(rank); + tensorflow::int64 new_size = 0; + // index of adjustable dimension + int adjustable_dimension = 0; + process_desired_shape(context, original_shape, desired_shape.dim_sizes(), + &adjustable_dimension, &target_dimensions, &new_size); // ensure that the new dimension is greater than zero OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0, @@ -160,11 +293,14 @@ template allocate_output(0, output_shape, &output_tensor)); auto output = output_tensor->flat(); - // memory is allocated for these variables outside the inner loop for - // efficiency (although, I could create a separate class scope for - // this purpose instead) - tensorflow::int64 result = 0; - std::vector output_indices(target_dimensions.size()); + // input is a strided array (last index is fastest, C-ordered) + auto input = source_tensor.flat(); // Fill output tensor with periodically resampled input tensor values - for (tensorflow::int64 output_index = 0; output_index < new_size; - ++output_index) { - output(output_index) = input(compute_input_index( - &target_dimensions, output_index, original_dimensions, - adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result, - &output_indices, rank)); - } + InputIndexer input_indexer(target_dimensions, original_shape, + adjustable_dimension); + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + auto fill_output_tensor = [&input_indexer, &output, &input]( + tensorflow::int64 start, tensorflow::int64 limit) { + InputIndexer local_indexer(input_indexer); + local_indexer.MoveToOutputIndex(start); + for (tensorflow::int64 output_index = start; output_index < limit; + ++output_index) { + if (mode == Mode::kForward) { + output(output_index) = input(local_indexer.linear_input_index()); + } else { + output(local_indexer.linear_input_index()) = input(output_index); + } + local_indexer.IncrementOutputIndex(); + } + }; + ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, + new_size, costPerFillIndex, fill_output_tensor); } +#define DATA_TYPE_SWITCH(data_type, context, CASE) \ + switch (data_type) { \ + CASE(float) \ + CASE(double) \ + CASE(tensorflow::int32) \ + CASE(tensorflow::int64) \ + default: \ + context->CtxFailure(__FILE__, __LINE__, \ + tensorflow::errors::InvalidArgument( \ + "Unsuppored tensor elements type")); \ + break; \ + } + void create_output_tensor( tensorflow::OpKernelContext* context, const tensorflow::Tensor& input_tensor, const tensorflow::DataType& input_tensor_type, - const tensorflow::PartialTensorShape& desired_shape_tensor) { - auto desired_shape = desired_shape_tensor.dim_sizes(); - - // obligatory type switch - switch (input_tensor_type) { - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, input_tensor.shape(), desired_shape, input_tensor); \ break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); + + DATA_TYPE_SWITCH(input_tensor_type, context, CASE); +#undef CASE +} + +void create_grad_tensor(tensorflow::OpKernelContext* context, + const tensorflow::Tensor& grad_tensor, + const tensorflow::DataType& grad_tensor_type, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, original_shape, desired_shape, grad_tensor); \ break; - default:; - } + + DATA_TYPE_SWITCH(grad_tensor_type, context, CASE); +#undef CASE } } // namespace @@ -238,4 +400,25 @@ class PeriodicResampleOp : public tensorflow::OpKernel { tensorflow::PartialTensorShape desired_shape; }; +class PeriodicResampleOpGrad : public tensorflow::OpKernel { + public: + explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context) + : tensorflow::OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("original_shape", &original_shape)); + OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + const tensorflow::Tensor& grad_tensor = context->input(0); + const tensorflow::DataType grad_tensor_type = context->input_dtype(0); + create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape, + desired_shape); + } + + private: + tensorflow::TensorShape original_shape; + tensorflow::PartialTensorShape desired_shape; +}; + #endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc index 82bd79695646e3673c2c78ad99dd2bd200fc2fbf..fd38cd09b4d0939d7955f7839763a8e955b71fa5 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample") .Input("values: T") .Attr("shape: shape") .Output("output: T") - .SetShapeFn(shape_inference::ExplicitShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::PartialTensorShape desired_shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); + shape_inference::ShapeHandle input_tensor_shape = c->input(0); + shape_inference::DimensionHandle num_input_elements = + c->NumElements(input_tensor_shape); + shape_inference::ShapeHandle result_shape_handle; + if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) { + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + desired_shape, &result_shape_handle)); + } else { + const int rank = c->Rank(input_tensor_shape); + std::vector target_dimensions(rank); + tensorflow::int64 new_sliced_size = 1; + int adjustable_dimension = 0; + for (int i = 0; i < rank; ++i) { + if (desired_shape.dim_size(i) < 1) { + adjustable_dimension = i; + } else { + target_dimensions[i] = desired_shape.dim_size(i); + new_sliced_size *= target_dimensions[i]; + } + } + target_dimensions[adjustable_dimension] = + shape_inference::InferenceContext::Value( + num_input_elements) / new_sliced_size; + tensorflow::TensorShape result_shape; + for (int i = 0; i < rank; ++i) { + result_shape.AddDim(target_dimensions[i]); + } + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape( + result_shape, &result_shape_handle)); + } + c->set_output(0, result_shape_handle); + return Status::OK(); + }) .Doc(R"doc( Periodically resample elements of a tensor to conform to `shape`. @@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in )doc"); + +REGISTER_OP("PeriodicResampleOpGrad") + .Attr("T: numbertype") + .Input("grad: T") + .Attr("original_shape: shape") + .Attr("desired_shape: shape") + .Output("grad_values: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::TensorShape original_shape; + TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s)); + c->set_output(0, s); + return Status::OK(); +}); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..43b7c1799ffb2e27f9d15bc6011d49334867b6ec --- /dev/null +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(ArrayOpsTest, PeriodicResample_ShapeFn) { + ShapeInferenceTestOp op("PeriodicResample"); + // Case 1: output shape can be fully inferreed. + PartialTensorShape shape({4, 4, -1}); + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + + TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample") + .Input({"values", 0, DT_INT32}) + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "[2,2,4]", "[4,4,1]"); + // Case 2: output shape can not be inferred - report desired shape. + INFER_OK(op, "[2,2,?]", "[4,4,?]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index a25de55e18b223db2b724aafb54b18d8f48a5baa..31a6fe1d94b8a972087e00cf7c676105b0f1129b 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import numpy from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleErrors(self): input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) with self.test_session(): - variables.global_variables_initializer().run() with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Dimension 3 input tensor has size 4, desired shape has size 1'): @@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): '4, to be the same as the length of the desired shape, 3'): periodic_resample(input_tensor, [None, 4, 4]).eval() + def testPeriodicResampleGradient(self): + desired_shape = numpy.array([4, 4, None]) + result_shape = (4, 4, 1) + input_shape = (2, 2, 4) + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=input_shape) + output = periodic_resample(x, desired_shape) + error = gradient_checker.compute_gradient_error( + x, input_shape, output, result_shape) + self.assertLess(error, 1e-4) + + def testPeriodicResampleShapeInference(self): + with self.test_session() as sess: + # Case 1: output shape can be fully inferreed. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4)) + output = periodic_resample(x, [4, 4, None]) + self.assertEqual(output.shape, [4, 4, 1]) + # Case 2: output shape can not be inferred - report desired shape. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None)) + output = periodic_resample(x, [4, 4, None]) + self.assertTrue(output.shape.is_compatible_with([4, 4, None])) + self.assertEqual(output.shape[2].value, None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py index 348623d8f8d0c2ed60f559eca281343722038100..470e300ccbe7108fd49718341f4a522683366fe3 100644 --- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -21,11 +21,17 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op -from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample +from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader # pylint: enable=unused-import _periodic_resample_op = loader.load_op_library( resource_loader.get_path_to_datafile('_periodic_resample_op.so')) + +@ops.RegisterGradient("PeriodicResample") +def _periodic_resample_grad_cc(op, grad): + return periodic_resample_op_grad( + grad, op.inputs[0].shape, op.get_attr('shape')) diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index b7a98c68e2343e9c8bb4b41556dc96bfe4ef444c..af3b2ad1b531b835f484a155efcc57bbe634f2df 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor): prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Initialize a `ContribEstimatorPredictor`. Args: @@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor): multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor): checkpoint_path = saver.latest_checkpoint(estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_filename_with_path=checkpoint_path)) input_alternative_key = ( diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py index d78d94c2699b14c80e7decee2181d190a6d91f99..a725072e72df2db64cde5ea31ab16e7c2dc5d2ce 100644 --- a/tensorflow/contrib/predictor/core_estimator_predictor.py +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor): estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor): `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor): checkpoint_dir = estimator.model_dir self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_dir=checkpoint_dir)) feed_tensor_info = signature_def.inputs diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index 6e77e934fe19851eea9ed0b74eb7aecc76f6237a..f275bc15adfa0a51a48964dff8edddbd45500e45 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -30,7 +30,8 @@ def from_contrib_estimator(estimator, prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `tf.contrib.learn.Estimator`. Args: @@ -44,6 +45,7 @@ def from_contrib_estimator(estimator, multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -62,13 +64,15 @@ def from_contrib_estimator(estimator, prediction_input_fn, input_alternative_key=input_alternative_key, output_alternative_key=output_alternative_key, - graph=graph) + graph=graph, + config=config) def from_estimator(estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `tf.python.estimator.Estimator`. Args: @@ -79,6 +83,7 @@ def from_estimator(estimator, `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -93,14 +98,19 @@ def from_estimator(estimator, 'tf.contrib.learn.Estimator. You likely want to call ' 'from_contrib_estimator.') return core_estimator_predictor.CoreEstimatorPredictor( - estimator, serving_input_receiver_fn, output_key=output_key, graph=graph) + estimator, + serving_input_receiver_fn, + output_key=output_key, + graph=graph, + config=config) def from_saved_model(export_dir, signature_def_key=None, signature_def=None, tags=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `SavedModel` on disk. Args: @@ -115,6 +125,7 @@ def from_saved_model(export_dir, `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -128,4 +139,5 @@ def from_saved_model(export_dir, signature_def_key=signature_def_key, signature_def=signature_def, tags=tags, - graph=graph) + graph=graph, + config=config) diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py index 578d9424b25dd38f1d77a267d1fdf1ff9ff2da88..a2ef1dc3af0986afacf646f0dc04b7ef857a7f93 100644 --- a/tensorflow/contrib/predictor/predictor_factories_test.py +++ b/tensorflow/contrib/predictor/predictor_factories_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.predictor import predictor_factories from tensorflow.contrib.predictor import testing_common +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import test MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' @@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase): """Test loading from_saved_model with tags.""" predictor_factories.from_saved_model(self._export_dir, tags='serve') + def testFromSavedModelWithSessionConfig(self): + """Test loading from_saved_model with session config.""" + predictor_factories.from_saved_model( + self._export_dir, config=config_pb2.ConfigProto()) + def testFromSavedModelWithBadTags(self): """Test that loading fails for bad tags.""" bad_tags_regex = ('.*? could not be found in SavedModel') @@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase): predictor_factories.from_contrib_estimator( estimator, input_fn, output_alternative_key='sum') + def testFromContribEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=False) + input_fn = testing_common.get_arithmetic_input_fn(core=False) + predictor_factories.from_contrib_estimator( + estimator, input_fn, output_alternative_key='sum', + config=config_pb2.ConfigProto()) + def testFromContribEstimatorWithCoreEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=True) input_fn = testing_common.get_arithmetic_input_fn(core=True) @@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase): input_fn = testing_common.get_arithmetic_input_fn(core=True) predictor_factories.from_estimator(estimator, input_fn) + def testFromCoreEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=True) + input_fn = testing_common.get_arithmetic_input_fn(core=True) + predictor_factories.from_estimator( + estimator, input_fn, config=config_pb2.ConfigProto()) + def testFromCoreEstimatorWithContribEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=False) input_fn = testing_common.get_arithmetic_input_fn(core=False) diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py index 0dbca0f8136e4e618234101ee41c80bc085511c0..95da6d04edc5214d1b5c1851c4ab05c6d7080b9b 100644 --- a/tensorflow/contrib/predictor/saved_model_predictor.py +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -121,7 +121,8 @@ class SavedModelPredictor(predictor.Predictor): input_names=None, output_names=None, tags=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -142,6 +143,7 @@ class SavedModelPredictor(predictor.Predictor): the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Raises: ValueError: If more than one of signature_def_key OR signature_def OR (input_names AND output_names) is specified. @@ -152,7 +154,7 @@ class SavedModelPredictor(predictor.Predictor): self._graph = graph or ops.Graph() with self._graph.as_default(): - self._session = session.Session() + self._session = session.Session(config=config) loader.load(self._session, tags.split(','), export_dir) if input_names is None: diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b9918fdee1ece2bae1ab1459985066a35b6431be..23363617eddd2078db9052a64d70d5f8c234805d 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -155,8 +155,10 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index c83623ec947c1550991352a9dd9a5c6ee9282290..27a933c0f945e53a1838aefd30aed82fadbbc146 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -6,7 +6,7 @@ inference. The details of the transformation implemented in this package is described here [1]. This is done using the -[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). +[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization). Literature has shown that fixed point networks provide comparable performance to floating point networks [2]. This is achieved by modeling the quantization diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 55479bf5f74299bf09f131a6127f9f11d6192d90..e3c48998305e9d9b6c185fd4c0f324fa0449c691 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -121,7 +121,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands( - match.layer_op, match.input_tensor, scaled_weight_tensor) + match.layer_op, match.input_tensor, scaled_weight_tensor, + match.batch_to_space_op) if correction_recip is not None: new_layer_tensor = math_ops.multiply( @@ -149,6 +150,8 @@ def _FindFusedBatchNorms(graph): _FusedBatchNormMatches. """ input_pattern = graph_matcher.OpTypePattern('*') + # In practice, the weight pattern can match a Variable or a SpaceToBatchND + # operation that follows a variable for atrous convolutions. weight_pattern = graph_matcher.OpTypePattern('*') gamma_pattern = graph_matcher.OpTypePattern('*') beta_pattern = graph_matcher.OpTypePattern('*') @@ -160,16 +163,27 @@ def _FindFusedBatchNorms(graph): layer_pattern = graph_matcher.OpTypePattern( 'Conv2D|DepthwiseConv2dNative|MatMul', inputs=[input_pattern, weight_pattern]) + batch_to_space_pattern = graph_matcher.OpTypePattern( + 'BatchToSpaceND', + inputs=[ + layer_pattern, + graph_matcher.OpTypePattern('*'), + graph_matcher.OpTypePattern('*') + ]) + layer_output_pattern = graph_matcher.OneofPattern( + [layer_pattern, batch_to_space_pattern]) # MatMul has a Reshape between it and FusedBatchNorm. matmul_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', inputs=[layer_pattern, - graph_matcher.OpTypePattern('*')]) + 'Reshape', + inputs=[layer_output_pattern, + graph_matcher.OpTypePattern('*')]) batch_norm_pattern = graph_matcher.OpTypePattern( 'FusedBatchNorm', inputs=[ - graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]), - gamma_pattern, beta_pattern, mean_pattern, variance_pattern + graph_matcher.OneofPattern( + [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern, + beta_pattern, mean_pattern, variance_pattern ]) matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( 'Reshape', inputs=[batch_norm_pattern, @@ -192,6 +206,7 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor = None bn_decay_mean_tensor = None bn_decay_var_tensor = None + batch_to_space_op = None layer_op = match_result.get_op(layer_pattern) layer_tensor = match_result.get_tensor(layer_pattern) bn_op = match_result.get_op(batch_norm_pattern) @@ -213,6 +228,7 @@ def _FindFusedBatchNorms(graph): if not output_tensor.consumers(): continue + batch_to_space_op = match_result.get_op(batch_to_space_pattern) input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -276,7 +292,8 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon=batch_epsilon) + batch_epsilon=batch_epsilon, + batch_to_space_op=batch_to_space_op) def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, @@ -380,7 +397,8 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, return correction_scale, correction_recip, correction_offset -def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, + batch_to_space_op): """Clones layer_op with input_tensor and weight_tensor as new inputs.""" new_layer_name = layer_op.name.split('/')[-1] + '_Fold' if layer_op.type == 'Conv2D': @@ -400,12 +418,25 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): transpose_b=layer_op.get_attr('transpose_b'), name=new_layer_name) elif layer_op.type == 'DepthwiseConv2dNative': - return nn.depthwise_conv2d( + conv = nn.depthwise_conv2d( input_tensor, weight_tensor, + rate=layer_op.get_attr('dilations'), strides=layer_op.get_attr('strides'), padding=layer_op.get_attr('padding'), name=new_layer_name) + # Copy the batch to space operation if we have a atrous convolution. + if batch_to_space_op: + batch_to_space_op = layer_op.outputs[0].consumers()[0] + # TODO(suharshs): It's hard to make this name match with the unfused name. + # Restructure this code to not rely on scope at all. + new_batch_to_space_name = batch_to_space_op.name.split('/')[-1] + '_Fold' + conv = array_ops.batch_to_space_nd( + conv, + batch_to_space_op.inputs[1], + batch_to_space_op.inputs[2], + name=new_batch_to_space_name) + return conv else: raise ValueError('Cannot handle operation of type: %s' % layer_op.type) @@ -617,7 +648,8 @@ def _GetBatchNormParams(graph, context, has_scaling): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon=batch_epsilon) + batch_epsilon=batch_epsilon, + batch_to_space_op=None) def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, @@ -651,6 +683,11 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, '/BatchNorm/batchnorm_1/' + mul_scale_name) op_below = mul_scale.inputs[0].op + # Skip over the BatchToSpace operation in the case of atrous convolutions. + batch_to_space_op = None + if op_below.type == 'BatchToSpaceND': + batch_to_space_op = op_below + op_below = op_below.inputs[0].op weights = op_below.inputs[1] match = _GetBatchNormParams( graph=graph, context=context, has_scaling=has_scaling) @@ -691,7 +728,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, context + '/correction_mult') mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) else: - raise ValueError('Cannot handle operation of type: %s' % op_below.op) + raise ValueError('Cannot handle operation of type: %s' % op_below.type) _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', @@ -701,6 +738,13 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, context + '/BatchNorm/batchnorm_1/add_1') corrected_output = conv_or_fc_folded.outputs[0] + # Copy the batch to space operation if we have a atrous convolution. + if batch_to_space_op: + corrected_output = array_ops.batch_to_space_nd( + corrected_output, + batch_to_space_op.inputs[1], + batch_to_space_op.inputs[2], + name=batch_to_space_op.name + '_Fold') if correction_offset is not None: with ops.device(conv_or_fc_folded.device): corrected_output = math_ops.multiply(correction_recip, corrected_output, @@ -898,7 +942,8 @@ class _BatchNormMatch(object): def __init__(self, layer_op, bn_op, output_tensor, input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, variance_tensor, moving_mean_tensor, moving_variance_tensor, - bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon): + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon, + batch_to_space_op): self._layer_op = layer_op self._bn_op = bn_op self._output_tensor = output_tensor @@ -913,6 +958,7 @@ class _BatchNormMatch(object): self._bn_decay_mean_tensor = bn_decay_mean_tensor self._bn_decay_var_tensor = bn_decay_var_tensor self._batch_epsilon = batch_epsilon + self._batch_to_space_op = batch_to_space_op @property def layer_op(self): @@ -969,3 +1015,7 @@ class _BatchNormMatch(object): @property def bn_decay_var_tensor(self): return self._bn_decay_var_tensor + + @property + def batch_to_space_op(self): + return self._batch_to_space_op diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index bfa9d3bf705e327091098a8e416b7902f852605a..7c907ffd92c1ae0c762e41cc429b0e6ce053f6b9 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -438,6 +438,90 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) + def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm, freeze_batch_norm_delay): + """Tests folding: inputs -> AtrousConv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + dilation_rate = 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + if fused_batch_norm: + scale_reshape_op_name = scope + '/BatchNorm_Fold/scale_reshape' + else: + scale_reshape_op_name = scope + '/scale_reshape' + self._AssertInputOpsAre(folded_mul, + [scope + '/correction_mult', scale_reshape_op_name]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) + + scale_reshape = g.get_operation_by_name(scale_reshape_op_name) + self.assertEqual(scale_reshape.type, 'Reshape') + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scale_reshape_op_name + '/shape' + ]) + self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) + + folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') + self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') + self._AssertInputOpsAre( + folded_conv, [scope + '/mul_fold', scope + '/depthwise/SpaceToBatchND']) + if fused_batch_norm: + self._AssertOutputGoesToOps(folded_conv, g, + [scope + '/BatchToSpaceND_Fold']) + else: + self._AssertOutputGoesToOps(folded_conv, g, + [scope + '/depthwise/BatchToSpaceND_Fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/correction_add', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + + def testFoldAtrousConv2d(self): + self._RunTestOverParameters(self._TestFoldAtrousConv2d) + def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, freeze_batch_norm_delay): diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 4e0de24e0e72053dd2497c6e7e492cd21bbd8264..19e5bef1ea48ca4441cdef6b1a74e98e9cf6ddb9 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -194,6 +194,8 @@ def _FindLayersToQuantize(graph): / conv|fc | + [batch_to_space_nd] + | [post_conv_correction] | biasadd|folded_bias @@ -218,8 +220,19 @@ def _FindLayersToQuantize(graph): """ input_pattern = graph_matcher.OpTypePattern('*') weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2') - weight_identity_pattern = graph_matcher.OpTypePattern( + weight_partition_identity_pattern = graph_matcher.OpTypePattern( 'Identity', inputs=[weight_var_pattern]) + weight_partition_concat_pattern = graph_matcher.OpTypePattern( + 'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*']) + weight_identity_pattern = graph_matcher.OpTypePattern( + 'Identity', + inputs=[ + graph_matcher.OneofPattern([ + weight_partition_identity_pattern, + weight_partition_concat_pattern, + weight_var_pattern, + ]) + ]) weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp') folded_weight_pattern = graph_matcher.OpTypePattern('Mul') @@ -236,9 +249,21 @@ def _FindLayersToQuantize(graph): ], ordered_inputs=False) + # For atrous convolutions a BatchToSpaceND will occur after the depthwise + # convolution. + batch_to_space_pattern = graph_matcher.OpTypePattern( + 'BatchToSpaceND', + inputs=[ + layer_pattern, + graph_matcher.OpTypePattern('*'), + graph_matcher.OpTypePattern('*') + ]) + + layer_output_pattern = graph_matcher.OneofPattern( + [batch_to_space_pattern, layer_pattern]) folded_bias_mul_pattern = graph_matcher.OpTypePattern( 'Mul', - inputs=[graph_matcher.OpTypePattern('*'), layer_pattern], + inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern], ordered_inputs=False) post_layer_op_correction_pattern = graph_matcher.OpTypePattern( 'Add', @@ -254,7 +279,7 @@ def _FindLayersToQuantize(graph): ordered_inputs=False) bias_add_pattern = graph_matcher.OpTypePattern( - 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False) + 'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False) # The bias can come from the bias add or the folded bias add. bypass_pattern = graph_matcher.OpTypePattern( @@ -362,14 +387,6 @@ def _FindLayersToQuantize(graph): return layer_matches -def _HasPostActivationBypass(activation_op): - for activation_tensor in activation_op.outputs: - for output_op in activation_tensor.consumers(): - if output_op.type == 'Add': - return True - return False - - class _LayerMatch(object): """Contains all information related to a matched Layer.""" diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index db745aa56212af6a9c20e06ee9e4e5d6e27cf3c3..5e3af0a567536ef6fcfd86d82e94c0ba21077a85 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -276,6 +276,52 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, delay, use_resource) + def testQuantize_AtrousConvWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithoutBatchNorm) + + def _TestQuantize_AtrousConvWithoutBatchNorm( + self, activation, activation_op_name, with_bypass, delay, use_resource): + """Tests quantization: inputs -> atrous conv no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + dilation_rate = 2 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithoutBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, + delay, use_resource) + def _RunBatchNormTestOverParameters(self, test_fn): # TODO(suharshs): Use parameterized test once OSS TF supports it. parameters_list = [ @@ -543,6 +589,61 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, delay, use_resource) + def testQuantize_AtrousConvWithBatchNorm(self): + self._RunBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithBatchNorm) + + def _TestQuantize_AtrousConvWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_resource): + """Tests quantization: inputs -> atrous conv with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + dilation_rate = 2 + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + + # Manually add a bypass (optional) and an activation. + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + fold_batch_norms.FoldBatchNorms(graph, is_training=True) + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, + with_bypass, delay, use_resource) + def _AssertIdempotent(self, graph): # Ensure that calling the rewrite again doesn't change the graph. graph_def_before = str(graph.as_graph_def()) diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index e7360ae03ca535146dee007eeec88373adf39f12..92ca4a1b0c3126ebccf2b525f01f4d6455c4d527 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -27,6 +27,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import googletest conv2d = layers.conv2d @@ -327,6 +329,66 @@ class QuantizeTest(test_util.TensorFlowTestCase): # No ops should be inserted or removed. self.assertEqual(op_names_before_quantize, op_names_after_quantize) + def testSinglePartitionedVariable(self): + self._RunTestOverParameters(self._testSinglePartitionedVariable) + + def _testSinglePartitionedVariable(self, is_training): + # When weights are partitioned into a single partition, the weights variable + # is followed by a identity -> identity (An additional identity node). + partitioner = partitioned_variables.fixed_size_partitioner(1) + graph = ops.Graph() + with graph.as_default(): + with variable_scope.variable_scope('part', partitioner=partitioner): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check that the weight's quant node was added. + op_names = [op.name for op in graph.get_operations()] + self.assertTrue( + 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + + def testMultiplePartitionedVariables(self): + self._RunTestOverParameters(self._testMultiplePartitionedVariables) + + def _testMultiplePartitionedVariables(self, is_training): + # When weights are partitioned into multiple partitions the weights variable + # is followed by a identity -> concat -> identity to group the partitions. + partitioner = partitioned_variables.fixed_size_partitioner(2) + graph = ops.Graph() + with graph.as_default(): + with variable_scope.variable_scope('part', partitioner=partitioner): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check that the weight's quant node was added. + op_names = [op.name for op in graph.get_operations()] + self.assertTrue( + 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 3ff85faf611afad71b6e6203453bbe97c56f9242..79b015a9163f5727caa40b54579c71e57621c92f 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -6,6 +6,32 @@ region your output features depend on. Better yet, using the parameters computed by the library, you can easily find the exact image region which is used to compute each convnet feature. +This library can be used to compute receptive field parameters of popular +convnets: + +

+ +convnet model | receptive field | effective stride | effective padding +:-----------------: | :-------------: | :--------------: | :---------------: +alexnet_v2 | 195 | 32 | 64 +vgg_16 | 212 | 32 | 90 +inception_v2 | 699 | 32 | 318 +inception_v3 | 1311 | 32 | 618 +inception_v4 | 2071 | 32 | 998 +inception_resnet_v2 | 3039 | 32 | 1482 +mobilenet_v1 | 315 | 32 | 126 +mobilenet_v1_075 | 315 | 32 | 126 +resnet_v1_50 | 483 | 32 | 241 +resnet_v1_101 | 1027 | 32 | 513 +resnet_v1_152 | 1507 | 32 | 753 +resnet_v1_200 | 1763 | 32 | 881 + +
+ +A comprehensive table with pre-computed receptive field parameters for different +end-points, input resolutions, and other variants of these networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). + ## Basic usage The main function to be called is `compute_receptive_field_from_graph_def`, @@ -96,9 +122,9 @@ The script will write to stdout the receptive field parameters for many variants of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They are also written to the file `/tmp/rf_benchmark_results.csv`. -TODO: include here a plot for receptive field sizes of different convnets. - -TODO: include table/link to pre-computed RF parameters. +A comprehensive table with pre-computed receptive field parameters for different +networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). ## Compute RF parameters from a graph pbtxt diff --git a/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md new file mode 100644 index 0000000000000000000000000000000000000000..736fbef6e7c66176e74144115f0b1acd6bf6cd2f --- /dev/null +++ b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md @@ -0,0 +1,629 @@ +# Pre-computed receptive field parameters + +## Table with results + +The table below presents the receptive field parameters for several popular +convolutional neural networks. These are computed using the models from the +[TF-Slim +repository](https://github.com/tensorflow/models/tree/master/research/slim), +by using the [rf_benchmark +script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py). + +Questions? See the [FAQ](#faq). + +CNN | resolution | end-point | RF | effective stride | effective padding +:----------------------------: | :--------: | :------------------: | :--: | :--------------: | :---------------: +alexnet_v2 | None | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | None | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | None | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | None | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | None | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | None | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | None | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 224 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 224 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 224 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 224 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 224 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 224 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 224 | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 321 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 321 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 321 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 321 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 321 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 321 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 321 | alexnet_v2/pool5 | 195 | 32 | 64 +vgg_a | None | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | None | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | None | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | None | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | None | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | None | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | None | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | None | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | None | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | None | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | None | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | None | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | None | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 224 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 224 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 224 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 224 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 224 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 224 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 224 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 224 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 224 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 224 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 224 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 224 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 224 | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 321 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 321 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 321 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 321 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 321 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 321 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 321 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 321 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 321 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 321 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 321 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 321 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 321 | vgg_a/pool5 | 150 | 32 | 59 +vgg_16 | None | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | None | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | None | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | None | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | None | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | None | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | None | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | None | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | None | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | None | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | None | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | None | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | None | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 224 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 224 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 224 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 224 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 224 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 224 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 224 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 224 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 224 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 224 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 224 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 224 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 224 | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 321 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 321 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 321 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 321 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 321 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 321 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 321 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 321 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 321 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 321 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 321 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 321 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 321 | vgg_16/pool5 | 212 | 32 | 90 +inception_v2 | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2 | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2 | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2 | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2 | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2 | None | Mixed_3b | 59 | 8 | None +inception_v2 | None | Mixed_3c | 91 | 8 | None +inception_v2 | None | Mixed_4a | 123 | 16 | None +inception_v2 | None | Mixed_4b | 187 | 16 | None +inception_v2 | None | Mixed_4c | 251 | 16 | None +inception_v2 | None | Mixed_4d | 315 | 16 | None +inception_v2 | None | Mixed_4e | 379 | 16 | None +inception_v2 | None | Mixed_5a | 443 | 32 | None +inception_v2 | None | Mixed_5b | 571 | 32 | None +inception_v2 | None | Mixed_5c | 699 | 32 | None +inception_v2 | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2 | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2 | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2 | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2 | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2 | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2 | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2 | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2 | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2 | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2 | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2 | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2 | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2 | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2 | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2 | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2 | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2 | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2 | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2 | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2 | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2 | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2 | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2 | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2 | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2 | 321 | Mixed_5c | 699 | 32 | 349 +inception_v2-no-separable-conv | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2-no-separable-conv | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2-no-separable-conv | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3b | 59 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3c | 91 | 8 | None +inception_v2-no-separable-conv | None | Mixed_4a | 123 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4b | 187 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4c | 251 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4d | 315 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4e | 379 | 16 | None +inception_v2-no-separable-conv | None | Mixed_5a | 443 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5b | 571 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5c | 699 | 32 | None +inception_v2-no-separable-conv | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2-no-separable-conv | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2-no-separable-conv | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2-no-separable-conv | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2-no-separable-conv | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2-no-separable-conv | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2-no-separable-conv | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2-no-separable-conv | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2-no-separable-conv | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2-no-separable-conv | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2-no-separable-conv | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2-no-separable-conv | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2-no-separable-conv | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2-no-separable-conv | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2-no-separable-conv | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2-no-separable-conv | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2-no-separable-conv | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2-no-separable-conv | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2-no-separable-conv | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2-no-separable-conv | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2-no-separable-conv | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2-no-separable-conv | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2-no-separable-conv | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2-no-separable-conv | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2-no-separable-conv | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2-no-separable-conv | 321 | Mixed_5c | 699 | 32 | 349 +inception_v3 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | None | Mixed_5b | 63 | 8 | 18 +inception_v3 | None | Mixed_5c | 95 | 8 | 34 +inception_v3 | None | Mixed_5d | 127 | 8 | 50 +inception_v3 | None | Mixed_6a | 159 | 16 | 58 +inception_v3 | None | Mixed_6b | 351 | 16 | 154 +inception_v3 | None | Mixed_6c | 543 | 16 | 250 +inception_v3 | None | Mixed_6d | 735 | 16 | 346 +inception_v3 | None | Mixed_6e | 927 | 16 | 442 +inception_v3 | None | Mixed_7a | 1055 | 32 | 490 +inception_v3 | None | Mixed_7b | 1183 | 32 | 554 +inception_v3 | None | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 224 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 224 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 224 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 224 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 224 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 224 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 224 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 224 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 224 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 224 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 224 | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 321 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 321 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 321 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 321 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 321 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 321 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 321 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 321 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 321 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 321 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 321 | Mixed_7c | 1311 | 32 | 618 +inception_v4 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | None | Mixed_3a | 15 | 4 | 2 +inception_v4 | None | Mixed_4a | 47 | 4 | 14 +inception_v4 | None | Mixed_5a | 55 | 8 | 14 +inception_v4 | None | Mixed_5b | 87 | 8 | 30 +inception_v4 | None | Mixed_5c | 119 | 8 | 46 +inception_v4 | None | Mixed_5d | 151 | 8 | 62 +inception_v4 | None | Mixed_5e | 183 | 8 | 78 +inception_v4 | None | Mixed_6a | 215 | 16 | 86 +inception_v4 | None | Mixed_6b | 407 | 16 | 182 +inception_v4 | None | Mixed_6c | 599 | 16 | 278 +inception_v4 | None | Mixed_6d | 791 | 16 | 374 +inception_v4 | None | Mixed_6e | 983 | 16 | 470 +inception_v4 | None | Mixed_6f | 1175 | 16 | 566 +inception_v4 | None | Mixed_6g | 1367 | 16 | 662 +inception_v4 | None | Mixed_6h | 1559 | 16 | 758 +inception_v4 | None | Mixed_7a | 1687 | 32 | 806 +inception_v4 | None | Mixed_7b | 1815 | 32 | 870 +inception_v4 | None | Mixed_7c | 1943 | 32 | 934 +inception_v4 | None | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 224 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 224 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 224 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 224 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 224 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 224 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 224 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 224 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 224 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 224 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 224 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 224 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 224 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 224 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 224 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 224 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 224 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 224 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 224 | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 321 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 321 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 321 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 321 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 321 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 321 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 321 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 321 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 321 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 321 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 321 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 321 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 321 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 321 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 321 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 321 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 321 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 321 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 321 | Mixed_7d | 2071 | 32 | 998 +inception_resnet_v2 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | None | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | None | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | None | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | None | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | None | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 224 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 224 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 224 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 224 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 321 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 321 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 321 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 321 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2-same | None | Conv2d_1a_3x3 | 3 | 2 | None +inception_resnet_v2-same | None | Conv2d_2a_3x3 | 7 | 2 | None +inception_resnet_v2-same | None | Conv2d_2b_3x3 | 11 | 2 | None +inception_resnet_v2-same | None | MaxPool_3a_3x3 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_3b_1x1 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_4a_3x3 | 23 | 4 | None +inception_resnet_v2-same | None | MaxPool_5a_3x3 | 31 | 8 | None +inception_resnet_v2-same | None | Mixed_5b | 63 | 8 | None +inception_resnet_v2-same | None | Mixed_6a | 415 | 16 | None +inception_resnet_v2-same | None | PreAuxLogits | 2335 | 16 | None +inception_resnet_v2-same | None | Mixed_7a | 2399 | 32 | None +inception_resnet_v2-same | None | Conv2d_7b_1x1 | 3039 | 32 | None +inception_resnet_v2-same | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2-same | 224 | Conv2d_2a_3x3 | 7 | 2 | 2 +inception_resnet_v2-same | 224 | Conv2d_2b_3x3 | 11 | 2 | 4 +inception_resnet_v2-same | 224 | MaxPool_3a_3x3 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_3b_1x1 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_4a_3x3 | 23 | 4 | 8 +inception_resnet_v2-same | 224 | MaxPool_5a_3x3 | 31 | 8 | 8 +inception_resnet_v2-same | 224 | Mixed_5b | 63 | 8 | 24 +inception_resnet_v2-same | 224 | Mixed_6a | 415 | 16 | 192 +inception_resnet_v2-same | 224 | PreAuxLogits | 2335 | 16 | 1152 +inception_resnet_v2-same | 224 | Mixed_7a | 2399 | 32 | 1168 +inception_resnet_v2-same | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1488 +inception_resnet_v2-same | 321 | Conv2d_1a_3x3 | 3 | 2 | 1 +inception_resnet_v2-same | 321 | Conv2d_2a_3x3 | 7 | 2 | 3 +inception_resnet_v2-same | 321 | Conv2d_2b_3x3 | 11 | 2 | 5 +inception_resnet_v2-same | 321 | MaxPool_3a_3x3 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_3b_1x1 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_4a_3x3 | 23 | 4 | 11 +inception_resnet_v2-same | 321 | MaxPool_5a_3x3 | 31 | 8 | 15 +inception_resnet_v2-same | 321 | Mixed_5b | 63 | 8 | 31 +inception_resnet_v2-same | 321 | Mixed_6a | 415 | 16 | 207 +inception_resnet_v2-same | 321 | PreAuxLogits | 2335 | 16 | 1167 +inception_resnet_v2-same | 321 | Mixed_7a | 2399 | 32 | 1199 +inception_resnet_v2-same | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1519 +mobilenet_v1 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +mobilenet_v1_075 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1_075 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1_075 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1_075 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1_075 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1_075 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1_075 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1_075 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1_075 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1_075 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1_075 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1_075 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1_075 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1_075 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1_075 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1_075 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1_075 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1_075 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1_075 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1_075 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1_075 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1_075 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1_075 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1_075 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1_075 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1_075 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1_075 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1_075 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1_075 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1_075 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1_075 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1_075 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1_075 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1_075 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1_075 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1_075 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1_075 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1_075 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1_075 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1_075 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1_075 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1_075 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +resnet_v1_50 | None | resnet_v1_50/block1 | 35 | 8 | None +resnet_v1_50 | None | resnet_v1_50/block2 | 99 | 16 | None +resnet_v1_50 | None | resnet_v1_50/block3 | 291 | 32 | None +resnet_v1_50 | None | resnet_v1_50/block4 | 483 | 32 | None +resnet_v1_50 | 224 | resnet_v1_50/block1 | 35 | 8 | 15 +resnet_v1_50 | 224 | resnet_v1_50/block2 | 99 | 16 | 47 +resnet_v1_50 | 224 | resnet_v1_50/block3 | 291 | 32 | 143 +resnet_v1_50 | 224 | resnet_v1_50/block4 | 483 | 32 | 239 +resnet_v1_50 | 321 | resnet_v1_50/block1 | 35 | 8 | 17 +resnet_v1_50 | 321 | resnet_v1_50/block2 | 99 | 16 | 49 +resnet_v1_50 | 321 | resnet_v1_50/block3 | 291 | 32 | 145 +resnet_v1_50 | 321 | resnet_v1_50/block4 | 483 | 32 | 241 +resnet_v1_101 | None | resnet_v1_101/block1 | 35 | 8 | None +resnet_v1_101 | None | resnet_v1_101/block2 | 99 | 16 | None +resnet_v1_101 | None | resnet_v1_101/block3 | 835 | 32 | None +resnet_v1_101 | None | resnet_v1_101/block4 | 1027 | 32 | None +resnet_v1_101 | 224 | resnet_v1_101/block1 | 35 | 8 | 15 +resnet_v1_101 | 224 | resnet_v1_101/block2 | 99 | 16 | 47 +resnet_v1_101 | 224 | resnet_v1_101/block3 | 835 | 32 | 415 +resnet_v1_101 | 224 | resnet_v1_101/block4 | 1027 | 32 | 511 +resnet_v1_101 | 321 | resnet_v1_101/block1 | 35 | 8 | 17 +resnet_v1_101 | 321 | resnet_v1_101/block2 | 99 | 16 | 49 +resnet_v1_101 | 321 | resnet_v1_101/block3 | 835 | 32 | 417 +resnet_v1_101 | 321 | resnet_v1_101/block4 | 1027 | 32 | 513 +resnet_v1_152 | None | resnet_v1_152/block1 | 35 | 8 | None +resnet_v1_152 | None | resnet_v1_152/block2 | 163 | 16 | None +resnet_v1_152 | None | resnet_v1_152/block3 | 1315 | 32 | None +resnet_v1_152 | None | resnet_v1_152/block4 | 1507 | 32 | None +resnet_v1_152 | 224 | resnet_v1_152/block1 | 35 | 8 | 15 +resnet_v1_152 | 224 | resnet_v1_152/block2 | 163 | 16 | 79 +resnet_v1_152 | 224 | resnet_v1_152/block3 | 1315 | 32 | 655 +resnet_v1_152 | 224 | resnet_v1_152/block4 | 1507 | 32 | 751 +resnet_v1_152 | 321 | resnet_v1_152/block1 | 35 | 8 | 17 +resnet_v1_152 | 321 | resnet_v1_152/block2 | 163 | 16 | 81 +resnet_v1_152 | 321 | resnet_v1_152/block3 | 1315 | 32 | 657 +resnet_v1_152 | 321 | resnet_v1_152/block4 | 1507 | 32 | 753 +resnet_v1_200 | None | resnet_v1_200/block1 | 35 | 8 | None +resnet_v1_200 | None | resnet_v1_200/block2 | 419 | 16 | None +resnet_v1_200 | None | resnet_v1_200/block3 | 1571 | 32 | None +resnet_v1_200 | None | resnet_v1_200/block4 | 1763 | 32 | None +resnet_v1_200 | 224 | resnet_v1_200/block1 | 35 | 8 | 15 +resnet_v1_200 | 224 | resnet_v1_200/block2 | 419 | 16 | 207 +resnet_v1_200 | 224 | resnet_v1_200/block3 | 1571 | 32 | 783 +resnet_v1_200 | 224 | resnet_v1_200/block4 | 1763 | 32 | 879 +resnet_v1_200 | 321 | resnet_v1_200/block1 | 35 | 8 | 17 +resnet_v1_200 | 321 | resnet_v1_200/block2 | 419 | 16 | 209 +resnet_v1_200 | 321 | resnet_v1_200/block3 | 1571 | 32 | 785 +resnet_v1_200 | 321 | resnet_v1_200/block4 | 1763 | 32 | 881 +resnet_v2_50 | None | resnet_v2_50/block1 | 35 | 8 | None +resnet_v2_50 | None | resnet_v2_50/block2 | 99 | 16 | None +resnet_v2_50 | None | resnet_v2_50/block3 | 291 | 32 | None +resnet_v2_50 | None | resnet_v2_50/block4 | 483 | 32 | None +resnet_v2_50 | 224 | resnet_v2_50/block1 | 35 | 8 | 15 +resnet_v2_50 | 224 | resnet_v2_50/block2 | 99 | 16 | 47 +resnet_v2_50 | 224 | resnet_v2_50/block3 | 291 | 32 | 143 +resnet_v2_50 | 224 | resnet_v2_50/block4 | 483 | 32 | 239 +resnet_v2_50 | 321 | resnet_v2_50/block1 | 35 | 8 | 17 +resnet_v2_50 | 321 | resnet_v2_50/block2 | 99 | 16 | 49 +resnet_v2_50 | 321 | resnet_v2_50/block3 | 291 | 32 | 145 +resnet_v2_50 | 321 | resnet_v2_50/block4 | 483 | 32 | 241 +resnet_v2_101 | None | resnet_v2_101/block1 | 35 | 8 | None +resnet_v2_101 | None | resnet_v2_101/block2 | 99 | 16 | None +resnet_v2_101 | None | resnet_v2_101/block3 | 835 | 32 | None +resnet_v2_101 | None | resnet_v2_101/block4 | 1027 | 32 | None +resnet_v2_101 | 224 | resnet_v2_101/block1 | 35 | 8 | 15 +resnet_v2_101 | 224 | resnet_v2_101/block2 | 99 | 16 | 47 +resnet_v2_101 | 224 | resnet_v2_101/block3 | 835 | 32 | 415 +resnet_v2_101 | 224 | resnet_v2_101/block4 | 1027 | 32 | 511 +resnet_v2_101 | 321 | resnet_v2_101/block1 | 35 | 8 | 17 +resnet_v2_101 | 321 | resnet_v2_101/block2 | 99 | 16 | 49 +resnet_v2_101 | 321 | resnet_v2_101/block3 | 835 | 32 | 417 +resnet_v2_101 | 321 | resnet_v2_101/block4 | 1027 | 32 | 513 +resnet_v2_152 | None | resnet_v2_152/block1 | 35 | 8 | None +resnet_v2_152 | None | resnet_v2_152/block2 | 163 | 16 | None +resnet_v2_152 | None | resnet_v2_152/block3 | 1315 | 32 | None +resnet_v2_152 | None | resnet_v2_152/block4 | 1507 | 32 | None +resnet_v2_152 | 224 | resnet_v2_152/block1 | 35 | 8 | 15 +resnet_v2_152 | 224 | resnet_v2_152/block2 | 163 | 16 | 79 +resnet_v2_152 | 224 | resnet_v2_152/block3 | 1315 | 32 | 655 +resnet_v2_152 | 224 | resnet_v2_152/block4 | 1507 | 32 | 751 +resnet_v2_152 | 321 | resnet_v2_152/block1 | 35 | 8 | 17 +resnet_v2_152 | 321 | resnet_v2_152/block2 | 163 | 16 | 81 +resnet_v2_152 | 321 | resnet_v2_152/block3 | 1315 | 32 | 657 +resnet_v2_152 | 321 | resnet_v2_152/block4 | 1507 | 32 | 753 +resnet_v2_200 | None | resnet_v2_200/block1 | 35 | 8 | None +resnet_v2_200 | None | resnet_v2_200/block2 | 419 | 16 | None +resnet_v2_200 | None | resnet_v2_200/block3 | 1571 | 32 | None +resnet_v2_200 | None | resnet_v2_200/block4 | 1763 | 32 | None +resnet_v2_200 | 224 | resnet_v2_200/block1 | 35 | 8 | 15 +resnet_v2_200 | 224 | resnet_v2_200/block2 | 419 | 16 | 207 +resnet_v2_200 | 224 | resnet_v2_200/block3 | 1571 | 32 | 783 +resnet_v2_200 | 224 | resnet_v2_200/block4 | 1763 | 32 | 879 +resnet_v2_200 | 321 | resnet_v2_200/block1 | 35 | 8 | 17 +resnet_v2_200 | 321 | resnet_v2_200/block2 | 419 | 16 | 209 +resnet_v2_200 | 321 | resnet_v2_200/block3 | 1571 | 32 | 785 +resnet_v2_200 | 321 | resnet_v2_200/block4 | 1763 | 32 | 881 + +## FAQ + +### What does a resolution of 'None' mean? + +In this case, the input resolution is undefined. For most models, the receptive +field parameters can be computed even without knowing the input resolution. + +### For some networks, effective_padding shows as 'None' (eg, for Inception_v2 or Mobilenet_v1 when input size is not specified). Why is that? + +This means that the padding for these networks depends on the input size. So, +unless we know exactly the input image dimensionality to be used, it is not +possible to determine the padding applied at the different layers. Look at the +other entries where the input size is fixed; for those cases, effective_padding +is not None. + +This happens due to Tensorflow's implementation of the 'SAME' padding mode, +which may depend on the input feature map size to a given layer. For background +on this, see [these notes from the TF +documentation](https://www.tensorflow.org/versions/master/api_guides/python/nn#Notes_on_SAME_Convolution_Padding). + +Also, note that in this case the program is not able to check if the network is +aligned (ie, it could be that the different paths from input to output have +receptive fields which are not consistently centered at the same position in the +input image). + +So you should be aware that such networks might not be aligned -- the program +has no way of checking it when the padding cannot be determined. + +### The receptive field parameters for network X seem different from what I expected... maybe your calculation is incorrect? + +First, note that the results presented here are based on the tensorflow +implementations from the [TF-Slim model +library](https://github.com/tensorflow/models/tree/master/research/slim). + +So, it is possible that due to some implementation details the RF parameters are +different. + +One common case of confusion is the TF-Slim Resnet implementation, which applies +stride in the last residual unit of each block, instead of at the input +activations in the first residual unit of each block (which is what is described +in the Resnet paper) -- see [this +comment](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py#L30). +This makes the stride with respect to each convolution block potentially +different. In this case, though, note that a +[flag](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py#L150) +may be used to recover the original striding convention. + +Second, it could be that we have a bug somewhere. While we include [many +tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py) +in our library, it is always possible that we missed something. If you suspect +this is happening, please file a GitHub issue +[here](https://github.com/tensorflow/tensorflow/issues). diff --git a/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py new file mode 100644 index 0000000000000000000000000000000000000000..4495d74bbf66fa461a05f38b430dd404d7da4b08 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple script to convert CSV output from rf_benchmark to Markdown format. + +The input CSV should have the following fields: +- CNN +- input resolution +- end_point +- RF size hor +- RF size ver +- effective stride hor +- effective stride ver +- effective padding hor +- effective padding ver + +Since usually in all cases the parameters in the horizontal and vertical +directions are the same, this is assumed by this script, which only prints one +of them to the Markdown file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import sys + +from tensorflow.python.platform import app + +cmd_args = None + + +def main(unused_argv): + with open(cmd_args.markdown_path, 'w') as f: + # Write table header and field size. + f.write('CNN | resolution | end-point | RF | effective stride | ' + 'effective padding|\n') + f.write( + ':--------------------: | :----------: | :---------------: | :-----: |' + ' :----: | :----:|\n') + with open(cmd_args.csv_path) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + # Make sure horizontal and parameters are the same. + assert row['RF size hor'] == row['RF size ver'] + assert row['effective stride hor'] == row['effective stride ver'] + assert row['effective padding hor'] == row['effective padding ver'] + + f.write('%s|%s|%s|%s|%s|%s\n' % + (row['CNN'], row['input resolution'], row['end_point'], + row['RF size hor'], row['effective stride hor'], + row['effective padding hor'])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--csv_path', + type=str, + default='/tmp/rf.csv', + help='Path where CSV output of rf_benchmark was saved.') + parser.add_argument( + '--markdown_path', + type=str, + default='/tmp/rf.md', + help='Path where Markdown output will be saved.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index bc383a803496380aaba4d0248d2b7f93253b2b50..0e3c46f17d2e2a277418d39e31927db73a509670 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,7 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" ] # Different ways in which padding modes may be spelled. diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD index b3cb04ce26d96333f516f1298c8d5c331964f05b..f9827f766da022b184b3348fc24b1570bac8678f 100644 --- a/tensorflow/contrib/recurrent/BUILD +++ b/tensorflow/contrib/recurrent/BUILD @@ -102,5 +102,8 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["nopip"], + tags = [ + "nopip", + "optonly", + ], ) diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 43c0f7595590802aa80e1012967d377a6ab83d29..4eb5c920b3517a8968ff730003e786ae2a9c9e26 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -193,6 +193,10 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "manual", + "notap", + ], ) cuda_py_tests( diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 67f31785b57fddef67733c18c3b744322532c28c..07227bcb77d353200ee46763d51727ed9c0974a1 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -58,6 +58,7 @@ See @{$python/contrib.rnn} guide. @@Conv3DLSTMCell @@HighwayWrapper @@GLSTMCell +@@SRUCell @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index e512e8db53ed3fc24df6e056d7fac9a6d37cfa50..86f1e27abd53d011f37f06851dd6d0977853c8f4 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import os import numpy as np @@ -30,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -39,6 +41,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -189,6 +192,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(cell.dtype, None) self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error g, out_m = cell(x, m) # Layer infers the input type. self.assertEqual(cell.dtype, dtype.name) @@ -439,6 +443,26 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + @test_util.run_in_graph_and_eager_modes + def testWrapperCheckpointing(self): + for wrapper_type in [ + rnn_cell_impl.DropoutWrapper, + rnn_cell_impl.ResidualWrapper, + lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: + with self.test_session(): + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) + def testOutputProjectionWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -485,6 +509,7 @@ class RNNCellTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) (name, dep), = wrapper_object._checkpoint_dependencies + wrapper_object.get_config() # Should not throw an error self.assertIs(dep, base_cell) self.assertEqual("cell", name) @@ -534,6 +559,7 @@ class RNNCellTest(test.TestCase): wrapped = rnn_cell_impl.GRUCell(3) cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") (name, dep), = cell._checkpoint_dependencies + cell.get_config() # Should not throw an error self.assertIs(dep, wrapped) self.assertEqual("cell", name) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index be99a5d67a3e49b1d522406601d050392f75e963..1c20d88fe4bcbe2c1f1e3413502dbf276f2d21b3 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -921,7 +921,7 @@ class LSTMTest(test.TestCase): # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -997,7 +997,7 @@ class LSTMTest(test.TestCase): self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -1285,7 +1285,7 @@ class LSTMTest(test.TestCase): "Comparing individual variable gradients iteration %d" % i) self.assertAllEqual(a, b) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicEquivalentToStaticRNN(self): self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 184144f64a56358206014a0f75473b4a9b16617a..c7fbeea3105ae4c9c9ec2fd131f3468018990028 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -250,7 +250,7 @@ class BeamSearchDecoder(decoder.Decoder): ``` tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index e69725ff8ab1ba4de880c914a6f5fdad5e54566d..f58268eff525a4b592c79acb32207e1a3f62bdc7 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import six +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -182,19 +183,20 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - def _is_xla_tensor(tensor): - try: - op = tensor.op - except AttributeError: - return False - if control_flow_util.IsInXLAContext(op): - return True - return False - with variable_scope.variable_scope(scope, "decoder") as varscope: - # Properly cache variable values inside the while_loop - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + # Determine context types. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None + in_while_loop = ( + control_flow_util.GetContainingWhileContext(ctxt) is not None) + # Properly cache variable values inside the while_loop. + # Don't set a caching device when running in a loop, since it is possible + # that train steps could be wrapped in a tf.while_loop. In that scenario + # caching prevents forward computations in loop iterations from re-reading + # the updated weights. + if not context.executing_eagerly() and not in_while_loop: + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( @@ -208,9 +210,6 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) - is_xla = False - if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): - is_xla = True if is_xla and maximum_iterations is None: raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index fdecceff526a860a274354e53e824b98d11418a6..6bd58c4d322c04d4d14d04678e24a05c0f876208 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -1,4 +1,4 @@ -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 03d6da7765ba5249a9fb22f56a469cf07c310479..f10d78259a3be3a3a6f7f78c196ab107f18a53aa 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -147,7 +147,7 @@ class SpectralOpsTest(test.TestCase): inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, fft_length=16, frame_step=8) expected_length = (stft.shape[0] - 1) * 8 + 8 - self.assertAllEqual([None], inverse_stft.shape.as_list()) + self.assertAllEqual([256], inverse_stft.shape.as_list()) self.assertAllEqual([expected_length], inverse_stft.eval().shape) def test_stft_and_inverse_stft(self): diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py index 9a3603b6a97ef7c3a4b940b83281ebceda93c9db..7d6289532addfd4b4b867bf64d9113253bd1c76d 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/test_util.py +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -39,6 +39,7 @@ def grappler_optimize(graph, fetches=None, rewriter_config=None): """ if rewriter_config is None: rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 if fetches is not None: for fetch in fetches: graph.add_to_collection('train_op', fetch) diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 94fc12ca814721acf62f16b72ffa50473043cc8b..2c97834523424d0fab56330b4d9355a75427e0ef 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -26,7 +26,6 @@ import time import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib -from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib from tensorflow.core.protobuf import saver_pb2 @@ -34,9 +33,9 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics from tensorflow.python.ops import variables from tensorflow.python.platform import flags from tensorflow.python.platform import gfile @@ -89,8 +88,8 @@ class EvaluationTest(test.TestCase): self._predictions, self._scale = TestModel(self._inputs) def testFinalOpsOnEvaluationLoop(self): - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: @@ -136,9 +135,10 @@ class EvaluationTest(test.TestCase): self.assertTrue(obj.hook_was_run) def _create_names_to_metrics(self, predictions, labels): - accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels) - accuracy1, update_op1 = metric_ops.streaming_accuracy(predictions + 1, - labels) + accuracy0, update_op0 = metrics.accuracy( + labels=labels, predictions=predictions) + accuracy1, update_op1 = metrics.accuracy( + labels=labels, predictions=predictions + 1) names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1} names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1} @@ -198,8 +198,8 @@ class EvaluationTest(test.TestCase): predictions_limited = input.limit_epochs(self._predictions, num_epochs=1) labels_limited = input.limit_epochs(self._labels, num_epochs=1) - value_op, update_op = metric_ops.streaming_accuracy( - predictions_limited, labels_limited) + value_op, update_op = metrics.accuracy( + labels=labels_limited, predictions=predictions_limited) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) @@ -241,7 +241,7 @@ class SingleEvaluationTest(test.TestCase): checkpoint_path = os.path.join(self.get_temp_dir(), 'this_file_doesnt_exist') log_dir = os.path.join(self.get_temp_dir(), 'error_raised') - with self.assertRaises(errors.NotFoundError): + with self.assertRaises(ValueError): evaluation.evaluate_once('', checkpoint_path, log_dir) def _prepareCheckpoint(self, checkpoint_path): @@ -260,8 +260,8 @@ class SingleEvaluationTest(test.TestCase): self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) # Run the evaluation and verify the results: accuracy_value = evaluation.evaluate_once( @@ -276,8 +276,8 @@ class SingleEvaluationTest(test.TestCase): self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 9305c6a11c4ec898c82553773e8e7277a54ab82e..85918bf8506623cf5e0c9106ae9ed80e233f5a7d 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import linalg_ops def conjugate_gradient(operator, diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 30be14c10cd8576ded75b8489cc89d439a9cc282..0b8fc0cdc66ae41807cce92776ada263675b1f94 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,5 +31,8 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], - tags = ["no_windows"], + tags = [ + "no_windows", + "notap", # TODO(b/80546574): test is flaky + ], ) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index f1ef218e74bbd225071324a8269fdfeb5de0e038..3e41e3d0b48ea06f9cb8c1862e27eacb5ebc4417 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -81,6 +81,19 @@ class EagerFileTest(test_util.TensorFlowTestCase): # test here that we're calling them correctly. self.assertTrue(gfile.Exists(logdir)) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerMemory(self): + training_util.get_or_create_global_step() + logdir = self.get_temp_dir() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index e893e1d1c836cc7feef15757dde79d0db362cbaf..d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -21,10 +21,10 @@ import numpy as np from tensorflow.contrib import losses from tensorflow.contrib.learn.python.learn.estimators import prediction_key -from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics from tensorflow.python.ops import nn INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES @@ -38,12 +38,13 @@ def _top_k_generator(k): targets = math_ops.to_int32(targets) if targets.get_shape().ndims > 1: targets = array_ops.squeeze(targets, axis=[1]) - return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k)) + return metrics.mean(nn.in_top_k(probabilities, targets, k)) return _top_k def _accuracy(predictions, targets, weights=None): - return metric_ops.streaming_accuracy(predictions, targets, weights=weights) + return metrics.accuracy( + labels=targets, predictions=predictions, weights=weights) def _r2(probabilities, targets, weights=None): @@ -53,7 +54,7 @@ def _r2(probabilities, targets, weights=None): squares_residuals = math_ops.reduce_sum( math_ops.square(targets - probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) - return metric_ops.streaming_mean(score, weights=weights) + return metrics.mean(score, weights=weights) def _squeeze_and_onehot(targets, depth): @@ -62,7 +63,7 @@ def _squeeze_and_onehot(targets, depth): def _sigmoid_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.sigmoid_cross_entropy(probabilities, _squeeze_and_onehot( targets, @@ -71,7 +72,7 @@ def _sigmoid_entropy(probabilities, targets, weights=None): def _softmax_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.sparse_softmax_cross_entropy(probabilities, math_ops.to_int32(targets)), weights=weights) @@ -82,7 +83,7 @@ def _predictions(predictions, unused_targets, **unused_kwargs): def _class_log_loss(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.log_loss(probabilities, _squeeze_and_onehot(targets, array_ops.shape(probabilities)[1])), @@ -90,34 +91,36 @@ def _class_log_loss(probabilities, targets, weights=None): def _precision(predictions, targets, weights=None): - return metric_ops.streaming_precision(predictions, targets, weights=weights) + return metrics.precision( + labels=targets, predictions=predictions, weights=weights) def _precision_at_thresholds(predictions, targets, weights=None): - return metric_ops.streaming_precision_at_thresholds( - array_ops.slice(predictions, [0, 1], [-1, 1]), - targets, - np.arange( - 0, 1, 0.01, dtype=np.float32), + return metrics.precision_at_thresholds( + labels=targets, + predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), + thresholds=np.arange(0, 1, 0.01, dtype=np.float32), weights=weights) def _recall(predictions, targets, weights=None): - return metric_ops.streaming_recall(predictions, targets, weights=weights) + return metrics.recall( + labels=targets, predictions=predictions, weights=weights) def _recall_at_thresholds(predictions, targets, weights=None): - return metric_ops.streaming_recall_at_thresholds( - array_ops.slice(predictions, [0, 1], [-1, 1]), - targets, - np.arange( - 0, 1, 0.01, dtype=np.float32), + return metrics.recall_at_thresholds( + labels=targets, + predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), + thresholds=np.arange(0, 1, 0.01, dtype=np.float32), weights=weights) def _auc(probs, targets, weights=None): - return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]), - targets, weights=weights) + return metrics.auc( + labels=targets, + predictions=array_ops.slice(probs, [0, 1], [-1, 1]), + weights=weights) _EVAL_METRICS = { diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 7a35a70bbe3112e0649cefd8116cc50565978da5..6f62cd11a9733949c350e35b6b0c436dd097cc33 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -295,7 +295,7 @@ def get_epoch_variable(): # A simple container to hold the training variables for a single tree. -class TreeTrainingVariables(object): +class TreeVariables(object): """Stores tf.Variables for training a single random tree. Uses tf.get_variable to get tree-specific names so that this can be used @@ -303,7 +303,7 @@ class TreeTrainingVariables(object): then relies on restoring that model to evaluate). """ - def __init__(self, params, tree_num, training): + def __init__(self, params, tree_num, training, tree_config='', tree_stat=''): if (not hasattr(params, 'params_proto') or not isinstance(params.params_proto, _params_proto.TensorForestParams)): @@ -315,27 +315,28 @@ class TreeTrainingVariables(object): # TODO(gilberth): Manually shard this to be able to fit it on # multiple machines. self.stats = stats_ops.fertile_stats_variable( - params, '', self.get_tree_name('stats', tree_num)) + params, tree_stat, self.get_tree_name('stats', tree_num)) self.tree = model_ops.tree_variable( - params, '', self.stats, self.get_tree_name('tree', tree_num)) + params, tree_config, self.stats, self.get_tree_name('tree', tree_num)) def get_tree_name(self, name, num): return '{0}-{1}'.format(name, num) -class ForestTrainingVariables(object): +class ForestVariables(object): """A container for a forests training data, consisting of multiple trees. - Instantiates a TreeTrainingVariables object for each tree. We override the + Instantiates a TreeVariables object for each tree. We override the __getitem__ and __setitem__ function so that usage looks like this: - forest_variables = ForestTrainingVariables(params) + forest_variables = ForestVariables(params) ... forest_variables.tree ... """ def __init__(self, params, device_assigner, training=True, - tree_variables_class=TreeTrainingVariables): + tree_variables_class=TreeVariables, + tree_configs=None, tree_stats=None): self.variables = [] # Set up some scalar variables to run through the device assigner, then # we can use those to colocate everything related to a tree. @@ -347,7 +348,13 @@ class ForestTrainingVariables(object): for i in range(params.num_trees): with ops.device(self.device_dummies[i].device): - self.variables.append(tree_variables_class(params, i, training)) + kwargs = {} + if tree_configs is not None: + kwargs.update(dict(tree_config=tree_configs[i])) + if tree_stats is not None: + kwargs.update(dict(tree_stat=tree_stats[i])) + self.variables.append(tree_variables_class( + params, i, training, **kwargs)) def __setitem__(self, t, val): self.variables[t] = val @@ -361,9 +368,11 @@ class RandomForestGraphs(object): def __init__(self, params, + tree_configs=None, + tree_stats=None, device_assigner=None, variables=None, - tree_variables_class=TreeTrainingVariables, + tree_variables_class=TreeVariables, tree_graphs=None, training=True): self.params = params @@ -371,9 +380,10 @@ class RandomForestGraphs(object): device_assigner or framework_variables.VariableDeviceChooser()) logging.info('Constructing forest with params = ') logging.info(self.params.__dict__) - self.variables = variables or ForestTrainingVariables( + self.variables = variables or ForestVariables( self.params, device_assigner=self.device_assigner, training=training, - tree_variables_class=tree_variables_class) + tree_variables_class=tree_variables_class, + tree_configs=tree_configs, tree_stats=tree_stats) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class(self.variables[i], self.params, i) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index bbe627b15773fafe83a0700da696f429876c0968..1c9c81827e0f251c8ae7bc47242334fb202835ac 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from google.protobuf.json_format import ParseDict +from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import resources +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -110,6 +114,47 @@ class TensorForestTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(paths, ops.Tensor)) self.assertTrue(isinstance(var, ops.Tensor)) + def testInfrenceFromRestoredModel(self): + input_data = [[-1., 0.], [-1., 2.], # node 1 + [1., 0.], [1., -2.]] # node 2 + expected_prediction = [[0.0, 1.0], [0.0, 1.0], + [0.0, 1.0], [0.0, 1.0]] + hparams = tensor_forest.ForestHParams( + num_classes=2, + num_features=2, + num_trees=1, + max_nodes=1000, + split_after_samples=25).fill() + tree_weight = {'decisionTree': + {'nodes': + [{'binaryNode': + {'rightChildId': 2, + 'leftChildId': 1, + 'inequalityLeftChildTest': + {'featureId': {'id': '0'}, + 'threshold': {'floatValue': 0}}}}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 1}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 2}]}} + restored_tree_param = ParseDict(tree_weight, + _tree_proto.Model()).SerializeToString() + graph_builder = tensor_forest.RandomForestGraphs(hparams, + [restored_tree_param]) + probs, paths, var = graph_builder.inference_graph(input_data) + self.assertTrue(isinstance(probs, ops.Tensor)) + self.assertTrue(isinstance(paths, ops.Tensor)) + self.assertTrue(isinstance(var, ops.Tensor)) + with self.test_session(): + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + self.assertEquals(probs.eval().shape, (4, 2)) + self.assertEquals(probs.eval().tolist(), expected_prediction) + def testTrainingConstructionClassificationSparse(self): input_data = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]], diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 630c0607ae21d0276a9dd0507346d5dc4ed9f4a9..cfdc884277a025aa11995d329389f3748b17490c 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include + #include "tensorflow/contrib/tensorboard/db/summary_converter.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -66,14 +68,9 @@ const char* kImagePluginName = "images"; const char* kAudioPluginName = "audio"; const char* kHistogramPluginName = "histograms"; -const int kScalarSlots = 10000; -const int kImageSlots = 10; -const int kAudioSlots = 10; -const int kHistogramSlots = 1; -const int kTensorSlots = 10; - const int64 kReserveMinBytes = 32; const double kReserveMultiplier = 1.5; +const int64 kPreallocateRows = 1000; // Flush is a misnomer because what we're actually doing is having lots // of commits inside any SqliteTransaction that writes potentially @@ -139,22 +136,6 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) { } } -int GetSlots(const Tensor& t, const SummaryMetadata& metadata) { - if (metadata.plugin_data().plugin_name() == kScalarPluginName) { - return kScalarSlots; - } else if (metadata.plugin_data().plugin_name() == kImagePluginName) { - return kImageSlots; - } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) { - return kAudioSlots; - } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) { - return kHistogramSlots; - } else if (t.dims() == 0 && t.dtype() != DT_STRING) { - return kScalarSlots; - } else { - return kTensorSlots; - } -} - Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { const char* sql = R"sql( INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) @@ -481,24 +462,6 @@ class RunMetadata { return insert.StepAndReset(); } - Status GetIsWatching(Sqlite* db, bool* is_watching) - SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - mutex_lock lock(mu_); - if (experiment_id_ == kAbsent) { - *is_watching = true; - return Status::OK(); - } - const char* sql = R"sql( - SELECT is_watching FROM Experiments WHERE experiment_id = ? - )sql"; - SqliteStatement stmt; - TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); - stmt.BindInt(1, experiment_id_); - TF_RETURN_IF_ERROR(stmt.StepOnce()); - *is_watching = stmt.ColumnInt(0) != 0; - return Status::OK(); - } - private: Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); @@ -659,43 +622,15 @@ class RunMetadata { /// \brief Tensor writer for a single series, e.g. Tag. /// -/// This class can be used to write an infinite stream of Tensors to the -/// database in a fixed block of contiguous disk space. This is -/// accomplished using Algorithm R reservoir sampling. -/// -/// The reservoir consists of a fixed number of rows, which are inserted -/// using ZEROBLOB upon receiving the first sample, which is used to -/// predict how big the other ones are likely to be. This is done -/// transactionally in a way that tries to be mindful of other processes -/// that might be trying to access the same DB. -/// -/// Once the reservoir fills up, rows are replaced at random, and writes -/// gradually become no-ops. This allows long training to go fast -/// without configuration. The exception is when someone is actually -/// looking at TensorBoard. When that happens, the "keep last" behavior -/// is turned on and Append() will always result in a write. -/// -/// If no one is watching training, this class still holds on to the -/// most recent "dangling" Tensor, so if Finish() is called, the most -/// recent training state can be written to disk. -/// -/// The randomly selected sampling points should be consistent across -/// multiple instances. -/// /// This class is thread safe. class SeriesWriter { public: - SeriesWriter(int64 series, int slots, RunMetadata* meta) - : series_{series}, - slots_{slots}, - meta_{meta}, - rng_{std::mt19937_64::default_seed} { + SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} { DCHECK(series_ > 0); - DCHECK(slots_ > 0); } Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, - Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db) + const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); if (rowids_.empty()) { @@ -705,41 +640,20 @@ class SeriesWriter { return s; } } - DCHECK(rowids_.size() == slots_); - int64 rowid; - size_t i = count_; - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - i = rng_() % (i + 1); - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - bool keep_last; - TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last)); - if (!keep_last) { - ++count_; - dangling_tensor_.reset(new Tensor(std::move(t))); - dangling_step_ = step; - dangling_computed_time_ = computed_time; - return Status::OK(); - } - rowid = last_rowid_; - } - } + int64 rowid = rowids_.front(); Status s = Write(db, rowid, step, computed_time, t); if (s.ok()) { ++count_; - dangling_tensor_.reset(); } + rowids_.pop_front(); return s; } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); - // Short runs: Delete unused pre-allocated Tensors. - if (count_ < rowids_.size()) { + // Delete unused pre-allocated Tensors. + if (!rowids_.empty()) { SqliteTransaction txn(*db); const char* sql = R"sql( DELETE FROM Tensors WHERE rowid = ? @@ -747,19 +661,13 @@ class SeriesWriter { SqliteStatement deleter; TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); for (size_t i = count_; i < rowids_.size(); ++i) { - deleter.BindInt(1, rowids_[i]); + deleter.BindInt(1, rowids_.front()); TF_RETURN_IF_ERROR(deleter.StepAndReset()); + rowids_.pop_front(); } TF_RETURN_IF_ERROR(txn.Commit()); rowids_.clear(); } - // Long runs: Make last sample be the very most recent one. - if (dangling_tensor_) { - DCHECK(last_rowid_ != kAbsent); - TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_, - dangling_computed_time_, *dangling_tensor_)); - dangling_tensor_.reset(); - } return Status::OK(); } @@ -783,7 +691,6 @@ class SeriesWriter { Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, const StringPiece& data, int64 rowid) { - // TODO(jart): How can we ensure reservoir fills on replace? const char* sql = R"sql( UPDATE OR REPLACE Tensors @@ -878,7 +785,7 @@ class SeriesWriter { // TODO(jart): Maybe preallocate index pages by setting step. This // is tricky because UPDATE OR REPLACE can have a side // effect of deleting preallocated rows. - for (int64 i = 0; i < slots_; ++i) { + for (int64 i = 0; i < kPreallocateRows; ++i) { insert.BindInt(1, series_); insert.BindInt(2, reserved_bytes); TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); @@ -902,16 +809,10 @@ class SeriesWriter { mutex mu_; const int64 series_; - const int slots_; RunMetadata* const meta_; - std::mt19937_64 rng_ GUARDED_BY(mu_); uint64 count_ GUARDED_BY(mu_) = 0; - int64 last_rowid_ GUARDED_BY(mu_) = kAbsent; - std::vector rowids_ GUARDED_BY(mu_); + std::deque rowids_ GUARDED_BY(mu_); uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; - std::unique_ptr dangling_tensor_ GUARDED_BY(mu_); - int64 dangling_step_ GUARDED_BY(mu_) = 0; - double dangling_computed_time_ GUARDED_BY(mu_) = 0.0; TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); }; @@ -928,10 +829,10 @@ class RunWriter { explicit RunWriter(RunMetadata* meta) : meta_{meta} {} Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, - double computed_time, Tensor t, int slots) + double computed_time, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - SeriesWriter* writer = GetSeriesWriter(tag_id, slots); - return writer->Append(db, step, now, computed_time, std::move(t)); + SeriesWriter* writer = GetSeriesWriter(tag_id); + return writer->Append(db, step, now, computed_time, t); } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) @@ -948,11 +849,11 @@ class RunWriter { } private: - SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) { + SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) { mutex_lock sl(mu_); auto spot = series_writers_.find(tag_id); if (spot == series_writers_.end()) { - SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_); + SeriesWriter* writer = new SeriesWriter(tag_id, meta_); series_writers_[tag_id].reset(writer); return writer; } else { @@ -1082,8 +983,7 @@ class SummaryDbWriter : public SummaryWriterInterface { TF_RETURN_IF_ERROR( meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); TF_RETURN_WITH_CONTEXT_IF_ERROR( - run_.Append(db_, tag_id, step, now, computed_time, t, - GetSlots(t, metadata)), + run_.Append(db_, tag_id, step, now, computed_time, t), meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), "/", tag, "@", step); return Status::OK(); @@ -1155,8 +1055,7 @@ class SummaryDbWriter : public SummaryWriterInterface { int64 tag_id; TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t, - GetSlots(t, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } // TODO(jart): Refactor Summary -> Tensor logic into separate file. @@ -1169,8 +1068,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kScalarPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kScalarSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { @@ -1201,8 +1099,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kHistogramPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kHistogramSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { @@ -1216,8 +1113,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kImagePluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kImageSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { @@ -1230,8 +1126,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kAudioPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kAudioSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Env* const env_; diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 2044692b6e746bc317843d715fa17ab5ec0bf99d..2e8d4109dd624ab66d774668ad04def9a7d3cdf2 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -189,7 +189,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 user_id = QueryInt("SELECT user_id FROM Users"); int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); @@ -238,7 +238,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); } TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { @@ -255,7 +255,7 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); TF_ASSERT_OK(writer_->Flush()); ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); EXPECT_GT(tag1_id, 0LL); diff --git a/tensorflow/contrib/tensorboard/graph_explorer/proto/graph_explorer.proto b/tensorflow/contrib/tensorboard/graph_explorer/proto/graph_explorer.proto deleted file mode 100644 index 835337ed5c58d0f0595ce8a88f08c8e63a860a36..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorboard/graph_explorer/proto/graph_explorer.proto +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2015 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the 'License'); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an 'AS IS' BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -// GraphExplorer is a tool that supports interactive, hierarchical visualization -// of graphs. GraphExplorer renders graphs generated by TensorFlow represented -// as GraphDef messages defined in tensorflow/core/framework/graph.proto. The -// GraphDef proto does not allow for explicitly specifying visual attributes of -// the graph such as color, line thickness, fonts, etc. This file introduces a -// new proto for representing graphs and specifying visual attributes of graphs. -// -// The structure of the Graph proto is given by the EBNF grammar below. Consult -// the message definitions below for details. -// -// graph ::= node* edge* node_attribute* metanode_attribute* edge_attribute* -// graph_attribute* -// node ::= node_id node_attribute* metanode_attribute* node_data* -// edge ::= source_id target_id edge_attribute* edge_data* -// -// A graph consists of a list of nodes and a list of edges and attributes for -// nodes, edges and the graph. Attributes have a name and a value and are -// represented as key-value pairs, with {"color", "blue"} being an example. -// Attributes have a scope, where the broadest scope is the graph and the -// narrowest is a node that has no internal structure. -syntax = "proto3"; - -package graph_explorer; - -// There are two types of nodes. A 'metanode' contains other -// nodes and a 'leaf node' has no internal structure. The metanode containment -// relationship is acyclic, meaning that if a metanode 'A' contains the metanode -// 'B', then 'B' cannot contain 'A'. -message Node { - // The identifier of a node is a sequence of strings separated by '/'. The - // identifier provides a unique name for a node and defines its hierarchical - // relation to other nodes. If no label is provided the last part of the - // identifier is used as a label. - // - // Example: In the graph below, metanodes are written with square brackets and - // leaf nodes with parentheses. The metanode 'node1' contains the leaf node - // 'node4' and the metanode 'node2', which contains the leaf node 'node3'. - // - // [node1 [node2 (node3)] (node4)] - // - // The identifiers for these nodes are: "node1", "node1/node2", - // "node1/node2/node3", and "node1/node4". - string name = 1; - - // A node attribute is information used by Graph Explorer to style a node. - map node_attr = 2; - - // A metanode attribute is one that is inherited by all nodes inside the - // current metanode. If an attribute applies only to the current node and - // should not be inherited, it should be specified as a node attribute. - map metanode_attr = 3; -}; - -// An edge consists of a source and a target node, specified by their -// identifiers. An edge has attributes and data that are similar to node -// attributes and node data. Edges do not form a hierarchy so there are no -// metanode attributes. -message Edge { - // The source and target fields must have the format of a Node name. - string source = 1; - string target = 2; - - // Edge attributes. - map edge_attr = 3; -} - -message Graph { - // List of nodes in the graph. - repeated Node node = 1; - - // List of edges in the graph. - repeated Edge edge = 2; - - // Default values of node, metanode and edge attributes. - map node_attr = 3; - map metanode_attr = 4; - map edge_attr = 5; - - // Graph attributes. - map graph_attr = 6; -}; diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 7a8a71ac7f491ec48a47ae1ea1aff750a587beaa..adda0b758b172f5e80c165e4b28dbdbecef2ba16 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -49,7 +49,6 @@ tf_cuda_cc_test( tf_custom_op_library( name = "python/ops/_trt_engine_op.so", srcs = [ - "ops/trt_calib_op.cc", "ops/trt_engine_op.cc", ], deps = [ @@ -76,11 +75,9 @@ tf_cuda_library( cc_library( name = "trt_engine_op_kernel", srcs = [ - "kernels/trt_calib_op.cc", "kernels/trt_engine_op.cc", ], hdrs = [ - "kernels/trt_calib_op.h", "kernels/trt_engine_op.h", ], copts = tf_copts(), @@ -89,20 +86,22 @@ cc_library( ":trt_logging", ":trt_plugins", ":trt_resources", + ":trt_conversion", + ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd) + # TODO(laigd): fix this by merging header file in cc file. alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs ) tf_gen_op_libs( op_lib_names = [ "trt_engine_op", - "trt_calib_op", ], ) @@ -122,7 +121,6 @@ tf_gen_op_wrapper_py( name = "trt_engine_op", gen_locally = True, deps = [ - ":trt_calib_op_op_lib", ":trt_engine_op_op_lib", ":trt_logging", ":trt_shape_function", @@ -140,7 +138,6 @@ tf_custom_op_py_library( kernels = [ ":trt_engine_op_kernel", ":trt_engine_op_op_lib", - ":trt_calib_op_op_lib", ":trt_shape_function", ], srcs_version = "PY2AND3", @@ -191,8 +188,7 @@ tf_py_wrap_cc( deps = [ ":trt_conversion", ":trt_engine_op_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -211,6 +207,7 @@ tf_cuda_library( ], deps = [ ":trt_logging", + ":utils", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_proto_parsing", @@ -237,12 +234,12 @@ tf_cuda_library( ":trt_plugins", ":trt_logging", ":trt_resources", + ":utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", "//tensorflow/core:framework_lite", "//tensorflow/core:graph", @@ -303,7 +300,7 @@ tf_cuda_library( ], deps = [ "//tensorflow/core:framework_lite", - "//tensorflow/core:platform_base", + "//tensorflow/core:lib_proto_parsing", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -343,3 +340,8 @@ py_test( "//tensorflow/python:framework_test_lib", ], ) + +cc_library( + name = "utils", + hdrs = ["convert/utils.h"], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b7b26cfb1c05ae74e932c8b9cb2479cfca308514..189944f29b5a0c24f544e0510a6fb19bd5727229 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include #include #include #include @@ -24,10 +24,17 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -39,17 +46,39 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT +#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT +#include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" #include "tensorrt/include/NvInfer.h" - namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// Returns compiled TRT version information {Maj, Min, Patch} +std::vector GetLinkedTensorRTVersion() { + return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; +} + +// Returns loaded TRT library version {Maj, Min, Patch} +std::vector GetLoadedTensorRTVersion() { + int ver = getInferLibVersion(); + int ver_major = ver / 1000; + ver = ver - ver_major * 1000; + int ver_minor = ver / 100; + int ver_patch = ver - ver_minor * 100; + return {ver_major, ver_minor, ver_patch}; +} + namespace { bool IsTensorRTCandidate(const tensorflow::Node* node) { @@ -82,220 +111,6 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } -void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* incoming_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->in_edges()) { - if (!subgraph_node_ids.count(edge->src()->id()) && - !edge->src()->IsSource() && !edge->IsControlEdge()) { - incoming_edges->insert(edge); - } else { - VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, "; - } - } - } -} - -void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* outgoing_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->out_edges()) { - if (!subgraph_node_ids.count(edge->dst()->id()) && - !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, "; - outgoing_edges->insert(edge); - } else { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, "; - } - } - } -} - -std::pair ParseTensorName(const string& name, - int default_idx = 0) { - string name_no_idx = name; - int idx = default_idx; - const size_t sep = name_no_idx.find_last_of(':'); - if (sep != string::npos) { - name_no_idx = name_no_idx.substr(0, sep); - idx = std::stoi(name.substr(sep + 1)); - } - return std::make_pair(name_no_idx, idx); -} - -std::unordered_map> BuildTensorNameMap( - const std::vector& tensor_names) { - std::unordered_map> result; - for (const string& tensor_name : tensor_names) { - string node_name; - int index; - std::tie(node_name, index) = ParseTensorName(tensor_name); - result[node_name].push_back(index); - } - return result; -} - -// TODO(sami): convert references to pointers -struct ConvertGraphParams { - ConvertGraphParams( - tensorflow::Graph& inp_graph, - const std::vector& output_node_names, - const std::set& subgraph_node_id_numbers, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - int engine_precision_mode, const string& device_name, - std::shared_ptr allocator, int cuda_gpu_id) - : graph(inp_graph), - output_names(output_node_names), - subgraph_node_ids(subgraph_node_id_numbers), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - tensorflow::Graph& graph; - const std::vector& output_names; - const std::set& subgraph_node_ids; - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - int precision_mode; - string device_name_; - std::shared_ptr allocator_; - int cuda_gpu_id_; - std::vector> subgraph_inputs; - std::vector> subgraph_outputs; - tensorflow::EdgeSet subgraph_incoming_edges; - tensorflow::EdgeSet subgraph_outgoing_edges; -}; - -static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { - GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_incoming_edges); - for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); - } - auto output_name_to_index_map = BuildTensorNameMap(p->output_names); - std::set> subgraph_outputs_set; - // Collect outputs referenced from output_names - for (int node_id : p->subgraph_node_ids) { - tensorflow::Node* node = p->graph.FindNodeId(node_id); - if (output_name_to_index_map.count(node->name())) { - for (int index : output_name_to_index_map.at(node->name())) { - subgraph_outputs_set.insert({node_id, index}); - } - } - } - GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_outgoing_edges); - for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); - } - p->subgraph_outputs.reserve(subgraph_outputs_set.size()); - p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - subgraph_outputs_set.begin(), - subgraph_outputs_set.end()); - return tensorflow::Status::OK(); -} - -tensorflow::Status GetCalibNode(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(InjectCalibrationNode(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - TF_RETURN_IF_ERROR(status); - - for (auto in_edge : - params->subgraph_incoming_edges) { // loop over incoming edges and - // attach them to calib node - // tensorflow::Node* src_node = in_edge->src(); - auto src_output = in_edge->src_output(); - auto dst_node = in_edge->dst(); - auto dst_input = in_edge->dst_input(); - VLOG(1) << " update edge " << trt_node->name() << ":" << src_output - << " -> " << dst_node->name() << ":" << dst_input; - TF_RETURN_IF_ERROR( - params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input)); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - // AddNode does not wire edges. - // Re-map incoming edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_input_map; - for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { - subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); - } - for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - int new_src_output = subgraph_edge_to_input_map.at(old_src); - params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, - new_src_output); - params->graph.RemoveEdge(edge); - } - - VLOG(2) << "new wiring edges: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); - } - - TF_RETURN_IF_ERROR(status); - - // Re-map outgoing edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_output_map; - for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) { - subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); - } - TF_RETURN_IF_ERROR(status); - for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - int new_src_output = subgraph_edge_to_output_map.at(old_src); - TF_RETURN_IF_ERROR(params->graph.UpdateEdge( - trt_node, new_src_output, edge->dst(), edge->dst_input())); - } - // Remove the original subgraph - for (int node_id : params->subgraph_node_ids) { - tensorflow::Node* node = params->graph.FindNodeId(node_id); - // Don't remove the input placeholders - if (node->type_string() == "Placeholder") { - continue; - } - params->graph.RemoveNode(node); - } - return tensorflow::Status::OK(); -} - tensorflow::Status BuildNodeMap( const tensorflow::Graph& graph, std::unordered_map* node_map) { @@ -309,48 +124,78 @@ tensorflow::Status BuildNodeMap( } } // namespace + +// Function to get calibration from ResourceMgr and put them into nodedef. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) { + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, + bool is_dyn_op) { VLOG(0) << "Starting Calib Conversion"; - tensorflow::Graph graph(tensorflow::OpRegistry::Global()); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), graph_def, &graph)); - // get calib nodes - std::vector calib_nodes; - for (auto node : graph.op_nodes()) { - if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node"; - calib_nodes.push_back(node); - } + infer_graph->CopyFrom(graph_def); + auto trt_rm = TRTResourceManager::instance(); + auto calib_rm = trt_rm->getManager("TRTCalibration"); + int num_nodes = infer_graph->node_size(); + if (!is_dyn_op) { + LOG(WARNING) << "Construction of static int8 engine is not implemented " + "yet!. Dynamic engine will be constructed"; } - VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size(); - if (calib_nodes.size() == 0) - return tensorflow::errors::FailedPrecondition( - "Graph doesn't contain any calibration nodes!." - " Please generate calibration graph and run calibration first"); - for (auto n : calib_nodes) { - TF_RETURN_IF_ERROR( - tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n)); + for (int i = 0; i < num_nodes; ++i) { + auto n = infer_graph->mutable_node(i); + if (n->op() == "TRTEngineOp") { + VLOG(1) << "Processing " << n->name(); + const string& container_name = n->attr().at("segment_funcdef_name").s(); + TRTCalibrationResource* cres = nullptr; + auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); + if (!status.ok()) { + LOG(ERROR) << "Could not get Calibration information. Did you run with " + "calibration data?"; + return tensorflow::errors::FailedPrecondition( + "Need to run graph with calibration data first!"); + } + if (cres->calibrator_) { + cres->calibrator_->waitAndSetDone(); + cres->thr_->join(); + const auto& calibration_table = + cres->calibrator_->getCalibrationTableAsString(); + if (!calibration_table.size()) { + LOG(ERROR) << "Calibration table is empty"; + return tensorflow::errors::Unknown( + "Calibration table is missing. This shouldn't have happened!"); + } + n->mutable_attr()->at("calibration_data").set_s(calibration_table); + } else { + LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; + return tensorflow::errors::Unknown( + "Can't get TRTCalibrator from resource manager!"); + } + cres->Unref(); + calib_rm->Cleanup(container_name); + } } - graph.ToGraphDef(infer_graph); return tensorflow::Status::OK(); } +// Entry function from Python. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = FP32MODE, int minimum_segment_size = 3) { + int precision_mode, int minimum_segment_size, bool is_dyn_op, + int max_cached_engines, std::vector cached_engine_batches) { // optimization pass tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - + // grappler requires a virtual cluster with a proper GPU device + // in order to calculate flops>0 or fails with FATAL + // We add numbers from a Pascal card here to have flops>0 tensorflow::DeviceProperties device_properties; device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); - tensorflow::grappler::Cluster* cluster = - new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); + device_properties.set_num_cores(3584); + device_properties.set_frequency(1531); + std::unique_ptr cluster( + new tensorflow::grappler::VirtualCluster( + {{"/GPU:0", device_properties}})); // single machine int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); @@ -358,134 +203,633 @@ tensorflow::Status ConvertGraphDefToTensorRT( VLOG(2) << "cpu_cores: " << num_cpu_cores; VLOG(2) << "gpus: " << num_gpus; tensorflow::RewriterConfig rw_cfg; + // use only const folding and layout for the time being since new optimizers + // break the graph for us + rw_cfg.add_optimizers("constfold"); + rw_cfg.add_optimizers("layout"); + rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE); tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); tensorflow::GraphDef gdef; - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster, item, &gdef)); + TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef)); item.graph = gdef; // AJ refactoring shape inference through grappler/GraphProperties. tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); // Build full graph - - return ConvertAfterShapes(gdef, output_names, max_batch_size, - max_workspace_size_bytes, new_graph_def, - precision_mode, minimum_segment_size, - static_graph_properties, nullptr); + ConversionParams cp; + cp.input_graph_def = &gdef; + cp.output_names = &output_names; + cp.max_batch_size = max_batch_size; + cp.output_graph_def = new_graph_def; + cp.precision_mode = precision_mode; + cp.is_dyn_op = is_dyn_op; + cp.max_cached_engines = max_cached_engines; + cp.cached_engine_batches = cached_engine_batches; + cp.minimum_segment_size = minimum_segment_size; + cp.graph_properties = &static_graph_properties; + cp.max_workspace_size_bytes = max_workspace_size_bytes; + if (VLOG_IS_ON(5)) { + std::fstream f; + f.open("TRTConversionInput.pb", + std::fstream::out | std::fstream::binary | std::fstream::trunc); + f << gdef.SerializeAsString(); + f.close(); + } + return ConvertAfterShapes(cp); } -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& gdef, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, +// Function to get subsegment information structure. +tensorflow::Status GetEngineInfo( + const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster) { - // Segment the graph into subgraphs that can be converted to TensorRT - tensorflow::tensorrt::segment::SegmentOptions segment_options; + const std::set& segment_nodes, + const std::unordered_map& node_map, + const std::vector& reverse_topo_order, + EngineInfo* info) { + std::vector subgraph_node_ids; + std::set segment_devices; + int input_port = 0; + int output_port = 0; + + // Map from src_node_name+port to the unique port numbers of the TRT op, where + // the src_node_name is the name of the source node of the input/output + // edge, thus there must not be any duplicates since source nodes of + // input/output edges must be in different split of the graph. + // TODO(aaroey): consider using node id and port instead. + std::unordered_map created_edges; + for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); + ++it) { + const auto& node_name = (*it)->name(); + + if (segment_nodes.count(node_name) == 0) continue; + auto node = node_map.at(node_name); + auto node_device = node->requested_device(); + if (!node_device.empty()) { + segment_devices.insert(node_device); + } else { + if (node->has_assigned_device_name()) { + segment_devices.insert(node->assigned_device_name()); + } else { + VLOG(2) << "Node " << node->name() + << " neither have requested device nor assigned device"; + } + } + int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + for (const auto edge : node->in_edges()) { + auto input_node = edge->src(); + if (segment_nodes.count(input_node->name()) == 0) { + // Add constant input node into the segment. We don't care if it has + // other output edges going into other engines or TF nodes. Since we add + // it only to the subsegment node list, not the subsegment itself, it + // won't be removed from the graph. If it doesn't have any edges, TF + // will prune it out. + if (input_node->type_string() == "Const") { + subgraph_node_ids.push_back(input_node->id()); + } else if (!edge->IsControlEdge() && !input_node->IsSource()) { + string s(input_node->name()); + StrAppend(&s, ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + int port = input_port; + if (created_edges.count(s)) { + port = created_edges.at(s); + } else { + created_edges.insert({s, port}); + input_port++; + } + info->connections.emplace_back(input_node->name(), input_node->id(), + edge->src_output(), node_name, node_id, + edge->dst_input(), true, port); + } + } + } + for (const auto edge : node->out_edges()) { + auto output_node = edge->dst(); + if (segment_nodes.count(output_node->name()) == 0 && + !edge->IsControlEdge() && !output_node->IsSink()) { + string s(node_name); + StrAppend(&s, ":", edge->src_output()); + VLOG(1) << "Output edge = " << s; + int port = output_port; + if (created_edges.count(s)) { + port = created_edges.at(s); + } else { + created_edges.insert({s, port}); + output_port++; + } + info->connections.emplace_back(output_node->name(), output_node->id(), + edge->dst_input(), node_name, node_id, + edge->src_output(), false, port); + } + } + } + + TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( + g, graph_properties, subgraph_node_ids, &info->connections, + &info->segment_graph_def, &info->engine_name)); + // TODO(sami): This should not happen once segmenter is updated. + if (segment_devices.size() == 1) { + info->device = *segment_devices.begin(); + } else if (segment_devices.size() > 1) { + LOG(WARNING) << "Detected multiple(" << segment_devices.size() + << ") devices for the segment. Picking first one to continue " + << "but this shouldn't have happened"; + info->device = *segment_devices.begin(); + } else { + VLOG(1) << "Segment devices size is 0"; + } + return Status::OK(); +} + +// Function to insert a TRT node into the graph. The graph is not modified if +// the returned status is not ok. +// 'alloc' is only used for creating static engine. +tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, + const std::vector& infos, int pos, + nvinfer1::IGpuAllocator* alloc, + int max_batch_size) { + const auto& info = infos.at(pos); + std::vector out_shapes; + std::vector input_shapes; + std::vector shapes; + std::vector inputs; + std::vector out_types; + VLOG(1) << "Processing " << info.engine_name; + + // Update the shape and data types of input/output nodes, and find all unique + // inputs. + for (const auto& conn : info.connections) { + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (out_shapes.size() <= conn.port_number) { + out_shapes.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + out_shapes.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + continue; + } + + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shapes.size() <= conn.port_number) { + input_shapes.resize(conn.port_number + 1); + shapes.resize(conn.port_number + 1); + } + input_shapes.at(conn.port_number) = in_shape; + shapes.at(conn.port_number) = conn.outside_shape; + + string input_node = conn.outside_node_name; + int input_port = conn.outside_port; + bool found_engine = false; + // Rewire the inputs to other engines if they contain original input node. + // Note that we use the information of the engine here, not the information + // of the created TRT nodes, so we're able to find all the connections to + // any other engines beforehand. + for (size_t t = 0; t < infos.size(); ++t) { + if (t == pos) continue; + auto& engine_info = infos.at(t); + for (const auto& eng_conn : engine_info.connections) { + if (eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == input_node) { + input_node = engine_info.engine_name; + if (eng_conn.inside_port == input_port) { + input_port = eng_conn.port_number; + found_engine = true; + break; + } + } + } + if (found_engine) break; + } + VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " + << info.engine_name << ":" << inputs.size(); + // Skip duplicate inputs. + bool new_input = true; + for (const auto& inp : inputs) { + if (inp.node == input_node && inp.index == input_port) { + new_input = false; + break; + } + } + if (new_input) { + inputs.emplace_back(input_node, input_port, conn.connection_type); + } + } + + // Build the engine and get its serialized representation. + string segment_string; + if (info.engine_type == EngineInfo::EngineType::TRTStatic || + info.precision_mode == INT8MODE) { + // Create static engine for fp32/fp16 mode, and test validity of the engine + // for int8 mode. We don't want engine to fail at the calibration time. + // So we are constructing a FP32 engine here to check its validity, and if + // it is a valid engine then we put the serialized graphdef to the op. + // Otherwise we skip node creation for this engine. + Logger trt_logger; + TrtUniquePtrType engine; + // TODO(sami): What happens if 1st dim is not batch? + TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( + info.segment_graph_def, + info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, + max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger, + alloc, /*calibrator=*/nullptr, &engine, + /*convert_successfully=*/nullptr)); + TrtUniquePtrType engine_data(engine->serialize()); + segment_string = + string((const char*)engine_data->data(), engine_data->size()); + if (info.precision_mode == INT8MODE) { + // See above comment about why not putting this inside the 'else' branch. + segment_string = info.segment_graph_def.SerializeAsString(); + } + } else { + segment_string = info.segment_graph_def.SerializeAsString(); + } + + // TODO(aaroey): use enum instead, and add a helper method to do the + // conversion. + string prec_string; + switch (info.precision_mode) { + case FP32MODE: + prec_string = "FP32"; + break; + case FP16MODE: + prec_string = "FP16"; + break; + case INT8MODE: + prec_string = "INT8"; + if (!TRTResourceManager::instance()->getManager("TRTCalibration")) { + LOG(ERROR) << "Failed to construct calibration storage"; + } + break; + default: + return tensorflow::errors::OutOfRange("Unknown precision mode"); + } + tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); + if (!info.device.empty()) node_builder.Device(info.device); + if (VLOG_IS_ON(1)) { + string ins = StrCat(info.engine_name, " inputs= "); + for (const auto& ii : inputs) { + StrAppend(&ins, ii.node, ":", ii.index, " "); + } + VLOG(1) << ins; + } + node_builder.Input(inputs); + if (info.engine_type == EngineInfo::EngineType::TRTStatic && + info.cached_engine_batches.size()) { + LOG(WARNING) << "Cached engine batches are ignored for static engines"; + } + tensorflow::NodeDef trt_node; + tensorflow::Status status = + node_builder.Attr("input_shapes", input_shapes) + .Attr("output_shapes", out_shapes) + .Attr("static_engine", + info.engine_type == EngineInfo::EngineType::TRTStatic) + .Attr("segment_funcdef_name", + StrCat(info.engine_name, "_native_segment")) + .Attr("serialized_segment", segment_string) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", info.maximum_cached_engines) + .Attr("cached_engine_batches", {max_batch_size}) + .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("precision_mode", prec_string) + .Attr("OutT", out_types) + .Finalize(&trt_node); + if (!status.ok()) { + LOG(ERROR) << "Node construction failed with" << status; + return status; + } + VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph"; + + // Up until this point, graph is not modified. If we return !status.ok() from + // here, this segment will be skipped + tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + if (!status.ok()) { + LOG(ERROR) << "Adding node failed " << status; + return status; + } + // Updates the inputs of output edges destination nodes, and point them to the + // engine node. + for (auto& conn : info.connections) { + if (conn.is_input_edge) continue; + VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " + << conn.port_number << " out_id " << conn.outside_id + << " name=" << conn.outside_node_name; + auto dst_node = graph->FindNodeId(conn.outside_id); + // dst_node can only be removed if it is an input node of another engine. + // In this case, other engines input edge is updated in nodedef to point to + // this engine. Even though edge doesn't exists in the graph, when it is + // deserialized again, correct edges will be constructed. This is a problem + // of graph->AddNode(). + if (!dst_node) continue; + VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number + << " to " << dst_node->name() << ":" << conn.outside_port; + auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, + conn.outside_port); + CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" + << conn.port_number << " -> " << dst_node->name() << ":" + << conn.outside_port; + } + return status; +} + +// Function to construct a funcdef from the segment and add it to the graph. +tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( + tensorflow::Graph* graph, const tensorflow::GraphDef& segment, + const string& name) { + tensorflow::Graph sgraph(graph->flib_def()); + tensorflow::GraphConstructorOptions gcopts; + TF_RETURN_IF_ERROR( + tensorflow::ConvertGraphDefToGraph(gcopts, segment, &sgraph)); + std::map io_nodes; + int num_inputs = 0; + for (auto n : sgraph.op_nodes()) { + if (tensorflow::str_util::StartsWith(n->name(), kInputPHName)) { + num_inputs++; + io_nodes.insert({n->name(), n}); + } else if (tensorflow::str_util::StartsWith(n->name(), kOutputPHName)) { + io_nodes.insert({n->name(), n}); + } + } + + for (int i = 0; i < num_inputs; ++i) { + auto name = StrCat(kInputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Arg"), tensorflow::FunctionLibraryDefinition::kArgOp); + VLOG(1) << "Adding " << StrCat(name, "_Arg"); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + tensorflow::Status s; + auto node_arg = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Arg node for " << name; + } + for (auto edge : node->out_edges()) { + sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input()); + VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0 + << " - > " << edge->dst()->name() << ":" << edge->dst_input(); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << node_arg->name() + << " to " << edge->dst()->name() << ":" << edge->dst_input(); + } + } + sgraph.RemoveNode(node); + } + + for (int i = 0; i < io_nodes.size() - num_inputs; ++i) { + auto name = StrCat(kOutputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Ret"), tensorflow::FunctionLibraryDefinition::kRetOp); + auto edge = *(node->in_edges().begin()); + tensorflow::NodeDefBuilder::NodeOut nout( + edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())); + VLOG(1) << " input " << nout.node << ":" << nout.index + << " dtype=" << tensorflow::DataTypeString(nout.data_type); + node_builder.Input({nout}); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + if (VLOG_IS_ON(3)) { + VLOG(3) << nd.DebugString(); + } + tensorflow::Status s; + auto node_ret = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Ret node for " << name; + } + VLOG(1) << "Update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" << 0; + sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0); + s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" + << 0; + } + sgraph.RemoveNode(node); + } + tensorflow::FunctionDefLibrary fdeflib; + auto native_segment = fdeflib.add_function(); + TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( + sgraph, StrCat(name, "_native_segment"), native_segment)); + if (VLOG_IS_ON(7)) { + VLOG(7) << name << " Function_Def "; + VLOG(7) << native_segment->DebugString(); + } + VLOG(1) << "Adding funcdef to graphlib"; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); + return tensorflow::Status::OK(); +} + +std::pair GetDeviceAndAllocator( + ConversionParams& params, EngineInfo& engine) { + int cuda_device_id = -1; + auto check_device_id = [](int tfid) -> int { + tensorflow::TfGpuId tf_gpu_id(tfid); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (s.ok()) { + VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " + << cuda_gpu_id.value(); + return cuda_gpu_id.value(); + } + VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s; + return -1; + }; + tensorflow::Allocator* dev_allocator = nullptr; + // we need to us PM here since in python path there is no way to get + // to allocators. + // TODO(sami): when grappler devices become available else path will not be + // necessary + auto pm = tensorflow::GPUProcessState::singleton(); + if (params.cluster) { // get allocator + tensorflow::Device* device = nullptr; + if (params.cluster->GetDeviceSet()) { + device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device); + } + if (device) { + tensorflow::AllocatorAttributes alloc_attr; + dev_allocator = device->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name(); + } else { + LOG(WARNING) << "Cluster is set but device '" << engine.device + << "' is not found in the cluster"; + } + } else { // cluster not found, possibly a python call + VLOG(1) << "Cluster is not set, probably called from python"; + int found_device = 0; + bool try_gpu_ids = true; + // if device is set, try to find the device. Might be a problem for multi + // host case but TensorRT do not support multi host setups yet. + if (!engine.device.empty()) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) { + cuda_device_id = parsed_name.has_id ? parsed_name.id : -1; + } + try_gpu_ids = !parsed_name.has_id; + } + if (try_gpu_ids) { + while (found_device < 100) { + cuda_device_id = check_device_id(found_device); + if (cuda_device_id >= 0) break; + found_device++; + } + } + if (found_device == 100) { + LOG(ERROR) << " Can't find a GPU device to work with. Please " + "instantiate a session to initialize devices"; + return std::make_pair(cuda_device_id, dev_allocator); + } + LOG(WARNING) + << "Can't determine the device, constructing an allocator at device " + << found_device; + tensorflow::GPUOptions gpuoptions; + // this will be a noop if device is already initialized + gpuoptions.set_allow_growth(true); + tensorflow::TfGpuId tf_gpu_id(found_device); + dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); + } + return std::make_pair(cuda_device_id, dev_allocator); +} + +// Entry function from optimization pass. +tensorflow::Status ConvertAfterShapes(ConversionParams& params) { + // Convert graphdef to graph. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - gdef.library()); + params.input_graph_def->library()); tensorflow::Graph graph(flib); TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), gdef, &graph)); + tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph)); + // Segment the graph into subgraphs that can be converted to TensorRT + tensorflow::tensorrt::segment::SegmentOptions segment_options; // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) - for (auto node : output_names) { + for (auto node : *(params.output_names)) { segment_options.exclude_node_list.insert(node); } - - // TODO(sami): this should be passed as a knob!!!! - segment_options.minimum_segment_size = minimum_segment_size; - tensorflow::tensorrt::segment::SegmentNodesVector segments; + segment_options.minimum_segment_size = params.minimum_segment_size; + tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, segment_options, &segments)); - if (segments.size() > 1) { - VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); + &graph, IsTensorRTCandidate, segment_options, &initial_segments)); + if (initial_segments.size() > 1) { + VLOG(0) << "MULTIPLE tensorrt candidate conversion: " + << initial_segments.size(); } + + // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); - std::unordered_map> output_edge_map; - int count = 0; float total_num_nodes_in_segments = 0.; - for (auto s : segments) { - total_num_nodes_in_segments += s.first.size(); - } - // We create the map here since cluster may not be available in all cases. - std::map name_to_device_map; - if (cluster) { - // TODO(aaroey): consider using DeviceSet::FindDeviceByName(), as in a - // distributed environment, devices from different workers can have same - // short name. - for (const auto dm : cluster->GetDeviceSet()->devices()) { - name_to_device_map[dm->name()] = dm; + std::vector engine_segments; + engine_segments.reserve(initial_segments.size()); + std::vector reverse_topo_order; + tensorflow::GetPostOrder(graph, &reverse_topo_order); + size_t total_engine_bytes_size = 0; + std::vector engine_bytes_size; + tensorflow::tensorrt::segment::SegmentNodesVector converted_segments; + converted_segments.reserve(initial_segments.size()); + for (size_t t = 0; t < initial_segments.size(); t++) { + auto& curr_segment = initial_segments.at(t); + EngineInfo curr_engine; + Status status = + GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, + node_map, reverse_topo_order, &curr_engine); + if (!status.ok()) { + LOG(WARNING) << "Failed to get engine info for segment " << t << ": " + << status; + continue; } - } - for (const auto& segment_nodes_and_device : segments) { - const std::set& subgraph_node_names = - segment_nodes_and_device.first; - std::set subgraph_node_ids; - size_t max_mem_per_engine = - max_workspace_size_bytes * - ((float)subgraph_node_names.size() / total_num_nodes_in_segments); - std::stringstream oss; - for (const string& node_name : subgraph_node_names) { - oss << " " << node_name; - subgraph_node_ids.insert(node_map.at(node_name)->id()); + curr_engine.precision_mode = params.precision_mode; + curr_engine.engine_type = + (params.is_dyn_op || params.precision_mode == INT8MODE + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); + curr_engine.cached_engine_batches = params.cached_engine_batches; + curr_engine.maximum_cached_engines = params.max_cached_engines; + StrAppend(&curr_engine.engine_name, "my_trt_op_", t); + status = RegisterSegmentFunctionToFunctionLibrary( + &graph, curr_engine.segment_graph_def, curr_engine.engine_name); + if (!status.ok()) { + LOG(WARNING) << "Failed to register segment graphdef as a function " << t + << ": " << status; + continue; } - VLOG(1) << "Subgraph nodes at device " << segment_nodes_and_device.second - << " : " << oss.str(); - auto target_device = - name_to_device_map.find(segment_nodes_and_device.second); - std::shared_ptr allocator(0); + engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); + total_engine_bytes_size += engine_bytes_size.back(); + total_num_nodes_in_segments += curr_segment.first.size(); + engine_segments.push_back(std::move(curr_engine)); + converted_segments.push_back(std::move(curr_segment)); + + if (VLOG_IS_ON(8)) { + string fname = curr_engine.engine_name; + StrAppend(&fname, ".pb"); + std::fstream f; + f.open(fname.c_str(), std::fstream::out | std::fstream::binary); + f << engine_segments.at(t).segment_graph_def.SerializeAsString(); + f.close(); + } + } + + // Create a TRT node for each segment using its EngineInfo. + int old_cuda_device = 0; + auto err = cudaGetDevice(&old_cuda_device); + if (err != cudaSuccess) { + LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); + } + VLOG(1) << "Current cuda device is " << old_cuda_device; + for (int i = 0; i < engine_segments.size(); ++i) { + auto& engine = engine_segments.at(i); + // Partition the workspace size by the average of node ratio and segment + // graphdef size + engine.max_workspace_size_bytes = + params.max_workspace_size_bytes * + (engine_bytes_size.at(i) / total_engine_bytes_size + + converted_segments.at(i).first.size() / total_num_nodes_in_segments) / + 2.0; + // The allocator is used to build the engine. The build and the built engine + // will be destroyed after we get the serialized engine string, so it's fine + // to use unique_ptr here. + std::unique_ptr alloc; + auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; - if (target_device != name_to_device_map.end()) { - tensorflow::TfGpuId tf_gpu_id(target_device->second->parsed_name().id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); - if (!s.ok()) { - LOG(ERROR) - << "Cuda device identification failed, using device 0. Error= " - << s; - } else { - cuda_device_id = cuda_gpu_id.value(); - } - tensorflow::GPUOptions gpuoptions; - // we need to us PM here since in python path there is no way to get to - // allocators - auto pm = tensorflow::ProcessState::singleton(); - // this should be instantiated by now - auto dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); - VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value() - << " cuda device= " << cuda_device_id << " at " << dev_allocator; - allocator = std::make_shared(dev_allocator); - } else { // device unknown or not available - allocator = std::make_shared(); + if (device_alloc.first >= 0) { + cuda_device_id = device_alloc.first; + alloc.reset(new TRTDeviceAllocator(device_alloc.second)); + } else { + // Setting allocator as nullptr should get revert to the cudamalloc + LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } - ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size, - max_mem_per_engine, graph_properties, &output_edge_map, - precision_mode, segment_nodes_and_device.second, - allocator, cuda_device_id); - if (precision_mode == INT8MODE) { - tensorflow::Status status = GetCalibNode(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; + cudaSetDevice(cuda_device_id); + auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), + params.max_batch_size); + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (auto node_name : converted_segments.at(i).first) { + graph.RemoveNode(node_map.at(node_name)); } } else { - tensorflow::Status status = ConvertSubGraphToTensorRT(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; - } + // Graph is not modified. + LOG(WARNING) << "Engine creation for segment " << i << ", composed of " + << converted_segments.at(i).first.size() + << " nodes failed: " << status << ". Skipping..."; } - count++; } - graph.ToGraphDef(new_graph_def); + cudaSetDevice(old_cuda_device); + graph.ToGraphDef(params.output_graph_def); + VLOG(1) << "Returning from conversion"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 65a67d7e73e32f904bd636a4f4aaefe32b0c092d..9d986e489043c0a0e16e379166aa2e8f7ac0b11f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -30,29 +30,60 @@ namespace tensorflow { namespace tensorrt { namespace convert { -// This method converts an already generated calibration graph which was used in -// calibration runs to an inference graph +struct ConversionParams { + ConversionParams() + : input_graph_def(nullptr), + max_batch_size(1), + max_workspace_size_bytes(1 << 30), + output_graph_def(nullptr), + precision_mode(1), + minimum_segment_size(3), + graph_properties(nullptr), + cluster(nullptr), + is_dyn_op(false), + fixed_input_size(true), + max_cached_engines(1) {} + const tensorflow::GraphDef* input_graph_def; + const std::vector* output_names; + size_t max_batch_size; + size_t max_workspace_size_bytes; + tensorflow::GraphDef* output_graph_def; + int precision_mode; + int minimum_segment_size; + const tensorflow::grappler::GraphProperties* graph_properties; + const tensorflow::grappler::Cluster* cluster; + bool is_dyn_op; // Whether to create engine on conversion or execution time + bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed + int max_cached_engines; // maximum number of cached engines + std::vector cached_engine_batches; // list of cached engines +}; + +// This method extracts calibration information from the resource managers +// and puts them in to engine nodedefs. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def); + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, + bool is_dyn_op); -// max_batch_size: maximum batch size which can be used for inference for -// optimization targets inference run with max batch size. -// max_workspace_size_bytes: The upper bound of memory allowance for -// engine building. +// - max_batch_size: maximum batch size which can be used for inference for +// optimization targets inference run with max batch size. +// - max_workspace_size_bytes: The upper bound of memory allowance for engine +// building. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size); + int precision_mode = 1, int minimum_segment_size = 3, + bool is_dyn_op = false, int max_cached_engines = 1, + std::vector cached_engine_batches = {}); // Method to call from optimization pass -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& graph, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, - const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster); +tensorflow::Status ConvertAfterShapes(ConversionParams& params); + +// Return compile time TensorRT library version information. +std::vector GetLinkedTensorRTVersion(); + +// Return runtime time TensorRT library version information. +std::vector GetLoadedTensorRTVersion(); } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 32b211dcd1e282d334327b83a27f9401de7f310a..146b9c7344b0a9c2b3ec87b395e9b1096dbef06c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -25,7 +24,9 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -54,8 +56,11 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::str_util::Split; + using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; + namespace { inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, @@ -121,12 +126,10 @@ static std::vector> CreateSamePadding( string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { size_t last_scope_separator = 0; - for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) { - if (op_name_a[i] != op_name_b[i]) { - break; - } else if (op_name_a[i] == '/') { - last_scope_separator = i + 1; - } + const size_t min_size = std::min(op_name_a.size(), op_name_b.size()); + for (size_t i = 0; i < min_size; ++i) { + if (op_name_a[i] != op_name_b[i]) break; + if (op_name_a[i] == '/') last_scope_separator = i + 1; } return op_name_a.substr(0, last_scope_separator); } @@ -362,10 +365,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, break; } case tensorflow::DataType::DT_HALF: { - Reorder2({k, c}, static_cast(iweights.GetValues()), - istrides, static_cast( - const_cast(oweights->GetValues())), - ostrides); + Reorder2( + {k, c}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); break; } default: @@ -416,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } -struct InferDeleter { - template - void operator()(T* obj) const { - if (obj) { - obj->destroy(); - } - } -}; - -template -inline std::shared_ptr infer_object(T* obj) { - return std::shared_ptr(obj, InferDeleter()); -} - class Converter; using OpConverter = @@ -443,7 +433,7 @@ class Converter { OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; - tensorflow::tensorrt::TRTWeightStore* weight_store_; + TRTWeightStore* weight_store_; bool fp16_; void register_op_converters(); tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def, @@ -485,11 +475,11 @@ class Converter { public: explicit Converter(nvinfer1::INetworkDefinition* trt_network, - tensorflow::tensorrt::TRTWeightStore* ws, bool fp16) + TRTWeightStore* ws, bool fp16) : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) { this->register_op_converters(); } - tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; } + TRTWeightStore* weight_store() { return weight_store_; } TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, nvinfer1::Dims shape) { TRT_ShapedWeights weights(type, nullptr, shape); @@ -1179,9 +1169,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + - " not supported at: " + - node_def.name()); + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -2138,514 +2128,266 @@ void Converter::register_op_converters() { } } // namespace -tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) { - return tensorflow::errors::Unimplemented("Not implemented yet"); -} -tensorflow::Status ConvertCalibrationNodeToEngineNode( - tensorflow::Graph& graph, tensorflow::Node* c_node) { - const auto ndef = c_node->def(); - - TFAttrs attrs(ndef); - std::vector segment_nodes( - attrs.get>("segment_nodes")); - std::vector output_nodes( - attrs.get>("segment_output_names")); - std::vector input_names( - attrs.get>("input_names")); - string res_name = attrs.get("resource_name"); - VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name; - string engine_name = "my_trt_op"; - { - const auto node_id = tensorflow::str_util::Split(res_name, "_"); - engine_name += node_id.back(); - } - std::map node_maps; - - for (auto n : graph.op_nodes()) { - node_maps.insert({n->name(), n}); - } - VLOG(1) << "Output Nodes:"; - std::vector out_types; - std::vector out_edges; - for (auto& i : output_nodes) { - auto node_port = tensorflow::str_util::Split(i, ":"); - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto out_node_name = node_port.at(0); - if (node_port.size() > 1) { - VLOG(1) << "Multi port output" << node_port.at(0) << " " - << node_port.at(1) << " size=" << node_port.size(); - } - auto node_it = node_maps.find(out_node_name); - if (node_it != node_maps.end()) { - tensorflow::Node* out_node = node_it->second; - int port = 0; - if (node_port.size() == 2) { - port = std::strtoul(node_port.at(1).c_str(), nullptr, 10); - out_types.push_back(out_node->output_type(port)); - } else { - out_types.push_back(out_node->output_type(0)); - } - for (auto out_edge : out_node->out_edges()) { - if (out_edge->src_output() == port) { - out_edges.push_back(out_edge); - break; - } - } - } else { - LOG(WARNING) << " couldn't find output node " << out_node_name; - } - } - VLOG(1) << "Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - } - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto resmgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = resmgr->Lookup(res_name, res_name, &calib_res); - if (!status.ok() || !calib_res->calibrator_) { - return tensorflow::errors::FailedPrecondition( - "You must run calibration" - " and inference conversion in the same process"); - } - - calib_res->calibrator_->setDone(); - calib_res->thr_->join(); - delete calib_res->thr_; - if (!calib_res->engine_) { - LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run " - "calibration graph?"; - return tensorflow::errors::FailedPrecondition( - "Calibration graph needs to be executed on" - " calibration data before convertsion to inference graph"); - } - auto weight_rmgr = trt_rm->getManager("WeightStore"); - TF_CHECK_OK(weight_rmgr->Delete( - res_name, res_name)); - auto engine_plan = calib_res->engine_->serialize(); - calib_res->engine_->destroy(); - calib_res->network_->destroy(); - calib_res->builder_->destroy(); - calib_res->thr_ = nullptr; - calib_res->engine_ = nullptr; - calib_res->builder_ = nullptr; - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - std::vector income_edges; - for (const auto in_edge : c_node->in_edges()) { - auto src = in_edge->src(); - int dest_port = in_edge->dst_input(); - income_edges.emplace_back(src->name(), in_edge->src_output(), - c_node->input_type(dest_port)); - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - op_builder.Input(input_list); - tensorflow::NodeDef engine_node; - const char* engine_plan_data = static_cast(engine_plan->data()); - string engine_plan_string(engine_plan_data, - engine_plan_data + engine_plan->size()); - status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_nodes) - .Attr("OutT", out_types) - .Finalize(&engine_node); - if (!status.ok()) { - LOG(ERROR) << "Engine Node creation failed"; - return status; - } - auto trt_engine_node = graph.AddNode(engine_node, &status); - TF_RETURN_IF_ERROR(status); - for (size_t i = 0; i < out_edges.size(); i++) { - VLOG(1) << "Connecting trt_engine_node output " << i << " with " - << out_edges.at(i)->dst()->name() << " port " - << out_edges.at(i)->dst_input(); - TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i, - out_edges.at(i)->dst(), - out_edges.at(i)->dst_input())); - } - VLOG(1) << "Segment nodes:"; - for (auto& i : segment_nodes) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto it = node_maps.find(i); - if (it != node_maps.end()) { - graph.RemoveNode(it->second); - } - } - graph.RemoveNode(c_node); - return tensorflow::Status::OK(); -} -tensorflow::Status ReverseTopologicalSort( - const tensorrt::convert::SubGraphParams& s, - std::list* order) { - std::vector order_vec; - tensorflow::GetPostOrder(s.graph, &order_vec); - // Select just the subgraph - for (tensorflow::Node* node : order_vec) { - if (s.subgraph_node_ids.count(node->id())) { - // We want topological order to contstruct the - // network layer by layer - order->push_front(node); - } +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully) { + engine->reset(); + if (convert_successfully) *convert_successfully = false; + + // Create the builder. + TrtUniquePtrType builder( + nvinfer1::createInferBuilder(*logger)); + builder->setMaxBatchSize(max_batch_size); + // TODO(aaroey): use the allocator to allocate the TRT workspace. + builder->setMaxWorkspaceSize(max_workspace_size_bytes); +#if NV_TENSORRT_MAJOR > 3 + builder->setGpuAllocator(allocator); +#endif + if (precision_mode == FP16MODE) { + builder->setHalf2Mode(true); + } else if (precision_mode == INT8MODE) { + builder->setInt8Mode(true); + builder->setInt8Calibrator(calibrator); } - return tensorflow::Status::OK(); -} -tensorflow::Status SetInputList( - const tensorrt::convert::SubGraphParams& s, - tensorflow::NodeDefBuilder* op_builder, - const std::vector* input_names, - std::vector* input_dtypes) { - std::vector income_edges; - VLOG(2) << "input edge size: " << input_names->size(); - for (size_t i = 0; i < input_names->size(); ++i) { - VLOG(2) << "input edges: " << i << " " << input_names->at(i); - int output_idx = s.input_inds.at(i).second; - // we wired up the input here already, it is redundant to do it again in - // ConvertSubGraphToTensorRT(convert_graph.cc) - auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( - input_names->at(i), output_idx, input_dtypes->at(i)); - income_edges.push_back(incoming_edge); - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - op_builder->Input(input_list); - return tensorflow::Status::OK(); -} - -string SubgraphNameScopeGenerator(const std::list* order) { - string subgraph_name_scope; - if (!order->empty()) { - subgraph_name_scope = order->front()->name(); - } - for (const tensorflow::Node* node : *order) { - subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name()); + // Create the network. + auto trt_network = + TrtUniquePtrType(builder->createNetwork()); + if (!trt_network) { + return tensorflow::errors::Internal( + "Failed to create TensorRT network object"); } - // TODO(sami,ben,jie): proper naming! - return subgraph_name_scope; -} - -tensorflow::Status ConvertSubgraph( - Converter& converter, tensorrt::convert::SubGraphParams& s, - std::list* order, std::vector* input_names, - std::vector* input_dtypes, - std::vector* output_names, - std::vector* output_dtypes, - const string& engine_name) { - for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first; - int node_id = input.first; - int output_idx = input.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - auto node_name = node->name(); - // input_names should use the node name in the graph - // here it should be the input tensor name -> matching the binding - // insert original node name without port - auto tensor_name = node_name; - if (output_idx != 0) { - tensor_name = StrCat(tensor_name, ":", output_idx); - } - - VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name - << " idx: " << output_idx; - - auto shape_inference_node_name = node_name; - auto shape_inference_output_idx = output_idx; - // rewire the shape inference to original node in the graph - if (s.output_edge_map->count(tensor_name)) { - shape_inference_node_name = s.output_edge_map->at(tensor_name).second; - shape_inference_output_idx = s.output_edge_map->at(tensor_name).first; - } - if (shape_inference_output_idx < 0) continue; - VLOG(2) << "shapeinference name: " << shape_inference_node_name - << " idx: " << shape_inference_output_idx; - - if (!s.graph_properties.HasOutputProperties(shape_inference_node_name)) - return tensorflow::errors::Internal("failed to find input node: " + - shape_inference_node_name); - - auto op_info_vec = - s.graph_properties.GetOutputProperties(shape_inference_node_name); - if (static_cast(op_info_vec.size()) <= shape_inference_output_idx) - return tensorflow::errors::Internal( - "accessing output index of: ", shape_inference_output_idx, - ", at node: ", shape_inference_node_name, - " with output entry from shape_map: ", op_info_vec.size()); - - auto op_info = op_info_vec.at(shape_inference_output_idx); - tensorflow::DataType tf_dtype = op_info.dtype(); - input_dtypes->push_back(tf_dtype); - - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = ConvertDType(tf_dtype, &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } + auto ws = std::unique_ptr(new TRTWeightStore()); - VLOG(2) << "Accessing output index of: " << output_idx - << ", at node: " << node_name - << " with output entry from shape_map: " << op_info_vec.size(); - // TODO(ben,jie): update TRT input format/dimension - nvinfer1::DimsCHW input_dim_pseudo_chw; - for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1; - - // TODO(jie): TRT 3.x only support 4 dimensional input tensor. - // update the code once TRT 4.0 comes out. - if (op_info.shape().dim_size() != 4) { - string err_str = "Require 4 dimensional input."; - StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ", - shape_inference_node_name); - return tensorflow::errors::Unimplemented(err_str); - } - - for (int i = 1; i < op_info.shape().dim_size(); i++) { - VLOG(2) << "dimension: " << i - << " , size: " << op_info.shape().dim(i).size(); - input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size(); - } + // Build the network + VLOG(1) << "Starting engine conversion "; + Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE); + std::vector> output_tensors; + // Graph nodes are already topologically sorted during construction + for (const auto& node_def : gdef.node()) { + string node_name = node_def.name(); + VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); + if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && + (node_def.op() == "Placeholder")) { + nvinfer1::DimsCHW input_dim_pseudo_chw; + for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; + nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); + auto type_status = + ConvertDType(node_def.attr().at("dtype").type(), &dtype); + if (type_status != tensorflow::Status::OK()) { + LOG(WARNING) << "Type conversion failed for " << node_name; + return type_status; + } + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8, + &slot_number)) { + LOG(ERROR) << "Failed to parse slot number from " << node_name + << " +8= " << node_name.c_str() + 8; + } + auto shape = input_shapes.at(slot_number); + if (shape.dims() > 8) { + LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name + << " at input slot " << slot_number; + return tensorflow::errors::OutOfRange( + "Input tensor rank is greater than 8"); + } + if (VLOG_IS_ON(1)) { + string dim_str("dims="); + StrAppend(&dim_str, "[ ", shape.dim_size(0)); + for (int i = 1; i < shape.dims(); i++) { + StrAppend(&dim_str, ", ", shape.dim_size(i)); + } + StrAppend(&dim_str, " ]"); + VLOG(1) << dim_str; + } + for (int i = 1; i < shape.dims(); i++) { + input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i); + } - // TODO(ben,jie): proper way to restore input tensor name? - auto input_tensor_name = node_name; - if (output_idx != 0) { - input_tensor_name = StrCat(node_name, ":", output_idx); + input_dim_pseudo_chw.nbDims = shape.dims() - 1; + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + node_name.c_str(), dtype, input_dim_pseudo_chw); + if (!input_tensor) { + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer tensor ", node_name, + " rank=", shape.dims() - 1); + } + VLOG(1) << "Input tensor name :" << node_name; + if (!converter.insert_input_tensor(node_name, input_tensor)) { + return tensorflow::errors::AlreadyExists( + "Output tensor already exists for op: " + node_name); + } + } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && + (node_def.op() == "Identity")) { + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9, + &slot_number)) { + LOG(ERROR) << "Failed to parse slot number from " << node_name + << " +9=" << node_name.c_str() + 9; + } + if (output_tensors.size() <= slot_number) { + output_tensors.resize(slot_number + 1); + } + output_tensors.at(slot_number) = {node_def.input(0), node_name}; + } else { + VLOG(2) << "Converting node: " << node_def.name() << " , " + << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); } - - input_names->push_back(input_tensor_name); - nvinfer1::ITensor* input_tensor = converter.network()->addInput( - input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); - - if (!input_tensor) - return tensorflow::errors::InvalidArgument( - "Failed to create Input layer"); - VLOG(2) << "Input tensor name :" << input_tensor_name; - - if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) - return tensorflow::errors::AlreadyExists( - "Output tensor already exists for op: " + input_tensor_name); - } - - for (const tensorflow::Node* node : *order) { - const tensorflow::NodeDef& node_def = node->def(); - VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); - TF_RETURN_IF_ERROR(converter.convert_node(node_def)); - } - - VLOG(2) << "Finished conversion"; - - // Gather output metadata - int trt_engine_op_output_idx = 0; - for (const std::pair& output : s.output_inds) { - int node_id = output.first; - int output_idx = output.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - string op_name = node->name(); - string tensor_name = op_name; - - s.output_edge_map->insert( - {trt_engine_op_output_idx == 0 - ? engine_name - : StrCat(engine_name, ":", trt_engine_op_output_idx), - {output_idx, tensor_name}}); - trt_engine_op_output_idx++; - if (output_idx != 0) - tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); - VLOG(2) << "Output tensor name: " << tensor_name; - output_names->push_back(tensor_name); - auto tensor_or_weights = converter.get_tensor(tensor_name); + } + for (const auto& output : output_tensors) { + auto tensor_or_weights = converter.get_tensor(output.first); if (!tensor_or_weights.is_tensor()) { - return tensorflow::errors::InvalidArgument("Output node '" + tensor_name + - "' is weights not tensor"); + return tensorflow::errors::InvalidArgument( + "Output node '" + output.first + "' is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + tensor->setName(output.second.c_str()); if (!tensor) { return tensorflow::errors::NotFound("Output tensor not found: " + - tensor_name); + output.first); } - converter.network()->markOutput(*tensor); - tensorflow::DataType tf_dtype = node->output_type(output_idx); - output_dtypes->push_back(tf_dtype); - nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; - TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); - tensor->setType(trt_dtype); - } + VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " + << output.second; - return tensorflow::Status::OK(); -} - -tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - // Toposort - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - // TODO(sami,ben,jie): proper naming! - string calib_op_name = - StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id); - static_id++; - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto op_rmgr = trt_rmgr->getManager("TRTCalibOps"); - auto op_res = new tensorflow::tensorrt::TRTCalibrationResource(); - TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res)); - op_res->logger_ = new tensorflow::tensorrt::Logger(); - cudaSetDevice(s.cuda_gpu_id_); - op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_)); - op_res->allocator_ = s.allocator_; -#if NV_TENSORRT_MAJOR > 3 - op_res->builder_->setGpuAllocator(s.allocator_.get()); -#endif - if (!op_res->builder_) { - return tensorflow::errors::Internal( - "failed to create TensorRT builder object"); + converter.network()->markOutput(*tensor); } + if (convert_successfully) *convert_successfully = true; - op_res->network_ = op_res->builder_->createNetwork(); - if (!op_res->network_) { - return tensorflow::errors::Internal( - "failed to create TensorRT network object"); + // Build the engine. + VLOG(1) << "Starting engine creation"; + engine->reset(builder->buildCudaEngine(*converter.network())); + if (engine->get() == nullptr) { + return tensorflow::errors::Internal("Failed to build TensorRT engine"); } - - // Build the network - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws)); - Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished processing outputs"; - - // Build the engine - op_res->builder_->setMaxBatchSize(s.max_batch_size); - op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - - // Build the TRT op - // TODO(sami,ben,jie): proper naming! - tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); - SetInputList(s, &op_builder, &input_names, &input_dtypes); - - std::vector segment_names; - segment_names.reserve(s.subgraph_node_ids.size()); - for (int i : s.subgraph_node_ids) { - auto node = s.graph.FindNodeId(i); - segment_names.push_back(node->name()); - } - LOG(INFO) << "finished op preparation"; - - auto status = op_builder.Attr("segment_nodes", segment_names) - .Attr("input_names", input_names) - .Attr("segment_output_names", output_names) - .Attr("resource_name", calib_op_name) - .Finalize(s.trt_node); - - LOG(INFO) << status.ToString(); - LOG(INFO) << "finished op building"; - + VLOG(1) << "Finished conversion"; return tensorflow::Status::OK(); } -tensorflow::Status ConvertSubGraphToTensorRTNodeDef( - tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++); - - tensorflow::tensorrt::Logger trt_logger; - cudaSetDevice(s.cuda_gpu_id_); - auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); - if (!trt_builder) { - return tensorflow::errors::Internal( - "Failed to create TensorRT builder object"); - } -#if NV_TENSORRT_MAJOR > 3 - trt_builder->setGpuAllocator(s.allocator_.get()); -#endif - auto trt_network = infer_object(trt_builder->createNetwork()); - if (!trt_network) { - return tensorflow::errors::Internal( - "Failed to create TensorRT network object"); - } - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws)); - - // Build the network - Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished output"; - - // Build the engine - trt_builder->setMaxBatchSize(s.max_batch_size); - trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - if (s.precision_mode == FP16MODE) { - trt_builder->setHalf2Mode(true); - VLOG(0) << "Using FP16 precision mode"; - } - LOG(INFO) << "starting build engine"; - string engine_plan_string; - { - auto trt_engine = - infer_object(trt_builder->buildCudaEngine(*converter.network())); - VLOG(0) << "Built network"; - if (trt_engine.get() == nullptr) { - return tensorflow::errors::Internal("Engine building failure"); +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::vector& subgraph_node_ids, // In topological order + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope) { + std::set marker_nodes; + // Update connection shapes/data types and add corresponding input/output + // nodes in the segment graphdef. + for (size_t i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + auto outside_node = graph->FindNodeId(connection.outside_id); + if (!outside_node) { + // This should never happen, unless the original graph is problematic. + return tensorflow::errors::NotFound( + "Cannot find node with id ", connection.outside_id, " in the graph."); + } + // Updates the shape and data types of input/output connections. + tensorflow::DataType input_type = tensorflow::DT_FLOAT; + tensorflow::PartialTensorShape partial_shape; + if (connection.is_input_edge) { + if (graph_properties.HasOutputProperties(connection.outside_node_name)) { + auto output_params = + graph_properties.GetOutputProperties(connection.outside_node_name); + auto out_shape = output_params.at(connection.outside_port); + input_type = out_shape.dtype(); + std::vector dims; + partial_shape = out_shape.shape(); + connection.outside_shape = partial_shape; + } else { + VLOG(0) << "Unknown output shape" << outside_node->name(); + input_type = graph->FindNodeId(connection.outside_id) + ->output_type(connection.outside_port); + } + connection.connection_type = input_type; + + } else { // output edge + if (graph_properties.HasInputProperties(connection.outside_node_name)) { + auto input_params = + graph_properties.GetInputProperties(connection.outside_node_name); + auto in_shape = input_params.at(connection.outside_port); + input_type = in_shape.dtype(); + partial_shape = in_shape.shape(); + connection.inside_shape = partial_shape; + } else { + input_type = graph->FindNodeId(connection.inside_id) + ->output_type(connection.outside_port); + } + connection.connection_type = input_type; } - auto engine_plan = infer_object(trt_engine->serialize()); - VLOG(0) << "Serialized engine"; - const char* engine_plan_data = - static_cast(engine_plan->data()); - engine_plan_string = - string(engine_plan_data, engine_plan_data + engine_plan->size()); - } - TF_RETURN_IF_ERROR(weight_rmgr->Delete( - engine_name, engine_name)); - LOG(INFO) << "finished engine " << engine_name << " containing " - << s.subgraph_node_ids.size() << " nodes"; - - // Build the TRT op - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - SetInputList(s, &op_builder, &input_names, &input_dtypes); - - VLOG(0) << "Finished op preparation"; - - auto status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_names) - .Attr("OutT", output_dtypes) - .Device(s.device_name_) - .Finalize(s.trt_node); - - VLOG(0) << status.ToString() << " finished op building for " << engine_name - << " on device " << s.device_name_; + // Add dummy input/output nodes to the segment graphdef. + if (connection.is_input_edge) { + const string node_name = StrCat(kInputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing input " << node_name << " for the edge " + << connection.outside_node_name << ":" + << connection.outside_port << " -> " + << connection.inside_node_name << ":" << connection.inside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); + auto status = builder.Attr("shape", partial_shape) + .Attr("dtype", input_type) + .Finalize(seg_node); + VLOG(1) << "Constructing input " << node_name << " for the edge " + << connection.outside_node_name << ":" << connection.outside_port + << " -> " << connection.inside_node_name << ":" + << connection.inside_port; + } else { + const string node_name = StrCat(kOutputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Identity"); + auto status = builder.Input(connection.inside_node_name, 0, input_type) + .Finalize(seg_node); + VLOG(1) << "Constructing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + } + } // for each connection. + + std::unordered_map old_to_new_id_map; + // Copy internal nodes to new graphdef + string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); + for (const auto node_id : subgraph_node_ids) { + const auto node = graph->FindNodeId(node_id); + local_scope = GetCommonNameScope(local_scope, node->name()); + old_to_new_id_map[node_id] = segment_def->node_size(); + auto snode = segment_def->add_node(); + snode->CopyFrom(node->def()); + VLOG(1) << "Copying " << snode->name() << " to subgraph"; + } + // Update the inputs of the new input nodes to point to placeholder nodes. + for (int i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + if (!connection.is_input_edge) continue; + auto snode = + segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); + const string placeholder_name = + StrCat(kInputPHName, connection.port_number); + VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port + << " from " << snode->input(connection.inside_port) << " to " + << placeholder_name; + snode->set_input(connection.inside_port, placeholder_name); + } + *common_scope = local_scope; + VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 3f6592cd25ff013cadc0621ba64f0553983dd10b..1a4c0e755d1cd1e88ac26c39996eb3a750421a0a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -22,69 +22,112 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { +static const char* kInputPHName = "InputPH_"; +static const char* kOutputPHName = "OutputPH_"; namespace convert { +// TODO(aaroey): use an enum instead. const int FP32MODE = 0; const int FP16MODE = 1; const int INT8MODE = 2; -struct SubGraphParams { - SubGraphParams( - tensorflow::Graph& inp_graph, - const std::set& subgraph_node_id_numbers, - const std::vector>& input_indices, - const std::vector>& output_indices, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - tensorflow::NodeDef* constructed_trt_node, - int engine_precision_mode = FP32MODE, const string& device_name = "", - std::shared_ptr allocator = nullptr, - int cuda_gpu_id = 0) - : graph(inp_graph), - subgraph_node_ids(subgraph_node_id_numbers), - input_inds(input_indices), - output_inds(output_indices), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - trt_node(constructed_trt_node), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - - tensorflow::Graph& graph; - const std::set& subgraph_node_ids; - const std::vector>& input_inds; // {node_id, output_idx} - const std::vector>& output_inds; // {node_id, output_idx} - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - tensorflow::NodeDef* trt_node; - const int precision_mode; - const string device_name_; - std::shared_ptr allocator_; - const int cuda_gpu_id_; +struct EngineConnection { + EngineConnection(const string& outside, int out_id, int out_port, + const string& inside, int in_id, int in_port, + bool input_edge, int port) + : outside_node_name(outside), + outside_id(out_id), + outside_port(out_port), + inside_node_name(inside), + inside_id(in_id), + inside_port(in_port), + is_input_edge(input_edge), + port_number(port) {} + + const string outside_node_name; + const int outside_id; + const int outside_port; + tensorflow::PartialTensorShape outside_shape; + + const string inside_node_name; + const int inside_id; + const int inside_port; + tensorflow::PartialTensorShape inside_shape; + + tensorflow::DataType connection_type; + bool is_input_edge; + + // The port number of the TRT node connecting to this edge. + int port_number; +}; + +struct EngineInfo { + EngineInfo() + : engine_type(EngineType::TRTStatic), + max_workspace_size_bytes(0), + precision_mode(FP32MODE) {} + + string engine_name; + string device; + tensorflow::GraphDef segment_graph_def; + + // The segment nodes that are on one side of the edges are topological sorted. + std::vector connections; + + enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; + EngineType engine_type; + int64 max_workspace_size_bytes; + int maximum_cached_engines; + std::vector cached_engine_batches; + int precision_mode; }; -// TODO(sami): Replace references with const reference or pointers -tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); -tensorflow::Status InjectCalibrationNode(SubGraphParams& params); -tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, - tensorflow::Node* c_node); +// Constructs a graphdef from the segment in the given graph. Adds placeholder +// nodes for input edges (InputPH_*) and identity nodes for output edges +// (OutputPH_*). This function needs to be called before TensorRT nodes +// inserted in order to correctly get sizes from the original graph. +// +// - subgraph_node_ids: the node ids of the subgraph, must be sorted in +// topological order. +// - segment_def: the output GraphDef, whose non-input/output nodedefs will be +// sorted in topological order. +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::vector& subgraph_node_ids, + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope); + +// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff +// 'builder' successfully build the engine. If the result is not ok, 'engine' +// will be set to nullptr +// Once returned, 'builder' is not needed any more and can be safely detroyed. +// +// - convert_successfully: indicates whether the converson to TensorRT network +// is successful. This is different than successfully building the engine: +// building can still fail afterwards. +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully); + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 8f634b1f74717310a69a6bab5d5224c9bdbf10cc..ec9dbfa13bfd0a158dcf41cf1fdb7128a2adf641 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -45,8 +45,24 @@ tensorflow::Status TRTOptimizationPass::Init( if (params.count("max_batch_size")) { maximum_batch_size_ = params.at("max_batch_size").i(); } - if (params.count("max_workspace_size_bytes")) + is_dynamic_op_ = false; + if (params.count("is_dynamic_op")) { + is_dynamic_op_ = params.at("is_dynamic_op").b(); + } + if (params.count("cached_engine_batches")) { + auto batch_vec = params.at("cached_engine_batches").list(); + batches_.reserve(batch_vec.i_size()); + for (const auto i : batch_vec.i()) { + batches_.push_back(i); + } + } + max_cached_batches_ = 1; + if (params.count("maximum_cached_engines")) { + max_cached_batches_ = params.at("maximum_cached_engines").i(); + } + if (params.count("max_workspace_size_bytes")) { maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); + } if (params.count("precision_mode")) { string pm = Uppercase(params.at("precision_mode").s()); if (pm == "FP32") { @@ -175,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize( if (VLOG_IS_ON(1)) { PrintDebugInfo(cluster, item); } + // This is a hack to workaround optimizer issue. MetaOptimizer calls + // optimization passes on function objects as well, we should not modify + // generated funcdefs! This is fragile but we don't have any other option + // until framework fixes it. + if (item.id != "tf_graph") { + LOG(WARNING) << name_ + << " is probably called on funcdef! This optimizer must *NOT* " + "be called on function objects."; + *optimized_graph = item.graph; + return tensorflow::Status::OK(); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { @@ -204,11 +231,22 @@ tensorflow::Status TRTOptimizationPass::Optimize( } tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - auto status = tensorflow::tensorrt::convert::ConvertAfterShapes( - item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_, - optimized_graph, precision_mode_, minimum_segment_size_, - static_graph_properties, cluster); + tensorflow::tensorrt::convert::ConversionParams cp; + cp.input_graph_def = &item.graph; + cp.output_names = &item.fetch; + cp.max_batch_size = maximum_batch_size_; + cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.output_graph_def = optimized_graph; + cp.precision_mode = precision_mode_; + cp.minimum_segment_size = minimum_segment_size_; + cp.graph_properties = &static_graph_properties; + cp.cluster = cluster; + cp.is_dyn_op = is_dynamic_op_; + cp.cached_engine_batches = batches_; + cp.max_cached_engines = max_cached_batches_; + auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(2) << optimized_graph->DebugString(); + VLOG(1) << "Returning from " << name_; return status; } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index d8ecead23efaa5c3bab95b8ba481e2307b0af772..463ed3883e4808408104c618a289989472c497ea 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { int minimum_segment_size_; int precision_mode_; int maximum_batch_size_; + bool is_dynamic_op_; + std::vector batches_; + int max_cached_batches_; int64_t maximum_workspace_size_; }; diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f601c06701fdbf983b708cf5f5c7d22634bb810b --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ + +#include + +namespace tensorflow { +namespace tensorrt { + +template +struct TrtDestroyer { + void operator()(T* t) { + if (t) t->destroy(); + } +}; + +template +using TrtUniquePtrType = std::unique_ptr>; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc deleted file mode 100644 index aea44fd8a2fcc4c359a6cb0c98ae34711708326e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/stream_executor.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); - OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); -}; - -#define TYPECASE(dt, X, Y) \ - case dt: { \ - return (void*)X->flat::Type>().data(); \ - } - -void* GetTensorAddress(const Tensor* tensor_ptr) { - auto tensor_type = tensor_ptr->dtype(); - switch (tensor_type) { - TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); - default: { - LOG(FATAL) << "Unsupported Data type " - << tensorflow::DataTypeString(tensor_type); - return nullptr; - } - } -} - -void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { - // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); - - if (!status.ok()) { - ctx->SetStatus(status); - return; - } - int num_inputs = ctx->num_inputs(); - // first run instantiate calibrator - if (calib_res->calibrator_ == nullptr) { - dev_tensors_.resize(num_inputs); - int batch_size = ctx->input(0).dim_size(0); - VLOG(1) << " Constructing calibrator"; - for (int i = 0; i < num_inputs; i++) { - // allocate workspace on device for inputs - const tensorflow::Tensor& t = ctx->input(i); - OP_REQUIRES_OK(ctx, - ctx->allocate_persistent(t.dtype(), t.shape(), - &dev_tensors_.at(i), nullptr)); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); - void* device_address = GetTensorAddress(device_tensor); - device_buffers_.emplace(input_names_.at(i), - std::pair( - device_address, device_tensor->TotalBytes())); - } - - calib_res->calibrator_ = - new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); - string label(resource_name_); - calib_res->thr_ = new std::thread([calib_res, label]() { - VLOG(1) << "Starting calibration thread, Calibration Resource @ " - << calib_res; - calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); - calib_res->builder_->setInt8Mode(true); - calib_res->engine_ = calib_res->builder_->buildCudaEngine( - *calib_res->network_); // will loop until we terminate calibrator - VLOG(1) << "Calibration loop terminated " << label; - }); - VLOG(1) << "initialized calibrator resource"; - } // calibrator initialized - - // Pass input data to calibrator - std::unordered_map input_data; - for (int i = 0; i < num_inputs; i++) { - const Tensor& t = ctx->input(i); - void* data_address = GetTensorAddress(&t); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), - device_tensor->TotalBytes()); // use the tensor so FW keeps it - input_data.emplace(input_names_.at(i), data_address); - ctx->set_output(i, t); - } - VLOG(2) << "Filled map for sending"; - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files - const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(ctx->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - calib_res->calibrator_->setBatch(input_data, *stream); - VLOG(2) << "Passed calibration data"; - // TODO(aaroey): make sure we wait for the completion of calibration on the - // last batch in future PR. -}; - -#undef TYPECASE - -REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); - -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h deleted file mode 100644 index 23df9db32f077a080eaff7479fcbe90d6a504c42..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H - -#include -#include -#include -#include -#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/types.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -namespace tensorflow { -namespace tensorrt { -// TODO(sami): Convert this to async kernel! -class TRTCalibOp : public OpKernel { - public: - explicit TRTCalibOp(OpKernelConstruction* context); - - void Compute(OpKernelContext* context) override; - - private: - string resource_name_; - std::vector segment_nodes_; - std::vector input_names_; - std::vector shapes_; - std::unordered_map> device_buffers_; - std::vector dev_tensors_; -}; -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 9ac8047944874181de228a6cc58e2dafe46abe50..8a17eb02f1af7c8f148c9cd4e14cc3876b6e13e3 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -14,8 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -25,144 +33,556 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" namespace tensorflow { -static ::tensorflow::tensorrt::Logger logger; -using IRuntime = nvinfer1::IRuntime; -using Dims = nvinfer1::Dims; - namespace tensorrt { +static Logger logger; +using ::nvinfer1::IRuntime; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// A helper class to call done() when destructed for asynchronous execution. +// Helps simultaneous execution of native and TRT engines. +class AsyncHelper : public tensorflow::core::RefCounted { + public: + AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; } + ~AsyncHelper() override { done_(); } -TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { + private: + tensorflow::AsyncOpKernel::DoneCallback done_; +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ + } + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(ERROR) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } + } +} + +tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { + VLOG(1) << "Constructing function handle"; + auto lib = ctx->function_library(); + if (lib == nullptr) { + return tensorflow::errors::Internal("Context function library is null"); + } + auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_); + if (fdef == nullptr) { + return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_, + " can't be found in function library"); + } + tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops; + inst_ops.overlay_lib = nullptr; + inst_ops.state_handle = ""; + inst_ops.target = ctx->device()->name(); + native_func_ = 0; + auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), + inst_ops, &native_func_); + if (!status.ok()) { + LOG(ERROR) << " Instantiating native function " << funcdef_name_ + << " failed!"; + } + return status; +} + +TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { // read serialized_engine OP_REQUIRES_OK(context, - context->GetAttr("serialized_engine", &serialized_engine_)); + context->GetAttr("serialized_segment", &serialized_segment_)); + OP_REQUIRES_OK(context, + context->GetAttr("workspace_size_bytes", &workspace_size_)); + OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_)); + if (!static_engine_) { + if (!segment_graph_.ParseFromString(serialized_segment_)) { + LOG(ERROR) << "Parsing segment graph failed!"; + context->SetStatus(tensorflow::errors::InvalidArgument( + "Failed to parse segment graphdef!")); + return; + } + serialized_segment_.resize(0); + } + VLOG(1) << "Constructing " << name(); + string precision_string; + OP_REQUIRES_OK(context, + context->GetAttr("precision_mode", &precision_string)); + string calibration_data; + OP_REQUIRES_OK(context, + context->GetAttr("calibration_data", &calibration_data)); + OP_REQUIRES_OK(context, + context->GetAttr("segment_funcdef_name", &funcdef_name_)); + if (precision_string == "FP32") { + precision_mode_ = convert::FP32MODE; + } else if (precision_string == "FP16") { + precision_mode_ = convert::FP16MODE; + } else if (precision_string == "INT8") { + precision_mode_ = convert::INT8MODE; + } + calibration_mode_ = + (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0); + if (calibration_data.size()) { + calibrator_.reset(new TRTInt8Calibrator(calibration_data)); + calibration_data.resize(0); + } + native_func_ = tensorflow::kInvalidHandle; + OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", + &max_cached_engines_)); + OP_REQUIRES_OK(context, + context->GetAttr("fixed_input_size", &fixed_input_size_)); + OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", + &cached_engine_batches_)); + std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); + if (VLOG_IS_ON(1)) { + string s("Engine Batches= "); + for (auto i : cached_engine_batches_) { + StrAppend(&s, i, " "); + } + VLOG(1) << s; + } +} - // register input output node name in trt_sub_graph - OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); +void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + if (!calibration_mode_) { + VLOG(1) << "Executing native engine"; + } + std::vector inputs; + std::vector* outputs = new std::vector(); + if (native_func_ == tensorflow::kInvalidHandle) { + auto status = ConstructFunctionHandle(ctx); + if (!status.ok()) { + LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_; + ctx->SetStatus(status); + return; + } + } + auto lib = ctx->function_library(); + tensorflow::FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.rendezvous = ctx->rendezvous(); + opts.cancellation_manager = ctx->cancellation_manager(); + opts.runner = ctx->runner(); + for (int i = 0; i < ctx->num_inputs(); i++) { + inputs.push_back(ctx->input(i)); + } + helper->Ref(); // Increment count for calculating native graph + VLOG(1) << "Executing native segment " << name(); + lib->Run(opts, native_func_, inputs, outputs, + [ctx, outputs, helper](const tensorflow::Status& s) { + tensorflow::core::ScopedUnref sc(helper); + VLOG(1) << "Native Segment completed"; + if (!s.ok()) { + ctx->SetStatus(s); + return; + } + for (size_t t = 0; t < outputs->size(); ++t) { + ctx->set_output(t, outputs->at(t)); + } + delete outputs; + }); } -void TRTEngineOp::Compute(OpKernelContext* context) { - // TODO(samikama) runtime should be taken from a resourcemanager as well. - // Only engine should be in the op and context and runtime should be taken - // from resourcemanager +void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + helper->Ref(); + tensorflow::core::ScopedUnref sc(helper); + // TODO(aaroey): remove the ResourceMgr singleton. + auto trt_rm = TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibration"); + TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->LookupOrCreate( + funcdef_name_, "Calibrator", &calib_res, + {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + int num_inputs = ctx->num_inputs(); + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + if (data_address == nullptr) { + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i)); + return; + } + // Check the allocated buffer is sufficient for input + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + input_data.emplace(StrCat(kInputPHName, i), data_address); + } + VLOG(2) << "Filled map for sending"; + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(ctx->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + calib_res->calibrator_->setBatch(input_data, *stream); + VLOG(2) << "Passed calibration data"; + ExecuteNativeSegment(ctx, helper); +} - if (!trt_execution_context_ptr_) { - IRuntime* infer = nvinfer1::createInferRuntime(logger); -#if NV_TENSORRT_MAJOR > 3 - auto device = context->device(); - auto dev_allocator = - device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!dev_allocator) { - LOG(FATAL) << "Can't find device allocator for gpu device " - << device->name(); - } - allocator_ = std::make_shared(dev_allocator); - infer->setGpuAllocator(allocator_.get()); -#endif - trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine_.c_str(), serialized_engine_.size(), - PluginFactoryTensorRT::GetInstance())); - trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); - // Runtime is safe to delete after engine creation - infer->destroy(); - serialized_engine_.clear(); +int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { + int num_batch = ctx->input(0).shape().dim_size(0); + int smallest_engine = 0; + for (const auto i : cached_engine_batches_) { + if (i >= num_batch) { + smallest_engine = i; + break; + } } - int num_binding = context->num_inputs() + context->num_outputs(); - std::vector buffers(num_binding); + // TODO(sami): Need an LRU here + if (smallest_engine == 0) { + if (max_cached_engines_ > cached_engine_batches_.size()) { + smallest_engine = num_batch; + cached_engine_batches_.push_back(num_batch); + VLOG(1) << "Running with batch size " << num_batch; + } else { + string s("Engine buffer is full. buffer limit= "); + StrAppend(&s, max_cached_engines_, ", current entries= "); + for (auto i : cached_engine_batches_) StrAppend(&s, i, ", "); + StrAppend(&s, "Requested batch= ", num_batch); + LOG(ERROR) << s; + ctx->SetStatus(tensorflow::errors::ResourceExhausted( + "Requested batch size is not available and engine cache is full")); + return -1; + } + } + return smallest_engine; +} - size_t binding_index; - int num_batch = 0; - for (int i = 0; i < context->num_inputs(); i++) { - // Grab the input tensor - binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); +void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, + tensorflow::AsyncOpKernel::DoneCallback done) { + auto helper = new AsyncHelper(done); + tensorflow::core::ScopedUnref sc(helper); + if (calibration_mode_) { + ExecuteCalibration(ctx, helper); + return; + } + const int smallest_engine = GetEngineBatch(ctx); + if (smallest_engine < 0) return; // GetEngineBatch already set the status. + + const int num_batch = ctx->input(0).shape().dim_size(0); + auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); + auto& trt_engine_ptr = engine_ctx_pair.first; + if (!trt_engine_ptr) { + LOG(WARNING) << "Engine retrieval for batch size " << num_batch + << " failed Running native segment"; + ExecuteNativeSegment(ctx, helper); + return; + } - const Tensor& input_tensor = context->input(i); + const int num_binding = ctx->num_inputs() + ctx->num_outputs(); + std::vector buffers(num_binding); + for (int i = 0; i < ctx->num_inputs(); i++) { + const string inp_name = StrCat(kInputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(inp_name.c_str()); + + const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); - if (i == 0) { - num_batch = input_shape.dim_size(0); - if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { - LOG(FATAL) << "input tensor batch larger than max_batch_size: " - << trt_engine_ptr_->getMaxBatchSize(); - } - } else if (num_batch != input_shape.dim_size(0)) { - LOG(FATAL) << "input data inconsistent batch size"; - break; + if (num_batch != input_shape.dim_size(0)) { + LOG(ERROR) << "input data inconsistent batch size"; + ctx->SetStatus(tensorflow::errors::FailedPrecondition( + "Different batch sizes between input tensors")); + return; } - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "FP16 inputs are not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "FP16 inputs are not supported!")); + return; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; - break; + LOG(ERROR) << "INT8 inputs are not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "INT8 inputs are not supported!")); + return; default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unknown output TRT data type! ", static_cast(dtype))); + return; } } - for (int i = 0; i < static_cast(output_nodes_.size()); i++) { - // This is bad that we have to reallocate output buffer every run. + for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor - binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); + const string output_name = StrCat(kOutputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr_->getBindingDimensions(binding_index); + auto dims = trt_engine_ptr->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; - OP_REQUIRES_OK(context, - TensorShapeUtils::MakeShape( - trt_shape.data(), trt_shape.size(), &output_shape)); + OP_REQUIRES_OK( + ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), + &output_shape)); } else { - LOG(FATAL) << "output node not found, at " << output_nodes_[i]; - break; + LOG(ERROR) << "output node not found, at " << output_name; + ctx->SetStatus(tensorflow::errors::Internal("output ", output_name, + " couldn't be found!")); + return; } - - OP_REQUIRES_OK(context, - context->allocate_output(i, output_shape, &output_tensor)); - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto status = ctx->allocate_output(i, output_shape, &output_tensor); + if (!status.ok()) { + LOG(ERROR) << "Allocating output failed with " << status; + ctx->SetStatus(status); + return; + } + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = reinterpret_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "half size is not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Half outputs are not supported!")); + return; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; - break; + LOG(ERROR) << "int8 is not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "INT8 outputs are not supported!")); + return; default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported output data type! ", static_cast(dtype))); + return; } } // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(context->op_device_context() + reinterpret_cast(ctx->op_device_context() ->stream() ->implementation() ->CudaStreamMemberHack())); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], - *stream, nullptr); - VLOG(2) << "enqueue returns: " << ret; + auto& trt_execution_context_ptr = engine_ctx_pair.second; + auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, + nullptr); + if (!ret) { + LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name(); + ctx->SetStatus(tensorflow::errors::Internal( + "Failed to enqueue batch for TRT engine: ", name())); + } // sync should be done by TF. } + TRTEngineOp::~TRTEngineOp() { - // Order matters! - trt_execution_context_ptr_.reset(); - trt_engine_ptr_.reset(); + // We need to manually destroy the engine and execution context before + // the allocator is destructed. + for (auto& eng : engine_map_) { + eng.second.first.reset(); + eng.second.second.reset(); + } allocator_.reset(); } + +nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { + if (allocator_) return allocator_.get(); + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + ctx->SetStatus(tensorflow::errors::Internal( + "Can't get device allocator for device ", device->name())); + return nullptr; + } + allocator_.reset(new TRTDeviceAllocator(alloc)); + return allocator_.get(); +} + +TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, + OpKernelContext* ctx) { + static EngineCtxPair null_pair = { + TrtUniquePtrType(nullptr), + TrtUniquePtrType(nullptr)}; + // TODO(sami): This method needs to be re-written to use resource manager and + // with LRU mechanism option. + tensorflow::mutex_lock lock(engine_mutex_); + + if (static_engine_) { + if (engine_map_.size()) { + if (engine_map_.begin()->first >= batch_size) { + return engine_map_.begin()->second; + } + return null_pair; + } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); +#if NV_TENSORRT_MAJOR > 3 + auto allocator = GetAllocator(ctx); + if (allocator == nullptr) { + // GetAllocator already set the Status. + return null_pair; + } + infer->setGpuAllocator(allocator); +#endif + TrtUniquePtrType static_engine( + infer->deserializeCudaEngine(serialized_segment_.c_str(), + serialized_segment_.size(), nullptr)); + auto raw_static_engine = static_engine.get(); + const auto max_batch_size = raw_static_engine->getMaxBatchSize(); + engine_map_[max_batch_size] = { + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext())}; + // Runtime is safe to delete after engine creation + serialized_segment_.clear(); + if (max_batch_size < batch_size) return null_pair; + return engine_map_.at(max_batch_size); + } // static_engine_ + + // Handle the dynamic engine case. + auto engine_it = engine_map_.find(batch_size); + if (engine_it == engine_map_.end() && + engine_map_.size() < (size_t)max_cached_engines_) { + nvinfer1::IGpuAllocator* allocator = nullptr; +#if NV_TENSORRT_MAJOR > 3 + allocator = GetAllocator(ctx); + if (allocator == nullptr) { + // GetAllocator already set the Status. + return null_pair; + } +#endif + std::vector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + shapes.emplace_back(ctx->input(i).shape()); + } + TrtUniquePtrType engine; + bool convert_successfully = false; + VLOG(0) << name() << " Constructing a new engine with batch size " + << batch_size; + // Up to this point, calibrator_ can never be empty, since otherwise it + // means calibration_mode_ is true and this path won't get executed. + auto status = convert::ConvertGraphDefToEngine( + segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, + &logger, allocator, calibrator_.get(), &engine, &convert_successfully); + if (!status.ok()) { + if (convert_successfully) { + // This means it fail to build the engine even when the network is built + // successfully, probably due to internal issues. In this case we don't + // retry in the future. + engine_map_[batch_size] = {nullptr, nullptr}; + } + LOG(ERROR) << "Engine creation for batch size " << batch_size + << " failed " << status; + ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!")); + return null_pair; + } + VLOG(1) << "Conversion is done"; + TrtUniquePtrType exec_context( + engine->createExecutionContext()); + engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + } + return engine_map_.at(batch_size); +} + +tensorflow::Status TRTEngineOp::AllocateCalibrationResources( + tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) { + auto cres = new TRTCalibrationResource(); + *cr = cres; + // Get the allocator. + auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(WARNING) << "Can't get device allocator will not be able to " + "allocate memory from TensorFlow memory pool"; + cres->allocator_.reset(new TRTCudaAllocator); + } else { + cres->allocator_.reset(new TRTDeviceAllocator(alloc)); + } + // Get the input shapes. + const int batch_size = ctx->input(0).dim_size(0); + const int num_inputs = ctx->num_inputs(); + std::vector shapes; + dev_tensors_.resize(num_inputs); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + shapes.emplace_back(t.shape()); + Tensor* device_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + if (device_address == nullptr) { + return tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i); + } + device_buffers_.emplace( + StrCat(kInputPHName, i), + std::pair(device_address, device_tensor->TotalBytes())); + } + cres->calibrator_.reset( + new TRTInt8Calibrator(device_buffers_, batch_size, name())); + const string label(name()); + auto segment_graph = &segment_graph_; + const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id; + if (cuda_gpu_id < 0) { + LOG(ERROR) << "Can't get gpu_device_info from context->device()"; + return tensorflow::errors::InvalidArgument( + "Context->device doesn't contain device info!"); + } + const int64 workspace_size_bytes = workspace_size_; + cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes, + cuda_gpu_id, workspace_size_bytes]() { + VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id + << ", Calibration Resource @ " << cres; + auto err = cudaSetDevice(cuda_gpu_id); + if (err != cudaSuccess) { + // TODO(aaroey): should return error here. + LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id + << " in calibration thread"; + } + // ConvertGraphDefToEngine() will try to build the engine. This thread + // will loop inside buildCudaEngine() consuming the calibration data + // that is set by the TF op, and drive the builder until calibrator returns + // false. Engine is discarded after calibration table is generated + // + // TODO(aaroey): maybe setting the max batch size using the python + // calibration wrapper class. + auto s = convert::ConvertGraphDefToEngine( + *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(), + workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), + cres->calibrator_.get(), &cres->engine_, + /*convert_successfully=*/nullptr); + if (!s.ok()) { + LOG(ERROR) << "Calibration failed: " << s; + cres->calibrator_->setDone(); // Ignore further pushes + } + VLOG(1) << "Calibration loop terminated " << label; + })); + VLOG(1) << "initialized calibrator resource"; + return tensorflow::Status::OK(); +} + REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index e613a71422852e60565ba7554516d7eace6b9cc7..6fe318be6a6bc9f01ce3b52e0430f2090b53002b 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -19,9 +19,14 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -30,32 +35,95 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class Logger; - +class TRTInt8Calibrator; +class TRTCalibrationResource; +class AsyncHelper; // TODO(Sami): Remove this file? -class TRTEngineOp : public OpKernel { + +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { public: explicit TRTEngineOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override; + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; ~TRTEngineOp(); private: - template - struct Destroyer { - void operator()(T* d) { d->destroy(); } - }; - - template - using destroyed_ptr = std::unique_ptr>; - destroyed_ptr trt_engine_ptr_; + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + TRTCalibrationResource** cr); + // TODO(samikama): context should go to a resource manager! - destroyed_ptr trt_execution_context_ptr_; + typedef std::pair, + TrtUniquePtrType> + EngineCtxPair; + EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); + // Return engine batch closest to input batch. + int GetEngineBatch(OpKernelContext* ctx); + + nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); + + // map to keep engines and their execution context for given batch size. + std::unordered_map engine_map_; std::vector input_nodes_; std::vector output_nodes_; - std::shared_ptr allocator_; - string serialized_engine_; + + // keep device allocator for TRT. + std::unique_ptr allocator_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector dev_tensors_; + + // Engine Precision mode. + int precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Whether non-batch ranks of the inputs are assumed to be fixed or not for + // engine construction. + bool fixed_input_size_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc deleted file mode 100644 index 4835e5065068ec7a59995eb7f6126b31aecf6704..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -namespace tensorflow { - -REGISTER_OP("TRTCalibOp") - .Attr("segment_nodes: list(string)") // names of the ops in segment - .Attr("segment_output_names: list(string)") // names of the output ops in - // segment - .Attr("input_names: list(string)") // names of the inputs for - // passing into tensorrt - .Attr("resource_name: string") - .Attr("InT: list({int8, float16, float32})") - .Input("in_tensor: InT") - .Output("out_tensor: InT") - .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_inputs(); i++) { - c->set_output(i, c->input(i)); - } - return Status::OK(); - }); - -} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 079d73f7bec3f9a9740e455b31a259cec287f849..383635f428812984915a8c46ad3b92cc7b28a5f7 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -28,11 +28,19 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c); } REGISTER_OP("TRTEngineOp") - .Attr("serialized_engine: string") - .Attr("input_nodes: list(string)") - .Attr("output_nodes: list(string)") - .Attr("InT: list({float32})") - .Attr("OutT: list({float32})") + .Attr("serialized_segment: string") + .Attr("input_shapes: list(shape)") + .Attr("output_shapes: list(shape)") + .Attr("segment_funcdef_name: string") + .Attr("InT: list({int8,float16,float32})") + .Attr("OutT: list({int8,float16,float32})") + .Attr("static_engine: bool = true") + .Attr("fixed_input_size: bool = true") + .Attr("cached_engine_batches: list(int) = []") + .Attr("max_cached_engines_count: int = 1") + .Attr("workspace_size_bytes: int") + .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") + .Attr("calibration_data: string = ''") .Input("in_tensor: InT") .Output("out_tensor: OutT") .SetShapeFn(shape_inference::TRTEngineOpShapeInference); diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 338475d90ea55ab2c1bb8df77f27a71a4a36a5dd..79f512dbcf6bd4d84b98cf69630778734566391c 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -21,6 +21,8 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -29,7 +31,9 @@ from tensorflow.python.framework import errors_impl as _impl from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.platform import tf_logging from tensorflow.python.util import compat + # pylint: enable=unused-import,line-too-long @@ -40,7 +44,10 @@ def create_inference_graph(input_graph_def, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", - minimum_segment_size=3): + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]): """Python wrapper for the TRT transformation. Args: @@ -51,6 +58,10 @@ def create_inference_graph(input_graph_def, precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. + is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT + network and engine at run time. + maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. + cached_engine_batches: batch sizes used to pre-create cached engines. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. @@ -65,6 +76,30 @@ def create_inference_graph(input_graph_def, "It should be one of {}").format( precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] + compiled_version = get_linked_tensorrt_version() + loaded_version = get_loaded_tensorrt_version() + version_mismatch = False + if loaded_version[0] < compiled_version[0]: + tf_logging.error( + "TensorRT version mismatch. Tensorflow was compiled against " + + "TensorRT %s but library loaded from environment is TensorRT %s" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version])) + + ". Please make sure that correct version of TensorRT " + + "is available in the system and added to ldconfig or LD_LIBRARY_PATH" + ) + raise RuntimeError("Incompatible TensorRT library version") + for i in zip(loaded_version, compiled_version): + if i[0] != i[1]: + tf_logging.warn("TensorRT mismatch. Compiled against version " + + "%s, but loaded %s. Things may not work" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version]))) + version_mismatch = True + break + if not version_mismatch: + tf_logging.info("Running against TensorRT version %s" % ".".join( + [str(x) for x in loaded_version])) def py2bytes(inp): return inp @@ -100,7 +135,9 @@ def create_inference_graph(input_graph_def, # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes, mode, minimum_segment_size) + max_workspace_size_bytes, mode, minimum_segment_size, + is_dynamic_op, maximum_cached_engines, + cached_engine_batches) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory @@ -120,11 +157,12 @@ def create_inference_graph(input_graph_def, return output_graph_def -def calib_graph_to_infer_graph(calibration_graph_def): +def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data + is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: @@ -141,9 +179,16 @@ def calib_graph_to_infer_graph(calibration_graph_def): to_string = py2string else: to_string = py3string - + is_calib_graph = False + for n in calibration_graph_def.node: + if n.op == "TRTEngineOp": + is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s + if not is_calib_graph: + tf_logging.error( + "Not a calib graph. Doesn't seem to contain any calibration nodes.") + return None graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str) + out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc index 0f0508331c13055096714352e83fc360f0ef39b4..9f115990c3a3e6e92093e5f0d82b985af1b25482 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc @@ -50,7 +50,7 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) } void TRTDeviceAllocator::free(void* memory) { - VLOG(2) << "Deallocating " << memory; + VLOG(2) << "Deallocating @ " << memory; allocator_->DeallocateRaw(memory); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index a0c2540a7698bc46a65dbd967412351bac2a4dd2..c5d2cec730f4ae97e4c6bcc19897fd9f321122a7 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ - #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/framework/allocator.h" @@ -52,7 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); - virtual ~TRTDeviceAllocator() {} + virtual ~TRTDeviceAllocator() { + VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); + } void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override; void free(void* memory) override; diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index dc7c93f869f5ef7c8eaa2a87eed26cfe69597fdb..dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include -#include #include #include "tensorflow/core/platform/logging.h" @@ -37,20 +36,29 @@ TRTInt8Calibrator::TRTInt8Calibrator( : batch_size_(batch_size), done_(false), dev_buffers_(dev_buffers), - calib_running_(false), + // Make sure setBatch() waits until getBatch() is called (the first time). + calib_running_(true), batch_is_set_(false), engine_name_(engine_name) {} +TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) + : batch_size_(0), + done_(true), + calib_running_(false), + batch_is_set_(false), + calibration_table_(calib_data) {} + bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - while ((calib_running_ || batch_is_set_) && - !done_) { // wait while calibration is running - cond_.wait(lock); - } + + // Wait while the queue is full or calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); if (done_) return false; CHECK(!calib_running_ && !batch_is_set_); VLOG(1) << "Set Batch Waiting finished"; + + // Sets the batch. for (const auto it : data) { auto devptr = dev_buffers_.find(it.first); if (devptr == dev_buffers_.end()) { @@ -59,8 +67,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } const auto& d = devptr->second; - // TODO(aaroey): we should not use sync copy on default stream. Make sure - // stream->ThenMemcpy() is used in future PRs. // TODO(sami,aaroey): Need to figure out a way to ensure synchronization // between stream, perhaps using a tensor? auto status = cudaMemcpyAsync(d.first, it.second, d.second, @@ -72,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } // TODO(Sami, aaorey): Find an alternative way! - cudaStreamSynchronize( - stream); // we have to wait for the stream before returning! + // we have to wait for the stream before returning! + cudaStreamSynchronize(stream); batch_is_set_ = true; cond_.notify_all(); return true; @@ -82,23 +88,21 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, int num_bindings) { tensorflow::mutex_lock lock(cond_mtx_); + // Notify finish of last round of calibration. calib_running_ = false; cond_.notify_all(); - while ((!batch_is_set_ && !done_)) { // wait until new batch arrives - cond_.wait(lock); - } - if (done_) { - return false; - } + // Wait until new batch arrives + while ((!batch_is_set_ && !done_)) cond_.wait(lock); + if (done_) return false; + // Gets the batch for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); if (it == dev_buffers_.end()) { LOG(FATAL) << "Calibration engine asked for unknown tensor name '" << names[i] << "' at position " << i; } - bindings[i] = it->second.first; } batch_is_set_ = false; @@ -106,8 +110,21 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, return true; } +void TRTInt8Calibrator::waitAndSetDone() { + tensorflow::mutex_lock lock(cond_mtx_); + // Wait while the queue is full or calibration is running, so we don't miss + // the last batch. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); + if (!done_) { + done_ = true; + cond_.notify_all(); + } +} + const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { - return nullptr; + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); } void TRTInt8Calibrator::setDone() { @@ -117,7 +134,11 @@ void TRTInt8Calibrator::setDone() { } void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, - std::size_t length) {} + std::size_t length) { + calibration_table_ = string((const char*)ptr, length); + VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr + << " length=" << length; +} TRTInt8Calibrator::~TRTInt8Calibrator() { VLOG(1) << "Destroying calibrator for " << engine_name_; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index d77aa2c5ab184756adaee38f88180b3c128ebe03..65466c9741989fda5f82fc27d813d026f35fe386 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -36,32 +36,59 @@ namespace tensorrt { struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { public: + // Construct a calibrator for future calibration. TRTInt8Calibrator( const std::unordered_map>& dev_buffers, int batch_size, string engine_name); + + // Construct a finalized calibrator where we don't need to run calibration any + // more, as the calibration data is provided. + TRTInt8Calibrator(const string& calibration_data); + + ~TRTInt8Calibrator(); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], int num_bindings) override; + bool setBatch(const std::unordered_map& data, const cudaStream_t stream); + + // Wait until the last batch is consumed by the calibrator and set done. + void waitAndSetDone(); + + // Notify that calibration is done and future batches provided by setBatch() + // will be ignored. void setDone(); + + // If not null, calibration is skipped. const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; - ~TRTInt8Calibrator(); + + const string& getCalibrationTableAsString() { return calibration_table_; } private: const int batch_size_; - tensorflow::mutex cond_mtx_; // mutex for condition_variable - tensorflow::condition_variable cond_; // condition variable to implement - // producer-consumer queue for - // calibration + + // mutex for condition_variable + tensorflow::mutex cond_mtx_; + + // condition variable to implement producer-consumer queue for calibration + tensorflow::condition_variable cond_; + + // Is calibration finished? bool done_; - const std::unordered_map> - dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with - // buffer names + + // Map to keep tensorrt input buffers and sizes keyed with buffer names + const std::unordered_map> dev_buffers_; + bool calib_running_; bool batch_is_set_; + string engine_name_; + string calibration_table_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index e3469124acd4b9f6f4dd81b9998aa60bfe469b35..b7d5ffd6748ba34c6c4ddbfbfbb44edb6bf2aca8 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" @@ -34,50 +35,48 @@ limitations under the License. namespace tensorflow { namespace tensorrt { + class TRTCalibrationResource : public tensorflow::ResourceBase { public: - TRTCalibrationResource() - : calibrator_(nullptr), - builder_(nullptr), - network_(nullptr), - engine_(nullptr), - logger_(nullptr), - thr_(nullptr) {} - ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); } string DebugString() override { std::stringstream oss; - oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl - << " Builder = " << std::hex << builder_ << std::dec << std::endl - << " Network = " << std::hex << network_ << std::dec << std::endl - << " Engine = " << std::hex << engine_ << std::dec << std::endl - << " Logger = " << std::hex << logger_ << std::dec << std::endl - << " Allocator = " << std::hex << allocator_.get() << std::dec - << std::endl - << " Thread = " << std::hex << thr_ << std::dec << std::endl; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; return oss.str(); } - TRTInt8Calibrator* calibrator_; - nvinfer1::IBuilder* builder_; - nvinfer1::INetworkDefinition* network_; - nvinfer1::ICudaEngine* engine_; - std::shared_ptr allocator_; - tensorflow::tensorrt::Logger* logger_; + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! - std::thread* thr_; + std::unique_ptr thr_; }; -class TRTWeightStore : public tensorflow::ResourceBase { +class TRTWeightStore { public: TRTWeightStore() {} virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } - string DebugString() override { + string DebugString() { std::stringstream oss; size_t len_bytes = 0; for (const auto& v : store_) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 1568dd915344e6ba982b5a5550cc5386e047ff9f..81b4bfe49fe375d19f4c7811459f38e25d2edea8 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -29,8 +29,9 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// vector of segments, each entry contains a device name and a set of nodes in -// segment +// Vector of segments, each entry contains a set of node names and a device name +// in the segment. +// TODO(aaroey): use node pointer instead of node name. using SegmentNodesVector = std::vector, string>>; struct SegmentOptions { @@ -48,6 +49,8 @@ struct SegmentOptions { // in the vector describes a subgraph by giving a set of the names of // all the NodeDefs in that subgraph. // @return the status. +// +// TODO(aaroey): remove this method. tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, const std::function& candidate_fn, diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 2de3923b06a8ddf89c7e6f922138a85f55a618d6..f5b2d258d70d5577a9d68f2d9f6d6e678ede97ce 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -275,13 +275,13 @@ TEST_F(SegmentTest, Multiple) { // Expect two subgraphs EXPECT_EQ(segments.size(), 2); - std::vector expected0{"add0", "add1", "add2", "add3"}; + std::vector expected0{"add6", "add8"}; for (const auto& ex : expected0) { EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) << "Missing expected node " << ex; } - std::vector expected1{"add6", "add8"}; + std::vector expected1{"add0", "add1", "add2", "add3"}; for (const auto& ex : expected1) { EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) << "Missing expected node " << ex; diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f36495f6b69ecb2f2a8d730b9ae4919fea3c04b8..227ac120dde8c986379c687987cd1bd822d559f7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -29,61 +29,35 @@ namespace tensorflow { namespace shape_inference { tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - tensorflow::tensorrt::Logger logger; - string serialized_engine; - TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); - nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), - tensorrt::PluginFactoryTensorRT::GetInstance()); - - int num_batch = -1; - std::vector<::tensorflow::DataType> input_type; - TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type)); - for (size_t i = 0; i < context->num_inputs(); i++) { - // Check if input shape is legit - auto input_shape = context->input(i); - for (int j = 0; j < context->Rank(input_shape); j++) { - auto dim_handler = context->Dim(input_shape, j); - if (j == 0) { - if (i == 0) { - num_batch = context->Value(dim_handler); - } else if (num_batch != context->Value(dim_handler)) { - // TODO(jie): TensorRT engine requires consistent batch between inputs - // tensors. Segmenter should be aware of this. - LOG(FATAL) << "TensorRT engine requires consistent batch size"; - } - } - } + std::vector shapes; + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, context->UnknownShape()); } - - // Arrange input here - std::vector input_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); - - // Arrange output here - std::vector output_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); - for (size_t i = 0; i < output_nodes.size(); i++) { - int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); - ShapeHandle output_shape; - std::vector dim_vec; - dim_vec.emplace_back(context->MakeDim(num_batch)); - if (binding_index != -1) { - auto dims = trt_engine->getBindingDimensions(binding_index); - for (int j = 0; j < dims.nbDims; j++) { - dim_vec.emplace_back(context->MakeDim(dims.d[j])); - } - } else { - LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; - } - output_shape = context->MakeShape(dim_vec); - context->set_output(i, output_shape); + auto status = context->GetAttr("input_shapes", &shapes); + // it is ok to not to have shapes + if (!status.ok()) return Status::OK(); + if ((int)shapes.size() != context->num_inputs()) return Status::OK(); + bool different_input = false; + for (int i = 0; i < context->num_inputs(); ++i) { + if (shapes.at(i) != context->input_tensor(i)->shape()) + different_input = true; + } + if (different_input) return Status::OK(); + shapes.resize(0); + status = context->GetAttr("output_shapes", &shapes); + if (!status.ok()) return Status::OK(); + if ((int)shapes.size() != context->num_outputs()) return Status::OK(); + std::vector shape_handles(shapes.size()); + for (size_t i = 0; i < shapes.size(); ++i) { + status = + context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i)); + if (!status.ok()) return Status::OK(); + } + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, shape_handles.at(i)); } - return Status::OK(); } - } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 175ccd800686255092e241aa59568df407d6eebc..090aa8bdb0487973e186631af3b4edac48096a5f 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import numpy as np +import six as _six # normally we should do import tensorflow as tf and then # tf.placeholder, tf.constant, tf.nn.conv2d etc but @@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes from tensorflow.python.framework import importer as importer from tensorflow.python.framework import ops as ops from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import math_ops as mops from tensorflow.python.ops import nn as nn from tensorflow.python.ops import nn_ops as nn_ops +def py2bytes(inp): + return inp + + +def py3bytes(inp): + return inp.encode("utf-8", errors="surrogateescape") + + +def py2string(inp): + return inp + + +def py3string(inp): + return inp.decode("utf-8") + + +if _six.PY2: + to_bytes = py2bytes + to_string = py2string +else: + to_bytes = py3bytes + to_string = py3string + + +def get_multi_engine_graph_def(mode="FP32"): + """Create a simple graph and return its graph_def.""" + dtype = dtypes.float32 + if mode.upper() == "FP16": + dtype = dtypes.float16 + else: + pass + + g = ops.Graph() + with g.as_default(): + x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) + with g.name_scope("Global_scope"): + with g.name_scope("first_scope"): + e = cop.constant( + np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) + t = conv * b + + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) + q = conv / b + edge = mops.sin(q) + edge1 = mops.cos(conv) + with g.name_scope("test_scope"): + de = edge + edge1 + t -= edge1 + q *= edge + t += q + t -= de + k = aops.squeeze(t, name="output") + print(k.dtype) + return g.as_graph_def() + + def get_simple_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -65,7 +131,9 @@ def get_simple_graph_def(): def execute_graph(gdef, dumm_inp): """Run given graphdef once.""" print("executing") - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) ops.reset_default_graph() g = ops.Graph() @@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp): # for calibration. For this test script it is random data. def execute_calibration(gdef, dumm_inp): """Run given calibration graph multiple times.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp): return val -def user(run_graph=execute_graph, run_calibration=execute_calibration): +def user(multi_engine, + run_graph=execute_graph, + run_calibration=execute_calibration): """Example function that converts a graph to TFTRT graph.""" - - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() # use a frozen graph for inference # Get optimized graph trt_graph = trt.create_inference_graph( input_graph_def=orig_graph, @@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) o5 = run_graph(int8_graph, dummy_input) - assert np.allclose(o1, o4) - assert np.allclose(o1, o5) + print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) + print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) print("Pass") -def auto(): +def auto(multi_engine): """Run the conversion as an optimization pass.""" - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() opt_config = rwpb2.RewriterConfig() + opt_config.meta_optimizer_iterations = opt_config.ONE opt_config.optimizers.extend(["constfold", "layout"]) custom_op = opt_config.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = "FP32" + custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") custom_op.parameter_map["max_batch_size"].i = inp_dims[0] custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 print(custom_op) - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) graph_options = cpb2.GraphOptions(rewrite_options=opt_config) sessconfig = cpb2.ConfigProto( gpu_options=gpu_options, graph_options=graph_options) @@ -168,7 +256,7 @@ def auto(): ops.reset_default_graph() with g.as_default(): inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"]) + graph_def=orig_graph, return_elements=["input", "output"], name="") inp = inp.outputs[0] out = out.outputs[0] with csess.Session(config=sessconfig, graph=g) as sess: @@ -186,8 +274,14 @@ if "__main__" in __name__: action="store_true", help="Do TRT conversion automatically", default=False) + P.add_argument( + "--multi-engine", + "-m", + action="store_true", + help="Use a graph that will result in 2 engines", + default=False) flags, unparsed = P.parse_known_args() if flags.automatic: - auto() + auto(flags.multi_engine) else: - user() + user(flags.multi_engine) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py index 0403b652d72877196c3537a3181529aeeb997395..d9c41f90d0ab111b48c37aeaae5f0ce3177646c2 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py @@ -18,131 +18,330 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple +import itertools import warnings import numpy as np +import six from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.python.framework import constant_op as cop -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops -from tensorflow.python.platform import googletest +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test +INPUT_NAME = "input" +OUTPUT_NAME = "output" +INPUT_DIMS = [100, 24, 24, 2] +MODE_FP32 = "FP32" +MODE_FP16 = "FP16" +MODE_INT8 = "INT8" -class IntegrationTest(test_util.TensorFlowTestCase): +if six.PY2: + to_bytes = lambda s: s + to_string = lambda s: s +else: + to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") + to_string = lambda s: s.decode("utf-8") + + +# TODO(aaroey): test graph with different dtypes. +def GetSingleEngineGraphDef(dtype=dtypes.float32): + """Create a graph containing single segment.""" + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + bias = constant_op.constant( + [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + added = nn.bias_add(conv, bias, name="bias_add") + relu = nn.relu(added, "relu") + identity = array_ops.identity(relu, "identity") + pool = nn_ops.max_pool( + identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + array_ops.squeeze(pool, name=OUTPUT_NAME) + return g.as_graph_def() + + +# TODO(aaroey): test graph with different dtypes. +def GetMultiEngineGraphDef(dtype=dtypes.float32): + """Create a graph containing multiple segment.""" + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + c1 = constant_op.constant( + np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype) + p = conv * c1 + c2 = constant_op.constant( + np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype) + q = conv / c2 + + edge = math_ops.sin(q) + edge /= edge + r = edge + edge + + p -= edge + q *= edge + s = p + q + s -= r + array_ops.squeeze(s, name=OUTPUT_NAME) + return g.as_graph_def() + + +TestGraph = namedtuple("TestGraph", + ["gdef", "num_expected_engines", "expected_output_dims"]) + +TEST_GRAPHS = { + "SingleEngineGraph": + TestGraph( + gdef=GetSingleEngineGraphDef(), + num_expected_engines=1, + expected_output_dims=(100, 6, 6, 6)), + "MultiEngineGraph": + TestGraph( + gdef=GetMultiEngineGraphDef(), + num_expected_engines=2, + expected_output_dims=(100, 12, 12, 6)), + # TODO(aaroey): add a large complex graph to test. +} + + +class TfTrtIntegrationTest(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" def setUp(self): """Setup method.""" - super(IntegrationTest, self).setUp() + super(TfTrtIntegrationTest, self).setUp() warnings.simplefilter("always") - inp_dims = (100, 24, 24, 2) - self._input = np.random.random_sample(inp_dims) - self._original_graph = self.get_simple_graph_def() - self._gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - self._config = cpb2.ConfigProto(gpu_options=self._gpu_options) - self._reference = self.run_graph(self._original_graph, self._input) - - def get_simple_graph_def(self): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = aops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - e = cop.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtypes.float32) - conv = nn.conv2d( - input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") - b = cop.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) - t = nn.bias_add(conv, b, name="biasAdd") - relu = nn.relu(t, "relu") - idty = aops.identity(relu, "ID") - v = nn_ops.max_pool( - idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - aops.squeeze(v, name="output") - return g.as_graph_def() - - def run_graph(self, gdef, dumm_inp): - """Run given graphdef once.""" - ops.reset_default_graph() + self._input = np.random.random_sample(INPUT_DIMS) + + def _GetConfigProto(self, + use_optimizer, + precision_mode=None, + is_dynamic_op=None): + if use_optimizer: + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + custom_op = rewriter_cfg.custom_optimizers.add() + custom_op.name = "TensorRTOptimizer" + custom_op.parameter_map["minimum_segment_size"].i = 3 + custom_op.parameter_map["max_batch_size"].i = self._input.shape[0] + custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op + custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 + custom_op.parameter_map["precision_mode"].s = to_bytes(precision_mode) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) + else: + graph_options = config_pb2.GraphOptions() + + gpu_options = config_pb2.GPUOptions() + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options.per_process_gpu_memory_fraction = 0.50 + + config = config_pb2.ConfigProto( + gpu_options=gpu_options, graph_options=graph_options) + return config + + def _RunGraph(self, graph_key, gdef, input_data, config, num_runs=2): + """Run given graphdef multiple times.""" g = ops.Graph() with g.as_default(): inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) + graph_def=gdef, return_elements=[INPUT_NAME, OUTPUT_NAME], name="") inp = inp.outputs[0] out = out.outputs[0] with self.test_session( - graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess: - val = sess.run(out, {inp: dumm_inp}) + graph=g, config=config, use_gpu=True, force_gpu=True) as sess: + val = None + # Defaults to 2 runs to verify result across multiple runs is same. + for _ in range(num_runs): + new_val = sess.run(out, {inp: input_data}) + self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims, + new_val.shape) + if val is not None: + self.assertAllEqual(new_val, val) + val = new_val return val # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. - def run_calibration(self, gdef, dumm_inp): - """Run given calibration graph multiple times.""" - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - # run over real calibration data here, we are mimicking a calibration - # set of 30 different batches. Use as much calibration data as you want - with self.test_session( - graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess: - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val + def _RunCalibration(self, graph_key, gdef, input_data, config): + """Run calibration on given graph.""" + return self._RunGraph(graph_key, gdef, input_data, config, 30) - def get_trt_graph(self, mode): + def _GetTrtGraph(self, gdef, precision_mode, is_dynamic_op): """Return trt converted graph.""" - if mode in ["FP32", "FP16", "INT8"]: - return trt.create_inference_graph( - input_graph_def=self._original_graph, - outputs=["output"], - max_batch_size=self._input.shape[0], - max_workspace_size_bytes=1 << 25, - precision_mode=mode, # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) - return None - - def testFP32(self): - """Test FP32 conversion. Results should be identical to native case.""" - trt_graph = self.get_trt_graph("FP32") - result = self.run_graph(trt_graph, self._input) - self.assertAllEqual(self._reference, result) - result1 = self.run_graph(trt_graph, self._input) - self.assertAllEqual(result1, result) - - def testFP16(self): - """Test FP16 conversion. Results may be different from native case.""" - trt_graph = self.get_trt_graph("FP16") - result = self.run_graph(trt_graph, self._input) - self.assertAllClose(self._reference, result, rtol=1.e-03) - result1 = self.run_graph(trt_graph, self._input) - self.assertAllEqual(result1, result) - - def testINT8(self): - """Test INT8 conversion. Results may be different from native case.""" - calib_graph = self.get_trt_graph("INT8") - result = self.run_calibration(calib_graph, self._input) - self.assertAllEqual(self._reference, result) - int8_graph = trt.calib_graph_to_infer_graph(calib_graph) - result = self.run_graph(int8_graph, self._input) - self.assertAllClose(self._reference, result, rtol=1.e-03) - result1 = self.run_graph(int8_graph, self._input) - self.assertAllEqual(result1, result) + return trt.create_inference_graph( + input_graph_def=gdef, + outputs=[OUTPUT_NAME], + max_batch_size=self._input.shape[0], + max_workspace_size_bytes=1 << 25, + precision_mode=precision_mode, + minimum_segment_size=2, + is_dynamic_op=is_dynamic_op) + + def _VerifyGraphDef(self, + graph_key, + gdef, + precision_mode=None, + is_calibrated=None, + dynamic_engine=None): + num_engines = 0 + for n in gdef.node: + if n.op == "TRTEngineOp": + num_engines += 1 + self.assertNotEqual("", n.attr["serialized_segment"].s) + self.assertNotEqual("", n.attr["segment_funcdef_name"].s) + self.assertEquals(n.attr["precision_mode"].s, precision_mode) + self.assertEquals(n.attr["static_engine"].b, not dynamic_engine) + if precision_mode == MODE_INT8 and is_calibrated: + self.assertNotEqual("", n.attr["calibration_data"].s) + else: + self.assertEquals("", n.attr["calibration_data"].s) + if precision_mode is None: + self.assertEquals(num_engines, 0) + else: + self.assertEquals(num_engines, + TEST_GRAPHS[graph_key].num_expected_engines) + + def _RunTest(self, graph_key, use_optimizer, precision_mode, + dynamic_infer_engine, dynamic_calib_engine): + assert precision_mode in [MODE_FP32, MODE_FP16, MODE_INT8] + input_gdef = TEST_GRAPHS[graph_key].gdef + self._VerifyGraphDef(graph_key, input_gdef) + + # Get reference result without running trt. + config_no_trt = self._GetConfigProto(False) + print("Running original graph w/o trt, config:\n%s" % str(config_no_trt)) + ref_result = self._RunGraph(graph_key, input_gdef, self._input, + config_no_trt) + + # Run calibration if necessary. + if precision_mode == MODE_INT8: + + calib_config = self._GetConfigProto(use_optimizer, precision_mode, + dynamic_calib_engine) + print("Running calibration graph, config:\n%s" % str(calib_config)) + if use_optimizer: + self.assertTrue(False) + # TODO(aaroey): uncomment this and get infer_gdef when this mode is + # supported. + # result = self._RunCalibration(graph_key, input_gdef, self._input, + # calib_config) + else: + calib_gdef = self._GetTrtGraph(input_gdef, precision_mode, + dynamic_calib_engine) + self._VerifyGraphDef(graph_key, calib_gdef, precision_mode, False, + dynamic_calib_engine) + result = self._RunCalibration(graph_key, calib_gdef, self._input, + calib_config) + infer_gdef = trt.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(graph_key, infer_gdef, precision_mode, True, + dynamic_calib_engine) + self.assertAllClose(ref_result, result, rtol=1.e-03) + else: + infer_gdef = input_gdef + + # Run inference. + infer_config = self._GetConfigProto(use_optimizer, precision_mode, + dynamic_infer_engine) + print("Running final inference graph, config:\n%s" % str(infer_config)) + if use_optimizer: + result = self._RunGraph(graph_key, infer_gdef, self._input, infer_config) + else: + trt_infer_gdef = self._GetTrtGraph(infer_gdef, precision_mode, + dynamic_infer_engine) + self._VerifyGraphDef(graph_key, trt_infer_gdef, precision_mode, True, + dynamic_infer_engine) + result = self._RunGraph(graph_key, trt_infer_gdef, self._input, + infer_config) + self.assertAllClose(ref_result, result, rtol=1.e-03) + + def testIdempotence(self): + # Test that applying tensorrt optimizer or offline conversion tools multiple + # times to the same graph will result in same graph. + # TODO(aaroey): implement this. + pass + + +def GetTests(): + + def _GetTest(g, u, p, i, c): + + def _Test(self): + print("Running test with parameters: graph_key=%s, use_optimizer=%s, " + "precision_mode=%s, dynamic_infer_engine=%s, " + "dynamic_calib_engine=%s" % (g, u, p, i, c)) + self._RunTest(g, u, p, i, c) + + return _Test + + use_optimizer_options = [False, True] + precision_mode_options = [MODE_FP32, MODE_FP16, MODE_INT8] + dynamic_infer_engine_options = [False, True] + dynamic_calib_engine_options = [False, True] + for (graph_key, use_optimizer, precision_mode, + dynamic_infer_engine, dynamic_calib_engine) in itertools.product( + TEST_GRAPHS, use_optimizer_options, precision_mode_options, + dynamic_infer_engine_options, dynamic_calib_engine_options): + if precision_mode == MODE_INT8: + if not dynamic_calib_engine and dynamic_infer_engine: + # TODO(aaroey): test this case, the conversion from static calibration + # engine to dynamic inference engine should be a noop. + continue + if use_optimizer: + # TODO(aaroey): if use_optimizer is True we need to get the inference + # graphdef using custom python wrapper class, which is not currently + # supported yet. + continue + if not dynamic_calib_engine: + # TODO(aaroey): construction of static calibration engine is not + # supported yet. + continue + if dynamic_calib_engine and not dynamic_infer_engine: + # TODO(aaroey): construction of static inference engine using dynamic + # calibration engine is not supported yet. + continue + else: # In non int8 mode. + if dynamic_calib_engine: + # dynamic_calib_engine doesn't affect non-int8 modes, so just let + # related tests run once on dynamic_calib_engine=False. + continue + yield _GetTest(graph_key, use_optimizer, precision_mode, + dynamic_infer_engine, dynamic_calib_engine) if __name__ == "__main__": - googletest.main() + for index, t in enumerate(GetTests()): + setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t) + test.main() diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 46480e99a113afb34702b0ecd71468d4bdc83f98..d6628cd1eb69e46b188de613dee803a2e0dd07d4 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -48,12 +48,53 @@ PyObject* pair_helper(std::pair* in) { } return tuple; } + +struct version_struct{ + int vmajor; + int vminor; + int vpatch; +}; + +PyObject* version_helper(version_struct* in) { + PyObject *tuple(nullptr); + tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch); + if (!tuple) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Tuple creation from version structure failed!"); + } + return NULL; + } + return tuple; +} +/* Define converters for vector */ +template<> +bool _PyObjAs(PyObject *pyobj, int* dest) { + *dest = PyLong_AsLong(pyobj); + return true; +} + +template<> +PyObject *_PyObjFrom(const int& src) { + return PyLong_FromLong(src); +} + %} + +_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); + %typemap(out) std::pair { PyObject *tuple = pair_helper(&$1); if (!tuple) SWIG_fail; $result = tuple; } + +%typemap(out) version_struct { + PyObject *tuple = version_helper(&$1); + if (!tuple) SWIG_fail; + $result = tuple; +} + %{ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -65,6 +106,8 @@ PyObject* pair_helper(std::pair* in) { %unignore tensorflow; %unignore trt_convert; %unignore calib_convert; +%unignore get_linked_tensorrt_version; +%unignore get_loaded_tensorrt_version; %{ @@ -74,7 +117,10 @@ std::pair trt_convert( size_t max_batch_size, size_t max_workspace_size_bytes, int precision_mode, - int minimum_segment_size + int minimum_segment_size, + bool is_dyn_op, + int max_cached_engines, + std::vector cached_engine_batches // Unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -102,11 +148,12 @@ std::pair trt_convert( out_status = "InvalidArgument;Size of the output_names vector is 0"; return std::pair{out_status, ""}; } - tensorflow::GraphDef outGraph; + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &outGraph, precision_mode, minimum_segment_size); + &out_graph, precision_mode, minimum_segment_size, + is_dyn_op, max_cached_engines, cached_engine_batches); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -116,7 +163,7 @@ std::pair trt_convert( return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -128,7 +175,8 @@ std::pair trt_convert( #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } -std::pair calib_convert(string graph_def_string // const tensorflow::GraphDef& +std::pair calib_convert( + string graph_def_string, bool is_dyn_op // unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -147,11 +195,11 @@ std::pair calib_convert(string graph_def_string // const tenso out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; return std::pair{out_status, ""}; } - - tensorflow::GraphDef outGraph; + graph_def_string.resize(0); + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, - &outGraph); + tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph( + graph_def, &out_graph, is_dyn_op); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -161,7 +209,7 @@ std::pair calib_convert(string graph_def_string // const tenso return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -172,15 +220,43 @@ std::pair calib_convert(string graph_def_string // const tenso return std::pair{"9;TensorRT is not enabled!", ""}; #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } + +version_struct get_linked_tensorrt_version() { + // Return the version at the link time. + version_struct s; +#if GOOGLE_CUDA && GOOGLE_TENSORRT + const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + return s; +} +version_struct get_loaded_tensorrt_version(){ + // Return the version from the loaded library. + version_struct s; +#if GOOGLE_CUDA && GOOGLE_TENSORRT + const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + return s; +} + %} -std::pair calib_convert(string graph_def_string); +std::pair calib_convert(string graph_def_string, bool is_dyn_op); std::pair trt_convert(string graph_def_string, std::vector output_names, size_t max_batch_size, size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size); - + int precision_mode, int minimum_segment_size, + bool is_dyn_op, + int max_cached_engines, + std::vector cached_engine_batches); +version_struct get_linked_tensorrt_version(); +version_struct get_loaded_tensorrt_version(); %unignoreall diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index e4963596d38dbe8aea98fddbc67dbbf761c215c8..ec9a7861e7f7ef48344f9b60bda40173c2b31f6e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -184,6 +184,7 @@ py_test( "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index ce96180c9271b95991826c2527cec526c1397ae5..d8089453340e894db6af9fc3a3b360c9512207eb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -30,9 +30,9 @@ from tensorflow.python.estimator import estimator_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import sequential -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 4ec8d26116159fee3ac00581010d1603ac9e19f3..769183f40ad269954dac70db393207c266052144 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -288,7 +288,7 @@ class StateSpaceRegressor(TimeSeriesRegressor): """An Estimator for general state space models.""" def __init__(self, model, state_manager=None, optimizer=None, model_dir=None, - config=None): + config=None, head_type=ts_head_lib.TimeSeriesRegressionHead): """See TimeSeriesRegressor. Uses the ChainingStateManager by default.""" if not isinstance(model, state_space_model.StateSpaceModel): raise ValueError( @@ -301,7 +301,8 @@ class StateSpaceRegressor(TimeSeriesRegressor): state_manager=state_manager, optimizer=optimizer, model_dir=model_dir, - config=config) + config=config, + head_type=head_type) class StructuralEnsembleRegressor(StateSpaceRegressor): @@ -344,7 +345,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): anomaly_prior_probability=None, optimizer=None, model_dir=None, - config=None): + config=None, + head_type=ts_head_lib.TimeSeriesRegressionHead): """Initialize the Estimator. Args: @@ -401,6 +403,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): from tf.train.Optimizer. Defaults to Adam with step size 0.02. model_dir: See `Estimator`. config: See `Estimator`. + head_type: The kind of head to use for the model (inheriting from + `TimeSeriesRegressionHead`). """ if anomaly_prior_probability is not None: filtering_postprocessor = StateInterpolatingAnomalyDetector( @@ -424,4 +428,5 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): model=model, optimizer=optimizer, model_dir=model_dir, - config=config) + config=config, + head_type=head_type) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index a28a5872b850b51630240bdeb3ff22f372613523..8686a803e5bb023bbddb7df3203080fee0e13fea 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -19,11 +19,7 @@ from __future__ import print_function import re -from tensorflow.python.training import training_util -from tensorflow.contrib.layers.python.layers import optimizers - from tensorflow.contrib.timeseries.python.timeseries import feature_keys - from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys @@ -35,8 +31,9 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest from tensorflow.python.summary import summary +from tensorflow.python.training import training_util +from tensorflow.python.util import nest class _NoStatePredictOutput(export_lib.PredictOutput): @@ -102,12 +99,9 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce use_resource=True): model_outputs = self.create_loss(features, mode) - train_op = optimizers.optimize_loss( + train_op = self.optimizer.minimize( model_outputs.loss, - global_step=training_util.get_global_step(), - optimizer=self.optimizer, - # Learning rate is set in the Optimizer object - learning_rate=None) + global_step=training_util.get_global_step()) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, mode=mode, @@ -132,7 +126,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce loss=model_outputs.loss, mode=mode, eval_metric_ops=metrics, - predictions={}) + # needed for custom metrics. + predictions=model_outputs.predictions) def _predict_ops(self, features): """Add ops for prediction to the graph.""" @@ -210,12 +205,12 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope(self._name, "head"): - if labels: + if labels is not None and labels != {}: # for better error messages. raise ValueError( - "The model received a `labels` dictionary, which is " - "not supported. Pass '{}' and '{}' as " - "features.".format(feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) + "The model received a `labels`, which is not supported. " + "Pass '{}' and '{}' as features.".format( + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) del labels features = { name: self._convert_feature_to_tensor(name=name, value=value) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index c606db76a668235ab6a837159b9dec072b5fd801..78c2cec21cf4b6ccf6c314e54de41f3e95466adf 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +from absl.testing import parameterized import numpy import six +from tensorflow.contrib.estimator.python.estimator import extenders from tensorflow.contrib.timeseries.examples import lstm as lstm_example from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys @@ -35,6 +39,7 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import variables @@ -53,9 +58,12 @@ class HeadTest(test.TestCase): model_fn = _stub_model_fn() for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, estimator_lib.ModeKeys.PREDICT]: - with self.assertRaisesRegexp(ValueError, "labels"): + with self.assertRaisesRegexp(ValueError, "received a `labels`"): model_fn(features={}, labels={"a": "b"}, mode=mode) + with self.assertRaisesRegexp(ValueError, "received a `labels`"): + model_fn(features={}, labels=array_ops.zeros([]), mode=mode) + def test_unknown_mode(self): model_fn = _stub_model_fn() with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): @@ -128,6 +136,44 @@ class EvaluationMetricsTests(test.TestCase): coordinator.request_stop() coordinator.join() + def test_custom_metrics(self): + """Tests that the custom metrics can be applied to the estimator.""" + model_dir = self.get_temp_dir() + estimator = ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel(num_features=1, num_units=4), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + model_dir=model_dir) + + def input_fn(): + return { + feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]], + feature_keys.TrainEvalFeatures.VALUES: + numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]]) + } + + def metrics_fn(predictions, features): + # checking that the inputs are properly passed. + predict = predictions["mean"] + target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0] + return { + "plain_boring_metric386": + (math_ops.reduce_mean(math_ops.abs(predict - target)), + control_flow_ops.no_op()), + "fun_metric101": (math_ops.reduce_sum(predict + target), + control_flow_ops.no_op()), + } + + # Evaluation without training is enough for testing custom metrics. + estimator = extenders.add_metrics(estimator, metrics_fn) + evaluation = estimator.evaluate(input_fn, steps=1) + self.assertIn("plain_boring_metric386", evaluation) + self.assertIn("fun_metric101", evaluation) + # The values are deterministic because of fixed tf_random_seed. + # However if they become flaky, remove such exacts comparisons. + self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380) + self.assertAllClose(evaluation["fun_metric101"], 10.435442) + class _StubModel(object): num_features = 3 @@ -274,10 +320,38 @@ class PredictFeatureCheckingTests(test.TestCase): mode=estimator_lib.ModeKeys.PREDICT) -class OneShotTests(test.TestCase): - - def test_one_shot_prediction_head_export(self): - model_dir = self.get_temp_dir() +def _custom_time_series_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel( + num_features=5, num_units=128, + exogenous_feature_columns=exogenous_feature_columns), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + state_manager=state_management.ChainingStateManager(), + head_type=head_type, + model_dir=model_dir) + + +def _structural_ensemble_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.StructuralEnsembleRegressor( + periodicities=None, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + head_type=head_type, + model_dir=model_dir) + + +class OneShotTests(parameterized.TestCase): + + @parameterized.named_parameters( + {"testcase_name": "custom_time_series_regressor", + "estimator_factory": _custom_time_series_regressor}, + {"testcase_name": "structural_ensemble_regressor", + "estimator_factory": _structural_ensemble_regressor}) + def test_one_shot_prediction_head_export(self, estimator_factory): + model_dir = os.path.join(test.get_temp_dir(), str(ops.uid())) categorical_column = feature_column.categorical_column_with_hash_bucket( key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ @@ -285,15 +359,10 @@ class OneShotTests(test.TestCase): "2d_exogenous_feature", shape=(2,)), feature_column.embedding_column( categorical_column=categorical_column, dimension=10)] - estimator = ts_estimators.TimeSeriesRegressor( - model=lstm_example._LSTMModel( - num_features=5, num_units=128, - exogenous_feature_columns=exogenous_feature_columns), - optimizer=adam.AdamOptimizer(0.001), - config=estimator_lib.RunConfig(tf_random_seed=4), - state_manager=state_management.ChainingStateManager(), - head_type=ts_head_lib.OneShotPredictionHead, - model_dir=model_dir) + estimator = estimator_factory( + model_dir=model_dir, + exogenous_feature_columns=exogenous_feature_columns, + head_type=ts_head_lib.OneShotPredictionHead) train_features = { feature_keys.TrainEvalFeatures.TIMES: numpy.arange( 20, dtype=numpy.int64), @@ -308,7 +377,7 @@ class OneShotTests(test.TestCase): num_threads=1, batch_size=16, window_size=16) estimator.train(input_fn=train_input_fn, steps=5) input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() - export_location = estimator.export_savedmodel(self.get_temp_dir(), + export_location = estimator.export_savedmodel(test.get_temp_dir(), input_receiver_fn) graph = ops.Graph() with graph.as_default(): @@ -342,7 +411,7 @@ class OneShotTests(test.TestCase): for output_key, output_value in predict_signature.outputs.items()} output = session.run(fetches, feed_dict=feeds) - self.assertAllEqual((2, 15, 5), output["mean"].shape) + self.assertEqual((2, 15, 5), output["mean"].shape) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index f84ff1bfe9b014733205a8e51b43f79c63b227cb..c08f088be78d1cb1caa18a805844541b3d573fad 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -181,6 +181,7 @@ py_library( ":datasets", ":profiler", ":tpu_py", + "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/core:protos_all_py", @@ -306,3 +307,13 @@ tf_py_test( "//tensorflow/python:framework_test_lib", ], ) + +tf_py_test( + name = "topology_test", + size = "small", + srcs = ["python/tpu/topology_test.py"], + additional_deps = [ + ":tpu", + "//tensorflow/python:framework_test_lib", + ], +) diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index d389050e67f9a9e48b91583e5088058ec4e2832f..06553929dc44ca1f75ce64532a4dcdf1c8aae3eb 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -23,15 +23,23 @@ REGISTER_OP("CrossReplicaSum") .Input("input: T") .Output("output: T") .Attr("T: {bfloat16, float}") + .Attr("group_assignment: list(int) = []") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( An Op to sum inputs across replicated TPU instances. Each -instance supplies its own input, and the output of each is the sum of -all the inputs. +instance supplies its own input. If group_assignment is empty, the output of +each is the sum of all the inputs, otherwise the output of each is the sum of +the inputs belonging to the same group. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1. +Thus we get the outputs: `[A+C, B+D, A+C, B+D]`. input: The local input to the sum. output: The sum of all the distributed inputs. T: The type of elements to be summed. +group_assignment: The list of group ids. `group_assignment[i]` represents the + group id of replica i. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index defed00537c407216703b3bf8651d33cdf311b56..15a2bb17a93212afe9ce5604a28d9dba5825f7d4 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -25,6 +25,7 @@ using shape_inference::ShapeHandle; REGISTER_OP("TPUReplicateMetadata") .Attr("num_replicas: int >= 0") .Attr("topology: string = \"\"") + .Attr("use_tpu: bool = true") .Attr("device_assignment: list(int) = []") .Attr("computation_shape: list(int) = []") .Attr("host_compute_core: list(string) = []") @@ -43,6 +44,27 @@ REGISTER_OP("TPUReplicatedInput") " with other shapes."); } c->set_output(0, cur); + + // If this is a resource, unify the resource shapes. + DataType dtype; + TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); + if (dtype == DT_RESOURCE) { + const std::vector* shapes_and_types = + nullptr; + for (int i = c->num_inputs() - 1; i >= 0; --i) { + if (shapes_and_types) { + // The return value of MergeInputHandleShapesAndTypes indicates + // the shape was refined, not that there was an error. + // TODO(phawkins): there seems to be no way to discover errors. + (void)c->MergeInputHandleShapesAndTypes(i, *shapes_and_types); + } else { + shapes_and_types = c->input_handle_shapes_and_types(i); + } + } + if (shapes_and_types) { + c->set_output_handle_shapes_and_types(0, *shapes_and_types); + } + } return Status::OK(); }) .Doc( @@ -72,6 +94,7 @@ REGISTER_OP("TPUReplicate") .Attr("computation: func") .Attr("num_replicas: int >= 1") .Attr("topology: string = \"\"") + .Attr("use_tpu: bool = true") .Attr("device_assignment: list(int) = []") .Attr("host_compute_core: list(string) = []") .Attr("computation_shape: list(int) = []") @@ -93,6 +116,9 @@ computation: a function containing the computation to run. num_replicas: the number of replicas of the computation to run. topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU topology. +use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU. +Currently, only supports a default placement (computation is placed on GPU +if one is available, and on CPU if not). computation_shape: a [mesh_dimension] array describing the shape of each computation replica in numbers of cores in the TPU mesh. device_assignment: a flattened array with shape diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index dbf1ab6bbf0ddc7429d8e19279451eb862981e0c..38d1c3049ef7185f2f9f448361029d066678cdae 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -49,11 +49,11 @@ tf_cc_binary( ":tpu_profiler_analysis_proto_cc", ":tpu_profiler_proto_cc", ":version", + "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", - "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 816897499b7a49365060c026d60b977990f3ecdc..f80f5652af79d410946971573ae160fdd0b85f6d 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,7 +18,7 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpc++/grpc++.h" +#include "grpcpp/grpcpp.h" #include #include @@ -79,7 +79,9 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.set_repository_root(repository_root); request.set_session_id(session_id); } + request.add_tools("op_profile"); request.add_tools("input_pipeline"); + request.add_tools("memory_viewer"); request.add_tools("overview_page"); *request.mutable_opts() = opts; std::cout << "Limiting the number of trace events to " << kMaxEvents diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 73d941e5e99801d239441e438e9d640478c4e6b6..98cc31f18d2d34765f2c123c3d34207802541036 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -38,6 +38,7 @@ namespace { using ::tensorflow::io::JoinPath; using ::tensorflow::protobuf::util::JsonOptions; using ::tensorflow::protobuf::util::MessageToJsonString; +using ::tensorflow::str_util::EndsWith; using ::tensorflow::strings::StrCat; constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; @@ -46,6 +47,9 @@ constexpr char kJsonTraceFileName[] = "trace.json.gz"; constexpr char kProfilePluginDirectory[] = "plugins/profile/"; constexpr char kProtoTraceFileName[] = "trace"; +constexpr char kFlatProfilerFileName[] = "flat_profiler.pb"; +constexpr char kTfStatsHelperSuffix[] = "tf_stats_helper_result"; + Status WriteGzippedDataToFile(const string& filename, const string& data) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filename, &file)); @@ -107,6 +111,10 @@ Status DumpToolDataToLogDirectory(StringPiece run_dir, const string& host_prefix, const tensorflow::ProfileToolData& tool, std::ostream* os) { + // Don't save the intermediate results for combining the per host tool data. + if (EndsWith(tool.name(), kFlatProfilerFileName) || + EndsWith(tool.name(), kTfStatsHelperSuffix)) + return Status::OK(); string path = JoinPath(run_dir, StrCat(host_prefix, tool.name())); TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); if (os) { diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 508c7a842fb82ec080082d7e7f02f8d2f2a79447..7a5d01cca42351f6d4d8b41d43756560ce7874d3 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -17,12 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl import flags - import os import subprocess import sys - +from absl import flags +from distutils.version import LooseVersion import tensorflow as tf # Cloud TPU Cluster Resolvers @@ -35,26 +34,26 @@ flags.DEFINE_string( None, help='GCE zone where the Cloud TPU is located in. If not specified, we ' 'will attempt to automatically detect the GCE project from metadata.') -flags.DEFINE_string('tpu_name', None, - 'Name of the Cloud TPU for Cluster Resolvers. You must ' - 'specify either this flag or --service_addr.') +flags.DEFINE_string( + 'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must ' + 'specify either this flag or --service_addr.') # Tool specific parameters flags.DEFINE_string( 'service_addr', None, 'Address of TPU profiler service e.g. ' - 'localhost:8466, you must specify either this flag or --tpu_name.') + 'localhost:8466, you must specify either this flag or --tpu.') flags.DEFINE_string( 'workers_list', None, 'The list of worker TPUs that we are about to profile' - ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu_name or ' + ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or ' '--service_addr to profile a subset of tpu nodes. You can also use only' - '--tpu_name and leave this flag unspecified to profile all the tpus.') -flags.DEFINE_string('logdir', None, - 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' - 'gs://tb_bucket') + '--tpu and leave this flag unspecified to profile all the tpus.') +flags.DEFINE_string( + 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' + 'gs://tb_bucket') flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') -flags.DEFINE_integer('num_tracing_attempts', 3, - 'Automatically retry N times when no trace ' - 'event is collected.') +flags.DEFINE_integer( + 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' + 'event is collected.') flags.DEFINE_boolean('include_dataset_ops', True, 'Set to false to profile longer TPU ' 'device traces.') @@ -63,42 +62,50 @@ FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' JOB_NAME = 'worker' + def get_workers_list(cluster_resolver): cluster_spec = cluster_resolver.cluster_spec() task_indices = cluster_spec.task_indices(JOB_NAME) - workers_list = [cluster_spec.task_address(JOB_NAME, i).split(':')[0] - for i in task_indices] + workers_list = [ + cluster_spec.task_address(JOB_NAME, i).split(':')[0] for i in task_indices + ] return ','.join(workers_list) + def run_main(): tf.app.run(main) + def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) + tf_version = tf.__version__ + print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu_name is None: - sys.exit('You must specify either --service_addr or --tpu_name.') + if FLAGS.service_addr is None and FLAGS.tpu is None: + sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None if FLAGS.service_addr is not None: - if FLAGS.tpu_name is not None: - tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring ' - '--tpu_name and using --service_addr.') + if FLAGS.tpu is not None: + tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' + '--tpu and using --service_addr.') service_addr = FLAGS.service_addr else: tpu_cluster_resolver = ( tf.contrib.cluster_resolver.TPUClusterResolver( - [FLAGS.tpu_name], - zone=FLAGS.tpu_zone, - project=FLAGS.gcp_project)) + [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') - workers_list = "" - if FLAGS.workers_list is not None: - workers_list = FLAGS.workers_list - elif tpu_cluster_resolver is not None: - workers_list = get_workers_list(tpu_cluster_resolver) + workers_list = '' + if LooseVersion(tf_version) < LooseVersion('1.9'): + tf.logging.warn('Attempt to profile with legacy support under TensorFlow ' + 'version %s' % tf_version) + else: + if FLAGS.workers_list is not None: + workers_list = FLAGS.workers_list + elif tpu_cluster_resolver is not None: + workers_list = get_workers_list(tpu_cluster_resolver) if not FLAGS.logdir: sys.exit('logdir must be provided.') diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index ebd478fd02295108b9d2454963eb06165828b523..19f088f8b862ce7b114490151f2b6a8c260b8580 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.6.0' +_VERSION = '1.9.0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', @@ -46,7 +46,7 @@ setup( # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index b9ac1a550c87e055fd5d555c346d5ec545bbe634..2b13343efa4e82386cb9259432b854be3ec821f7 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -87,6 +87,8 @@ message StepInfoResult { optional uint64 wait_duration_ps = 5; // The time spent on cross-replica-sum in picoseconds. optional uint64 crs_duration_ps = 6; + // Percentage of unit b time spent on infeed. + optional double unit_b_infeed_percent = 7; } // Result proto for a sequence of steps. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index 7be694e866729c58efae4ccf7932dd929c03ed91..f0fca63db0bca80cdaa27e491b2a03ae2246c007 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -68,7 +68,8 @@ message ProfileRequest { } message ProfileToolData { - // The tool's name which this data is associated. (e.g. "input_pipeline".) + // The file name which this data is associated (e.g. "input_pipeline.json", + // "cluster_xxx.memory_viewer.json"). string name = 1; // The data payload (likely json) for the specific tool. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto index 8b0bbde98e6a1dee8ade789328f3ba0624049562..d3c34bfd490080b86cf3d8b893c550f3a87bbbed 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto @@ -38,6 +38,9 @@ message EnumProfileSessionsAndToolsResponse { message ProfileSessionDataRequest { string repository_root = 1; string session_id = 2; + // Which host the data is associated. if empty, data from all hosts are + // aggregated. + string host_name = 5; // Which tool string tool_name = 3; // Tool's specific parameters. e.g. TraceViewer's viewport etc diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index 618479e1a6ccf26a4103ea1f182b662d7d9998da..1bf49966d12db83f1e6904f8c00453bba278847c 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.6.0" +#define TPU_PROFILER_VERSION "1.9.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index 7ecb36852c53bb74d70ed0f8c70ca1ce860a037a..26016f47dfb36990fd73267c70619878ac3450e5 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -2,7 +2,12 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", + "tf_proto_library_py", +) tf_proto_library( name = "tpu_embedding_config_proto", @@ -22,12 +27,14 @@ tf_proto_library( visibility = ["//visibility:public"], ) -tf_proto_library( +tf_proto_library_py( name = "compilation_result_proto", srcs = [ "compilation_result.proto", ], - cc_api_version = 2, - protodeps = ["//tensorflow/core:protos_all"], + protodeps = tf_additional_all_protos() + [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto index cf52897de3d0fefa55e68a6b889ae9af7b45864a..88585a5bd10fc28aa34bb0de72de970e21b2adb2 100644 --- a/tensorflow/contrib/tpu/proto/compilation_result.proto +++ b/tensorflow/contrib/tpu/proto/compilation_result.proto @@ -3,6 +3,7 @@ syntax = "proto3"; option cc_enable_arenas = true; package tensorflow.tpu; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/core/lib/core/error_codes.proto"; // Describes the result of a TPU compilation. @@ -10,4 +11,7 @@ message CompilationResultProto { // The error message, if any, returned during compilation. error.Code status_code = 1; string status_error_message = 2; + + // HLO proto. + repeated xla.HloProto hlo_protos = 3; } diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 14c63a79763300dcfe8d6c8e09b90f8e9c772358..bf442d9116d2ceca499ffc66258c64b5b94dd881 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -38,9 +38,8 @@ if platform.system() != "Windows": @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): - del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. - return gen_tpu_ops.cross_replica_sum(grad) + return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment")) # This extra type checking exists to give a more helpful error message in # the common case that uint8 and int64 values are infed. Remove when both diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index 2e472a2805f98b15505f56af403aa6223e28c667..d879170b6875b3088d284459b70dc91567e33bab 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -166,11 +166,21 @@ def StreamingFilesDataset(files, return remote_iterator.get_next() def MapFn(unused_input): - return functional_ops.remote_call( + if isinstance(source_dataset.output_types, dtypes.DType): + output_types = [source_dataset.output_types] + elif isinstance(source_dataset.output_types, (list, tuple)): + output_types = source_dataset.output_types + else: + raise ValueError('source dataset has invalid output types') + remote_calls = functional_ops.remote_call( args=[source_handle], - Tout=[dtypes.string], + Tout=output_types, f=LoadingFunc, - target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0] + target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) + if len(remote_calls) == 1: + return remote_calls[0] + else: + return remote_calls with ops.device('/job:%s' % worker_job): output_dataset = dataset_ops.Dataset.range(2).repeat().map( diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index 918cf0ed8e513de0d4207f7d2aac61ad886c8288..b58d05eac56f3586e183333f7c1a3867ee57456c 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -26,6 +26,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -162,6 +164,30 @@ class DatasetsTest(test.TestCase): self.assertEqual(set(all_contents), set(retrieved_values)) + def testArbitraryReaderFuncFromDatasetGenerator(self): + + def my_generator(): + yield (1, [1] * 10) + + def gen_dataset(dummy): + return dataset_ops.Dataset.from_generator( + my_generator, (dtypes.int64, dtypes.int64), + (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10]))) + + dataset = datasets.StreamingFilesDataset( + dataset_ops.Dataset.range(10), filetype=gen_dataset) + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = self._sess.run(get_next) + + self.assertIsInstance(retrieved_values, (list, tuple)) + self.assertEqual(len(retrieved_values), 2) + self.assertEqual(retrieved_values[0], 1) + self.assertItemsEqual(retrieved_values[1], [1] * 10) + def testUnexpectedFiletypeString(self): with self.assertRaises(ValueError): datasets.StreamingFilesDataset( diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 9cc841f7f26be3faf7ea172f0ffdef69ac6bfa98..754154438235f4c5e9e8db996acc8d843ab18431 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -19,15 +19,16 @@ To use, wrap your model with the `keras_support.tpu_model` function. Example usage: ``` -# Must activate before building TPU models -keras_support.setup_tpu_session(master_address) - image = tf.keras.layers.Input(shape=(28, 28, 3), name='image') c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image) flattened = tf.keras.layers.Flatten()(c1) logits = tf.keras.layers.Dense(10, activation='softmax')(flattened) model = tf.keras.Model(inputs=[image], outputs=[logits]) -model = keras_support.tpu_model(model) + +strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8) +model = keras_support.tpu_model(model, + strategy=strategy, + tpu_name_or_address=tpu_name) # Only TF optimizers are currently supported. model.compile(optimizer=tf.train.AdamOptimizer(), ...) @@ -35,9 +36,6 @@ model.compile(optimizer=tf.train.AdamOptimizer(), ...) # `images` and `labels` should be Numpy arrays. Support for tensor input # (e.g. datasets) is planned. model.fit(images, labels) - -# Invoke before shutting down -keras_support.shutdown_tpu_session() ``` """ @@ -48,9 +46,15 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import re +import sys import time +import numpy as np + +from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -61,15 +65,18 @@ from tensorflow.python.client import session as tf_session from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers -from tensorflow.python.keras._impl.keras import models -from tensorflow.python.keras._impl.keras import optimizers as keras_optimizers -from tensorflow.python.keras._impl.keras.layers import embeddings +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import models +from tensorflow.python.keras import optimizers as keras_optimizers +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.layers import embeddings from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name + class TPUEmbedding(embeddings.Embedding): """TPU compatible embedding layer. @@ -93,10 +100,9 @@ class TPUEmbedding(embeddings.Embedding): class TPUModelOp( - collections.namedtuple( - 'TPUModelOp', - ['compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', - 'outfeed_op'])): + collections.namedtuple('TPUModelOp', [ + 'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op' + ])): pass @@ -105,13 +111,69 @@ def _valid_name(tensor_name): return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name) -def _replicated_optimizer(opt, num_replicas): +def _replicated_optimizer(opt): """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" - if num_replicas == 1: - return opt return keras_optimizers.TFOptimizer( - optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer) - ) + optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer)) + + +class TPURewriteContext(object): + """Prepare the environment for a Keras model during `tpu.rewrite`. + + This overrides the default placeholder behaviour to instead refer to a preset + input mapping. Placeholders are unsupported in TPU compiled code, and must + be replaced with explicit inputs or values from the infeed queue. + + Instead of explicitly threading inputs all the way through the Keras codebase, + we override the behavior of the placeholder while compiling and inject the + Tensors from the infeed in place of the placeholder. + + Similarly, as we compile a new sub-graph for each unique shape and execution + mode, we need to override the behavior of an embedded `name_scope` call in + the base Keras layer code. This allows us to re-use the same weights across + many compiles and share a single session/graph. + """ + + def __init__(self, input_map): + self._input_map = input_map + self._default_placeholder = None + self._default_name_scope = None + + def __enter__(self): + + def _placeholder(dtype, shape=None, name=None): # pylint: disable=unused-argument + logging.info('Remapping placeholder for %s', name) + if name in self._input_map: + return self._input_map[name] + else: + logging.info('Default: %s', name) + return self._default_placeholder(dtype, shape, name) + + def _name_scope(name, default_name=None, values=None): + caller_frame = sys._getframe().f_back + caller_obj = caller_frame.f_locals.get('self') + if (caller_obj is not None and + isinstance(caller_obj, base_layer.Layer) and name is not None): + logging.info('Intercepted name_scope: %s', caller_obj) + return variable_scope.variable_scope( + name, default_name, values, reuse=variable_scope.AUTO_REUSE) + + return self._default_name_scope(name, default_name, values) + + self._default_placeholder = array_ops.placeholder + self._default_name_scope = ops.name_scope + self._default_make_variable = base_layer.make_variable + + array_ops.placeholder = _placeholder + ops.name_scope = _name_scope + base_layer.make_variable = variable_scope.get_variable + logging.info('Overriding default placeholder.') + return + + def __exit__(self, exc_type, exc_val, exc_tb): + array_ops.placeholder = self._default_placeholder + ops.name_scope = self._default_name_scope + base_layer.make_variable = self._default_make_variable class TPUFunction(object): @@ -126,19 +188,18 @@ class TPUFunction(object): instead of being injected as `feed_dict` items or fetches. """ - def __init__(self, model, execution_mode, num_replicas=1): + def __init__(self, model, execution_mode, strategy): self.model = model self.execution_mode = execution_mode + self._strategy = strategy self._compilation_cache = {} - self.num_replicas = num_replicas + self._cloned_model = None def _specialize_model(self, input_specs): """Specialize `self.model` (a Keras model) for the given input shapes.""" # Re-create our input and output layers inside our subgraph. They will be # attached to the true computation when we clone our model in `tpu_fn`. - K.set_learning_phase( - self.execution_mode == model_fn_lib.ModeKeys.TRAIN - ) + K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN) # functools.partial and callable objects are not supported by tpu.rewrite def _model_fn(): @@ -164,23 +225,22 @@ class TPUFunction(object): infeed_tensors)) tpu_targets = [] - tpu_inputs = [] + tpu_input_map = {} # Sort infeed outputs into inputs and labels for calling our Keras model. for tensor, layer in zip(infeed_tensors, infeed_layers): if layer in self.model._input_layers: - tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor)) + tpu_input_map[layer.name] = tensor if layer in self.model._output_layers: tpu_targets.append(tensor) - # Call our model with our infeed inputs (re-using the weights). - model_outputs = self.model(tpu_inputs) - child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs) + # Clone our CPU model, running within the TPU device context. + with TPURewriteContext(tpu_input_map): + self._cloned_model = models.clone_model(self.model) if is_training or is_test: - child_model.compile( - optimizer=_replicated_optimizer(self.model.optimizer, - self.num_replicas), + self._cloned_model.compile( + optimizer=_replicated_optimizer(self.model.optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, metrics=self.model.metrics, @@ -190,37 +250,37 @@ class TPUFunction(object): # Compute our outfeed depending on the execution mode if is_training: - child_model._make_train_function() + self._cloned_model._make_train_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.train_function.outputs + for tensor in self._cloned_model.train_function.outputs ] return [ - child_model.train_function.updates_op, + self._cloned_model.train_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - child_model.train_function.outputs, + self._cloned_model.train_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - child_model._make_test_function() + self._cloned_model._make_test_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.test_function.outputs + for tensor in self._cloned_model.test_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.test_function.outputs, + self._cloned_model.test_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: - child_model._make_predict_function() + self._cloned_model._make_predict_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.predict_function.outputs + for tensor in self._cloned_model.predict_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.predict_function.outputs, + self._cloned_model.predict_function.outputs, name='outfeed-enqueue-predict', ) ] @@ -235,7 +295,7 @@ class TPUFunction(object): # `execute op` replicates `_model_fn` `num_replicas` times, with each shard # running on a different logical core. compile_op, execute_op = tpu.split_compile_and_replicate( - _model_fn, inputs=[[]] * self.num_replicas) + _model_fn, inputs=[[]] * self._strategy.num_towers) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. @@ -243,7 +303,7 @@ class TPUFunction(object): outfeed_op = [] shard_infeed_tensors = [] - for shard_id in range(self.num_replicas): + for shard_id in range(self._strategy.num_towers): with ops.device('/device:TPU:%d' % shard_id): infeed_tensors = [] for spec in input_specs: @@ -254,32 +314,35 @@ class TPUFunction(object): name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) shard_infeed_tensors.append(infeed_tensors) - infeed_op.append(tpu_ops.infeed_enqueue_tuple( - infeed_tensors, [spec.shape for spec in input_specs], - name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) + infeed_op.append( + tpu_ops.infeed_enqueue_tuple( + infeed_tensors, [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) - outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple( - dtypes=[spec.dtype for spec in self._outfeed_spec], - shapes=[spec.shape for spec in self._outfeed_spec], - name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) + outfeed_op.extend( + tpu_ops.outfeed_dequeue_tuple( + dtypes=[spec.dtype for spec in self._outfeed_spec], + shapes=[spec.shape for spec in self._outfeed_spec], + name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) return TPUModelOp( - compile_op, execute_op, infeed_tensors=shard_infeed_tensors, - infeed_op=infeed_op, outfeed_op=outfeed_op) + compile_op, + execute_op, + infeed_tensors=shard_infeed_tensors, + infeed_op=infeed_op, + outfeed_op=outfeed_op) def _test_model_compiles(self, tpu_model_ops): """Verifies that the given TPUModelOp can be compiled via XLA.""" - session = K.get_session() - logging.info('Started compiling') start_time = time.clock() - result = session.run(tpu_model_ops.compile_op) + result = K.get_session().run(tpu_model_ops.compile_op) proto = tpu_compilation_result.CompilationResultProto() proto.ParseFromString(result) if proto.status_error_message: - raise RuntimeError( - 'Compilation failed: {}'.format(proto.status_error_message)) + raise RuntimeError('Compilation failed: {}'.format( + proto.status_error_message)) end_time = time.clock() logging.info('Finished compiling. Time elapsed: %s secs', @@ -296,17 +359,20 @@ class TPUFunction(object): Returns: List of lists containing the input to feed to each TPU shard. """ - if self.num_replicas == 1: + if self._strategy.num_towers == 1: return [inputs] batch_size = inputs[0].shape[0] - assert batch_size % self.num_replicas == 0, ( - 'batch_size must be divisible by num_replicas') - shard_size = batch_size // self.num_replicas + assert batch_size % self._strategy.num_towers == 0, ( + 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % + (batch_size, self._strategy.num_towers) + ) + shard_size = batch_size // self._strategy.num_towers input_list = [] - for index in range(self.num_replicas): - shard_inputs = [x[index * shard_size:(index + 1) * shard_size] - for x in inputs] + for index in range(self._strategy.num_towers): + shard_inputs = [ + x[index * shard_size:(index + 1) * shard_size] for x in inputs + ] input_list.append(shard_inputs) return input_list @@ -343,12 +409,15 @@ class TPUFunction(object): shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs]) if shape_key not in self._compilation_cache: - logging.info('New input shapes; (re-)compiling: mode=%s, %s', - self.execution_mode, input_specs) - new_tpu_model_ops = self._specialize_model(input_specs) - self._compilation_cache[shape_key] = new_tpu_model_ops - self._test_model_compiles(new_tpu_model_ops) - + with self.model.tpu_session(): + logging.info('New input shapes; (re-)compiling: mode=%s, %s', + self.execution_mode, input_specs) + new_tpu_model_ops = self._specialize_model(input_specs) + self._compilation_cache[shape_key] = new_tpu_model_ops + self._test_model_compiles(new_tpu_model_ops) + + # Initialize our TPU weights on the first compile. + self.model._initialize_weights(self._cloned_model) tpu_model_ops = self._compilation_cache[shape_key] infeed_dict = {} @@ -357,58 +426,83 @@ class TPUFunction(object): for tensor, value in zip(infeed_tensors, inputs): infeed_dict[tensor] = value - session = K.get_session() - _, _, outfeed_outputs = session.run([ - tpu_model_ops.infeed_op, tpu_model_ops.execute_op, - tpu_model_ops.outfeed_op - ], infeed_dict) + with self.model.tpu_session() as session: + _, _, outfeed_outputs = session.run([ + tpu_model_ops.infeed_op, tpu_model_ops.execute_op, + tpu_model_ops.outfeed_op + ], infeed_dict) # TODO(xiejw): Decide how to reduce outputs, or just discard all but first. - return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas] - - -@experimental -def setup_tpu_session(master): - """Initializes and returns a Keras/TF session connected the TPU `master`.""" - session = tf_session.Session( - target=master, config=config_pb2.ConfigProto(isolate_session_state=True)) - K.set_session(session) - K.get_session().run(tpu.initialize_system()) - return session - - -@experimental -def shutdown_tpu_session(session=None): - """Shutdown the TPU attached to session. + if self.execution_mode == model_fn_lib.ModeKeys.PREDICT: + outputs = [[]] * len(self._outfeed_spec) + outputs_per_replica = len(self._outfeed_spec) - This should be called to cleanly shut down the TPU system before the client - exits. - - Args: - session: Session to shutdown, or None to use the default session. - - Returns: - - """ - if session is None: - session = K.get_session() + for i in range(self._strategy.num_towers): + output_group = outfeed_outputs[ + i * outputs_per_replica:(i+1) * outputs_per_replica + ] + for j in range(outputs_per_replica): + outputs[j].append(output_group[j]) - session.run(tpu.shutdown_system()) + return [np.concatenate(group) for group in outputs] + else: + return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers] class KerasTPUModel(models.Model): """TPU compatible Keras model wrapper.""" - def __init__(self, inputs, outputs, name, replicas=1): - super(models.Model, self).__init__( - inputs=inputs, - outputs=outputs, - name=name, + def __init__(self, cpu_model, tpu_name_or_address, strategy): + super(models.Model, self).__init__( # pylint: disable=bad-super-call + inputs=cpu_model.inputs, + outputs=cpu_model.outputs, + name=cpu_model.name, ) + self.predict_function = None self.test_function = None self.train_function = None - self.replicas = replicas + self._strategy = strategy + + self._tpu_name_or_address = tpu_name_or_address + self._cpu_model = cpu_model + self._tpu_model = None + self._tpu_weights_initialized = False + self._graph = ops.Graph() + + cluster_resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu_name_or_address) + cluster_spec = cluster_resolver.cluster_spec() + self._session = tf_session.Session( + graph=self._graph, + target=cluster_resolver.master(), + config=config_pb2.ConfigProto(isolate_session_state=True)) + + if cluster_spec: + self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + + with self._graph.as_default(): + self._session.run(tpu.initialize_system()) + + # If the input CPU model has already been compiled, compile our TPU model + # immediately. + if self._cpu_model.optimizer: + self.compile( + self._cpu_model.optimizer, + self._cpu_model.loss, + self._cpu_model.metrics, + self._cpu_model.loss_weights, + self._cpu_model.sample_weight_mode, + self._cpu_model.weighted_metrics, + self._cpu_model.target_tensors, + ) + + def get_config(self): + return { + 'cpu_model': self._cpu_model, + 'tpu_name_or_address': self._tpu_name_or_address, + 'strategy': self._strategy, + } def compile(self, optimizer, @@ -430,6 +524,11 @@ class KerasTPUModel(models.Model): sample_weight_mode, weighted_metrics, target_tensors, **kwargs) + if not self._cpu_model.optimizer: + self._cpu_model.compile(optimizer, loss, metrics, loss_weights, + sample_weight_mode, weighted_metrics, + target_tensors, **kwargs) + # Keras optimizers are not compatible with TPU rewrite if not isinstance(self.optimizer, keras_optimizers.TFOptimizer): raise ValueError( @@ -437,37 +536,90 @@ class KerasTPUModel(models.Model): def _make_train_function(self): if not self.train_function: - self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN, - num_replicas=self.replicas) + self.train_function = TPUFunction( + self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy) return self.train_function def _make_test_function(self): if not self.test_function: - self.test_function = TPUFunction(self, model_fn_lib.ModeKeys.EVAL) + self.test_function = TPUFunction( + self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy) return self.test_function def _make_predict_function(self): if not self.predict_function: - self.predict_function = TPUFunction(self, model_fn_lib.ModeKeys.PREDICT) + self.predict_function = TPUFunction( + self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy) return self.predict_function - def cpu_model(self): - cpu_model = models.Model( - inputs=self.inputs, - outputs=self.outputs, - name=self.name, - ) + def _initialize_weights(self, cloned_model): + """Initialize TPU weights. - if self.optimizer: - cpu_model.compile( - optimizer=self.optimizer, - loss=self.loss, - metrics=self.metrics, - loss_weights=self.loss_weights, - ) + This is called on the first compile of the TPU model (first call to + fit/predict/evaluate). - return cpu_model + Args: + cloned_model: `keras.Model`, TPU model to initialize. + """ + if self._tpu_weights_initialized: + return + + self._tpu_model = cloned_model + self._tpu_weights_initialized = True + + weights = self._cpu_model.get_weights() + with self.tpu_session(): + logging.info('Setting weights on TPU model.') + cloned_model.set_weights(weights) + + def sync_to_cpu(self): + """Copy weights from the CPU, returning a synchronized CPU model.""" + if self._tpu_weights_initialized: + with self.tpu_session(): + logging.info('Copying TPU weights to the CPU') + tpu_weights = self._tpu_model.get_weights() + + self._cpu_model.set_weights(tpu_weights) + + return self._cpu_model + + def get_weights(self): + return self.sync_to_cpu().get_weights() + + def save_weights(self, *args, **kw): + return self.sync_to_cpu().save_weights(*args, **kw) + + def save(self, *args, **kw): + return self.sync_to_cpu().save(*args, **kw) + + def set_weights(self, weights): + # We may not have a TPU model available if we haven't run fit/predict, so + # we can't directly set the TPU weights here. + # Instead, reset CPU model weights and force TPU re-initialization at the + # next call. + self._cpu_model.set_weights(weights) + self._tpu_weights_initialized = False + + @contextlib.contextmanager + def tpu_session(self): + """Yields a TPU session and sets it as the default Keras session.""" + with self._graph.as_default(): + default_session = K.get_session() + # N.B. We have to call `K.set_session()` AND set our session as the + # TF default. `K.get_session()` surprisingly does not return the value + # supplied by K.set_session otherwise. + K.set_session(self._session) + with self._session.as_default(): + yield self._session + K.set_session(default_session) + + def shutdown(self): + logging.info('Shutting down TPU session.') + with self.tpu_session() as session: + session.run(tpu.shutdown_system()) + + self._session.close() def _validate_shapes(model): @@ -504,26 +656,8 @@ Output shape: %(output_shape)s @experimental -def tpu_model(model, replicas=None): - """Runs a model on TPU(s). - - Usage: - ``` - a = Input(shape=(32,)) - b = Dense(32)(a) - model = Model(inputs=a, outputs=b) - - model = keras_support.tpu_model(model) - model.compile( - optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), - ...) - ``` - - If `replicas` is set, replicates the model computation on all TPU cores. The - model computation is replicated `num_replicas` times; each shard will run on a - different TPU core. - - Limitation: Currently, replication is only supported for training. +def tpu_model(model, tpu_name_or_address=None, strategy=None): + """Copy `model` along with weights to the TPU. Returns a TPU model. Usage: ``` @@ -531,17 +665,24 @@ def tpu_model(model, replicas=None): b = Dense(32)(a) model = Model(inputs=a, outputs=b) - model = keras_support.tpu_model(model, replicas=2) + # If `num_cores_per_host` is greater than one, batch parallelism will be used + # to run on multiple TPU cores. + strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8) + model = keras_support.tpu_model(model, strategy) model.compile( optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), ...) + model.shutdown() ``` Args: model: A `KerasTPUModel`. - replicas: (Optional) Int, number of TPU cores which to create model - replicas. If `None`, the model runs on single core only, i.e., no - replication. + tpu_name_or_address: A string that is either the name of the Cloud TPU, + the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the + Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will + examine the environment to determine a potential Cloud TPU to use. + strategy: `TPUDistributionStrategy`. The strategy to use for replicating + model across multiple TPU cores. Returns: A new `KerasTPUModel` instance. @@ -550,7 +691,9 @@ def tpu_model(model, replicas=None): # TODO(xiejw): Validate TPU model. TPUModel only? # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? # TODO(xiejw): Adds reduction option. - replicas = 1 if replicas is None else replicas + if strategy is None: + strategy = TPUDistributionStrategy(num_cores_per_host=1) return KerasTPUModel( - inputs=model.inputs, outputs=model.outputs, name=model.name, - replicas=replicas) + cpu_model=model, + tpu_name_or_address=tpu_name_or_address, + strategy=strategy) diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index cda9a63f204ed686b527c95dd5b4fd7786ac60cf..1fb26e701a392d5ef3bc40d5772d4541fa38f773 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -55,8 +55,9 @@ class Topology(object): rank 3 numpy int32 array that describes a valid coordinate mapping. """ + self._serialized = serialized + if serialized: - self._serialized = serialized self._parse_topology(serialized) else: self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) @@ -131,7 +132,7 @@ class Topology(object): proto.mesh_shape[:] = list(self._mesh_shape) proto.num_tasks = self._device_coordinates.shape[0] proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates = list(self._device_coordinates.flatten()) + proto.device_coordinates.extend(list(self._device_coordinates.flatten())) self._serialized = proto.SerializeToString() return self._serialized diff --git a/tensorflow/contrib/tpu/python/tpu/topology_test.py b/tensorflow/contrib/tpu/python/tpu/topology_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e67fdb263aa48a37f65c3623365ebcf8f98bebd4 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/topology_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for topology.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import topology + +from tensorflow.python.platform import test + + +class TopologyTest(test.TestCase): + + def testSerialization(self): + """Test if the class is able to generate serialzied string.""" + original_topology = topology.Topology( + mesh_shape=[1, 1, 2], + device_coordinates=[[[0, 0, 0], [0, 0, 1]]], + ) + serialized_str = original_topology.serialized() + new_topology = topology.Topology(serialized=serialized_str) + + # Make sure the topology recovered from serialized str is same as the + # original topology. + self.assertAllEqual( + original_topology.mesh_shape, new_topology.mesh_shape) + self.assertAllEqual( + original_topology.device_coordinates, new_topology.device_coordinates) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index c8f24ed01d13a1325ed3d77d1d91d4df79b0e379..6a64893d9abcd64360554ab00502cdf360b820b6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -21,6 +21,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function @@ -125,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas): + def __init__(self, name, num_replicas, pivot): + """Builds a new TPUReplicateContext. + + Args: + name: a unique name for the context, used to populate the `_tpu_replicate` + attribute. + num_replicas: an integer that gives the number of replicas for the + computation. + pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -137,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] + self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -213,19 +227,26 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): class FakeOp(object): """A helper class to determine the current device. - Supports only the device set/get methods needed to run the + Supports only the type and device set/get methods needed to run the graph's _apply_device_function method. """ def __init__(self): self._device = "" + @property + def type(self): + return "FakeOp" + @property def device(self): return self._device def _set_device(self, device): - self._device = device.to_string() + if isinstance(device, pydev.DeviceSpec): + self._device = device.to_string() + else: + self._device = device if self._outside_compilation_cluster: raise NotImplementedError("Cannot nest outside_compilation clusters") @@ -261,9 +282,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() - def Exit(self): - super(TPUReplicateContext, self).Exit() - def HostComputeCore(self): return self._host_compute_core @@ -299,10 +317,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not control_inputs: + # pylint: disable=protected-access + op._add_control_input(self.GetControlPivot()) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + def AddValue(self, val): + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + result = val + self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + return result def AddInnerOp(self, op): @@ -318,17 +390,30 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False + + def GetControlPivot(self): + return self._pivot + -def outside_compilation(computation, args=None): +def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. Args: computation: A Python function that builds the computation to place on the host. - args: Inputs to pass to computation. + *args: the positional arguments for the computation. + **kwargs: the keyword arguments for the computation. + Returns: The Tensors returned by computation. """ + args = [] if args is None else args graph = ops.get_default_graph() # If we are in a TPUReplicateContext, signal that we are now @@ -340,7 +425,7 @@ def outside_compilation(computation, args=None): context._EnterOutsideCompilationScope() # pylint: disable=protected-access context = context.outer_context - retval = computation(*args) + retval = computation(*args, **kwargs) # If we are in a TPUReplicateContext, signal that we are no longer # outside_compilation @@ -394,7 +479,8 @@ def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, - name=None): + name=None, + use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile @@ -417,6 +503,9 @@ def split_compile_and_replicate(computation, only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. + use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU + backends. Currently, only supports a default placement (computation is + placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. @@ -497,26 +586,34 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) + pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") + context = TPUReplicateContext( + name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, **metadata_kwargs) + num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): - # The EncapsulateTPUComputations rewrite needs to identify the - # replicated arguments inside each computation. Adds identity operators - # tagged with an attribute _tpu_replicated_input to identify the - # replicated inputs. + # For backward compatibility reasons, we tag replicated inputs with the + # _tpu_replicated_input attribute. This does nothing and exists only for + # backward compatibility. + # TODO(phawkins): delete the attr_scope after 6/28/2018. # pylint: disable=protected-access - with graph._attr_scope({"_tpu_replicated_input": - attr_value_pb2.AttrValue(b=True)}): + with graph._attr_scope({ + "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True) + }): + # Add identity ops so even unused inputs are "consumed" by the + # computation. This is to avoid orphaned TPUReplicatedInput nodes. + # TODO(phawkins): consider instead pruning unused TPUReplicatedInput + # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs)] + for i, x in enumerate(computation_inputs) + ] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the @@ -539,10 +636,16 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) + # Append `no_op` here so that fetching any return value of this function + # will trigger TPUExecute node. + outputs += (control_flow_ops.no_op(),) try: with ops.device(core(0)): outputs = [ @@ -574,6 +677,7 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() @@ -590,10 +694,13 @@ def split_compile_and_replicate(computation, for i in xrange(output_arity)] with ops.control_dependencies([metadata]): - compile_status = tpu_ops.tpu_compilation_result() - op = compile_status.op - attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) - op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + if use_tpu: + compile_status = tpu_ops.tpu_compilation_result() + op = compile_status.op + attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) + op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + else: + compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: @@ -860,3 +967,152 @@ def rewrite(computation, device_assignment=device_assignment, name=name)[0] # pylint: enable=indexing-exception + + # Operations that indicate some error in the user's inference graph. +_BLACKLISTED_INFERENCE_OPS = set([ + "ReadVariableOp", + "AssignVariableOp", + "AssignAddVariableOp", + "AssignSubVariableOp", + "VarHandleOp", + "Variable", + "VariableV2", +]) + + +class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside a TPU inference computation. + + The primary role of `TPUReplicateContext` is to sanity check operators inside + a tpu.rewrite_for_inference() computation. + """ + + def __init__(self, name): + super(_TPUInferenceContext, self).__init__() + self._name = name + + def AddOp(self, op): + self._AddOpInternal(op) + + def _AddOpInternal(self, op): + # pylint: disable=protected-access + if op.type in _BLACKLISTED_INFERENCE_OPS: + raise NotImplementedError( + "Operation of type %s (%s) is not supported on the TPU for inference." + " Execution will fail if this op is used in the graph. Make sure your" + " variables are using variable_scope." % (op.type, op.name)) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + result = val + if self._outer_context: + result = self._outer_context.AddValue(val) + return result + + def AddInnerOp(self, op): + self._AddOpInternal(op) + + @property + def grad_state(self): + return None + + +@experimental +def validate_inference_rewrite_for_variables(graph): + """Validates whether rewrite_for_inference() 'worked' for variables. + + The rewrite_for_inference() method is supposed to append + GuaranteeConstOps after ReadVariableOps, but this mechanism works only + if you are using tf.get_variable() to create and access variables in your + tpu computation. This validation method can be called immediately after + calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps + where added to the graph. + + Typical usages: + tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) + + tpu.validate_inference_rewrite_for_variables(sess.graph) + + Args: + graph: The graph which needs to be validated. + Raises: + RuntimeError: if validation failed. + """ + if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]): + raise RuntimeError( + "No GuaranteeConst ops found in the graph after " + "running tpu.rewrite_for_inference(...). Please " + "check that you are using tf.get_variable() to " + "create and access variables in your tpu " + "computation.") + + +@experimental +def rewrite_for_inference(computation, + inputs=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Rewrites `computation` for inference on a TPU system. + + Other than 'rewriting' the computation to run on a TPU, if using variables + in your computation, it moves the ReadVariableOps outside the TPU + computation, and adds GuaranteeConst ops just after the ReadVariableOps. + This mechanism works only if you are using tf.get_variable() to create and + access variables in your tpu computation. You can validate whether + this worked, by calling validate_inference_rewrite_for_variables() method + immediately after this method to check whether GuaranteeConstOps where + added to the graph. + + Args: + computation: A Python function that builds a computation to apply + to the input. If the function takes n inputs, 'inputs' should be + a list of n tensors. If the function returns m outputs, rewrite + will return a list of m tensors. + inputs: A list of input tensors or `None` (equivalent to an empty list). + infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple + of arguments as inputs to `computation`. + device_assignment: if not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. May be omitted for a single-core computation, in which + case the core attached to task 0, TPU device 0 is used. + name: The name of the operator. + Returns: + A list of output tensors. + """ + + def guarantee_const_getter(getter, name, *args, **kwargs): + with ops.control_dependencies(None): + return array_ops.guarantee_const( + getter(name, *args, **kwargs), name=name + "/GuaranteeConst") + + def wrapped_computation(*args, **kwargs): + """Execute computation under `_TPUInferenceContext`.""" + context = _TPUInferenceContext( + name=ops.get_default_graph().unique_name("rewrite_for_inference")) + try: + context.Enter() + + vscope = variable_scope.get_variable_scope() + prev_custom_getter = vscope.custom_getter + prev_caching_device = vscope.caching_device + vscope.set_custom_getter(guarantee_const_getter) + vscope.set_caching_device(lambda op: op.device) + + result = computation(*args, **kwargs) + + vscope.set_custom_getter(prev_custom_getter) + vscope.set_caching_device(prev_caching_device) + finally: + context.Exit() + return result + + # pylint: disable=undefined-variable + return rewrite( + wrapped_computation, + inputs=inputs, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name) + # pylint: enable=undefined-variable diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 5b9aeaa8797b92b4cc596744812f440607054dce..aec59f3885ca7a2046c24ce5b94917ad6c3693e7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -92,6 +92,19 @@ class TPUContext(object): """ return self._internal_ctx.num_replicas + @property + def num_hosts(self): + """The number of hosts for the TPU system.""" + return self._internal_ctx.num_hosts + + @property + def num_of_replicas_per_host(self): + """The number of replicas for each host.""" + if self._internal_ctx.model_parallelism_enabled: + raise ValueError( + 'num_of_replicas_per_host is not supported for model_parallelism') + return self._internal_ctx.num_of_replicas_per_host + def device_for_replica(self, replica_id): """Returns the tuple of (CPU device and device ordinal) for replica. @@ -384,9 +397,7 @@ class _InternalTPUContext(object): # On TPU if self.is_input_sharded_per_core() or ( self.is_input_per_host_with_iterators()): - # We prohibit per core input sharding for the model parallelism case, - # therefore it is safe to use num_cores here. - return global_batch_size // self.num_cores + return global_batch_size // self.num_replicas else: return global_batch_size // self.num_hosts @@ -484,25 +495,27 @@ class _InternalTPUContext(object): return _placement_function - @property - def tpu_ordinal_function(self): + def tpu_ordinal_function(self, host_id): """Returns the TPU ordinal fn.""" - def _tpu_ordinal_function(index): + def _tpu_ordinal_function(shard_index_in_host): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: - index: the shard index + shard_index_in_host: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ if self.model_parallelism_enabled: - return self.device_assignment.tpu_ordinal(replica=index) + # We put both enqueue/dequeue ops at tpu.core(0) in each replica. + replica = self.device_assignment.lookup_replicas( + host_id, (0, 0, 0))[shard_index_in_host] + return self.device_assignment.tpu_ordinal(replica=replica) else: - return index % self.num_of_cores_per_host + return shard_index_in_host % self.num_of_cores_per_host return _tpu_ordinal_function diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 808545bb56134ac1b5da32a7c687119b8cfe6f97..49cd318b8956369f49d77d3cb1b030e171fa07aa 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,6 +46,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -61,6 +63,7 @@ from tensorflow.python.ops import summary_ops_v2 as contrib_summary from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation @@ -71,17 +74,24 @@ from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect + _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' +_USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' +_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +# Ideally _USE_TPU_KEY should be reserved as well. However there are already +# models that make use of this key, thus it can not be reserved now to prevent +# breakage. In the long run, we would like to mitigate this by migrating models +# off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -117,6 +127,33 @@ def _create_global_step(graph): def _create_or_get_iterations_per_loop(): + """Creates or gets the iterations_per_loop variable. + + In TPUEstimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The iterations of the loop are + specified by this variable, which adjusts its value on the CPU after each TPU + program execution and before the next TPU execution. + + The purpose of using a variable, rather then a constant, is to allow + TPUEstimator adapt the TPU training iterations according to the final steps + specified by users. For example, if the user sets the iterations_per_loop as 4 + in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop + variable will have the following value before each TPU training. + + - 1-th TPU execution: iterations_per_loop = 4 + - 2-th TPU execution: iterations_per_loop = 4 + - 3-th TPU execution: iterations_per_loop = 2 + + As model_fn increases the global step once per train_op invocation, the global + step is 10 after all TPU executions, matching the steps=10 inputs passed in by + users. + + Returns: + A TF non-trainable resource variable. + + Raises: + RuntimeError: If multi iterations_per_loop variables were found. + """ graph = ops.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) @@ -179,8 +216,8 @@ class _SIGNAL(object): class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and - 'export_outputs`. + See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and + `export_outputs`. For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where `metric_fn` runs on CPU to generate metrics and `tensors` represents the @@ -194,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - dict. `metric_fn` takes the `tensors` and returns a dict from metric string + a dict. `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the `eval_metrics`. @@ -383,20 +420,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return def _cancel_session(): - # Close the session to avoid the main thread from hanging. If input - # pipeline triggers any error, the infeed thread dies but the main thread - # for TPU computation waits for the infeed enqueue forever. Close the - # Session to cancel the main thread Session.run execution. - # - # We sleep for a few seconds before closing to give some time - # for the TPU compilation error, if any, propagating, from TPU to CPU - # host. Compilation errors should be reported by the main thread so that - # the program can be interrupted and users can take action. Due to a race - # condition, the infeed thread might see an error first. Closing the - # session here immediately would result in a session cancellation - # exception in the main thread, instead of the expected compile error. - # User code that depends on having the proper exception type will - # therefore be confused. + """Close the session to avoid the main thread from hanging. + + If input pipeline triggers any error, the infeed thread dies but the main + thread for TPU computation waits for the infeed enqueue forever. Close the + Session to cancel the main thread Session.run execution. + + We sleep for a few seconds before closing to give some time for the TPU + compilation error, if any, propagating, from TPU to CPU host. Compilation + errors should be reported by the main thread so that the program can be + interrupted and users can take action. Due to a race condition, the + infeed thread might see an error first. Closing the session here + immediately would result in a session cancellation exception in the main + thread, instead of the expected compile error. User code that depends on + having the proper exception type will therefore be confused. + """ time.sleep(5) # If the main session is still running, the infeed/outfeed errors are @@ -631,6 +669,7 @@ def generate_per_core_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A fn returns enqueue_ops.""" @@ -666,7 +705,7 @@ def generate_per_core_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue @@ -701,21 +740,18 @@ def generate_per_host_enqueue_ops_fn_for_host( if is_dataset: hooks.append(inputs.dataset_initializer_hook()) - # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the - # _InternalTPUContext.tpu_ordinal_function. We should either introduce another - # abstraction or a different helper method. - def _tpu_ordinal_function_impl(shard_index_in_host): - # We put both enqueue/dequeue op at tpu.core(0) in each replica. - replica = ctx.device_assignment.lookup_replicas( - host_id, (0, 0, 0))[shard_index_in_host] - return ctx.device_assignment.tpu_ordinal(replica=replica) - - if ctx.model_parallelism_enabled: - tpu_ordinal_function = _tpu_ordinal_function_impl - else: - tpu_ordinal_function = None + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): + """A Fn returning the TPU infeed enqueue ops. + + By providing as a Fn, it can be invoked inside the tf.while_loop such that + the input pipeline for multiple iterations can be executed by one + Session.run call. + + Returns: + list of dict of ops. + """ with ops.device(device): num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a @@ -740,7 +776,7 @@ def generate_per_host_enqueue_ops_fn_for_host( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function)) + tpu_ordinal_function=tpu_ordinal_function_impl)) if signals is None: return per_host_enqueue_ops else: @@ -774,6 +810,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.') hooks.append(inputs.dataset_initializer_hook()) + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """Generates the per_host enqueue ops.""" @@ -804,7 +841,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -1090,15 +1127,21 @@ class _InputPipeline(object): return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): - # Perform some sanity checks to log user friendly information. We should - # error out to give users better error message. But, if - # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - # user code, so, log a warning. + """Validates the input pipeline. + + Perform some sanity checks to log user friendly information. We should + error out to give users better error message. But, if + _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + user code, so, log a warning. + + Raises: + RuntimeError: If the validation failed. + """ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/programmers_guide/datasets for ' + 'https://www.tensorflow.org/guide/datasets for ' 'instructions.') if _WRAP_INPUT_FN_INTO_WHILE_LOOP: raise RuntimeError(err_msg) @@ -1264,13 +1307,11 @@ class _ModelFnWrapper(object): 'estimator_spec used by TPU prediction must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) + self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) + captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) to_record = {} identity_fn = lambda **kwargs: kwargs - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(tpu_estimator_spec.predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] to_record['signals'] = [identity_fn, stopping_signals] if tpu_estimator_spec.host_call is not None: @@ -1282,8 +1323,70 @@ class _ModelFnWrapper(object): return predict_step, host_calls, captured_scaffold_fn + def _verify_tpu_spec_predictions(self, predictions): + """Validates TPUEstimatorSpec.predictions dict.""" + # TODO(xiejw): Adds validation for prediction dictionrary. + # TODO(xiejw): Adds support for single tensor as predictions. + if not isinstance(predictions, dict): + raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') + + for (key, tensor) in predictions.items(): + if tensor.shape[0].value is None: + raise ValueError( + 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' + 'dynamic shape (should be static). Tensor: {}'.format( + key, tensor)) + return predictions + + def _validate_model_features_and_labels(self, + features, + labels, + is_export_mode): + """Validates that the features and labels for the model function are valid. + + A valid features/labels object is the one with: + - Type: Tensor or a dictionary of Tensors + - Static shape if is_export_mode is False. + + Args: + features: the features that would be input to the model function. + labels: the labels that would be input to the model function. + is_export_mode: boolean value specifying if in export mode. + + Raises: + TypeError: If features/labels are not of the correct type. + ValueError: If features/labels have dynamic shape. + """ + + def validate(obj, obj_name): + """Helper validate function.""" + if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict): + raise TypeError( + 'The {} to the model returned by input_fn must be either a Tensor ' + 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name, + obj)) + if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): + return + if isinstance(obj, ops.Tensor): + if not obj.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static shape.' + ' Tensor: {}'.format(obj_name, obj)) + else: + for (key, tensor) in obj.items(): + if not tensor.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static ' + 'shape. Key: \'{}\', Tensor: {}'.format( + obj_name, key, tensor)) + + validate(features, 'features') + if labels is not None: + validate(labels, 'labels') + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" + self._validate_model_features_and_labels(features, labels, is_export_mode) model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} @@ -1314,13 +1417,13 @@ class _ModelFnWrapper(object): batch_size_for_model_fn = self._ctx.batch_size_for_model_fn if batch_size_for_model_fn is not None: - if isinstance(params, hparam.HParams): - params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn) - else: - params[_BATCH_SIZE_KEY] = batch_size_for_model_fn + _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) + + running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) + _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu(is_export_mode) and + if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -1763,8 +1866,40 @@ class TPUEstimator(estimator_lib.Estimator): Exporting ========= - Exporting `SavedModel` support on TPU is not yet implemented. So, - `export_savedmodel` is executed on CPU, even if `use_tpu` is true. + `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, + and another with `tag_constants.SERVING` and `tag_constants.TPU`. + At serving time, these tags are used to select metagraph to load. + + Before running the graph on TPU, TPU system needs to be initialized. If + TensorFlow Serving model-server is used, this is done automatically. If + not, please call `session.run(tpu.initialize_system())`. + + `tpu.outside_compilation` can be used to wrap TPU incompatible ops in + `model_fn`. + + Example: + ---------------- + + ``` + def model_fn(features, labels, mode, config, params): + ... + logits = ... + export_outputs = { + 'logits': export_output_lib.PredictOutput( + {'logits': logits}) + } + + def host_call(logits): + class_ids = math_ops.argmax(logits) + classes = string_ops.as_string(class_ids) + export_outputs['classes'] = + export_output_lib.ClassificationOutput(classes=classes) + + tpu.outside_compilation(host_call, logits) + + ... + ``` + """ def __init__(self, @@ -1778,13 +1913,15 @@ class TPUEstimator(estimator_lib.Estimator): predict_batch_size=None, batch_axis=None, eval_on_tpu=True, + export_to_tpu=True, warm_start_from=None): """Constructs an `TPUEstimator` instance. Args: model_fn: Model function as required by `Estimator`. For training, the returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. + `TPUEstimator`. Instead, the user can pass the training hooks as + an argument to `TPUEstimator.train()`. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -1820,6 +1957,8 @@ class TPUEstimator(estimator_lib.Estimator): False or `PER_HOST_V2`, batch_axis is ignored. eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. + export_to_tpu: If True, `export_savedmodel()` exports a metagraph for + serving on TPU besides the one on CPU. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string @@ -1891,8 +2030,126 @@ class TPUEstimator(estimator_lib.Estimator): use_tpu, eval_on_tpu) + self._export_to_tpu = export_to_tpu + self._is_input_fn_invoked = None + def _add_meta_graph_for_mode(self, + builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=True, + mode=model_fn_lib.ModeKeys.PREDICT, + export_tags=None, + check_variables=True): + if mode != model_fn_lib.ModeKeys.PREDICT: + raise NotImplementedError( + 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' + 'got {}.'.format(mode)) + + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables, + mode=mode, + export_tags=export_tags, + check_variables=check_variables)) + + if self._export_to_tpu: + input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: + input_receiver_fn_map[mode]} + export_tags = [tag_constants.SERVING, tag_constants.TPU] + mode = _REWRITE_FOR_INFERENCE_MODE + # See b/110052256 for why `check_variables` is `False`. + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=False, + mode=mode, + export_tags=export_tags, + check_variables=False)) + + def _call_model_fn(self, features, labels, mode, config): + if mode == _REWRITE_FOR_INFERENCE_MODE: + return self._call_model_fn_for_inference(features, labels, mode, config) + else: + return super(TPUEstimator, self)._call_model_fn( + features, labels, mode, config) + + def _call_model_fn_for_inference(self, features, labels, mode, config): + """Wraps `_call_model_fn` for `export_savedmodel`.""" + if mode != _REWRITE_FOR_INFERENCE_MODE: + raise ValueError('mode must be {}; ' + 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + + capture = _CapturedObject() + + def computation(): + """Compute tpu tensors used in export_outputs. + + Passed to rewrite_for_inference so that model_fn will be called under + the rewriting contexts. Only tpu tensors are returned, but export_outputs + and scaffold are captured. + + Returns: + A list of Tensors used in export_outputs and not marked for + outside_compilation. + """ + # We should only call model fn once and it should be inside `computation` + # so that building the graph will happen under `rewrite_for_inference`. + mode = model_fn_lib.ModeKeys.PREDICT + estimator_spec = self._call_model_fn(features, labels, mode, config) + + # We pick the TPU tensors out from `export_output` and later return them + # from `computation` for rewriting. + tensors_dict = collections.OrderedDict( + (k, _export_output_to_tensors(v)) + for k, v in six.iteritems(estimator_spec.export_outputs) + ) + tensors = nest.flatten(tensors_dict) + tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + + # We cannot return anything other than `tpu_tensors` here so we capture + # the rest for later use. + capture.capture((estimator_spec, tensors_dict, tensors)) + return tpu_tensors + + tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) + estimator_spec, tensors_dict, tensors = capture.get() + + # Reconstruct `tensors`, but with `tpu_tensors` replaced with + # `tpu_tensors_on_cpu`. + new_tensors = [] + for t in tensors: + if _is_tpu_tensor(t): + new_tensors.append(tpu_tensors_on_cpu.pop(0)) + elif t is None: + new_tensors.append(None) + else: + # Only fetching `tpu_tensors_on_cpu` does not trigger + # TPU computation and blocks, so we add the control dependency here. + control_inputs = (tpu_tensors_on_cpu + if isinstance(tpu_tensors_on_cpu, (list, tuple)) + else (tpu_tensors_on_cpu,)) + with ops.control_dependencies(control_inputs): + new_tensors.append(array_ops.identity(t)) + + # Reconstruct `tensors_dict`. + new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) + # Reconstruct `export_outputs`. + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_tensors_dict) + ) + + return estimator_spec._replace(export_outputs=new_export_outputs) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1983,10 +2240,8 @@ class TPUEstimator(estimator_lib.Estimator): # input_fn for use_tpu=True/False. batch_size_for_input_fn = ctx.batch_size_for_input_fn if batch_size_for_input_fn is not None: - if isinstance(kwargs['params'], hparam.HParams): - kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn) - else: - kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn + _add_item_to_params(kwargs['params'], + _BATCH_SIZE_KEY, batch_size_for_input_fn) # For export_savedmodel, input_fn is never passed to Estimator. So, # `is_export_mode` must be False. @@ -2005,7 +2260,7 @@ class TPUEstimator(estimator_lib.Estimator): # dequeue_fn to model_fn. Here, `input_fn` is passed directly as # `features` in `model_fn` signature. def _input_fn(ctx): - kwargs['params'][_CTX_KEY] = ctx + _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) return input_fn(**kwargs) return _input_fn @@ -2077,11 +2332,11 @@ class TPUEstimator(estimator_lib.Estimator): if shutdown_mode: if shutdown_mode == 'shutdown_worker': finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=1000), + session_support.ShutdownLameWorkers(timeout_ms=60*1000), ] elif shutdown_mode == 'shutdown_computation': finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=1000), + session_support.RestartComputation(timeout_ms=60*1000), ] else: raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % @@ -2270,6 +2525,76 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn +def _is_tpu_tensor(tensor): + if not isinstance(tensor, ops.Tensor): + return False + try: + tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access + except ValueError: + return True + else: + return False + + +def _export_output_to_tensors(export_output): + """Get a list of `Tensors` used in `export_output`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + Returns: + a list of tensors used in export_output. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + return [export_output.scores, export_output.classes] + elif isinstance(export_output, export_output_lib.RegressionOutput): + return [export_output.value] + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output.outputs.values() + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + +def _clone_export_output_with_tensors(export_output, tensors): + """Clones `export_output` but with new `tensors`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + tensors: a list of `Tensors` used to construct a new `export_output`. + + Returns: + A dict similar to `export_output` but with `tensors`. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + if len(tensors) != 2: + raise ValueError('tensors must be of length 2; ' + 'got {}.'.format(len(tensors))) + return export_output_lib.ClassificationOutput(*tensors) + elif isinstance(export_output, export_output_lib.RegressionOutput): + if len(tensors) != 1: + raise ValueError('tensors must be of length 1; ' + 'got {}'.format(len(tensors))) + return export_output_lib.RegressionOutput(*tensors) + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output_lib.PredictOutput( + dict(zip(export_output.outputs.keys(), tensors))) + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() @@ -2417,7 +2742,7 @@ class _CapturedObject(object): def capture(self, o): if self._captured: raise RuntimeError( - 'InternalError: Object can be captured only. Please file bug .') + 'InternalError: Object can capture only once. Please file bug.') self._captured = True self._object = o @@ -2426,7 +2751,7 @@ class _CapturedObject(object): if not self._captured: raise RuntimeError( 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug .') + 'Please file bug.') return self._object @@ -2527,7 +2852,8 @@ class _Inputs(object): """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access - hook = estimator_lib._DatasetInitializerHook(iterator) + hook = estimator_util._DatasetInitializerHook(iterator) + # pylint: enable=protected-access self._iterator = iterator return hook @@ -2673,6 +2999,7 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): + """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. @@ -2792,7 +3119,7 @@ class _SignalsHelper(object): def __init__(self, signals): self._signal_keys = [] - for key in sorted(signals.iterkeys()): + for key in sorted(iter(signals.keys())): self._signal_keys.append(key) @property @@ -2804,7 +3131,7 @@ class _SignalsHelper(object): @staticmethod def as_tensor_list(signals): - return [signals[key] for key in sorted(signals.iterkeys())] + return [signals[key] for key in sorted(iter(signals.keys()))] def _verify_cross_hosts_transfer_size(tensor_dict, message): @@ -2823,3 +3150,16 @@ def _verify_cross_hosts_transfer_size(tensor_dict, message): '{}'.format(message, '\n'.join([ ' -- Key: {}, Shape: {}'.format(k, v) for k, v in tensor_structure.items()]))) + + +def _add_item_to_params(params, key, value): + """Adds a new item into `params`.""" + if isinstance(params, hparam.HParams): + # For HParams, we need to use special API. + if key in params: + params.set_hparam(key, value) + else: + params.add_hparam(key, value) + else: + # Now params is Python dict. + params[key] = value diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index e76cf83e4ddcd86ab3971bcecefe2e2dc979bf63..15f99d7eebddd46f9f6902b68f01e42359a72cbe 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.ops.losses import losses @@ -32,7 +34,8 @@ class CrossShardOptimizer(optimizer.Optimizer): def __init__(self, opt, reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer"): + name="CrossShardOptimizer", + group_assignment=None): """Construct a new cross-shard optimizer. Args: @@ -40,6 +43,8 @@ class CrossShardOptimizer(optimizer.Optimizer): reduction: The reduction to apply to the shard losses. name: Optional name prefix for the operations created when applying gradients. Defaults to "CrossShardOptimizer". + group_assignment: Optional list of group ids for applying the optimizer + to subgroups. Raises: ValueError: If reduction is not a valid cross-shard reduction. @@ -50,6 +55,35 @@ class CrossShardOptimizer(optimizer.Optimizer): super(CrossShardOptimizer, self).__init__(False, name) self._opt = opt self._reduction = reduction + self._group_assignment = group_assignment + + def _verify_and_get_subgroup_size(self, group_assignment, num_shards): + """Verify group_assignment and get the subgroup size". + + Args: + group_assignment: list of group ids for applying the optimizer + to subgroups. + num_shards: The number of TPU shards. + + Returns: + The size of one subgroup in group_assignment. + + Raises: + ValueError: If group_assignment is invalid. + """ + if not group_assignment: + return None + if len(group_assignment) != num_shards: + raise ValueError("The size of group_assignment does not equal to " + "num_shard({0}). Got group_assignment={1}".format( + num_shards, self._group_assignment)) + subgroup_size_list = dict(collections.Counter(group_assignment)).values() + if all(subgroup_size_list[0] == size for size in subgroup_size_list): + return subgroup_size_list[0] + else: + raise ValueError("The size of each subgroup in group_assignment must " + "be equal. Got group_assignment={}".format( + self._group_assignment)) def compute_gradients(self, loss, var_list=None, **kwargs): """Compute gradients of "loss" for the variables in "var_list". @@ -71,7 +105,8 @@ class CrossShardOptimizer(optimizer.Optimizer): A list of (gradient, variable) pairs. Raises: - ValueError: If not within a tpu_shard_context. + ValueError: If not within a tpu_shard_context or group_assignment is + invalid. """ num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: @@ -79,9 +114,17 @@ class CrossShardOptimizer(optimizer.Optimizer): "CrossShardOptimizer should be used within a tpu_shard_context, but " "got unset number_of_shards. Assuming 1.") num_shards = 1 + + subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, + num_shards) + if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - scale = 1.0 / num_shards + if self._group_assignment: + scale = 1.0 / subgroup_size + else: + scale = 1.0 / num_shards loss *= scale + return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) def apply_gradients(self, grads_and_vars, global_step=None, name=None): @@ -110,7 +153,8 @@ class CrossShardOptimizer(optimizer.Optimizer): if grad is None: summed_grads_and_vars.append((grad, var)) else: - summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var)) + summed_grads_and_vars.append((tpu_ops.cross_replica_sum( + grad, self._group_assignment), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) def get_slot(self, *args, **kwargs): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index c3882b8a27bc835f906c47dc5219f280c53800b8..6bdaa528f9f946ae4b9813d554409da2406b1f8d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context", 1) + pivot = control_flow_ops.no_op() + context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 5de55b5f7f2a41ac6edd27e5a102e565f33df12c..76927e62e82d02de172a0851819716dc63180371 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -295,7 +295,7 @@ py_test( tags = ["notsan"], deps = [ ":training_py", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index 409aba817c1ec37003eb98f000f6cf8918234c5d..a2444934bc21d58ed57d15494b3548a31ce3a2df 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( - dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + convert.partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + input_dataset.output_shapes, convert.partial_shape_to_tensor, padded_shapes) + # pylint: disable=protected-access padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py index 0338f409a203c232e63e99534a8f6d6a43fa661e..df0a186f4f6963d7e874bb4ab74a8db7e10a52ee 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 9720fd6e8657de18cf8d7565f834568ae52fdbda..19cb8983b6836266ebfac70c54657a96324e8435 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -53,12 +53,12 @@ cc_library( ":grpc_verbs_service_impl", ":rdma_mgr", ":verbs_service_proto_cc", + "//tensorflow:grpc++", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", ], alwayslink = 1, ) @@ -69,7 +69,7 @@ cc_library( hdrs = ["grpc_verbs_service_impl.h"], deps = [ ":verbs_service_proto_cc", - "@grpc//:grpc++_unsecure", + "//tensorflow:grpc++", ], ) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index 742f946c9536973eb8a6a11afda1b32ae4a7726b..af29abd91feda22824e57c19c13a3f48fb1d61b7 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -15,9 +15,9 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "grpc++/alarm.h" -#include "grpc++/grpc++.h" -#include "grpc++/server_builder.h" +#include "grpcpp/alarm.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" #include "tensorflow/contrib/verbs/grpc_verbs_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index 991f9a9d8bdf883b1b68bfa1fb6af7bf51b7e66a..4da7b59c69c88a4d04be37543aae7f03decd2c52 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/channel_interface.h" -#include "grpc++/impl/codegen/client_unary_call.h" -#include "grpc++/impl/codegen/method_handler_impl.h" -#include "grpc++/impl/codegen/rpc_service_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/channel_interface.h" +#include "grpcpp/impl/codegen/client_unary_call.h" +#include "grpcpp/impl/codegen/method_handler_impl.h" +#include "grpcpp/impl/codegen/rpc_service_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/sync_stream.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 1f0f10517e98a32ae882c027330091928f1a6ee2..abe5e08b07cd71b7ca28321e6eb2cf0eec5d1b0f 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ #define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/proto_utils.h" -#include "grpc++/impl/codegen/rpc_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/status.h" -#include "grpc++/impl/codegen/stub_options.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 86350a08e57e5050f18d019fe80d70f6381c1f7d..f7c979e86320d59ad033e2b8d7fcdff89ce0d133 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -1084,7 +1084,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, // The tensor must be copied from GPU to CPU, because either: // 1. The tensor is located on a non GDR compatible GPU. // 2. The tensor's meta-data has changed. - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); copy = Tensor(alloc, in.dtype(), in.shape()); CountCopies(rm_.name_, (void*)DMAHelper::base(&in), (void*)DMAHelper::base(©), in.TotalBytes(), true); @@ -1541,7 +1541,7 @@ bool RdmaTensorRequest::AllocateTensors() { if (mr_ == nullptr) { // Can't RDMA directly to result. Use a proxy. proxy_tensor_ = - new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0), + new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0), result_tensor_->dtype(), result_tensor_->shape()); rdma_addr_ = DMAHelper::base(proxy_tensor_); mr_ = diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 369bd986df5313955bc22d6e5c6d38815908ada3..9cb3d1fbbfdbc6d85a7a9799bd82438f0bf70c4f 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_client.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/framework/allocator_registry.h" @@ -282,7 +283,7 @@ void RdmaMgr::InitAllocators() { Allocator* allocators[] = { #if GOOGLE_CUDA - ProcessState::singleton()->GetCUDAHostAllocator(0), + GPUProcessState::singleton()->GetCUDAHostAllocator(0), ProcessState::singleton()->GetCPUAllocator(0), #endif // GOOGLE_CUDA cpu_allocator(), @@ -323,7 +324,8 @@ void RdmaMgr::InitAllocators() { std::bind(&RdmaMemoryMgr::InsertMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf)); - ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b64ae7f759799073c8d3ed4f70a42f44d1ef641c..97880219b80d663e9ee4eb8f0373786b23284b54 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,77 +72,77 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "cc_header_only_library", "full_path", "if_android", - "if_not_android_mips_and_mips64", "if_ios", "if_linux_x86_64", "if_mobile", "if_not_mobile", - "if_windows", "if_not_windows", - "tf_copts", + "if_windows", "tf_cc_test", "tf_cc_tests", + "tf_copts", "tf_cuda_library", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "cc_header_only_library", + "tf_features_nomodules_if_android", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", - "tf_platform_hdrs", - "tf_platform_srcs", - "tf_proto_library", - "tf_proto_library_cc", "tf_additional_all_protos", + "tf_additional_cloud_kernel_deps", + "tf_additional_cloud_op_deps", "tf_additional_core_deps", + "tf_additional_cupti_wrapper_deps", + "tf_additional_device_tracer_cuda_deps", + "tf_additional_device_tracer_deps", + "tf_additional_device_tracer_srcs", + "tf_additional_gdr_lib_defines", + "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", "tf_additional_lib_srcs", - "tf_additional_framework_hdrs", - "tf_additional_framework_srcs", - "tf_additional_minimal_lib_srcs", - "tf_additional_proto_hdrs", - "tf_additional_proto_srcs", - "tf_additional_cupti_wrapper_deps", "tf_additional_libdevice_data", "tf_additional_libdevice_deps", "tf_additional_libdevice_srcs", + "tf_additional_minimal_lib_srcs", + "tf_additional_mpi_lib_defines", + "tf_additional_proto_hdrs", + "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", - "tf_kernel_tests_linkstatic", - "tf_additional_cloud_op_deps", - "tf_additional_cloud_kernel_deps", - "tf_lib_proto_parsing_deps", "tf_additional_verbs_lib_defines", - "tf_additional_mpi_lib_defines", - "tf_additional_gdr_lib_defines", - "tf_additional_device_tracer_srcs", - "tf_additional_device_tracer_deps", - "tf_additional_device_tracer_cuda_deps", - "tf_pyclif_proto_library", "tf_jspb_proto_library", + "tf_kernel_tests_linkstatic", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", + "tf_platform_hdrs", + "tf_platform_srcs", + "tf_proto_library", + "tf_proto_library_cc", "tf_protos_all", "tf_protos_all_impl", "tf_protos_grappler", "tf_protos_grappler_impl", + "tf_pyclif_proto_library", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", "if_static", + "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") @@ -234,7 +234,6 @@ tf_proto_library( srcs = [], cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, protodeps = [ @@ -294,43 +293,18 @@ cc_library( ], ) -PLATFORM_BASE_HDRS = [ - "platform/env_time.h", - "platform/logging.h", - "platform/macros.h", - "platform/types.h", - "platform/byte_order.h", -] - -PLATFORM_OTHER_HDRS = [ - "platform/abi.h", - "platform/stacktrace.h", - "platform/stacktrace_handler.h", - "platform/context.h", - "platform/cpu_info.h", - "platform/cpu_feature_guard.h", - "platform/dynamic_annotations.h", - "platform/error.h", - "platform/env.h", - "platform/file_system.h", - "platform/file_system_helper.h", - "platform/fingerprint.h", - "platform/init_main.h", - "platform/mem.h", - "platform/mutex.h", - "platform/net.h", - "platform/notification.h", - "platform/null_file_system.h", - "platform/prefetch.h", - "platform/profile_utils/clock_cycle_profiler.h", - "platform/profile_utils/cpu_utils.h", - "platform/protobuf.h", - "platform/strong_hash.h", - "platform/subprocess.h", - "platform/thread_annotations.h", -] +filegroup( + name = "platform_base_hdrs", + srcs = [ + "platform/byte_order.h", + "platform/env_time.h", + "platform/logging.h", + "platform/macros.h", + "platform/types.h", + ], + visibility = ["//visibility:private"], +) -# Smaller platform libraries that don't depend on "lib" or "lib_internal". cc_library( name = "platform_base", srcs = tf_platform_hdrs([ @@ -342,16 +316,275 @@ cc_library( ]) + [ "platform/env_time.cc", ], - hdrs = PLATFORM_BASE_HDRS, + hdrs = [":platform_base_hdrs"], copts = tf_copts(), - # TODO(ahentz): remove use of this library so we can move it into 'platform' tags = ["avoid_dep"], + visibility = ["//tensorflow/core:__subpackages__"], deps = [ ":lib_platform", "//tensorflow/core/platform/default/build_config:base", ], ) +filegroup( + name = "platform_port_hdrs", + srcs = [ + "platform/cpu_info.h", + "platform/dynamic_annotations.h", + "platform/init_main.h", + "platform/mem.h", + "platform/mutex.h", + "platform/thread_annotations.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_port_internal_hdrs", + srcs = [ + "platform/demangle.h", + "platform/host_info.h", + "platform/snappy.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_port", + srcs = tf_platform_hdrs([ + "cpu_info.h", + "dynamic_annotations.h", + "thread_annotations.h", + "mutex.h", + ]) + tf_platform_srcs([ + "port.cc", + ]) + [ + "platform/cpu_info.cc", + ], + hdrs = [ + ":platform_port_hdrs", + ":platform_port_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib_platform", + ":platform_base", + "//tensorflow/core/platform/default/build_config:port", + "@snappy", + ], +) + +filegroup( + name = "platform_protobuf_hdrs", + srcs = [ + "platform/protobuf.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_protobuf_internal_hdrs", + srcs = [ + "platform/protobuf_internal.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_protobuf", + srcs = tf_platform_hdrs([ + "protobuf.h", + ]) + tf_platform_srcs([ + "protobuf.cc", + ]) + [ + "platform/protobuf_util.cc", + "lib/core/status.h", + ], + hdrs = [ + ":platform_protobuf_hdrs", + ":platform_protobuf_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib_platform", + ":platform_base", + ":platform_port", + "//tensorflow/core/platform/default/build_config:protobuf", + "@protobuf_archive//:protobuf", + ], +) + +cc_library( + name = "human_readable_json", + srcs = tf_platform_srcs(["human_readable_json.cc"]), + hdrs = ["platform/human_readable_json.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_human_readable_json_deps(), +) + +filegroup( + name = "platform_env_hdrs", + srcs = [ + "platform/env.h", + "platform/file_statistics.h", + "platform/file_system.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_env_internal_hdrs", + srcs = [ + "platform/load_library.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_env", + srcs = tf_platform_srcs([ + "env.cc", + "load_library.cc", + ]) + tf_platform_hdrs([ + "wide_char.h", + ]) + [ + "platform/env.cc", + "platform/file_system.cc", + ], + hdrs = [ + ":platform_env_hdrs", + ":platform_env_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":error_codes_proto_cc", + ":lib", + ":lib_internal", + ":lib_platform", + ":platform_base", + ":platform_port", + ":platform_protobuf", + "//tensorflow/core/platform/default/build_config:env", + ], +) + +filegroup( + name = "platform_file_system_hdrs", + srcs = [ + "platform/file_system_helper.h", + "platform/null_file_system.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_file_system", + srcs = tf_platform_srcs([ + ]) + tf_platform_hdrs([ + "windows_file_system.h", + ]) + [ + "platform/file_system_helper.cc", + ], + hdrs = [ + ":platform_file_system_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib", + ":lib_platform", + ":platform_env", + ], +) + +filegroup( + name = "platform_other_hdrs", + srcs = [ + "platform/abi.h", + "platform/context.h", + "platform/cpu_feature_guard.h", + "platform/error.h", + "platform/fingerprint.h", + "platform/net.h", + "platform/notification.h", + "platform/prefetch.h", + "platform/profile_utils/android_armv7a_cpu_utils_helper.h", + "platform/profile_utils/clock_cycle_profiler.h", + "platform/profile_utils/cpu_utils.h", + "platform/profile_utils/i_cpu_utils_helper.h", + "platform/stacktrace.h", + "platform/stacktrace_handler.h", + "platform/strong_hash.h", + "platform/subprocess.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_other_internal_hdrs", + srcs = [ + "platform/denormal.h", + "platform/setround.h", + "platform/tracing.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_other", + srcs = tf_platform_srcs([ + "subprocess.cc", + "net.cc", + "tracing.cc", + ]) + tf_platform_hdrs([ + "tracing.h", + "error.h", + "context.h", + "fingerprint.h", + "notification.h", + "stacktrace.h", + "strong_hash.h", + "subprocess.h", + "tracing_impl.h", + ]) + [ + "platform/cpu_feature_guard.cc", + "platform/setround.cc", + "platform/tracing.cc", + "platform/denormal.cc", + "platform/profile_utils/android_armv7a_cpu_utils_helper.cc", + "platform/profile_utils/clock_cycle_profiler.cc", + "platform/profile_utils/cpu_utils.cc", + ], + hdrs = [ + ":platform_other_hdrs", + ":platform_other_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib", + ":lib_platform", + ":platform_base", + ":platform_env", + ":platform_port", + ":platform_protobuf", + "//tensorflow/core/platform/default/build_config:other", + "//tensorflow/core/platform/default/build_config:platformlib", + "//tensorflow/core/platform/default/build_config:port", + ], +) + # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -385,8 +618,7 @@ cc_library( # tf_cc_test and tf_cc_binary will include the necessary symbols. cc_library( name = "lib", - hdrs = PLATFORM_BASE_HDRS + - PLATFORM_OTHER_HDRS + [ + hdrs = [ "lib/bfloat16/bfloat16.h", "lib/core/arena.h", "lib/core/bitmap.h", @@ -433,6 +665,12 @@ cc_library( "lib/strings/str_util.h", "lib/strings/strcat.h", "lib/strings/stringprintf.h", + ":platform_base_hdrs", + ":platform_env_hdrs", + ":platform_file_system_hdrs", + ":platform_other_hdrs", + ":platform_port_hdrs", + ":platform_protobuf_hdrs", ], visibility = ["//visibility:public"], deps = [ @@ -462,7 +700,9 @@ cc_library( srcs = ["platform/stacktrace_handler.cc"], hdrs = ["platform/stacktrace_handler.h"], deps = [ + ":abi", ":lib_platform", + ":stacktrace", ], ) @@ -552,6 +792,7 @@ tf_cuda_library( "framework/graph_def_util.h", "framework/graph_to_functiondef.h", "framework/kernel_def_builder.h", + "framework/kernel_def_util.h", "framework/log_memory.h", "framework/lookup_interface.h", "framework/memory_types.h", @@ -586,6 +827,7 @@ tf_cuda_library( "framework/types.h", "public/version.h", "util/activation_mode.h", + "util/batch_util.h", "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", @@ -603,6 +845,7 @@ tf_cuda_library( "util/sparse/group_iterator.h", "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", + "util/stat_summarizer_options.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -627,6 +870,18 @@ tf_cuda_library( deps = [":framework_internal"], ) +cc_library( + name = "stats_calculator_portable", + srcs = [ + "util/stat_summarizer_options.h", + "util/stats_calculator.cc", + ], + hdrs = [ + "util/stats_calculator.h", + ], + copts = tf_copts(), +) + cc_library( name = "overflow", hdrs = ["util/overflow.h"], @@ -636,11 +891,26 @@ cc_library( ], ) +cc_library( + name = "exec_on_stall", + hdrs = ["util/exec_on_stall.h"], + deps = [":framework_lite"], +) + cc_library( name = "ptr_util", hdrs = ["util/ptr_util.h"], ) +cc_library( + name = "status_util", + hdrs = ["util/status_util.h"], + deps = [ + ":graph", + ":lib", + ], +) + cc_library( name = "reader_base", srcs = ["framework/reader_base.cc"], @@ -738,6 +1008,7 @@ tf_gen_op_libs( "nn_ops", "no_op", "parsing_ops", + "random_grad", "random_ops", "remote_fused_graph_ops", "resource_variable_ops", @@ -936,6 +1207,7 @@ tf_cuda_library( hdrs = [ "common_runtime/device.h", "common_runtime/device_factory.h", + "common_runtime/function.h", "common_runtime/optimization_registry.h", "common_runtime/shape_refiner.h", "graph/algorithm.h", @@ -990,6 +1262,7 @@ cc_library( "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", @@ -1107,6 +1380,7 @@ cc_library( ":shape_inference_testutil", ":tensor_testutil", ":test", + ":testlib_ops", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", @@ -1114,6 +1388,18 @@ cc_library( ], ) +cc_library( + name = "testlib_ops", + testonly = 1, + srcs = ["common_runtime/testlib_ops.cc"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + # This is a link-only library to provide a DirectSession # implementation of the Session interface. tf_cuda_library( @@ -1175,6 +1461,7 @@ filegroup( "lib/png/**/*", "lib/gif/**/*", "util/events_writer.*", + "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/default/test_benchmark.*", @@ -1258,6 +1545,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1298,6 +1586,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1634,6 +1923,7 @@ tf_proto_library_cc( srcs = ["protobuf/master_service.proto"], has_services = 1, cc_api_version = 2, + cc_grpc_version = 1, cc_stubby_versions = ["2"], protodeps = [":master_proto"], visibility = [ @@ -1763,9 +2053,8 @@ cc_library( "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", - "platform/variant_coding.cc", - "platform/**/variant_cord_coding.cc", ], ) + tf_additional_lib_srcs( exclude = [ @@ -1777,9 +2066,8 @@ cc_library( "platform/**/env_time.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", - "platform/variant_coding.cc", - "platform/**/variant_cord_coding.cc", ] + # Protobuf deps already included through the ":lib_proto_parsing" # dependency. @@ -1964,9 +2252,9 @@ tf_proto_library( srcs = ERROR_CODES_PROTO_SRCS, cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, + provide_cc_alias = True, ) tf_generate_proto_text_sources( @@ -1985,7 +2273,6 @@ tf_proto_library( srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, protodeps = [ @@ -2028,7 +2315,6 @@ cc_library( ) FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ - "platform/variant_coding.h", "graph/edgeset.h", "graph/graph.h", "graph/graph_def_builder.h", @@ -2065,18 +2351,18 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests + "framework/resource_var.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", "framework/variant.h", - "platform/variant_coding.h", "util/command_line_flags.h", "util/env_var.h", "util/equal_graph_def.h", "util/presized_cuckoo_map.h", "util/tensor_slice_set.h", "util/tensor_slice_util.h", -] + tf_additional_framework_hdrs() +] tf_cuda_library( name = "framework_internal", @@ -2118,9 +2404,7 @@ cc_header_only_library( tf_cuda_library( name = "framework_internal_impl", - srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + [ - "platform/variant_coding.cc", - ] + glob( + srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + glob( [ "example/**/*.cc", "framework/**/*.cc", @@ -2154,7 +2438,7 @@ tf_cuda_library( "util/memmapped_file_system.cc", "util/memmapped_file_system_writer.cc", ], - }) + tf_additional_framework_srcs(), + }), hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), linkopts = select({ @@ -2363,6 +2647,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", + "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", "common_runtime/lower_if_op.h", @@ -2385,6 +2670,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", "common_runtime/visitable_allocator.h", + "common_runtime/process_state.h", + "common_runtime/pool_allocator.h", "graph/gradients.h", "graph/quantize_training.h", ] + if_mkl(["graph/mkl_graph_util.h"]) @@ -2412,6 +2699,7 @@ tf_cuda_library( "common_runtime/device_resolver_local.cc", "common_runtime/device_set.cc", "common_runtime/executor.cc", + "common_runtime/executor_factory.cc", "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", @@ -2422,7 +2710,9 @@ tf_cuda_library( "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", "common_runtime/placer.cc", + "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", + "common_runtime/process_state.cc", "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", @@ -2519,6 +2809,7 @@ cc_library( ], visibility = [ "//tensorflow/compiler:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", "//tensorflow/core/profiler:__subpackages__", ], deps = [":lib_internal"], @@ -2608,6 +2899,7 @@ cc_library( ) GPU_RUNTIME_HEADERS = [ + "common_runtime/gpu/cuda_host_allocator.h", "common_runtime/gpu/gpu_bfc_allocator.h", "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", @@ -2617,10 +2909,9 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_id_utils.h", "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_managed_allocator.h", + "common_runtime/gpu/gpu_process_state.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", - "common_runtime/gpu/pool_allocator.h", - "common_runtime/gpu/process_state.h", "common_runtime/gpu_device_context.h", ] @@ -2633,11 +2924,10 @@ tf_cuda_library( "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", "common_runtime/gpu/gpu_managed_allocator.cc", + "common_runtime/gpu/gpu_process_state.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", - "common_runtime/gpu/pool_allocator.cc", - "common_runtime/gpu/process_state.cc", ], hdrs = GPU_RUNTIME_HEADERS, copts = tf_copts(), @@ -2817,6 +3107,8 @@ cc_library( # we now need at least "str_util". ":lib", ":lib_platform", + ":stacktrace_handler", + ":test_lite", "//tensorflow/core/platform/default/build_config:test_lite_main", ], alwayslink = 1, @@ -2990,6 +3282,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exec_on_stall_test", + size = "small", + srcs = ["util/exec_on_stall_test.cc"], + deps = [ + ":exec_on_stall", + ":framework_lite", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], @@ -3081,10 +3385,12 @@ tf_cc_tests( "framework/bfloat16_test.cc", "framework/cancellation_test.cc", "framework/common_shape_fns_test.cc", + "framework/device_base_test.cc", "framework/function_test.cc", "framework/graph_def_util_test.cc", "framework/graph_to_functiondef_test.cc", "framework/kernel_def_builder_test.cc", + "framework/kernel_def_util_test.cc", "framework/memory_types_test.cc", "framework/node_def_builder_test.cc", "framework/node_def_util_test.cc", @@ -3109,6 +3415,7 @@ tf_cc_tests( "framework/variant_op_registry_test.cc", "framework/variant_test.cc", "graph/algorithm_test.cc", + "graph/control_flow_test.cc", "graph/edgeset_test.cc", "graph/graph_def_builder_test.cc", "graph/graph_partition_test.cc", @@ -3133,6 +3440,7 @@ tf_cc_tests( "util/semver_test.cc", "util/sparse/sparse_tensor_test.cc", "util/stat_summarizer_test.cc", + "util/status_util_test.cc", "util/tensor_format_test.cc", "util/tensor_slice_reader_test.cc", "util/tensor_slice_set_test.cc", @@ -3157,6 +3465,7 @@ tf_cc_tests( ":ops", ":protos_all_cc", ":protos_test_cc", + ":status_util", ":test", ":test_main", ":testlib", @@ -3284,7 +3593,10 @@ tf_cc_tests_gpu( tf_cc_test_mkl( name = "mkl_runtime_tests", size = "small", - srcs = ["common_runtime/mkl_cpu_allocator_test.cc"], + srcs = [ + "common_runtime/mkl_cpu_allocator_test.cc", + "common_runtime/mkl_threadpool_device_test.cc", + ], linkstatic = 1, deps = [ ":core", @@ -3386,6 +3698,37 @@ tf_cc_tests_gpu( ], ) +tf_cuda_cc_test( + name = "gpu_device_unified_memory_test", + size = "small", + srcs = [ + "common_runtime/gpu/gpu_device_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + # Runs test on a Guitar cluster that uses P100s to test unified memory + # allocations. + tags = tf_cuda_tests_tags() + [ + "guitar", + "multi_gpu", + ], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":gpu_id", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:ops_util", + ], +) + tf_cc_test_gpu( name = "cuda_libdevice_path_test", size = "small", @@ -3581,13 +3924,13 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", srcs = ["common_runtime/direct_session_test.cc"], + args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory linkstatic = tf_kernel_tests_linkstatic(), deps = [ - ":core", ":core_cpu", ":core_cpu_internal", ":direct_session_internal", @@ -3600,6 +3943,7 @@ tf_cc_test( ":test", ":test_main", ":testlib", + "//third_party/eigen3", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", @@ -3613,8 +3957,7 @@ tf_cc_test( "//tensorflow/core/kernels:queue_ops", "//tensorflow/core/kernels:session_ops", "//tensorflow/core/kernels:variable_ops", - "//third_party/eigen3", - ], + ] + if_cuda([":cuda"]), ) # This is identical to :common_runtime_direct_session_test with the addition of @@ -3719,6 +4062,31 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_executor_test", + size = "small", + srcs = ["common_runtime/executor_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + tf_cc_test( name = "common_runtime_function_test", size = "small", diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 19d643880966f7607405539a5ad43d8e03dc13fb..06b797e32edc046bab498f8d775040d57ef62ce9 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,6 +4,7 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def +# :java_api_def package( default_visibility = ["//visibility:private"], @@ -29,6 +30,12 @@ filegroup( visibility = ["//tensorflow:internal"], ) +filegroup( + name = "java_api_def", + srcs = glob(["java_api/*"]), + visibility = ["//tensorflow:internal"], +) + cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 477a0b670e49f8aa4ee8c250d4957886eb865ed5..ae03a61ae66ec8d0119d91eefe8c64e61348e9b4 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -149,6 +149,33 @@ void TestAllApiDefAttributeNamesAreValid( } } } + +void TestDeprecatedAttributesSetCorrectly( + const std::unordered_map& api_defs_map) { + for (const auto& name_and_api_def : api_defs_map) { + int num_deprecated_endpoints = 0; + const auto& api_def = name_and_api_def.second; + for (const auto& endpoint : api_def.endpoint()) { + if (endpoint.deprecated()) { + ++num_deprecated_endpoints; + } + } + + const auto& name = name_and_api_def.first; + ASSERT_TRUE(api_def.deprecation_message().empty() || + num_deprecated_endpoints == 0) + << "Endpoints are set to 'deprecated' for deprecated op " << name + << ". If an op is deprecated (i.e. deprecation_message is set), " + << "all the endpoints are deprecated implicitly and 'deprecated' " + << "field should not be set."; + if (num_deprecated_endpoints > 0) { + ASSERT_NE(num_deprecated_endpoints, api_def.endpoint_size()) + << "All " << name << " endpoints are deprecated. Please, set " + << "deprecation_message in api_def_" << name << ".pbtxt instead. " + << "to indicate that the op is deprecated."; + } + } +} } // namespace class BaseApiTest : public ::testing::Test { @@ -171,7 +198,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { continue; } - ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) + EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) << op.name() << " op does not have api_def_*.pbtxt file. " << "Please add api_def_" << op.name() << ".pbtxt file " << "under tensorflow/core/api_def/base_api/ directory."; @@ -236,6 +263,11 @@ TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) { TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); } +// Checks that deprecation is set correctly. +TEST_F(BaseApiTest, DeprecationSetCorrectly) { + TestDeprecatedAttributesSetCorrectly(api_defs_map_); +} + class PythonApiTest : public ::testing::Test { protected: PythonApiTest() { @@ -272,4 +304,9 @@ TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) { TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); } +// Checks that deprecation is set correctly. +TEST_F(PythonApiTest, DeprecationSetCorrectly) { + TestDeprecatedAttributesSetCorrectly(api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d8c2ed40a324d4854d83c471e8eef50e50277b93 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "AnonymousIterator" + out_arg { + name: "handle" + description: < ## Validate your installation @@ -491,13 +489,7 @@ TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the following: - -* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) -* @{$get_started/eager} - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. +To learn more, see the [TensorFlow tutorials](../tutorials/). ## TensorFlow GPU support @@ -517,7 +509,7 @@ on your system: from source. To use the TensorFlow binaries, version 3.5 or higher is required. See the [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of supported GPU cards. -* [GPU drivers](http://nvidia.com/driver) that support your version of the CUDA +* [GPU drivers](http://nvidia.com/drivers) that support your version of the CUDA Toolkit. * The `libcupti-dev` library is the NVIDIA CUDA Profile Tools Interface. This library provides advanced profiling support. To install this library, @@ -684,14 +676,14 @@ This section documents the relevant values for Linux installations. CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp27-none-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp27-none-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -703,14 +695,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -722,14 +714,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
 
@@ -741,14 +733,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
 
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 016e7bf1b90d35cf0199eca3209fb2697856547c..c6f0c17924c95e11d22b08c8976d9044c365dce2 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv: TensorFlow in the active Virtualenv is as follows:
 $ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -242,7 +242,7 @@ take the following steps: issue the following command:
 $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl 
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl If the preceding command fails, see [installation problems](#common-installation-problems). @@ -350,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
 (targetDirectory)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl @@ -403,14 +403,7 @@ writing TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the following: - -* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) -* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. - +To learn more, see the [TensorFlow tutorials](../tutorials/). ## Common installation problems @@ -524,7 +517,7 @@ The value you specify depends on your Python version.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl
 
@@ -532,5 +525,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl
 
diff --git a/tensorflow/docs_src/install/install_raspbian.md b/tensorflow/docs_src/install/install_raspbian.md new file mode 100644 index 0000000000000000000000000000000000000000..46c4944ca7448df2c993ee44d5099494b759dea8 --- /dev/null +++ b/tensorflow/docs_src/install/install_raspbian.md @@ -0,0 +1,313 @@ +# Installing TensorFlow on Raspbian + +This guide explains how to install TensorFlow on a Raspberry Pi running +Raspbian. Although these instructions might also work on other Pi variants, we +have only tested (and we only support) these instructions on machines meeting +the following requirements: + +* Raspberry Pi devices running Raspbian 9.0 or higher + +## Determine how to install TensorFlow + +You must pick the mechanism by which you install TensorFlow. The supported +choices are as follows: + +* "Native" pip. +* Cross-compiling from sources. + +**We recommend pip installation.** + +## Installing with native pip + +We have uploaded the TensorFlow binaries to piwheels.org. Therefore, you can +install TensorFlow through pip. + +The [REQUIRED_PACKAGES section of +setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py) +lists the packages that pip will install or upgrade. + +### Prerequisite: Python + +In order to install TensorFlow, your system must contain one of the following +Python versions: + +* Python 2.7 +* Python 3.4+ + +If your system does not already have one of the preceding Python versions, +[install](https://wiki.python.org/moin/BeginnersGuide/Download) it now. It +should already be included when Raspbian was installed though, so no extra steps +should be needed. + +### Prerequisite: pip + +[Pip](https://en.wikipedia.org/wiki/Pip_\(package_manager\)) installs and +manages software packages written in Python. If you intend to install with +native pip, then one of the following flavors of pip must be installed on your +system: + +* `pip3`, for Python 3.n (preferred). +* `pip`, for Python 2.7. + +`pip` or `pip3` was probably installed on your system when you installed Python. +To determine whether pip or pip3 is actually installed on your system, issue one +of the following commands: + +
$ pip3 -V # for Python 3.n
+$ pip -V  # for Python 2.7
+ +If it gives the error "Command not found", then the package has not been +installed yet. To install if for the first time, run: + +
$ sudo apt-get install python3-pip # for Python 3.n
+sudo apt-get install python-pip # for Python 2.7
+ +You can find more help on installing and upgrading pip in +[the Raspberry Pi documentation](https://www.raspberrypi.org/documentation/linux/software/python.md). + +### Prerequisite: Atlas + +[Atlas](http://math-atlas.sourceforge.net/) is a linear algebra library that +numpy depends on, and so needs to be installed before TensorFlow. To add it to +your system, run the following command: + +
$ sudo apt install libatlas-base-dev
+ +### Install TensorFlow + +Assuming the prerequisite software is installed on your Pi, install TensorFlow +by invoking **one** of the following commands: + +
 $ pip3 install tensorflow     # Python 3.n
+     $ pip install tensorflow      # Python 2.7
+ +This can take some time on certain platforms like the Pi Zero, where some Python +packages like scipy that TensorFlow depends on need to be compiled before the +installation can complete. The Python 3 version will typically be faster to +install because piwheels.org has pre-built versions of the dependencies +available, so this is our recommended option. + +### Next Steps + +After installing TensorFlow, [validate your +installation](#ValidateYourInstallation) to confirm that the installation worked +properly. + +### Uninstalling TensorFlow + +To uninstall TensorFlow, issue one of following commands: + +
$ pip uninstall tensorflow
+$ pip3 uninstall tensorflow 
+ +## Cross-compiling from sources + +Cross-compilation means building on a different machine than than you'll be +deploying on. Since Raspberry Pi's only have limited RAM and comparatively slow +processors, and TensorFlow has a large amount of source code to compile, it's +easier to use a MacOS or Linux desktop or laptop to handle the build process. +Because it can take over 24 hours to build on a Pi, and requires external swap +space to cope with the memory shortage, we recommend using cross-compilation if +you do need to compile TensorFlow from source. To make the dependency management +process easier, we also recommend using Docker to help simplify building. + +Note that we provide well-tested, pre-built TensorFlow binaries for Raspbian +systems. So, don't build a TensorFlow binary yourself unless you are very +comfortable building complex packages from source and dealing with the +inevitable aftermath should things not go exactly as documented + +### Prerequisite: Docker + +Install Docker on your machine as described in the [Docker +documentation](https://docs.docker.com/engine/installation/#/on-macos-and-windows). + +### Clone the TensorFlow repository + +Start the process of building TensorFlow by cloning a TensorFlow repository. + +To clone **the latest** TensorFlow repository, issue the following command: + +
$ git clone https://github.com/tensorflow/tensorflow 
+ +The preceding git clone command creates a subdirectory named +`tensorflow`. After cloning, you may optionally build a **specific branch** +(such as a release branch) by invoking the following commands: + +
+$ cd tensorflow
+$ git checkout Branch # where Branch is the desired branch
+
+ +For example, to work with the `r1.0` release instead of the master release, +issue the following command: + +
$ git checkout r1.0
+ +### Build from source + +To compile TensorFlow and produce a binary pip can install, do the following: + +1. Start a terminal. +2. Navigate to the directory containing the tensorflow source code. +3. Run a command to cross-compile the library, for example: + +
$ CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.4" \
+tensorflow/tools/ci_build/ci_build.sh PI-PYTHON3 tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
+ 
+ +This will build a pip .whl file for Python 3.4, with Arm v7 instructions that +will only work on the Pi models 2 or 3. These NEON instructions are required for +the fastest operation on those devices, but you can build a library that will +run across all Pi devices by passing `PI_ONE` at the end of the command line. +You can also target Python 2.7 by omitting the initial docker parameters. Here's +an example of building for Python 2.7 and Raspberry Pi model Zero or One +devices: + +
$ tensorflow/tools/ci_build/ci_build.sh PI tensorflow/tools/ci_build/pi/build_raspberry_pi.sh PI_ONE
+ +This will take some time to complete, typically twenty or thirty minutes, and +should produce a .whl file in an output-artifacts sub-folder inside your source +tree at the end. This wheel file can be installed through pip or pip3 (depending +on your Python version) by copying it to a Raspberry Pi and running a terminal +command like this (with the name of your actual file substituted): + +
$ pip3 install tensorflow-1.9.0-cp34-none-linux_armv7l.whl
+ +### Troubleshooting the build + +The build script uses Docker internally to create a Linux virtual machine to +handle the compilation. If you do have problems running the script, first check +that you're able to run Docker tests like `docker run hello-world` on your +system. + +If you're building from the latest development branch, try syncing to an older +version that's known to work, for example release 1.9, with a command like this: + +
$ git checkout r1.0
+ + + +## Validate your installation + +To validate your TensorFlow installation, do the following: + +1. Ensure that your environment is prepared to run TensorFlow programs. +2. Run a short TensorFlow program. + +### Prepare your environment + +If you installed on native pip, Virtualenv, or Anaconda, then do the following: + +1. Start a terminal. +2. If you installed TensorFlow source code, navigate to any directory *except* + one containing TensorFlow source code. + +### Run a short TensorFlow program + +Invoke python from your shell as follows: + +
$ python
+ +Enter the following short program inside the python interactive shell: + +```python +# Python +import tensorflow as tf +hello = tf.constant('Hello, TensorFlow!') +sess = tf.Session() +print(sess.run(hello)) +``` + +If the system outputs the following, then you are ready to begin writing +TensorFlow programs: + +
Hello, TensorFlow!
+ +If you're running with Python 3.5, you may see a warning when you first import +TensorFlow. This is not an error, and TensorFlow should continue to run with no +problems, despite the log message. + +If the system outputs an error message instead of a greeting, see [Common +installation problems](#common_installation_problems). + +To learn more, see the [TensorFlow tutorials](../tutorials/). + +## Common installation problems + +We are relying on Stack Overflow to document TensorFlow installation problems +and their remedies. The following table contains links to Stack Overflow answers +for some common installation problems. If you encounter an error message or +other installation problem not listed in the following table, search for it on +Stack Overflow. If Stack Overflow doesn't show the error message, ask a new +question about it on Stack Overflow and specify the `tensorflow` tag. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Stack Overflow Link Error Message
42006320
ImportError: Traceback (most recent call last):
+File ".../tensorflow/core/framework/graph_pb2.py", line 6, in 
+from google.protobuf import descriptor as _descriptor
+ImportError: cannot import name 'descriptor'
+
33623453
IOError: [Errno 2] No such file or directory:
+  '/tmp/pip-o6Tpui-build/setup.py'
+
35190574
SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify
+  failed
42009190
+  Installing collected packages: setuptools, protobuf, wheel, numpy, tensorflow
+  Found existing installation: setuptools 1.1.6
+  Uninstalling setuptools-1.1.6:
+  Exception:
+  ...
+  [Errno 1] Operation not permitted:
+  '/tmp/pip-a1DXRT-uninstall/.../lib/python/_markerlib' 
33622019
ImportError: No module named copyreg
37810228During a pip install operation, the system returns: +
OSError: [Errno 1] Operation not permitted
+
33622842An import tensorflow statement triggers an error such as the + following:
Traceback (most recent call last):
+  File "", line 1, in 
+  File "/usr/local/lib/python2.7/site-packages/tensorflow/__init__.py",
+    line 4, in 
+    from tensorflow.python import *
+    ...
+  File "/usr/local/lib/python2.7/site-packages/tensorflow/core/framework/tensor_shape_pb2.py",
+    line 22, in 
+    serialized_pb=_b('\n,tensorflow/core/framework/tensor_shape.proto\x12\ntensorflow\"d\n\x10TensorShapeProto\x12-\n\x03\x64im\x18\x02
+      \x03(\x0b\x32
+      .tensorflow.TensorShapeProto.Dim\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01
+      \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tb\x06proto3')
+  TypeError: __init__() got an unexpected keyword argument 'syntax'
+
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 5ba522b436137bc5588382fd79f7559c6e9d11ed..fc1f6d05bdc26785090e1fc2c6f47826660090ac 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -81,7 +81,7 @@ or [macOS](#PrepareMac) - + ## Prepare environment for Linux Before building TensorFlow on Linux, install the following build @@ -289,17 +289,27 @@ Note: If you're only interested in building the libraries for the TensorFlow C or Java APIs, see [Build the C or Java libraries](#BuildCorJava), you do not need to build the pip package in that case. -To build a pip package for TensorFlow with CPU-only support, -you would typically invoke the following command: +### CPU-only support + +To build a pip package for TensorFlow with CPU-only support: + +
+$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+
+ +To build a pip package for TensorFlow with CPU-only support for the Intel® MKL-DNN:
-$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+$ bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package
 
-To build a pip package for TensorFlow with GPU support, -invoke the following command: +### GPU support + +To build a pip package for TensorFlow with GPU support: -
$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package 
+
+$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
+
**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow website are built with gcc 4, which uses the older ABI. To @@ -328,10 +338,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl` file depends on your platform. For example, the following command will install the pip package -for TensorFlow 1.8.0 on Linux: +for TensorFlow 1.9.0rc0 on Linux:
-$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.8.0-py2-none-any.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0rc0-py2-none-any.whl
 
## Validate your installation @@ -362,7 +372,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/eager}. +To learn more, see the [TensorFlow tutorials](../tutorials/). If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). @@ -373,9 +383,9 @@ The build and installation problems you encounter typically depend on the operating system. See the "Common installation problems" section of one of the following guides: - * @{$install_linux#CommonInstallationProblems$Installing TensorFlow on Linux} - * @{$install_mac#CommonInstallationProblems$Installing TensorFlow on Mac OS} - * @{$install_windows#CommonInstallationProblems$Installing TensorFlow on Windows} + * @{$install_linux#common_installation_problems$Installing TensorFlow on Linux} + * @{$install_mac#common_installation_problems$Installing TensorFlow on Mac OS} + * @{$install_windows#common_installation_problems$Installing TensorFlow on Windows} Beyond the errors documented in those two guides, the following table notes additional errors specific to building TensorFlow. Note that we @@ -433,6 +443,8 @@ Stack Overflow and specify the `tensorflow` tag. **Linux** + + @@ -456,6 +468,7 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.9.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.11.0N/AN/A
tensorflow_gpu-1.9.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.11.079
tensorflow-1.8.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.10.0N/AN/A
tensorflow_gpu-1.8.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.7.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.10.0N/AN/A
+ @@ -472,6 +485,8 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.9.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.11.0N/AN/A
tensorflow-1.8.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
tensorflow-1.7.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
tensorflow-1.6.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
+ + diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index a139a49661ee4eb8606d4b9fa1bf96ae12d16b8e..7b7b17ce81407bbbff837a00bb43162b4b2d44f3 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -157,14 +157,7 @@ TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the following: - -* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) -* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. - +To learn more, see the [TensorFlow tutorials](../tutorials/). ## Common installation problems diff --git a/tensorflow/docs_src/install/leftnav_files b/tensorflow/docs_src/install/leftnav_files index e523e06f67aad508238ee0965f34ebe16c77bf90..ace275c0e82b794708bfc63c0e61d6bb3251a152 100644 --- a/tensorflow/docs_src/install/leftnav_files +++ b/tensorflow/docs_src/install/leftnav_files @@ -4,6 +4,7 @@ index.md install_linux.md: Ubuntu install_mac.md: MacOS install_windows.md: Windows +install_raspbian.md: Raspbian install_sources.md: From source >>> migration.md diff --git a/tensorflow/docs_src/mobile/leftnav_files b/tensorflow/docs_src/mobile/leftnav_files index 585470d5f0847716863ba6129bf75c26631fecbd..97340ef7e1af64634f8590b5d21a344b5181cb73 100644 --- a/tensorflow/docs_src/mobile/leftnav_files +++ b/tensorflow/docs_src/mobile/leftnav_files @@ -4,6 +4,7 @@ tflite/index.md tflite/devguide.md tflite/demo_android.md tflite/demo_ios.md +tflite/performance.md >>> ### TensorFlow Mobile mobile_intro.md diff --git a/tensorflow/docs_src/mobile/linking_libs.md b/tensorflow/docs_src/mobile/linking_libs.md index cf0db590210593914d42105c2cfae5bd99e18287..efef5dd0daa0b267d8384d32d62d9ce0226dc102 100644 --- a/tensorflow/docs_src/mobile/linking_libs.md +++ b/tensorflow/docs_src/mobile/linking_libs.md @@ -27,7 +27,7 @@ called `libandroid_tensorflow_inference_java.jar`. There are three ways to include this functionality in your program: 1. Include the jcenter AAR which contains it, as in this - [example app](https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/build.gradle#L59-L65) + [example app](https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tfmobile/build.gradle#L59-L65) 2. Download the nightly precompiled version from [ci.tensorflow.org](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/). diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/docs_src/mobile/mobile_intro.md index 241f01d460ae35e818a61be4c4914b3bd8dae00a..baad4433083d18a19ea3dd5ec0c1bae498ac2da9 100644 --- a/tensorflow/docs_src/mobile/mobile_intro.md +++ b/tensorflow/docs_src/mobile/mobile_intro.md @@ -38,7 +38,8 @@ speech-driven interface, and many of these require on-device processing. Most of the time a user isn’t giving commands, and so streaming audio continuously to a remote server would be a waste of bandwidth, since it would mostly be silence or background noises. To solve this problem it’s common to have a small neural -network running on-device @{$tutorials/audio_recognition$listening out for a particular keyword}. +network running on-device +[listening out for a particular keyword](../tutorials/sequences/audio_recognition). Once that keyword has been spotted, the rest of the conversation can be transmitted over to the server for further processing if more computing power is needed. diff --git a/tensorflow/docs_src/mobile/prepare_models.md b/tensorflow/docs_src/mobile/prepare_models.md index 8b22c04d872f18607c485775cb8f096f0a361995..2b84dbb97388b16c6a4ae1d3472e0b1a993285f0 100644 --- a/tensorflow/docs_src/mobile/prepare_models.md +++ b/tensorflow/docs_src/mobile/prepare_models.md @@ -105,8 +105,8 @@ inline constants so everything’s in one file. To handle the conversion, you need the `freeze_graph.py` script, that’s held in [`tensorflow/python/tools/freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py). You’ll run it like this: - bazel build tensorflow/tools:freeze_graph - bazel-bin/tensorflow/tools/freeze_graph \ + bazel build tensorflow/python/tools:freeze_graph + bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/tmp/model/my_graph.pb \ --input_checkpoint=/tmp/model/model.ckpt-1000 \ --output_graph=/tmp/frozen_graph.pb \ diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/docs_src/mobile/tflite/demo_android.md index 7f2f8882a24702d167599452e66afbe720026808..fdf0bcf3c1135f0e702c7dda4d1d608a26169470 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_android.md +++ b/tensorflow/docs_src/mobile/tflite/demo_android.md @@ -1,7 +1,7 @@ # Android Demo App An example Android application using TensorFLow Lite is available -[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). +[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo). The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. To run the demo, a device running Android 5.0 ( API 21) or higher is required. @@ -44,20 +44,22 @@ app: Android Studio project. * Install all the Gradle extensions it requests. -To get a model, either: +Now you can build and run the demo app. -* Download the quantized [Mobilenet TensorFlow Lite model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - and unzip and copy `mobilenet_quant_v1_224.tflite` to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/`. -* Or, download the floating point [Inception-v3 model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) - and unzip and copy `inceptionv3_non_slim_2015.tflite` to the assets - directory. Change the chosen classifier in - [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java)
+The build process downloads the quantized [Mobilenet TensorFlow Lite model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip), and unzips it into the assets directory: `tensorflow/contrib/lite/java/demo/app/src/main/assets/`. + +Some additional details are available on the +[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). + +### Using other models + +To use a different model: +* Download the floating point [Inception-v3 model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip). +* Unzip and copy `inceptionv3_non_slim_2015.tflite` to the assets directory. +* Change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java)
from: `classifier = new ImageClassifierQuantizedMobileNet(getActivity());`
to: `classifier = new ImageClassifierFloatInception(getActivity());`. -Now you can build and run the demo app. - ## Build TensorFlow Lite and the demo app from source diff --git a/tensorflow/docs_src/mobile/tflite/devguide.md b/tensorflow/docs_src/mobile/tflite/devguide.md index 4133bc172a1924f0ce8bb515d66fc03d716923c8..b168d6c18366708ebaa7216481d262b02051168d 100644 --- a/tensorflow/docs_src/mobile/tflite/devguide.md +++ b/tensorflow/docs_src/mobile/tflite/devguide.md @@ -54,10 +54,11 @@ both floating point and quantized inference. ### Train a custom model A developer may choose to train a custom model using Tensorflow (see the -@{$tutorials} for examples of building and training models). If you have already -written a model, the first step is to export this to a @{tf.GraphDef} file. This -is required because some formats do not store the model structure outside the -code, and we must communicate with other parts of the framework. See +[TensorFlow tutorials](../../tutorials/) for examples of building and training +models). If you have already written a model, the first step is to export this +to a @{tf.GraphDef} file. This is required because some formats do not store the +model structure outside the code, and we must communicate with other parts of the +framework. See [Exporting the Inference Graph](https://github.com/tensorflow/models/blob/master/research/slim/README.md) to create .pb file for the custom model. diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md index 562203482763991c412b523bd261b3163d361134..3d1733024e493042a2cc85aa9f2fec4b75eefa94 100644 --- a/tensorflow/docs_src/mobile/tflite/index.md +++ b/tensorflow/docs_src/mobile/tflite/index.md @@ -37,8 +37,9 @@ a custom (less-dynamic) memory allocator to ensure minimal load, initialization, and execution latency. TensorFlow Lite provides an interface to leverage hardware acceleration, if -available on the device. It does so via the Android Neural Networks library, -released as part of Android O-MR1. +available on the device. It does so via the +[Android Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/index.html), +available on Android 8.1 (API level 27) and higher. ## Why do we need a new mobile-specific library? @@ -116,6 +117,10 @@ following: Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) to all first-party and third-party apps. + Also see the complete list of + [TensorFlow Lite's supported models](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md), + including the model sizes, performance numbers, and downloadable model files. + - Quantized versions of the MobileNet model, which runs faster than the non-quantized (float) version on CPU. @@ -131,10 +136,10 @@ compatibility with this release. ## Getting Started We recommend you try out TensorFlow Lite with the pre-tested models indicated -above. If you have an existing mode, you will need to test whether your model is -compatible with both the converter and the supported operator set. To test your -model, see the [documentation on -GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). +above. If you have an existing model, you will need to test whether your model +is compatible with both the converter and the supported operator set. To test +your model, see the +[documentation on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). ### Retrain Inception-V3 or MobileNet for a custom data set diff --git a/tensorflow/docs_src/mobile/tflite/performance.md b/tensorflow/docs_src/mobile/tflite/performance.md new file mode 100644 index 0000000000000000000000000000000000000000..79bacaaa1b889a8711e5c09c7fd4e4912e70d3bd --- /dev/null +++ b/tensorflow/docs_src/mobile/tflite/performance.md @@ -0,0 +1,174 @@ +# Performance + +This document lists TensorFlow Lite performance benchmarks when running well +known models on some Android and iOS devices. + +These performance benchmark numbers were generated with the +[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) +and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). + +# Android performance benchmarks + +For Android benchmarks, the CPU affinity is set to use big cores on the device to +reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). + +It assumes that models were download and unzipped to the +`/data/local/tmp/tflite_models` directory. The benchmark binary is built +using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android) +and assumed in the `/data/local/tmp` directory. + +To run the benchmark: + +``` +adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ + --num_threads=1 \ + --graph=/data/local/tmp/tflite_models/${GRAPH} \ + --warmup_runs=1 \ + --num_runs=50 \ + --use_nnapi=false +``` + +Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity +chosen according to the following table: + +Device | CPU_MASK | +-------| ---------- +Pixel 2 | f0 | +Pixel xl | 0c | + + +
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.9.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.9.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.8.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.8.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.7.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + Pixel 2 166.5 ms (2.6 ms)
Pixel xl 122.9 ms (1.8 ms)
+ Mobilenet_1.0_224 (quant) + Pixel 2 69.5 ms (0.9 ms)
Pixel xl 78.9 ms (2.2 ms)
+ NASNet mobile + Pixel 2 273.8 ms (3.5 ms)
Pixel xl 210.8 ms (4.2 ms)
+ SqueezeNet + Pixel 2 234.0 ms (2.1 ms)
Pixel xl 158.0 ms (2.1 ms)
+ Inception_ResNet_V2 + Pixel 2 2846.0 ms (15.0 ms)
Pixel xl 1973.0 ms (15.0 ms)
+ Inception_V4 + Pixel 2 3180.0 ms (11.7 ms)
Pixel xl 2262.0 ms (21.0 ms)
+ +# iOS benchmarks + +To run iOS benchmarks, the [benchmark +app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) +was modified to include the appropriate model and `benchmark_params.json` was +modified to set `num_threads` to 1. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + iPhone 8 32.2 ms (0.8 ms)
+ Mobilenet_1.0_224 (quant) + iPhone 8 24.4 ms (0.8 ms)
+ NASNet mobile + iPhone 8 60.3 ms (0.6 ms)
+ SqueezeNet + iPhone 8 44.3 (0.7 ms)
+ Inception_ResNet_V2 + iPhone 8562.4 ms (18.2 ms)
+ Inception_V4 + iPhone 8 661.0 ms (29.2 ms)
diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md index 20165a090efcf26133ff2677fa4914c5153d5249..a5fa551dd4904df3a73c0c2357ab7c79685f0393 100644 --- a/tensorflow/docs_src/performance/benchmarks.md +++ b/tensorflow/docs_src/performance/benchmarks.md @@ -403,8 +403,6 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks) was run on the various platforms to generate the above results. -@{$performance_models$High-Performance Models} details techniques in the script -along with examples of how to execute the script. In order to create results that are as repeatable as possible, each test was run 5 times and then the times were averaged together. GPUs are run in their default diff --git a/tensorflow/docs_src/performance/index.md b/tensorflow/docs_src/performance/index.md index 49343eaac7f0a785691a7633d19cc41d281efb99..131d28fa3eb47ff363888934c728e9971283c45d 100644 --- a/tensorflow/docs_src/performance/index.md +++ b/tensorflow/docs_src/performance/index.md @@ -1,19 +1,31 @@ # Performance -Performance is often a significant issue when training a machine learning -model. This section explains various ways to optimize performance. Start -your investigation with the @{$performance_guide$Performance Guide} and then go -deeper with techniques detailed in @{$performance_models$High-Performance Models}: - - * @{$performance_guide$Performance Guide}, which contains a collection of best +Performance is an important consideration when training machine learning +models. Performance speeds up and scales research while +also providing end users with near instant predictions. This section provides +details on the high level APIs to use along with best practices to build +and train high performance models, and quantize models for the least latency +and highest throughput for inference. + + * @{$performance_guide$Performance Guide} contains a collection of best practices for optimizing your TensorFlow code. - * @{$performance_models$High-Performance Models}, which contains a collection - of advanced techniques to build highly scalable models targeting different - system types and network topologies. + * @{$datasets_performance$Data input pipeline guide} describes the tf.data + API for building efficient data input pipelines for TensorFlow. + + * @{$performance/benchmarks$Benchmarks} contains a collection of + benchmark results for a variety of hardware configurations. + + * For improving inference efficiency on mobile and + embedded hardware, see + @{$quantization$How to Quantize Neural Networks with TensorFlow}, which + explains how to use quantization to reduce model size, both in storage + and at runtime. + + * For optimizing inference on GPUs, refer to [NVIDIA TensorRT™ + integration with TensorFlow.]( + https://medium.com/tensorflow/speed-up-tensorflow-inference-on-gpus-with-tensorrt-13b49f3db3fa) - * @{$performance/benchmarks$Benchmarks}, which contains a collection of - benchmark results. XLA (Accelerated Linear Algebra) is an experimental compiler for linear algebra that optimizes TensorFlow computations. The following guides explore @@ -36,10 +48,5 @@ XLA: standalone tool that compiles TensorFlow graphs into executable code in order to optimize performance. -And finally, we offer the following guide: - * @{$quantization$How to Quantize Neural Networks with TensorFlow}, which - can explains how to use quantization to reduce model size, both in storage - and at runtime. Quantization can improve performance, especially on - mobile hardware. diff --git a/tensorflow/docs_src/performance/leftnav_files b/tensorflow/docs_src/performance/leftnav_files index 1f894c39fe4554261cd37ebc8cd48af6b36eef43..12e0dbd48ac4913e20a401f5fa1a1fd05a273fc3 100644 --- a/tensorflow/docs_src/performance/leftnav_files +++ b/tensorflow/docs_src/performance/leftnav_files @@ -1,7 +1,6 @@ index.md performance_guide.md datasets_performance.md -performance_models.md benchmarks.md quantization.md diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index b1796cf9b2d0bf7459e70ab542b6e6fcb203667a..cb0f5ca9242098d06aa0a9898e4a3774fab527b8 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -78,7 +78,7 @@ training CIFAR-10 illustrates the use of the `tf.data` API along with The `tf.data` API utilizes C++ multi-threading and has a much lower overhead than the Python-based `queue_runner` that is limited by Python's multi-threading performance. A detailed performance guide for the `tf.data` API can be found -[here](@{$datasets_performance}). +@{$datasets_performance$here}. While feeding data using a `feed_dict` offers a high level of flexibility, in general `feed_dict` does not provide a scalable solution. If only a single GPU diff --git a/tensorflow/docs_src/performance/quantization.md b/tensorflow/docs_src/performance/quantization.md index 2fea02d861d314cc61f2ba20475bf08ebea8fb5f..c97f74139c6ee852bf29724a3ac335d349a73fd3 100644 --- a/tensorflow/docs_src/performance/quantization.md +++ b/tensorflow/docs_src/performance/quantization.md @@ -227,8 +227,8 @@ of 30.0f, and an 8-bit array, the quantized values represent the following: - +
QuantizedFloat
0-10.0
25530.0
12810.0
25530.0
Table 2: Example quantized value range diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md index d9a979ccbd31773b9d227ff946486706844a8f81..6724d1eaf8f85320b963eddc37947d69dcaa8471 100644 --- a/tensorflow/docs_src/performance/xla/jit.md +++ b/tensorflow/docs_src/performance/xla/jit.md @@ -137,12 +137,12 @@ TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline -should look similar to the picture below with one long bar labeled `_XlaLaunch`. +should look similar to the picture below with one long bar labeled `XlaLaunch`.
-To understand what is happening in `_XlaLaunch`, look at the console output for +To understand what is happening in `XlaLaunch`, look at the console output for statements similar to the following: ```shell diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 5887c3d88bf8c7844349cc1cc0db224586e56719..4c4f3f39348f59aa018d19d4a7368f09bcef89ed 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -581,12 +581,21 @@ Computes a sum across replicas. Arguments | Type | Semantics --------- | ------- | ----------------------------- `operand` | `XlaOp` | Array to sum across replicas. +| `replica_group_ids` | `int64` vector | Group ID for each replica. | The output shape is the same as the input shape. For example, if there are two replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)` respectively on the two replicas, then the output value from this op will be `(4.0, 7.75)` on both replicas. +`replica_group_ids` identifies the group ID of each replica. The group ID must +either be empty (all replicas belong to a single group), or contain the same +number of elements as the number of replicas. For example, if +`replica_group_ids` = {0, 1, 2, 3, 0, 1, 2, 3} has eight replicas, there are +four subgroups of replica IDs: {0, 4}, {1, 5}, {2, 6}, and {3, 7}. The size of +each subgroup *must* be identical, so, for example, using: +`replica_group_ids` = {0, 1, 2, 0} for four replicas is invalid. + Computing the result of CrossReplicaSum requires having one input from each replica, so if one replica executes a CrossReplicaSum node more times than another, then the former replica will wait forever. Since the replicas are all @@ -1299,12 +1308,10 @@ See also : : : parameters of type T and M of : : : : arbitrary type : | `dimensions` | `int64` array | array of map dimensions | -| `static_operands` | sequence of M `XlaOp`s | M arrays of arbitrary type | Applies a scalar function over the given `operands` arrays, producing an array of the same dimensions where each element is the result of the mapped function -applied to the corresponding elements in the input arrays with `static_operands` -given as additional input to `computation`. +applied to the corresponding elements in the input arrays. The mapped function is an arbitrary computation with the restriction that it has N inputs of scalar type `T` and a single output with type `S`. The output has @@ -2003,13 +2010,35 @@ Slice(b, {2, 1}, {4, 3}) produces: See also [`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). -Sorts the elements in the operand. +There are two versions of the Sort instruction: a single-operand and a +two-operand version. `Sort(operand)` +Arguments | Type | Semantics +--------- | ------- | -------------------- +`operand` | `XlaOp` | The operand to sort. + +Sorts the elements in the operand in ascending order. The operand must be rank-1. +If the operand's elements have floating point type, and the operand contains +NaN elements, the order of elements in the output is implementation-defined. + +`Sort(key, value)` + +Sorts both the key and the value operands. The keys are sorted as in the +single-operand version. The values are sorted according to the order of their +corresponding keys. For example, if the inputs are `keys = [3, 1]` and +`values = [42, 50]`, then the output of the sort is the tuple `{[1, 3], [50, 42]}`. +The sort is not guaranteed to be stable, that is, if the keys array contains +duplicates, the order of their corresponding values may not be preserved. + Arguments | Type | Semantics --------- | ------- | ------------------- -`operand` | `XlaOp` | The operand to sort +`keys` | `XlaOp` | The sort keys. +`values` | `XlaOp` | The values to sort. + +The `keys` and `values` operand must both be rank-1, and must have the same +dimensions, but may have different element types. ## Transpose diff --git a/tensorflow/docs_src/tutorials/_index.yaml b/tensorflow/docs_src/tutorials/_index.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6fc8155669bb8672eef3ed4a62af00516648c90e --- /dev/null +++ b/tensorflow/docs_src/tutorials/_index.yaml @@ -0,0 +1,251 @@ +project_path: /_project.yaml +book_path: /_book.yaml +description: +landing_page: + show_side_navs: True + rows: + - description: > +

Get Started with TensorFlow

+

+ TensorFlow is an open-source machine learning library for research and + production. TensorFlow offers APIs for beginners and experts to develop + for desktop, mobile, web, and cloud. See the sections below to get + started. +

+ items: + - custom_html: > + +
+

Learn and use ML

+
+

+ The high-level Keras API provides building blocks to create and + train deep learning models. Start with these beginner-friendly + notebook examples, then read the + TensorFlow Keras guide. +

+
    +
  1. Basic classification
  2. +
  3. Text classification
  4. +
  5. Regression
  6. +
  7. Overfitting and underfitting
  8. +
  9. Save and load
  10. +
+
+ +
+ - classname: tfo-landing-row-item-code-block + code_block: | +
+        import tensorflow as tf
+        mnist = tf.keras.datasets.mnist
+
+        (x_train, y_train),(x_test, y_test) = mnist.load_data()
+        x_train, x_test = x_train / 255.0, x_test / 255.0
+
+        model = tf.keras.models.Sequential([
+          tf.keras.layers.Flatten(),
+          tf.keras.layers.Dense(512, activation=tf.nn.relu),
+          tf.keras.layers.Dropout(0.2),
+          tf.keras.layers.Dense(10, activation=tf.nn.softmax)
+        ])
+        model.compile(optimizer='adam',
+                      loss='sparse_categorical_crossentropy',
+                      metrics=['accuracy'])
+
+        model.fit(x_train, y_train, epochs=5)
+        model.evaluate(x_test, y_test)
+        
+ {% dynamic if request.tld != 'cn' %} + Run in a Notebook + {% dynamic endif %} + + - items: + - custom_html: > +
+

Research and experimentation

+
+

+ Eager execution provides an imperative, define-by-run interface for advanced operations. Write custom layers, forward passes, and training loops with auto‑differentiation. Start with + these notebooks, then read the eager execution guide. +

+
    +
  1. + {% dynamic if request.tld == 'cn' %} + Eager execution basics + {% dynamic else %} + Eager execution basics + {% dynamic endif %} +
  2. +
  3. + {% dynamic if request.tld == 'cn' %} + Automatic differentiation and gradient tape + {% dynamic else %} + Automatic differentiation and gradient tape + {% dynamic endif %} +
  4. +
  5. + {% dynamic if request.tld == 'cn' %} + Custom training: basics + {% dynamic else %} + Custom training: basics + {% dynamic endif %} +
  6. +
  7. + {% dynamic if request.tld == 'cn' %} + Custom layers + {% dynamic else %} + Custom layers + {% dynamic endif %} +
  8. +
  9. Custom training: walkthrough
  10. +
  11. + {% dynamic if request.tld == 'cn' %} + Example: Neural machine translation w/ attention + {% dynamic else %} + Example: Neural machine translation w/ attention + {% dynamic endif %} +
  12. +
+
+ +
+ - custom_html: > +
+

ML at production scale

+ + +
+ + - description: > +

Google Colab: An easy way to learn and use TensorFlow

+

+ Colaboratory + is a Google research project created to help disseminate machine learning + education and research. It's a Jupyter notebook environment that requires + no setup to use and runs entirely in the cloud. + Read the blog post. +

+ + - description: > +

Build your first ML app

+

Create and deploy TensorFlow models on web and mobile.

+ background: grey + items: + - custom_html: > +
+ +

Web developers

+
+
+ TensorFlow.js is a WebGL accelerated, JavaScript library to train and + deploy ML models in the browser and for Node.js. +
+
+ - custom_html: > +
+ +

Mobile developers

+
+
+ TensorFlow Lite is lightweight solution for mobile and embedded devices. +
+
+ + - description: > +

Videos and updates

+

+ Subscribe to the TensorFlow + YouTube channel + and blog for + the latest videos and updates. +

+ items: + - description: > +

Get started with TensorFlow's High-Level APIs

+ youtube_id: tjsHSIG8I08 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=tjsHSIG8I08 + - description: > +

Eager execution

+ youtube_id: T8AW0fKP0Hs + background: grey + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=T8AW0fKP0Hs + - description: > +

tf.data: Fast, flexible, and easy-to-use input pipelines

+ youtube_id: uIcqeP7MFH0 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=uIcqeP7MFH0 diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d46d570a93c7da03ab12e960e65d46d5db793cbd --- /dev/null +++ b/tensorflow/docs_src/tutorials/_toc.yaml @@ -0,0 +1,93 @@ +toc: +- title: Get started with TensorFlow + path: /tutorials/ + +- title: Learn and use ML + style: accordion + section: + - title: Overview + path: /tutorials/keras/ + - title: Basic classification + path: /tutorials/keras/basic_classification + - title: Text classification + path: /tutorials/keras/basic_text_classification + - title: Regression + path: /tutorials/keras/basic_regression + - title: Overfitting and underfitting + path: /tutorials/keras/overfit_and_underfit + - title: Save and restore models + path: /tutorials/keras/save_and_restore_models + +- title: Research and experimentation + style: accordion + section: + - title: Overview + path: /tutorials/eager/ + - title: Eager execution + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_intro.ipynb + status: external + - title: Automatic differentiation + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb + status: external + - title: "Custom training: basics" + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb + status: external + - title: Custom layers + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb + status: external + - title: "Custom training: walkthrough" + path: /tutorials/eager/custom_training_walkthrough + - title: Neural machine translation + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb + status: external + +- title: Images + style: accordion + section: + - title: Build a CNN using Estimators + path: /tutorials/images/layers + - title: Image recognition + path: /tutorials/images/image_recognition + - title: Image retraining + path: /hub/tutorials/image_retraining + - title: Advanced CNN + path: /tutorials/images/deep_cnn + +- title: Sequences + style: accordion + section: + - title: Recurrent neural network + path: /tutorials/sequences/recurrent + - title: Drawing classification + path: /tutorials/sequences/recurrent_quickdraw + - title: Simple audio recognition + path: /tutorials/sequences/audio_recognition + - title: Neural machine translation + path: https://github.com/tensorflow/nmt + status: external + +- title: Data representation + style: accordion + section: + - title: Linear models + path: /tutorials/representation/wide + - title: Wide and deep learning + path: /tutorials/representation/wide_and_deep + - title: Vector representations of words + path: /tutorials/representation/word2vec + - title: Kernel methods + path: /tutorials/representation/kernel_methods + - title: Large-scale linear models + path: /tutorials/representation/linear + +- title: Non-ML + style: accordion + section: + - title: Mandelbrot set + path: /tutorials/non-ml/mandelbrot + - title: Partial differential equations + path: /tutorials/non-ml/pdes + +- break: True +- title: Next steps + path: /tutorials/next_steps diff --git a/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md new file mode 100644 index 0000000000000000000000000000000000000000..b45fbefac01c575515798af4692318ea1e905607 --- /dev/null +++ b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md @@ -0,0 +1,3 @@ +# Custom training: walkthrough + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/eager.ipynb) diff --git a/tensorflow/docs_src/tutorials/eager/index.md b/tensorflow/docs_src/tutorials/eager/index.md new file mode 100644 index 0000000000000000000000000000000000000000..5445e0c3439392d4eeb8a6b3e9d229407b5b014e --- /dev/null +++ b/tensorflow/docs_src/tutorials/eager/index.md @@ -0,0 +1,13 @@ +# Research and experimentation + +Eager execution provides an imperative, define-by-run interface for advanced +operations. Write custom layers, forward passes, and training loops with +auto differentiation. Start with these notebooks, then read the +[eager execution guide](../../guide/eager). + +1. [Eager execution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_intro.ipynb){:.external} +2. [Automatic differentiation and gradient tape](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb){:.external} +3. [Custom training: basics](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb){:.external} +4. [Custom layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb){:.external} +5. [Custom training: walkthrough](/tutorials/eager/custom_training_walkthrough) +6. [Advanced example: Neural machine translation with attention](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb){:.external} diff --git a/tensorflow/docs_src/tutorials/image_retraining.md b/tensorflow/docs_src/tutorials/image_retraining.md deleted file mode 100644 index 27784eef9cdb5c6f8b9af44b3fc3f876cda39d13..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/image_retraining.md +++ /dev/null @@ -1,4 +0,0 @@ -# How to Retrain Inception's Final Layer for New Categories - -**NOTE: This tutorial has moved to** -https://github.com/tensorflow/hub/tree/master/docs/tutorials/image_retraining.md diff --git a/tensorflow/docs_src/tutorials/deep_cnn.md b/tensorflow/docs_src/tutorials/images/deep_cnn.md similarity index 98% rename from tensorflow/docs_src/tutorials/deep_cnn.md rename to tensorflow/docs_src/tutorials/images/deep_cnn.md index 6a4c9a9b0727208a158b1b57d13ca70290961ec2..1590f15eb91a0f20a91af3d899c3e08428f6c997 100644 --- a/tensorflow/docs_src/tutorials/deep_cnn.md +++ b/tensorflow/docs_src/tutorials/images/deep_cnn.md @@ -1,7 +1,4 @@ -# Convolutional Neural Networks - -> **NOTE:** This tutorial is intended for *advanced* users of TensorFlow -and assumes expertise and experience in machine learning. +# Advanced Convolutional Neural Networks ## Overview @@ -268,7 +265,7 @@ in `cifar10_input.py`. `cifar10_train.py` periodically @{tf.train.Saver$saves} all model parameters in -@{$programmers_guide/saved_model$checkpoint files} +@{$guide/saved_model$checkpoint files} but it does *not* evaluate the model. The checkpoint file will be used by `cifar10_eval.py` to measure the predictive performance (see [Evaluating a Model](#evaluating-a-model) below). @@ -438,9 +435,6 @@ with a batch size of 64 and compare the training speed. ## Next Steps -[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You have -completed the CIFAR-10 tutorial. - If you are now interested in developing and training your own image classification system, we recommend forking this tutorial and replacing components to address your image classification problem. diff --git a/tensorflow/docs_src/tutorials/image_recognition.md b/tensorflow/docs_src/tutorials/images/image_recognition.md similarity index 99% rename from tensorflow/docs_src/tutorials/image_recognition.md rename to tensorflow/docs_src/tutorials/images/image_recognition.md index 332bcf54f02e6e3c7d805746011dfab642943cfe..432d470d0cd281f688b28761d6d6a49f4d3e1efe 100644 --- a/tensorflow/docs_src/tutorials/image_recognition.md +++ b/tensorflow/docs_src/tutorials/images/image_recognition.md @@ -434,7 +434,6 @@ should be able to transfer some of that understanding to solving related problems. One way to perform transfer learning is to remove the final classification layer of the network and extract the [next-to-last layer of the CNN](https://arxiv.org/abs/1310.1531), in this case a 2048 dimensional vector. -There's a guide to doing this @{$image_retraining$in the how-to section}. ## Resources for Learning More diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/images/layers.md similarity index 92% rename from tensorflow/docs_src/tutorials/layers.md rename to tensorflow/docs_src/tutorials/images/layers.md index 496b1e4da9d3b85d88be4dd86086e04fda51b8be..12a215b50c54f276f3c084885810c7a496769681 100644 --- a/tensorflow/docs_src/tutorials/layers.md +++ b/tensorflow/docs_src/tutorials/images/layers.md @@ -1,4 +1,4 @@ -# A Guide to TF Layers: Building a Convolutional Neural Network +# Build a Convolutional Neural Network using Estimators The TensorFlow @{tf.layers$`layers` module} provides a high-level API that makes it easy to construct a neural network. It provides methods that facilitate the @@ -190,7 +190,7 @@ def cnn_model_fn(features, labels, mode): The following sections (with headings corresponding to each code block above) dive deeper into the `tf.layers` code used to create each layer, as well as how to calculate loss, configure the training op, and generate predictions. If -you're already experienced with CNNs and @{$get_started/custom_estimators$TensorFlow `Estimator`s}, +you're already experienced with CNNs and @{$custom_estimators$TensorFlow `Estimator`s}, and find the above code intuitive, you may want to skim these sections or just skip ahead to ["Training and Evaluating the CNN MNIST Classifier"](#train_eval_mnist). @@ -470,51 +470,18 @@ as the loss metric. The following code calculates cross entropy when the model runs in either `TRAIN` or `EVAL` mode: ```python -onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) -loss = tf.losses.softmax_cross_entropy( - onehot_labels=onehot_labels, logits=logits) +loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) ``` Let's take a closer look at what's happening above. -Our `labels` tensor contains a list of predictions for our examples, e.g. `[1, -9, ...]`. In order to calculate cross-entropy, first we need to convert `labels` -to the corresponding -[one-hot encoding](https://www.quora.com/What-is-one-hot-encoding-and-when-is-it-used-in-data-science): +Our `labels` tensor contains a list of prediction indices for our examples, e.g. `[1, +9, ...]`. `logits` contains the linear outputs of our last layer. -```none -[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ...] -``` - -We use the @{tf.one_hot} function -to perform this conversion. `tf.one_hot()` has two required arguments: - -* `indices`. The locations in the one-hot tensor that will have "on - values"—i.e., the locations of `1` values in the tensor shown above. -* `depth`. The depth of the one-hot tensor—i.e., the number of target classes. - Here, the depth is `10`. +`tf.losses.sparse_softmax_cross_entropy`, calculates the softmax crossentropy +(aka: categorical crossentropy, negative log-likelihood) from these two inputs +in an efficient, numerically stable way. -The following code creates the one-hot tensor for our labels, `onehot_labels`: - -```python -onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) -``` - -Because `labels` contains a series of values from 0–9, `indices` is just our -`labels` tensor, with values cast to integers. The `depth` is `10` because we -have 10 possible target classes, one for each digit. - -Next, we compute cross-entropy of `onehot_labels` and the softmax of the -predictions from our logits layer. `tf.losses.softmax_cross_entropy()` takes -`onehot_labels` and `logits` as arguments, performs softmax activation on -`logits`, calculates cross-entropy, and returns our `loss` as a scalar `Tensor`: - -```python -loss = tf.losses.softmax_cross_entropy( - onehot_labels=onehot_labels, logits=logits) -``` ### Configure the Training Op @@ -534,8 +501,8 @@ if mode == tf.estimator.ModeKeys.TRAIN: ``` > Note: For a more in-depth look at configuring training ops for Estimator model -> functions, see @{$get_started/custom_estimators#defining-the-training-op-for-the-model$"Defining the training op for the model"} -> in the @{$get_started/custom_estimators$"Creating Estimations in tf.estimator"} tutorial. +> functions, see @{$custom_estimators#defining-the-training-op-for-the-model$"Defining the training op for the model"} +> in the @{$custom_estimators$"Creating Estimations in tf.estimator"} tutorial. ### Add evaluation metrics @@ -600,7 +567,7 @@ be saved (here, we specify the temp directory `/tmp/mnist_convnet_model`, but feel free to change to another directory of your choice). > Note: For an in-depth walkthrough of the TensorFlow `Estimator` API, see the -> tutorial @{$get_started/custom_estimators$"Creating Estimators in tf.estimator."} +> tutorial @{$custom_estimators$"Creating Estimators in tf.estimator."} ### Set Up a Logging Hook {#set_up_a_logging_hook} @@ -627,7 +594,7 @@ operation earlier when we generated the probabilities in `cnn_model_fn`. > argument, TensorFlow will assign a default name. A couple easy ways to > discover the names applied to operations are to visualize your graph on > @{$graph_viz$TensorBoard}) or to enable the -> @{$programmers_guide/debugger$TensorFlow Debugger (tfdbg)}. +> @{$guide/debugger$TensorFlow Debugger (tfdbg)}. Next, we create the `LoggingTensorHook`, passing `tensors_to_log` to the `tensors` argument. We set `every_n_iter=50`, which specifies that probabilities @@ -719,7 +686,7 @@ Here, we've achieved an accuracy of 97.3% on our test data set. To learn more about TensorFlow Estimators and CNNs in TensorFlow, see the following resources: -* @{$get_started/custom_estimators$Creating Estimators in tf.estimator} +* @{$custom_estimators$Creating Estimators in tf.estimator} provides an introduction to the TensorFlow Estimator API. It walks through configuring an Estimator, writing a model function, calculating loss, and defining a training op. diff --git a/tensorflow/docs_src/tutorials/index.md b/tensorflow/docs_src/tutorials/index.md deleted file mode 100644 index af01d3eaa12157f82c981de005708509f6652cca..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/index.md +++ /dev/null @@ -1,60 +0,0 @@ -# Tutorials - - -This section contains tutorials demonstrating how to do specific tasks -in TensorFlow. If you are new to TensorFlow, we recommend reading the -documents in the "@{$get_started$Get Started}" section before reading -these tutorials. - -## Images - -These tutorials cover different aspects of image recognition: - - * @{$layers$MNIST}, which introduces convolutional neural networks (CNNs) and - demonstrates how to build a CNN in TensorFlow. - * @{$image_recognition}, which introduces the field of image recognition and - uses a pre-trained model (Inception) for recognizing images. - * @{$image_retraining}, which has a wonderfully self-explanatory title. - * @{$deep_cnn}, which demonstrates how to build a small CNN for recognizing - images. This tutorial is aimed at advanced TensorFlow users. - - -## Sequences - -These tutorials focus on machine learning problems dealing with sequence data. - - * @{$recurrent}, which demonstrates how to use a - recurrent neural network to predict the next word in a sentence. - * @{$seq2seq}, which demonstrates how to use a - sequence-to-sequence model to translate text from English to French. - * @{$recurrent_quickdraw} - builds a classification model for drawings, directly from the sequence of - pen strokes. - * @{$audio_recognition}, which shows how to - build a basic speech recognition network. - -## Data representation - -These tutorials demonstrate various data representations that can be used in -TensorFlow. - - * @{$wide}, uses - @{tf.feature_column$feature columns} to feed a variety of data types - to linear model, to solve a classification problem. - * @{$wide_and_deep}, builds on the - above linear model tutorial, adding a deep feed-forward neural network - component and a DNN-compatible data representation. - * @{$word2vec}, which demonstrates how to - create an embedding for words. - * @{$kernel_methods}, - which shows how to improve the quality of a linear model by using explicit - kernel mappings. - -## Non Machine Learning - -Although TensorFlow specializes in machine learning, the core of TensorFlow is -a powerful numeric computation system which you can also use to solve other -kinds of math problems. For example: - - * @{$mandelbrot} - * @{$pdes} diff --git a/tensorflow/docs_src/tutorials/keras/basic_classification.md b/tensorflow/docs_src/tutorials/keras/basic_classification.md new file mode 100644 index 0000000000000000000000000000000000000000..91bbd85b2442522ef34eba236bf5bab2fc8654a7 --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/basic_classification.md @@ -0,0 +1,3 @@ +# Basic Classification + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_classification.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_regression.md b/tensorflow/docs_src/tutorials/keras/basic_regression.md new file mode 100644 index 0000000000000000000000000000000000000000..a535f22f5a41e7cb34cb8424b60d10d4ad43940e --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/basic_regression.md @@ -0,0 +1,3 @@ +# Basic Regression + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_regression.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_text_classification.md b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md new file mode 100644 index 0000000000000000000000000000000000000000..7c5d4f78968f94e4d5685a2dffe75ab649431e38 --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md @@ -0,0 +1,3 @@ +# Basic Text Classification + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_text_classification.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/index.md b/tensorflow/docs_src/tutorials/keras/index.md new file mode 100644 index 0000000000000000000000000000000000000000..9d42281c8f97fd8930770c0bc30c9bcf1e50fde6 --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/index.md @@ -0,0 +1,22 @@ +# Learn and use machine learning + +This notebook collection is inspired by the book +*[Deep Learning with Python](https://books.google.com/books?id=Yo3CAQAACAAJ)*. +These tutorials use `tf.keras`, TensorFlow's high-level Python API for building +and training deep learning models. To learn more about using Keras with +TensorFlow, see the [TensorFlow Keras Guide](../../guide/keras). + +Publisher's note: *Deep Learning with Python* introduces the field of deep +learning using the Python language and the powerful Keras library. Written by +Keras creator and Google AI researcher François Chollet, this book builds your +understanding through intuitive explanations and practical examples. + +To learn about machine learning fundamentals and concepts, consider taking the +[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/). +Additional TensorFlow and machine learning resources are listed in [next steps](../next_steps). + +1. [Basic classification](./basic_classification) +2. [Text classification](./basic_text_classification) +3. [Regression](./basic_regression) +4. [Overfitting and underfitting](./overfit_and_underfit) +5. [Save and restore models](./save_and_restore_models) diff --git a/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md new file mode 100644 index 0000000000000000000000000000000000000000..e5b5ae7b5a70f476c25cc7bb76572bf6433c289f --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md @@ -0,0 +1,3 @@ +# Overfitting and Underfitting + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/overfit_and_underfit.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md new file mode 100644 index 0000000000000000000000000000000000000000..44b377294562cf5a0c8139e88d0c7226506b32ba --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md @@ -0,0 +1,3 @@ +# Save and restore Models + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/save_and_restore_models.ipynb) diff --git a/tensorflow/docs_src/tutorials/leftnav_files b/tensorflow/docs_src/tutorials/leftnav_files deleted file mode 100644 index 888052428f951fa1a7cbd9c6d35497a056387097..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/leftnav_files +++ /dev/null @@ -1,23 +0,0 @@ -index.md - -### Images -layers.md: MNIST -image_recognition.md: Image Recognition -image_retraining.md: Image Retraining -deep_cnn.md - -### Sequences -recurrent.md -seq2seq.md: Neural Machine Translation -recurrent_quickdraw.md: Drawing Classification -audio_recognition.md - -### Data Representation -wide.md: Linear Models -wide_and_deep.md: Wide & Deep Learning -word2vec.md -kernel_methods.md: Kernel Methods - -### Non-ML -mandelbrot.md -pdes.md diff --git a/tensorflow/docs_src/tutorials/next_steps.md b/tensorflow/docs_src/tutorials/next_steps.md new file mode 100644 index 0000000000000000000000000000000000000000..01c9f7204a7ddae16bcbd9eb5702516a39f8ce4c --- /dev/null +++ b/tensorflow/docs_src/tutorials/next_steps.md @@ -0,0 +1,36 @@ +# Next steps + +## Learn more about TensorFlow + +* The [TensorFlow Guide](/guide) includes usage guides for the + high-level APIs, as well as advanced TensorFlow operations. +* [Premade Estimators](/guide/premade_estimators) are designed to + get results out of the box. Use TensorFlow without building your own models. +* [TensorFlow.js](https://js.tensorflow.org/) allows web developers to train and + deploy ML models in the browser and using Node.js. +* [TFLite](/mobile/tflite) allows mobile developers to do inference efficiently + on mobile devices. +* [TensorFlow Serving](/serving) is an open-source project that can put + TensorFlow models in production quickly. +* The [ecosystem](/ecosystem) contains more projects, including + [Magenta](https://magenta.tensorflow.org/), [TFX](/tfx), + [Swift for TensorFlow](https://github.com/tensorflow/swift), and more. + +## Learn more about machine learning + +Recommended resources include: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/), + a course from Google that introduces machine learning concepts. +* [CS 20: Tensorflow for Deep Learning Research](http://web.stanford.edu/class/cs20si/), + notes from an intro course from Stanford. +* [CS231n: Convolutional Neural Networks for Visual Recognition](http://cs231n.stanford.edu/), + a course that teaches how convolutional networks work. +* [Machine Learning Recipes](https://www.youtube.com/watch?v=cKxRvEZd3Mw&list=PLOU2XLYxmsIIuiBfYad6rFYQU_jL2ryal), + a video series that introduces basic machine learning concepts with few prerequisites. +* [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python), + a book by Francois Chollet about the Keras API, as well as an excellent hands on intro to Deep Learning. +* [Hands-on Machine Learning with Scikit-Learn and TensorFlow](https://github.com/ageron/handson-ml), + a book by Aurélien Geron's that is a clear getting-started guide to data science and deep learning. +* [Deep Learning](https://www.deeplearningbook.org/), a book by Ian Goodfellow et al. + that provides a technical dive into learning machine learning. diff --git a/tensorflow/docs_src/tutorials/mandelbrot.md b/tensorflow/docs_src/tutorials/non-ml/mandelbrot.md old mode 100755 new mode 100644 similarity index 100% rename from tensorflow/docs_src/tutorials/mandelbrot.md rename to tensorflow/docs_src/tutorials/non-ml/mandelbrot.md diff --git a/tensorflow/docs_src/tutorials/pdes.md b/tensorflow/docs_src/tutorials/non-ml/pdes.md old mode 100755 new mode 100644 similarity index 98% rename from tensorflow/docs_src/tutorials/pdes.md rename to tensorflow/docs_src/tutorials/non-ml/pdes.md index 425e8d7084e7f2505b7a3013b431345b72b38cf0..b5a0fa834a8a0a51421657180f8c7817c0e3d140 --- a/tensorflow/docs_src/tutorials/pdes.md +++ b/tensorflow/docs_src/tutorials/non-ml/pdes.md @@ -135,7 +135,6 @@ for i in range(1000): DisplayArray(U.eval(), rng=[-0.1, 0.1]) ``` -![jpeg](../images/pde_output_2.jpg) +![jpeg](../../images/pde_output_2.jpg) Look! Ripples! - diff --git a/tensorflow/docs_src/tutorials/kernel_methods.md b/tensorflow/docs_src/tutorials/representation/kernel_methods.md similarity index 98% rename from tensorflow/docs_src/tutorials/kernel_methods.md rename to tensorflow/docs_src/tutorials/representation/kernel_methods.md index 73e5c5105784ddc9729b8cea6cd31921572837e1..f3c232c51155927a4b8e5abdd6e1e04403f8caa4 100644 --- a/tensorflow/docs_src/tutorials/kernel_methods.md +++ b/tensorflow/docs_src/tutorials/representation/kernel_methods.md @@ -27,7 +27,7 @@ TensorFlow will provide support for sparse features at a later release. This tutorial uses [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models. -If you are not familiar with this API, [tf.estimator Quickstart](https://www.tensorflow.org/get_started/estimator) +If you are not familiar with this API, The [Estimator guide](../../guide/estimators.md) is a good place to start. We will use the MNIST dataset. The tutorial consists of the following steps: @@ -53,7 +53,7 @@ In order to feed data to a `tf.contrib.learn Estimator`, it is helpful to conver it to Tensors. For this, we will use an `input function` which adds Ops to the TensorFlow graph that, when executed, create mini-batches of Tensors to be used downstream. For more background on input functions, check -@{$get_started/premade_estimators#create_input_functions$this section on input functions}. +@{$premade_estimators#create_input_functions$this section on input functions}. In this example, we will use the `tf.train.shuffle_batch` Op which, besides converting numpy arrays to Tensors, allows us to specify the batch_size and whether to randomize the input every time the input_fn Ops are executed diff --git a/tensorflow/docs_src/tutorials/linear.md b/tensorflow/docs_src/tutorials/representation/linear.md similarity index 99% rename from tensorflow/docs_src/tutorials/linear.md rename to tensorflow/docs_src/tutorials/representation/linear.md index 265ded877d1ff9fb0b1cc2ad678729a3b7247aa8..3f247ade266d2675eac4d0f59a4744daa61f27ea 100644 --- a/tensorflow/docs_src/tutorials/linear.md +++ b/tensorflow/docs_src/tutorials/representation/linear.md @@ -17,7 +17,7 @@ tutorial walks through the code in greater detail. To understand this overview it will help to have some familiarity with basic machine learning concepts, and also with -@{$get_started/premade_estimators$Estimators}. +@{$premade_estimators$Estimators}. [TOC] diff --git a/tensorflow/docs_src/tutorials/wide.md b/tensorflow/docs_src/tutorials/representation/wide.md similarity index 100% rename from tensorflow/docs_src/tutorials/wide.md rename to tensorflow/docs_src/tutorials/representation/wide.md diff --git a/tensorflow/docs_src/tutorials/wide_and_deep.md b/tensorflow/docs_src/tutorials/representation/wide_and_deep.md similarity index 100% rename from tensorflow/docs_src/tutorials/wide_and_deep.md rename to tensorflow/docs_src/tutorials/representation/wide_and_deep.md diff --git a/tensorflow/docs_src/tutorials/word2vec.md b/tensorflow/docs_src/tutorials/representation/word2vec.md similarity index 100% rename from tensorflow/docs_src/tutorials/word2vec.md rename to tensorflow/docs_src/tutorials/representation/word2vec.md diff --git a/tensorflow/docs_src/tutorials/seq2seq.md b/tensorflow/docs_src/tutorials/seq2seq.md deleted file mode 100644 index 8928ba4f7da26ae2e8e9351e2c7c03f0e657f613..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/seq2seq.md +++ /dev/null @@ -1,5 +0,0 @@ -# Sequence-to-Sequence Models - -Please check out the -[tensorflow neural machine translation tutorial](https://github.com/tensorflow/nmt) -for building sequence-to-sequence models with the latest Tensorflow API. diff --git a/tensorflow/docs_src/tutorials/audio_recognition.md b/tensorflow/docs_src/tutorials/sequences/audio_recognition.md similarity index 100% rename from tensorflow/docs_src/tutorials/audio_recognition.md rename to tensorflow/docs_src/tutorials/sequences/audio_recognition.md diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/sequences/recurrent.md similarity index 98% rename from tensorflow/docs_src/tutorials/recurrent.md rename to tensorflow/docs_src/tutorials/sequences/recurrent.md index 14da2c8785276abb34d6959d738f5b39e6c6a2e8..715cc7856af1d6a3422b65a796a3d48b6c1c3e0f 100644 --- a/tensorflow/docs_src/tutorials/recurrent.md +++ b/tensorflow/docs_src/tutorials/sequences/recurrent.md @@ -2,8 +2,8 @@ ## Introduction -Take a look at [this great article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for an introduction to recurrent neural networks and LSTMs in particular. +See [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/){:.external} +for an introduction to recurrent neural networks and LSTMs. ## Language Modeling diff --git a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md similarity index 98% rename from tensorflow/docs_src/tutorials/recurrent_quickdraw.md rename to tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md index 5d83fbe2a3709c0834f448cbc316453f80428dd1..37bce5b76d46741dfe04cbf3612f71863adb02c6 100644 --- a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md +++ b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md @@ -13,7 +13,7 @@ In this tutorial we'll show how to build an RNN-based recognizer for this problem. The model will use a combination of convolutional layers, LSTM layers, and a softmax output layer to classify the drawings: -
![RNN model structure](../images/quickdraw_model.png)
+
![RNN model structure](../../images/quickdraw_model.png)
The figure above shows the structure of the model that we will build in this tutorial. The input is a drawing that is encoded as a sequence of strokes of @@ -208,7 +208,7 @@ This data is then reformatted into a tensor of shape `[num_training_samples, max_length, 3]`. Then we determine the bounding box of the original drawing in screen coordinates and normalize the size such that the drawing has unit height. -
![Size normalization](../images/quickdraw_sizenormalization.png)
+
![Size normalization](../../images/quickdraw_sizenormalization.png)
Finally, we compute the differences between consecutive points and store these as a `VarLenFeature` in a @@ -220,7 +220,7 @@ length 2. ### Defining the model To define the model we create a new `Estimator`. If you want to read more about -estimators, we recommend @{$get_started/custom_estimators$this tutorial}. +estimators, we recommend @{$custom_estimators$this tutorial}. To build the model, we: diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 07f096418f53219c9ec7000a4560d78a3ff609e1..f327b645f58f35cedd27baa8ab521e334c8e7b15 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 307eede5c03780e9244b035f020fc7846290d4d9..740224744860fdd76bea9c4531242a4976b20784 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -17,7 +17,7 @@ This version is like fully_connected_feed.py but uses data converted to a TFRecords file containing tf.train.Example protocol buffers. See: -https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files +https://www.tensorflow.org/guide/reading_data#reading_from_files for context. YOU MUST run convert_to_records before running this (but you only need to diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py index 03e60972aa660fad4af8d3535e31463c96f7c69b..86f5204ec3e8713d5d22156419b6414acb2fa677 100644 --- a/tensorflow/examples/learn/iris.py +++ b/tensorflow/examples/learn/iris.py @@ -21,7 +21,8 @@ from __future__ import division from __future__ import print_function import os -import urllib + +from six.moves.urllib.request import urlretrieve import tensorflow as tf @@ -38,9 +39,7 @@ FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] def maybe_download_iris_data(file_name, download_url): """Downloads the file and returns the number of data.""" if not os.path.exists(file_name): - raw = urllib.urlopen(download_url).read() - with open(file_name, 'w') as f: - f.write(raw) + urlretrieve(download_url, file_name) # The first line is a comma-separated string. The first one is the number of # total data in the file. diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD index d7bc6a5a7d1e4cd3927c7c5067ccc22993885994..d4070fdd1e015fb78dcf2ff72fe30b6f1746c8fb 100644 --- a/tensorflow/examples/tutorials/mnist/BUILD +++ b/tensorflow/examples/tutorials/mnist/BUILD @@ -97,7 +97,7 @@ py_binary( py_test( name = "fully_connected_feed_test", - size = "small", + size = "medium", srcs = [ "fully_connected_feed.py", ], diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go new file mode 100644 index 0000000000000000000000000000000000000000..f86c5737bc79f1e349e442669615598949ecd333 --- /dev/null +++ b/tensorflow/go/attrs.go @@ -0,0 +1,245 @@ +/* +Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +// #include +// #include "tensorflow/c/c_api.h" +import "C" +import ( + "fmt" + "unsafe" +) + +// makeCShape converts a shape specified in C.int64_t into a Shape. +func makeCShape(shape []C.int64_t) Shape { + s := Shape{dims: make([]int64, len(shape))} + for i, n := range shape { + s.dims[i] = int64(n) + } + return s +} + +// Attr returns the value of an attribute on op. It returns an error if the +// attribute does not exist. +func (op *Operation) Attr(name string) (interface{}, error) { + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + + status := newStatus() + meta := C.TF_OperationGetAttrMetadata(op.c, cname, status.c) + if err := status.Err(); err != nil { + return nil, err + } + + if meta.is_list == 1 { + return listAttribute(op, cname, meta) + } + return scalarAttribute(op, cname, meta) +} + +func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) { + status := newStatus() + + switch meta._type { + case C.TF_ATTR_STRING: + if meta.list_size == 0 { + return []string(nil), nil + } + values := make([]unsafe.Pointer, meta.list_size) + lengths := make([]C.size_t, meta.list_size) + // Add one element in case total_size is zero. + storage := make([]C.char, meta.total_size+1) + C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + list := make([]string, meta.list_size) + for i, val := range values { + length := lengths[i] + list[i] = C.GoStringN((*C.char)(val), C.int(length)) + } + return list, nil + + case C.TF_ATTR_INT: + if meta.list_size == 0 { + return []int64(nil), nil + } + list := make([]C.int64_t, meta.list_size) + C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + vals := make([]int64, meta.list_size) + for i, val := range list { + vals[i] = int64(val) + } + return vals, nil + + case C.TF_ATTR_FLOAT: + if meta.list_size == 0 { + return []float32(nil), nil + } + list := make([]C.float, meta.list_size) + C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + vals := make([]float32, meta.list_size) + for i, val := range list { + vals[i] = float32(val) + } + return vals, nil + + case C.TF_ATTR_BOOL: + if meta.list_size == 0 { + return []bool(nil), nil + } + list := make([]C.uchar, meta.list_size) + C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + vals := make([]bool, meta.list_size) + for i, val := range list { + vals[i] = val == 1 + } + return vals, nil + + case C.TF_ATTR_TYPE: + if meta.list_size == 0 { + return []DataType(nil), nil + } + list := make([]C.TF_DataType, meta.list_size) + C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + vals := make([]DataType, meta.list_size) + for i, val := range list { + vals[i] = DataType(val) + } + return vals, nil + + case C.TF_ATTR_TENSOR: + if meta.list_size == 0 { + return []*Tensor(nil), nil + } + list := make([]*C.TF_Tensor, meta.list_size) + C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + vals := make([]*Tensor, meta.list_size) + for i, t := range list { + vals[i] = newTensorFromC(t) + } + return vals, nil + + case C.TF_ATTR_SHAPE: + if meta.list_size == 0 { + return []Shape(nil), nil + } + dims := make([]*C.int64_t, meta.list_size) + numDims := make([]C.int, meta.list_size) + // Add one element in case total_size is zero. + storage := make([]C.int64_t, meta.total_size+1) + C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + list := make([]Shape, meta.list_size) + for i, dim := range dims { + numDim := numDims[i] + // If the number of dimensions is unknown, default to empty shape. + if numDim < 0 { + continue + } + // A []C.int64_t slice backed by C memory. + // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + slice := (*[1 << 30]C.int64_t)(unsafe.Pointer(dim))[:numDim:numDim] + list[i] = makeCShape(slice) + } + return list, nil + + default: + return nil, fmt.Errorf("list type %v not supported", meta._type) + } +} + +func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) { + status := newStatus() + + switch meta._type { + case C.TF_ATTR_STRING: + if meta.total_size == 0 { + return "", nil + } + v := make([]C.char, meta.total_size) + C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c) + if err := status.Err(); err != nil { + return nil, err + } + return C.GoStringN(&v[0], C.int(meta.total_size)), nil + + case C.TF_ATTR_INT: + var v C.int64_t + C.TF_OperationGetAttrInt(op.c, cname, &v, status.c) + return int64(v), status.Err() + + case C.TF_ATTR_FLOAT: + var v C.float + C.TF_OperationGetAttrFloat(op.c, cname, &v, status.c) + return float32(v), status.Err() + + case C.TF_ATTR_BOOL: + var v C.uchar + C.TF_OperationGetAttrBool(op.c, cname, &v, status.c) + return v == 1, status.Err() + + case C.TF_ATTR_TYPE: + var v C.TF_DataType + C.TF_OperationGetAttrType(op.c, cname, &v, status.c) + return DataType(v), status.Err() + + case C.TF_ATTR_TENSOR: + var v *C.TF_Tensor + C.TF_OperationGetAttrTensor(op.c, cname, &v, status.c) + if err := status.Err(); err != nil { + return nil, err + } + return newTensorFromC(v), nil + + case C.TF_ATTR_SHAPE: + numDims := meta.total_size + // If number of dims is unknown return empty shape to indicate that. + if numDims < 0 { + return Shape{}, nil + } + if numDims == 0 { + return ScalarShape(), nil + } + dims := make([]C.int64_t, numDims) + C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c) + if err := status.Err(); err != nil { + return nil, err + } + return makeCShape(dims), nil + + default: + return nil, fmt.Errorf("type %v not supported", meta._type) + } +} diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ea8af221aeef3bf1d2edeab4372ae00f0cc7e92d --- /dev/null +++ b/tensorflow/go/attrs_test.go @@ -0,0 +1,193 @@ +/* +Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +import ( + "fmt" + "reflect" + "testing" +) + +func TestOperationAttrs(t *testing.T) { + g := NewGraph() + + i := 0 + makeConst := func(v interface{}) Output { + op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v) + i++ + if err != nil { + t.Fatal(err) + } + return op + } + + makeTensor := func(v interface{}) *Tensor { + tensor, err := NewTensor(v) + if err != nil { + t.Fatal(err) + } + return tensor + } + + cases := []OpSpec{ + { + Name: "type", + Type: "Placeholder", + Attrs: map[string]interface{}{ + "dtype": Float, + }, + }, + { + Name: "list(float)", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{1, 2, 3, 4}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32{0, 1, 2, 3, 4, 5}, + }, + }, + { + Name: "list(float) empty", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32(nil), + }, + }, + /* TODO(ashankar): debug this issue and add it back later. + { + Name: "list(type),list(shape)", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst(float32(1)), + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Float, Int32}, + "shapes": []Shape{ScalarShape(), MakeShape(1, 1)}, + }, + }, + { + Name: "list(type),list(shape) empty", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Int32}, + "shapes": []Shape(nil), + }, + }, + { + Name: "list(type) empty,string empty,int", + Type: "_XlaSendFromHost", + Input: []Input{ + OutputList([]Output{}), + makeConst(""), + }, + Attrs: map[string]interface{}{ + "Tinputs": []DataType(nil), + "key": "", + "device_ordinal": int64(0), + }, + }, + */ + { + Name: "list(int),int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": []int64{1, 2}, + }, + }, + { + Name: "list(int) empty,int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": ([]int64)(nil), + }, + }, + { + Name: "list(string),type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": []string{"foo", "bar"}, + }, + }, + { + Name: "list(string) empty,type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": ([]string)(nil), + }, + }, + { + Name: "tensor", + Type: "Const", + Attrs: map[string]interface{}{ + "dtype": String, + "value": makeTensor("foo"), + }, + }, + } + + for i, spec := range cases { + op, err := g.AddOperation(spec) + if err != nil { + t.Fatal(err) + } + for key, want := range spec.Attrs { + out, err := op.Attr(key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out, want) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want) + } + wantT, ok := want.(*Tensor) + if ok { + wantVal := wantT.Value() + outVal := out.(*Tensor).Value() + if !reflect.DeepEqual(outVal, wantVal) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal) + } + } + } + } +} diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 36db3dda6bcf0738fc840d449d203f1e44f55035..d20e88e95b02b6c4f12fbaec3a9576ffc0266180 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2674,206 +2674,463 @@ func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_uppe return op.Output(0) } -// Clips tensor values to a specified min and max. +// Returns the batched diagonal part of a batched tensor. // -// Given a tensor `t`, this operation returns a tensor of the same type and -// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. -// Any values less than `clip_value_min` are set to `clip_value_min`. Any values -// greater than `clip_value_max` are set to `clip_value_max`. +// This operation returns a tensor with the `diagonal` part +// of the batched `input`. The `diagonal` part is computed as follows: +// +// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: +// +// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. +// +// The input must be at least a matrix. +// +// For example: +// +// ``` +// # 'input' is [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// and input.shape = (2, 4, 4) +// +// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// which has shape (2, 4) +// ``` // // Arguments: -// t: A `Tensor`. -// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The minimum value to clip by. -// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The maximum value to clip by. +// input: Rank `k` tensor where `k >= 2`. // -// Returns A clipped `Tensor` with the same shape as input 't'. -func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { +// Returns The extracted diagonal(s) having shape +// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`. +func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ClipByValue", + Type: "MatrixDiagPart", Input: []tf.Input{ - t, clip_value_min, clip_value_max, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a sequence of numbers. +// Returns a batched diagonal tensor with a given batched diagonal values. // -// This operation creates a sequence of numbers that begins at `start` and -// extends by increments of `delta` up to but not including `limit`. +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: +// +// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a +// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: +// +// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. // // For example: // // ``` -// # 'start' is 3 -// # 'limit' is 18 -// # 'delta' is 3 -// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] +// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// and diagonal.shape = (2, 4) +// +// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// which has shape (2, 4, 4) // ``` // // Arguments: -// start: 0-D (scalar). First entry in the sequence. -// limit: 0-D (scalar). Upper limit of sequence, exclusive. -// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +// diagonal: Rank `k`, where `k >= 1`. // -// Returns 1-D. -func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { +// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. +func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Range", + Type: "MatrixDiag", Input: []tf.Input{ - start, limit, delta, + diagonal, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for SparseSegmentSqrtN. -// -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. +// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. +type QuantizedInstanceNormAttr func(optionalAttr) + +// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. // -// Arguments: -// grad: gradient propagated to the SparseSegmentSqrtN op. -// indices: indices passed to the corresponding SparseSegmentSqrtN op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. -func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNGrad", - Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, - }, +// value: If True, `given_y_min` and `given_y_min` +// and `given_y_max` are used as the output range. Otherwise, +// the implementation computes the output range. +// If not specified, defaults to false +func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["output_range_given"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the mean along sparse segments of a tensor. -// -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. -// -// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. +// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. // -// Arguments: +// value: Output in `y_min` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_min"] = value + } +} + +// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// value: Output in `y_max` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_max"] = value + } +} + +// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// value: A small float number to avoid dividing by 0. +// If not specified, defaults to 1e-05 +func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["variance_epsilon"] = value } - opspec := tf.OpSpec{ - Type: "SparseSegmentMean", - Input: []tf.Input{ - data, indices, segment_ids, - }, +} + +// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. +// +// value: Minimum value of `y_max - y_min` +// If not specified, defaults to 0.001 +func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["min_separation"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Pop the element at the top of the stack. +// Quantized Instance normalization. // // Arguments: -// handle: The handle to a stack. -// elem_type: The type of the elem that is popped. +// x: A 4D input Tensor. +// x_min: The value represented by the lowest quantized input. +// x_max: The value represented by the highest quantized input. // -// Returns The tensor that is popped from the top of the stack. -func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { +// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. +func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "StackPopV2", + Type: "QuantizedInstanceNorm", Input: []tf.Input{ - handle, + x, x_min, x_max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes the sum along sparse segments of a tensor. +// Returns the diagonal part of the tensor. // -// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. +// This operation returns a tensor with the `diagonal` part +// of the `input`. The `diagonal` part is computed as follows: // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a +// tensor of rank `k` with dimensions `[D1,..., Dk]` where: // -// For example: +// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. // -// ```python -// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// For example: // -// tf.sparse_segment_sum_with_num_segments( -// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) -// # => [[0 0 0 0] -// # [0 0 0 0] -// # [0 0 0 0]] +// ``` +// # 'input' is [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] // -// tf.sparse_segment_sum_with_num_segments(c, -// tf.constant([0, 1]), -// tf.constant([0, 2], -// num_segments=4)) -// # => [[ 1 2 3 4] -// # [ 0 0 0 0] -// # [-1 -2 -3 -4] -// # [ 0 0 0 0]] +// tf.diag_part(input) ==> [1, 2, 3, 4] // ``` // // Arguments: +// input: Rank k tensor where k is even and not zero. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `num_segments`. -func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// Returns The extracted diagonal. +func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentSumWithNumSegments", + Type: "DiagPart", Input: []tf.Input{ - data, indices, segment_ids, num_segments, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// PreventGradientAttr is an optional argument to PreventGradient. -type PreventGradientAttr func(optionalAttr) - -// PreventGradientMessage sets the optional message attribute to value. +// Gives a guarantee to the TF runtime that the input tensor is a constant. // -// value: Will be printed in the error when anyone tries to differentiate -// this operation. -// If not specified, defaults to "" -func PreventGradientMessage(value string) PreventGradientAttr { - return func(m optionalAttr) { +// The runtime is then free to make optimizations based on this. +// +// Only accepts value typed tensors as inputs and rejects resource variable handles +// as input. +// +// Returns the input tensor without modification. +func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GuaranteeConst", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Splits a tensor into `num_split` tensors along one dimension. +// +// Arguments: +// value: The tensor to split. +// size_splits: list containing the sizes of each output tensor along the split +// dimension. Must sum to the dimension of value along split_dim. +// Can contain one -1 indicating that dimension is to be inferred. +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. +// +// +// Returns Tensors whose shape matches that of `value` +// except along `axis`, where their sizes are +// `size_splits[i]`. +func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "SplitV", + Input: []tf.Input{ + value, size_splits, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("SplitV", err) + return + } + return output +} + +// Splits a tensor into `num_split` tensors along one dimension. +// +// Arguments: +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. +// value: The tensor to split. +// num_split: The number of ways to split. Must evenly divide +// `value.shape[split_dim]`. +// +// Returns They are identically shaped tensors, whose shape matches that of `value` +// except along `axis`, where their sizes are +// `values.shape[split_dim] / num_split`. +func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "Split", + Input: []tf.Input{ + axis, value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Split", err) + return + } + return output +} + +// Concatenates tensors along one dimension. +// +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Concat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts a flat index or array of flat indices into a tuple of +// +// coordinate arrays. +// +// @compatibility(numpy) +// Equivalent to np.unravel_index +// @end_compatibility +// +// Arguments: +// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the +// flattened version of an array of dimensions dims. +// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling +// indices. +// +// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the +// same shape as the indices array. +func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnravelIndex", + Input: []tf.Input{ + indices, dims, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Subtracts `v` into specified rows of `x`. +// +// Computes y = x; y[i, :] -= v; return y. +// +// Arguments: +// x: A `Tensor` of type T. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InplaceSub", + Input: []tf.Input{ + x, i, v, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the mean along sparse segments of a tensor. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentMean", + Input: []tf.Input{ + data, indices, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Pop the element at the top of the stack. +// +// Arguments: +// handle: The handle to a stack. +// elem_type: The type of the elem that is popped. +// +// Returns The tensor that is popped from the top of the stack. +func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"elem_type": elem_type} + opspec := tf.OpSpec{ + Type: "StackPopV2", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// PreventGradientAttr is an optional argument to PreventGradient. +type PreventGradientAttr func(optionalAttr) + +// PreventGradientMessage sets the optional message attribute to value. +// +// value: Will be printed in the error when anyone tries to differentiate +// this operation. +// If not specified, defaults to "" +func PreventGradientMessage(value string) PreventGradientAttr { + return func(m optionalAttr) { m["message"] = value } } @@ -3687,24 +3944,6 @@ func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Returns x + y element-wise. -// -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Add", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // NthElementAttr is an optional argument to NthElement. type NthElementAttr func(optionalAttr) @@ -3974,69 +4213,6 @@ func Digamma(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Shuffle dimensions of x according to a permutation. -// -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Transpose", - Input: []tf.Input{ - x, perm, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MinAttr is an optional argument to Min. -type MinAttr func(optionalAttr) - -// MinKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MinKeepDims(value bool) MinAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the minimum of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Min", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. type Conv2DBackpropFilterAttr func(optionalAttr) @@ -4511,6 +4687,24 @@ func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) return op.Output(0) } +// Returns x + y element-wise. +// +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Add", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes square of x element-wise. // // I.e., \\(y = x * x = x^2\\). @@ -4563,6 +4757,68 @@ func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// Returns a batched matrix tensor with new batched diagonal values. +// +// Given `input` and `diagonal`, this operation returns a tensor with the +// same shape and values as `input`, except for the main diagonal of the +// innermost matrices. These will be overwritten by the values in `diagonal`. +// +// The output is computed as follows: +// +// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has +// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a +// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: +// +// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. +// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. +// +// Arguments: +// input: Rank `k+1`, where `k >= 1`. +// diagonal: Rank `k`, where `k >= 1`. +// +// Returns Rank `k+1`, with `output.shape = input.shape`. +func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixSetDiag", + Input: []tf.Input{ + input, diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise max of two SparseTensors. +// +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// +// Arguments: +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSparseMaximum", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // OrderedMapClearAttr is an optional argument to OrderedMapClear. type OrderedMapClearAttr func(optionalAttr) @@ -5115,53 +5371,6 @@ func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Returns a batched diagonal tensor with a given batched diagonal values. -// -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: -// -// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a -// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: -// -// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. -// -// For example: -// -// ``` -// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// and diagonal.shape = (2, 4) -// -// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// which has shape (2, 4, 4) -// ``` -// -// Arguments: -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. -func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiag", - Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the inverse permutation of a tensor. // // This operation computes the inverse of an index permutation. It takes a 1-D @@ -5930,69 +6139,140 @@ func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// DepthToSpaceAttr is an optional argument to DepthToSpace. -type DepthToSpaceAttr func(optionalAttr) - -// DepthToSpaceDataFormat sets the optional data_format attribute to value. -// If not specified, defaults to "NHWC" -func DepthToSpaceDataFormat(value string) DepthToSpaceAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthToSpace for tensors of type T. -// -// Rearranges data from depth into blocks of spatial data. -// This is the reverse transformation of SpaceToDepth. More specifically, -// this op outputs a copy of the input tensor where values from the `depth` -// dimension are moved in spatial blocks to the `height` and `width` dimensions. -// The attr `block_size` indicates the input block size and how the data is moved. -// -// * Chunks of data of size `block_size * block_size` from depth are rearranged -// into non-overlapping blocks of size `block_size x block_size` -// * The width the output tensor is `input_depth * block_size`, whereas the -// height is `input_height * block_size`. -// * The Y, X coordinates within each block of the output image are determined -// by the high order component of the input channel index. -// * The depth of the input tensor must be divisible by -// `block_size * block_size`. -// -// The `data_format` attr specifies the layout of the input and output tensors -// with the following options: -// "NHWC": `[ batch, height, width, channels ]` -// "NCHW": `[ batch, channels, height, width ]` -// "NCHW_VECT_C": -// `qint8 [ batch, channels / 4, height, width, 4 ]` -// -// It is useful to consider the operation as transforming a 6-D Tensor. -// e.g. for data_format = NHWC, -// Each element in the input tensor can be specified via 6 coordinates, -// ordered by decreasing memory layout significance as: -// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates -// within the input image, bX, bY means coordinates -// within the output block, oC means output channels). -// The output would be the input transposed to the following layout: -// n,iY,bY,iX,bX,oC -// -// This operation is useful for resizing the activations between convolutions -// (but keeping all data), e.g. instead of pooling. It is also useful for training -// purely convolutional models. +// Computes offsets of concat inputs within its output. // -// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and -// block_size = 2: +// For example: // // ``` -// x = [[[[1, 2, 3, 4]]]] -// +// # 'x' is [2, 2, 7] +// # 'y' is [2, 3, 7] +// # 'z' is [2, 5, 7] +// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] // ``` // -// This operation will output a tensor of shape `[1, 2, 2, 1]`: +// This is typically used by gradient computations for a concat operation. // -// ``` -// [[[[1], [2]], -// [[3], [4]]]] -// ``` +// Arguments: +// concat_dim: The dimension along which to concatenate. +// shape: The `N` int32 vectors representing shape of tensors being concatenated. +// +// Returns The `N` int32 vectors representing the starting offset +// of input tensors within the concatenated output. +func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatOffset", + Input: []tf.Input{ + concat_dim, tf.OutputList(shape), + }, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { + scope.UpdateErr("ConcatOffset", err) + return + } + return offset +} + +// Compute the lower regularized incomplete Gamma function `Q(a, x)`. +// +// The lower regularized incomplete Gamma function is defined as: +// +// +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// +// where +// +// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) +// +// is the lower incomplete Gamma function. +// +// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +// Gamma function. +func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthToSpaceAttr is an optional argument to DepthToSpace. +type DepthToSpaceAttr func(optionalAttr) + +// DepthToSpaceDataFormat sets the optional data_format attribute to value. +// If not specified, defaults to "NHWC" +func DepthToSpaceDataFormat(value string) DepthToSpaceAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthToSpace for tensors of type T. +// +// Rearranges data from depth into blocks of spatial data. +// This is the reverse transformation of SpaceToDepth. More specifically, +// this op outputs a copy of the input tensor where values from the `depth` +// dimension are moved in spatial blocks to the `height` and `width` dimensions. +// The attr `block_size` indicates the input block size and how the data is moved. +// +// * Chunks of data of size `block_size * block_size` from depth are rearranged +// into non-overlapping blocks of size `block_size x block_size` +// * The width the output tensor is `input_depth * block_size`, whereas the +// height is `input_height * block_size`. +// * The Y, X coordinates within each block of the output image are determined +// by the high order component of the input channel index. +// * The depth of the input tensor must be divisible by +// `block_size * block_size`. +// +// The `data_format` attr specifies the layout of the input and output tensors +// with the following options: +// "NHWC": `[ batch, height, width, channels ]` +// "NCHW": `[ batch, channels, height, width ]` +// "NCHW_VECT_C": +// `qint8 [ batch, channels / 4, height, width, 4 ]` +// +// It is useful to consider the operation as transforming a 6-D Tensor. +// e.g. for data_format = NHWC, +// Each element in the input tensor can be specified via 6 coordinates, +// ordered by decreasing memory layout significance as: +// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates +// within the input image, bX, bY means coordinates +// within the output block, oC means output channels). +// The output would be the input transposed to the following layout: +// n,iY,bY,iX,bX,oC +// +// This operation is useful for resizing the activations between convolutions +// (but keeping all data), e.g. instead of pooling. It is also useful for training +// purely convolutional models. +// +// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and +// block_size = 2: +// +// ``` +// x = [[[[1, 2, 3, 4]]]] +// +// ``` +// +// This operation will output a tensor of shape `[1, 2, 2, 1]`: +// +// ``` +// [[[[1], [2]], +// [[3], [4]]]] +// ``` // // Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, // the corresponding output will have 2x2 elements and will have a depth of @@ -6749,55 +7029,51 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) return op.Output(0) } -// Transforms a Tensor into a serialized TensorProto proto. -// -// Arguments: -// tensor: A Tensor of type `T`. +// Shuffle dimensions of x according to a permutation. // -// Returns A serialized TensorProto proto of the input tensor. -func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SerializeTensor", + Type: "Transpose", Input: []tf.Input{ - tensor, + x, perm, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatrixSolveAttr is an optional argument to MatrixSolve. -type MatrixSolveAttr func(optionalAttr) +// MinAttr is an optional argument to Min. +type MinAttr func(optionalAttr) -// MatrixSolveAdjoint sets the optional adjoint attribute to value. +// MinKeepDims sets the optional keep_dims attribute to value. // -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. +// value: If true, retain reduced dimensions with length 1. // If not specified, defaults to false -func MatrixSolveAdjoint(value bool) MatrixSolveAttr { +func MinKeepDims(value bool) MinAttr { return func(m optionalAttr) { - m["adjoint"] = value + m["keep_dims"] = value } } -// Solves systems of linear equations. +// Computes the minimum of elements across dimensions of a tensor. // -// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is -// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix -// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `True` then each output matrix satisfies -// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns Shape is `[..., M, K]`. -func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { +// Returns The reduced tensor. +func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -6806,9 +7082,9 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr a(attrs) } opspec := tf.OpSpec{ - Type: "MatrixSolve", + Type: "Min", Input: []tf.Input{ - matrix, rhs, + input, axis, }, Attrs: attrs, } @@ -6816,6 +7092,26 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr return op.Output(0) } +// Transforms a Tensor into a serialized TensorProto proto. +// +// Arguments: +// tensor: A Tensor of type `T`. +// +// Returns A serialized TensorProto proto of the input tensor. +func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SerializeTensor", + Input: []tf.Input{ + tensor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes acos of x element-wise. func Acos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -7310,69 +7606,40 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ... return op.Output(0) } -// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. -type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) +// RandomPoissonAttr is an optional argument to RandomPoisson. +type RandomPoissonAttr func(optionalAttr) -// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { +// RandomPoissonSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed(value int64) RandomPoissonAttr { return func(m optionalAttr) { - m["data_format"] = value + m["seed"] = value } } -// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { +// RandomPoissonSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed2(value int64) RandomPoissonAttr { return func(m optionalAttr) { - m["dilations"] = value + m["seed2"] = value } } -// Computes the gradients of depthwise convolution with respect to the filter. -// -// Arguments: -// input: 4-D with shape based on `data_format`. For example, if -// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, -// in_width, in_channels]` tensor. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. +// Use RandomPoissonV2 instead. // -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { +// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 +func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropFilter", + Type: "RandomPoisson", Input: []tf.Input{ - input, filter_sizes, out_backprop, + shape, rate, }, Attrs: attrs, } @@ -7380,69 +7647,29 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s return op.Output(0) } -// LRNGradAttr is an optional argument to LRNGrad. -type LRNGradAttr func(optionalAttr) - -// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// Returns the element-wise sum of a list of tensors. // -// value: A depth radius. -// If not specified, defaults to 5 -func LRNGradDepthRadius(value int64) LRNGradAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } -} - -// LRNGradBias sets the optional bias attribute to value. -// -// value: An offset (usually > 0 to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNGradBias(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} - -// LRNGradAlpha sets the optional alpha attribute to value. +// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not +// wait for all of its inputs to be ready before beginning to sum. This can +// save memory if inputs are ready at different times, since minimum temporary +// storage is proportional to the output size rather than the inputs size. // -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNGradAlpha(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// LRNGradBeta sets the optional beta attribute to value. +// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. // -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNGradBeta(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["beta"] = value - } -} - -// Gradients for Local Response Normalization. +// Returns a `Tensor` of same shape and type as the elements of `inputs`. // // Arguments: -// input_grads: 4-D with shape `[batch, height, width, channels]`. -// input_image: 4-D with shape `[batch, height, width, channels]`. -// output_image: 4-D with shape `[batch, height, width, channels]`. -// -// Returns The gradients for LRN. -func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { +// inputs: A list of `Tensor` objects, each with same shape and type. +// shape: Shape of elements of `inputs`. +func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "LRNGrad", + Type: "AccumulateNV2", Input: []tf.Input{ - input_grads, input_image, output_image, + tf.OutputList(inputs), }, Attrs: attrs, } @@ -7450,33 +7677,49 @@ func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_ return op.Output(0) } -// AnyAttr is an optional argument to Any. -type AnyAttr func(optionalAttr) +// RandomShuffleAttr is an optional argument to RandomShuffle. +type RandomShuffleAttr func(optionalAttr) -// AnyKeepDims sets the optional keep_dims attribute to value. +// RandomShuffleSeed sets the optional seed attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AnyKeepDims(value bool) AnyAttr { +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomShuffleSeed(value int64) RandomShuffleAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["seed"] = value } } -// Computes the "logical or" of elements across dimensions of a tensor. +// RandomShuffleSeed2 sets the optional seed2 attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleSeed2(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Randomly shuffles a tensor along its first dimension. +// +// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped +// to one and only one `output[i]`. For example, a mapping that might occur for a +// 3x2 tensor is: +// +// ``` +// [[1, 2], [[5, 6], +// [3, 4], ==> [1, 2], +// [5, 6]] [3, 4]] +// ``` // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// value: The tensor to be shuffled. // -// Returns The reduced tensor. -func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { +// Returns A tensor of same shape and type as `value`, shuffled along its first +// dimension. +func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -7485,9 +7728,9 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou a(attrs) } opspec := tf.OpSpec{ - Type: "Any", + Type: "RandomShuffle", Input: []tf.Input{ - input, axis, + value, }, Attrs: attrs, } @@ -7495,105 +7738,126 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou return op.Output(0) } -// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. -type ResourceApplyFtrlAttr func(optionalAttr) +// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize. +type OrderedMapIncompleteSizeAttr func(optionalAttr) -// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { +// REQUIRES: value >= 0 +func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["capacity"] = value } } -// Update '*var' according to the Ftrl-proximal scheme. -// -// accum_new = accum + grad * grad -// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 regulariation. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Returns the created operation. -func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { +// REQUIRES: value >= 0 +func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapIncompleteSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of incomplete elements in the underlying container. +func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyFtrl", - Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, lr_power, - }, + Type: "OrderedMapIncompleteSize", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// RandomUniformAttr is an optional argument to RandomUniform. -type RandomUniformAttr func(optionalAttr) +// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. +type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) -// RandomUniformSeed sets the optional seed attribute to value. +// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformSeed(value int64) RandomUniformAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { - m["seed"] = value + m["data_format"] = value } } -// RandomUniformSeed2 sets the optional seed2 attribute to value. +// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformSeed2(value int64) RandomUniformAttr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { - m["seed2"] = value + m["dilations"] = value } } -// Outputs random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// Computes the gradients of depthwise convolution with respect to the filter. // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// input: 4-D with shape based on `data_format`. For example, if +// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, +// in_width, in_channels]` tensor. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. // -// Returns A tensor of the specified shape filled with uniform random values. -func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomUniform", + Type: "DepthwiseConv2dNativeBackpropFilter", Input: []tf.Input{ - shape, + input, filter_sizes, out_backprop, }, Attrs: attrs, } @@ -7601,8 +7865,177 @@ func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional .. return op.Output(0) } -// AssertAttr is an optional argument to Assert. -type AssertAttr func(optionalAttr) +// Returns immutable tensor from memory region. +// +// The current implementation memmaps the tensor from a file. +// +// Arguments: +// dtype: Type of the returned tensor. +// shape: Shape of the returned tensor. +// memory_region_name: Name of readonly memory region used by the tensor, see +// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. +func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} + opspec := tf.OpSpec{ + Type: "ImmutableConst", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StringJoinAttr is an optional argument to StringJoin. +type StringJoinAttr func(optionalAttr) + +// StringJoinSeparator sets the optional separator attribute to value. +// +// value: string, an optional join separator. +// If not specified, defaults to "" +func StringJoinSeparator(value string) StringJoinAttr { + return func(m optionalAttr) { + m["separator"] = value + } +} + +// Joins the strings in the given list of string tensors into one tensor; +// +// with the given separator (default is an empty separator). +// +// Arguments: +// inputs: A list of string tensors. The tensors must all have the same shape, +// or be scalars. Scalars may be mixed in; these will be broadcast to the shape +// of non-scalar inputs. +func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringJoin", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. +type ResourceApplyFtrlAttr func(optionalAttr) + +// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Ftrl-proximal scheme. +// +// accum_new = accum + grad * grad +// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 regulariation. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyFtrl", + Input: []tf.Input{ + var_, accum, linear, grad, lr, l1, l2, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RandomUniformAttr is an optional argument to RandomUniform. +type RandomUniformAttr func(optionalAttr) + +// RandomUniformSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformSeed(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomUniformSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformSeed2(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with uniform random values. +func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniform", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AssertAttr is an optional argument to Assert. +type AssertAttr func(optionalAttr) // AssertSummarize sets the optional summarize attribute to value. // @@ -7768,47 +8201,6 @@ func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf return output_indices, output_values, output_shape } -// RandomPoissonAttr is an optional argument to RandomPoisson. -type RandomPoissonAttr func(optionalAttr) - -// RandomPoissonSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomPoissonSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed2(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Use RandomPoissonV2 instead. -// -// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 -func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomPoisson", - Input: []tf.Input{ - shape, rate, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. type ResourceSparseApplyFtrlV2Attr func(optionalAttr) @@ -7916,27 +8308,29 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe return op.Output(0) } -// Reads the value of a variable. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// The tensor returned by this operation is immutable. +// The hash function is deterministic on the content of the string within the +// process. // -// The value returned by this operation is guaranteed to be influenced by all the -// writes on which this operation depends directly or indirectly, and to not be -// influenced by any of the writes which depend directly or indirectly on this -// operation. +// Note that the hash function may change from time to time. +// This functionality will be deprecated and it's recommended to use +// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. // // Arguments: -// resource: handle to the resource in which to store the variable. -// dtype: the dtype of the value. -func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { +// +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "ReadVariableOp", + Type: "StringToHashBucket", Input: []tf.Input{ - resource, + string_tensor, }, Attrs: attrs, } @@ -7944,24 +8338,99 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value return op.Output(0) } -// Computes tan of x element-wise. -func Tan(scope *Scope, x tf.Output) (y tf.Output) { +// Computes gradients for the exponential linear (Elu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Elu operation. +// outputs: The outputs of the corresponding Elu operation. +// +// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, +// `gradients` otherwise. +func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Tan", + Type: "EluGrad", Input: []tf.Input{ - x, + gradients, outputs, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Updates the tree ensemble by either adding a layer to the last tree being grown +// Creates a dataset that contains `count` elements from the `input_dataset`. // -// or by starting a new tree. +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be taken. A value of `-1` indicates that all of `input_dataset` +// is taken. +// +// +func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TakeDataset", + Input: []tf.Input{ + input_dataset, count, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads the value of a variable. +// +// The tensor returned by this operation is immutable. +// +// The value returned by this operation is guaranteed to be influenced by all the +// writes on which this operation depends directly or indirectly, and to not be +// influenced by any of the writes which depend directly or indirectly on this +// operation. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// dtype: the dtype of the value. +func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "ReadVariableOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes tan of x element-wise. +func Tan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Updates the tree ensemble by either adding a layer to the last tree being grown +// +// or by starting a new tree. // // Arguments: // tree_ensemble_handle: Handle to the ensemble variable. @@ -7999,157 +8468,184 @@ func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, fe return scope.AddOperation(opspec) } -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// EncodeJpegFormat sets the optional format attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["format"] = value } } -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// EncodeJpegQuality sets the optional quality attribute to value. // -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["quality"] = value + } +} + +// EncodeJpegProgressive sets the optional progressive attribute to value. // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["progressive"] = value + } +} + +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. // -// Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { - if scope.Err() != nil { - return +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) +} + +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", - Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, - }, - Attrs: attrs, +} + +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["density_unit"] = value } - return scope.AddOperation(opspec) } -// Returns which elements of x are Inf. +// EncodeJpegXDensity sets the optional x_density attribute to value. // -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value } - opspec := tf.OpSpec{ - Type: "IsInf", - Input: []tf.Input{ - x, - }, +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. // -// N is the size of the segment being reduced. +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. // -// Arguments: +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtN", - Input: []tf.Input{ - data, indices, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: // -// This Op does not require `a_indices` be sorted in standard lexicographic order. +// * 1: Output a grayscale image. +// * 3: Output an RGB image. // // Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. -// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. -// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. -// b: `ndims`-D Tensor. With shape `a_shape`. -func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { +// image: 3-D with shape `[height, width, channels]`. +// +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseTensorDenseAdd", + Type: "EncodeJpeg", Input: []tf.Input{ - a_indices, a_values, a_shape, b, + image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) - -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) + +// MultinomialSeed sets the optional seed attribute to value. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func MultinomialSeed(value int64) MultinomialAttr { return func(m optionalAttr) { - m["dtype"] = value + m["seed"] = value } } -// Outputs deterministic pseudorandom values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. +// MultinomialSeed2 sets the optional seed2 attribute to value. // -// The outputs are a deterministic function of `shape` and `seed`. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func MultinomialSeed2(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value + } +} + +// Draws samples from a multinomial distribution. // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. // -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8158,9 +8654,9 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", + Type: "Multinomial", Input: []tf.Input{ - shape, seed, + logits, num_samples, }, Attrs: attrs, } @@ -8168,83 +8664,89 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt return op.Output(0) } -// RestoreSliceAttr is an optional argument to RestoreSlice. -type RestoreSliceAttr func(optionalAttr) +// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. +type ResourceSparseApplyAdagradDAAttr func(optionalAttr) -// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. +// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. // -// value: Index of file to open first if multiple files match -// `file_pattern`. See the documentation for `Restore`. -// If not specified, defaults to -1 -func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { return func(m optionalAttr) { - m["preferred_shard"] = value + m["use_locking"] = value } } -// Restores a tensor from checkpoint files. -// -// This is like `Restore` except that restored tensor can be listed as filling -// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the -// larger tensor and the slice that the restored tensor covers. -// -// The `shape_and_slice` input has the same format as the -// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. // // Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// shape_and_slice: Scalar. The shapes and slice specifications to use when -// restoring a tensors. -// dt: The type of the tensor to be restored. +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. // -// Returns The restored tensor. -func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { +// Returns the created operation. +func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dt": dt} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RestoreSlice", + Type: "ResourceSparseApplyAdagradDA", Input: []tf.Input{ - file_pattern, tensor_name, shape_and_slice, + var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ImagAttr is an optional argument to Imag. -type ImagAttr func(optionalAttr) +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) -// ImagTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ImagTout(value tf.DataType) ImagAttr { +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { return func(m optionalAttr) { - m["Tout"] = value + m["use_locking"] = value } } -// Returns the imaginary part of a complex number. +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the imaginary part of each element in `input`. All -// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part returned by this operation. +// That is for rows we have grad for, we update var, accum and linear as follows: +// accum_new = accum + grad * grad +// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // -// For example: +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.imag(input) ==> [4.75, 5.75] -// ``` -func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { +// Returns the created operation. +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -8253,119 +8755,103 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "Imag", + Type: "ResourceSparseApplyFtrl", Input: []tf.Input{ - input, + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ComplexAttr is an optional argument to Complex. -type ComplexAttr func(optionalAttr) - -// ComplexTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_COMPLEX64 -func ComplexTout(value tf.DataType) ComplexAttr { - return func(m optionalAttr) { - m["Tout"] = value +// Returns which elements of x are Inf. +// +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsInf", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Converts two real numbers to a complex number. +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. // -// Given a tensor `real` representing the real part of a complex number, and a -// tensor `imag` representing the imaginary part of a complex number, this -// operation returns complex numbers elementwise of the form \\(a + bj\\), where -// *a* represents the `real` part and *b* represents the `imag` part. +// N is the size of the segment being reduced. // -// The input tensors `real` and `imag` must have the same shape. +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// For example: +// Arguments: // -// ``` -// # tensor 'real' is [2.25, 3.25] -// # tensor `imag` is [4.75, 5.75] -// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] -// ``` -func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Complex", + Type: "SparseSegmentSqrtN", Input: []tf.Input{ - real, imag, + data, indices, segment_ids, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Divides sparse updates into the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] /= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] /= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. // -//
-// -//
+// This Op does not require `a_indices` be sorted in standard lexicographic order. // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. +// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. +// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. +// b: `ndims`-D Tensor. With shape `a_shape`. +func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterDiv", + Type: "SparseTensorDenseAdd", Input: []tf.Input{ - resource, indices, updates, + a_indices, a_values, a_shape, b, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. -type StatelessRandomNormalAttr func(optionalAttr) +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) -// StatelessRandomNormalDtype sets the optional dtype attribute to value. +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. // // value: The type of the output. // If not specified, defaults to DT_FLOAT -func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { return func(m optionalAttr) { m["dtype"] = value } } -// Outputs deterministic pseudorandom values from a normal distribution. +// Outputs deterministic pseudorandom values from a truncated normal distribution. // -// The generated values will have mean 0 and standard deviation 1. +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. // // The outputs are a deterministic function of `shape` and `seed`. // @@ -8374,7 +8860,7 @@ func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { // seed: 2 seeds (shape [2]). // // Returns Random values with specified shape. -func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8383,7 +8869,7 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomNormal", + Type: "StatelessTruncatedNormal", Input: []tf.Input{ shape, seed, }, @@ -8393,21 +8879,73 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option return op.Output(0) } -// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. +// RestoreSliceAttr is an optional argument to RestoreSlice. +type RestoreSliceAttr func(optionalAttr) + +// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. +// +// value: Index of file to open first if multiple files match +// `file_pattern`. See the documentation for `Restore`. +// If not specified, defaults to -1 +func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { + return func(m optionalAttr) { + m["preferred_shard"] = value + } +} + +// Restores a tensor from checkpoint files. +// +// This is like `Restore` except that restored tensor can be listed as filling +// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +// larger tensor and the slice that the restored tensor covers. +// +// The `shape_and_slice` input has the same format as the +// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// +// Arguments: +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// shape_and_slice: Scalar. The shapes and slice specifications to use when +// restoring a tensors. +// dt: The type of the tensor to be restored. +// +// Returns The restored tensor. +func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RestoreSlice", + Input: []tf.Input{ + file_pattern, tensor_name, shape_and_slice, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Divides sparse updates into the variable referenced by `resource`. // // This operation computes // // # Scalar indices -// ref[indices, ...] = min(ref[indices, ...], updates[...]) +// ref[indices, ...] /= updates[...] // // # Vector indices (for each i) -// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) +// ref[indices[i], ...] /= updates[i, ...] // // # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] // // Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. +// the same location, their contributions multiply. // // Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. // @@ -8421,12 +8959,12 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option // updates: A tensor of updated values to add to `ref`. // // Returns the created operation. -func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterMin", + Type: "ResourceScatterDiv", Input: []tf.Input{ resource, indices, updates, }, @@ -8434,271 +8972,157 @@ func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// Reshapes a quantized tensor as per the Reshape op. -// -// ``` -// -// Arguments: -// -// shape: Defines the shape of the output tensor. -// input_min: The minimum value of the input. -// input_max: The maximum value of the input. -// -// Returns This value is copied from input_min.This value is copied from input_max. -func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Mutually reduces multiple tensors of identical type and shape. +func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} opspec := tf.OpSpec{ - Type: "QuantizedReshape", + Type: "CollectiveReduce", Input: []tf.Input{ - tensor, shape, input_min, input_max, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns the truth value of (x != y) element-wise. +// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. +type StatelessRandomNormalAttr func(optionalAttr) + +// StatelessRandomNormalDtype sets the optional dtype attribute to value. // -// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NotEqual", - Input: []tf.Input{ - x, y, - }, +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Inverse 3D real-valued fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 3 dimensions of `input`. +// Outputs deterministic pseudorandom values from a normal distribution. // -// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 3 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. +// The generated values will have mean 0 and standard deviation 1. // -// Along each axis `IRFFT3D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 3D real Fourier transform. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// @compatibility(numpy) -// Equivalent to np.irfftn with 3 dimensions. -// @end_compatibility -func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Returns Random values with specified shape. +func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IRFFT3D", + Type: "StatelessRandomNormal", Input: []tf.Input{ - input, fft_length, + shape, seed, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StringSplitAttr is an optional argument to StringSplit. -type StringSplitAttr func(optionalAttr) +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) -// StringSplitSkipEmpty sets the optional skip_empty attribute to value. +// MaxPoolDataFormat sets the optional data_format attribute to value. // -// value: A `bool`. If `True`, skip the empty strings from the result. -// If not specified, defaults to true -func StringSplitSkipEmpty(value bool) StringSplitAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { return func(m optionalAttr) { - m["skip_empty"] = value + m["data_format"] = value } } -// Split elements of `input` based on `delimiter` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `input` based on `delimiter` and return a `SparseTensor` -// containing the splitted tokens. Empty tokens are ignored. -// -// `delimiter` can be empty, or a string of split characters. If `delimiter` is an -// empty string, each element of `input` is split into individual single-byte -// character strings, including splitting of UTF-8 multibyte sequences. Otherwise -// every character of `delimiter` is a potential split point. -// -// For example: -// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output -// will be -// -// indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// shape = [2, 3] -// values = ['hello', 'world', 'a', 'b', 'c'] +// Performs max pooling on the input. // // Arguments: -// input: 1-D. Strings to split. -// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse -// tensor, where the first value is N and the second value is the maximum number -// of tokens in a single input entry. -func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { +// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StringSplit", + Type: "MaxPool", Input: []tf.Input{ - input, delimiter, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. -type ResourceSparseApplyMomentumAttr func(optionalAttr) +// SparseMatMulAttr is an optional argument to SparseMatMul. +type SparseMatMulAttr func(optionalAttr) -// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. +// SparseMatMulTransposeA sets the optional transpose_a attribute to value. // If not specified, defaults to false -func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { +func SparseMatMulTransposeA(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["transpose_a"] = value } } -// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. +// SparseMatMulTransposeB sets the optional transpose_b attribute to value. // If not specified, defaults to false -func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { +func SparseMatMulTransposeB(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["transpose_b"] = value } } -// Update relevant entries in '*var' and '*accum' according to the momentum scheme. -// -// Set use_nesterov = True if you want to use Nesterov momentum. -// -// That is for rows we have grad for, we update var and accum as follows: -// -// accum = accum * momentum + grad -// var -= lr * accum -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// momentum: Momentum. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyMomentum", - Input: []tf.Input{ - var_, accum, lr, grad, indices, momentum, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns the complex conjugate of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// complex numbers that are the complex conjugate of each element in `input`. The -// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the -// real part and *b* is the imaginary part. -// -// The complex conjugate returned by this operation is of the form \\(a - bj\\). -// -// For example: -// -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] -// ``` -func Conj(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Conj", - Input: []tf.Input{ - input, - }, +// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["a_is_sparse"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) - -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. +// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. // If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { +func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["b_is_sparse"] = value } } -// Resize `images` to `size` using bilinear interpolation. -// -// Input images can be of different types but output images are always float. +// Multiply matrix "a" by matrix "b". // -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// The inputs must be two-dimensional matrices and the inner dimension of "a" must +// match the outer dimension of "b". This op is optimized for the case where at +// least one of "a" or "b" is sparse. The breakeven for using this versus a dense +// matrix multiply on one platform was 30% zero values in the sparse matrix. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { +// The gradient computation of this operation will only take advantage of sparsity +// in the input gradient when that gradient comes from a Relu. +func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { if scope.Err() != nil { return } @@ -8707,9 +9131,9 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ... a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBilinear", + Type: "SparseMatMul", Input: []tf.Input{ - images, size, + a, b, }, Attrs: attrs, } @@ -8717,128 +9141,98 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ... return op.Output(0) } -// Computes softsign: `features / (abs(features) + 1)`. -func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Softsign", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a TensorList which, when stacked, has the value of `tensor`. +// Concatenates quantized tensors along one dimension. // -// Each tensor in the result list corresponds to one row of the input tensor. +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// input_mins: The minimum scalar values for each of the input tensors. +// input_maxes: The maximum scalar values for each of the input tensors. // -// tensor: The input tensor. -// output_handle: The list. -func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListFromTensor", + Type: "QuantizedConcat", Input: []tf.Input{ - tensor, element_shape, + concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), }, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. -type GenerateVocabRemappingAttr func(optionalAttr) - -// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. -// -// value: Number of entries in the old vocab file to consider. If -1, -// use the entire old vocabulary. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { - return func(m optionalAttr) { - m["old_vocab_size"] = value - } + return op.Output(0), op.Output(1), op.Output(2) } -// Given a path to new and old vocabulary files, returns a remapping Tensor of +// Slice a `SparseTensor` based on the `start` and `size`. // -// length `num_new_vocab`, where `remapping[i]` contains the row number in the old -// vocabulary that corresponds to row `i` in the new vocabulary (starting at line -// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` -// in the new vocabulary is not in the old vocabulary. The old vocabulary is -// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the -// default value of -1. +// For example, if the input is // -// `num_vocab_offset` enables -// use in the partitioned variable case, and should generally be set through -// examining partitioning info. The format of the files should be a text file, -// with each line containing a single entity within the vocabulary. +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] // -// For example, with `new_vocab_file` a text file containing each of the following -// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], -// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be -// `[0, -1, 2]`. +// Graphically the output tensors are: // -// The op also returns a count of how many entries in the new vocabulary -// were present in the old vocabulary, which is used to calculate the number of -// values to initialize in a weight matrix remapping +// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] +// [ a ] +// [b c ] // -// This functionality can be used to remap both row vocabularies (typically, -// features) and column vocabularies (typically, classes) from TensorFlow -// checkpoints. Note that the partitioning logic relies on contiguous vocabularies -// corresponding to div-partitioned variables. Moreover, the underlying remapping -// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should -// use the corresponding index_table_from_file() as the FeatureColumn framework -// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] +// [ d e ] +// [ ] // // Arguments: -// new_vocab_file: Path to the new vocab file. -// old_vocab_file: Path to the old vocab file. -// new_vocab_offset: How many entries into the new vocab file to start reading. -// num_new_vocab: Number of entries in the new vocab file to remap. +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// start: 1-D. tensor represents the start of the slice. +// size: 1-D. tensor represents the size of the slice. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. // -// Returns A Tensor of length num_new_vocab where the element at index i -// is equal to the old ID that maps to the new ID i. This element is -1 for any -// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. -func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "GenerateVocabRemapping", + Type: "SparseSlice", Input: []tf.Input{ - new_vocab_file, old_vocab_file, + indices, values, shape, start, size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// Assigns sparse updates to the variable referenced by `resource`. +// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. // // This operation computes // // # Scalar indices -// ref[indices, ...] = updates[...] +// ref[indices, ...] = min(ref[indices, ...], updates[...]) // // # Vector indices (for each i) -// ref[indices[i], ...] = updates[i, ...] +// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) // // # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] +// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
// // Arguments: // resource: Should be from a `Variable` node. @@ -8846,12 +9240,12 @@ func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_fi // updates: A tensor of updated values to add to `ref`. // // Returns the created operation. -func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterUpdate", + Type: "ResourceScatterMin", Input: []tf.Input{ resource, indices, updates, }, @@ -8859,867 +9253,945 @@ func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, return scope.AddOperation(opspec) } -// Creates and returns an empty tensor list. +// Reshapes a quantized tensor as per the Reshape op. // -// All list elements must be tensors of dtype element_dtype and shape compatible -// with element_shape. +// ``` // -// handle: an empty tensor list. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func EmptyTensorList(scope *Scope, element_shape tf.Output, element_dtype tf.DataType) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "EmptyTensorList", +// Arguments: +// +// shape: Defines the shape of the output tensor. +// input_min: The minimum value of the input. +// input_max: The maximum value of the input. +// +// Returns This value is copied from input_min.This value is copied from input_max. +func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QuantizedReshape", Input: []tf.Input{ - element_shape, + tensor, shape, input_min, input_max, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// AvgPoolGradAttr is an optional argument to AvgPoolGrad. -type AvgPoolGradAttr func(optionalAttr) - -// AvgPoolGradDataFormat sets the optional data_format attribute to value. +// Returns the truth value of (x != y) element-wise. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { - return func(m optionalAttr) { - m["data_format"] = value +// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NotEqual", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes gradients of the average pooling function. +// Inverse 3D real-valued fast Fourier transform. +// +// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 3 dimensions of `input`. +// +// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 3 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along each axis `IRFFT3D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. -// the output of `avg_pool`. -// ksize: The size of the sliding window for each dimension of the input. -// strides: The stride of the sliding window for each dimension of the input. -// padding: The type of padding algorithm to use. +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. // -// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. -func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { +// Returns A float32 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 3D real Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.irfftn with 3 dimensions. +// @end_compatibility +func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "AvgPoolGrad", + Type: "IRFFT3D", Input: []tf.Input{ - orig_input_shape, grad, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StageClearAttr is an optional argument to StageClear. -type StageClearAttr func(optionalAttr) +// StringSplitAttr is an optional argument to StringSplit. +type StringSplitAttr func(optionalAttr) -// StageClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// StringSplitSkipEmpty sets the optional skip_empty attribute to value. // -// REQUIRES: value >= 0 -func StageClearCapacity(value int64) StageClearAttr { +// value: A `bool`. If `True`, skip the empty strings from the result. +// If not specified, defaults to true +func StringSplitSkipEmpty(value bool) StringSplitAttr { return func(m optionalAttr) { - m["capacity"] = value + m["skip_empty"] = value } } -// StageClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Split elements of `input` based on `delimiter` into a `SparseTensor`. // -// REQUIRES: value >= 0 -func StageClearMemoryLimit(value int64) StageClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageClearContainer(value string) StageClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageClearSharedName(value string) StageClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. +// Let N be the size of source (typically N will be the batch size). Split each +// element of `input` based on `delimiter` and return a `SparseTensor` +// containing the splitted tokens. Empty tokens are ignored. // -// Returns the created operation. -func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { +// `delimiter` can be empty, or a string of split characters. If `delimiter` is an +// empty string, each element of `input` is split into individual single-byte +// character strings, including splitting of UTF-8 multibyte sequences. Otherwise +// every character of `delimiter` is a potential split point. +// +// For example: +// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output +// will be +// +// indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// shape = [2, 3] +// values = ['hello', 'world', 'a', 'b', 'c'] +// +// Arguments: +// input: 1-D. Strings to split. +// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// +// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse +// tensor, where the first value is N and the second value is the maximum number +// of tokens in a single input entry. +func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StageClear", - + Type: "StringSplit", + Input: []tf.Input{ + input, delimiter, + }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. -type ComputeAccidentalHitsAttr func(optionalAttr) +// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. +type ResourceSparseApplyMomentumAttr func(optionalAttr) -// ComputeAccidentalHitsSeed sets the optional seed attribute to value. +// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { return func(m optionalAttr) { - m["seed"] = value + m["use_locking"] = value } } -// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. +// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_nesterov"] = value } } -// Computes the ids of the positions in sampled_candidates that match true_labels. +// Update relevant entries in '*var' and '*accum' according to the momentum scheme. // -// When doing log-odds NCE, the result of this op should be passed through a -// SparseToDense op, then added to the logits of the sampled candidates. This has -// the effect of 'removing' the sampled labels that match the true labels by -// making the classifier sure that they are sampled labels. +// Set use_nesterov = True if you want to use Nesterov momentum. +// +// That is for rows we have grad for, we update var and accum as follows: +// +// accum = accum * momentum + grad +// var -= lr * accum // // Arguments: -// true_classes: The true_classes output of UnpackSparseLabels. -// sampled_candidates: The sampled_candidates output of CandidateSampler. -// num_true: Number of true labels per context. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// momentum: Momentum. Must be a scalar. // -// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label -// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element -// is -FLOAT_MAX. -func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { +// Returns the created operation. +func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ComputeAccidentalHits", + Type: "ResourceSparseApplyMomentum", Input: []tf.Input{ - true_classes, sampled_candidates, + var_, accum, lr, grad, indices, momentum, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. -type QuantizedRelu6Attr func(optionalAttr) - -// QuantizedRelu6OutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { - return func(m optionalAttr) { - m["out_type"] = value - } + return scope.AddOperation(opspec) } -// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` +// Returns the complex conjugate of a complex number. // -// Arguments: +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// complex numbers that are the complex conjugate of each element in `input`. The +// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the +// real part and *b* is the imaginary part. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// The complex conjugate returned by this operation is of the form \\(a - bj\\). // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] +// ``` +func Conj(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "QuantizedRelu6", + Type: "Conj", Input: []tf.Input{ - features, min_features, max_features, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. -type FixedLengthRecordReaderV2Attr func(optionalAttr) - -// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. -// -// value: Number of bytes in the header, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["header_bytes"] = value - } -} - -// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. -// -// value: Number of bytes in the footer, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["footer_bytes"] = value - } -} - -// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. -// -// value: Number of bytes to hop before each read. Default of 0 means using -// record_bytes. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["hop_bytes"] = value - } + return op.Output(0) } -// FixedLengthRecordReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) -// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["align_corners"] = value } } -// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. +// Resize `images` to `size` using bilinear interpolation. // -// value: The type of encoding for the file. Currently ZLIB and GZIP -// are supported. Defaults to none. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["encoding"] = value - } -} - -// A Reader that outputs fixed-length records from a file. +// Input images can be of different types but output images are always float. // // Arguments: -// record_bytes: Number of bytes in the record. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns The handle to reference the Reader. -func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"record_bytes": record_bytes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FixedLengthRecordReaderV2", - + Type: "ResizeBilinear", + Input: []tf.Input{ + images, size, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. -// -// Note that the hash function may change from time to time. -// This functionality will be deprecated and it's recommended to use -// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. -// -// Arguments: -// -// num_buckets: The number of buckets. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { +// Computes softsign: `features / (abs(features) + 1)`. +func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "StringToHashBucket", + Type: "Softsign", Input: []tf.Input{ - string_tensor, + features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for the exponential linear (Elu) operation. +// Creates a TensorList which, when stacked, has the value of `tensor`. // -// Arguments: -// gradients: The backpropagated gradients to the corresponding Elu operation. -// outputs: The outputs of the corresponding Elu operation. +// Each tensor in the result list corresponds to one row of the input tensor. // -// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, -// `gradients` otherwise. -func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// tensor: The input tensor. +// output_handle: The list. +func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "EluGrad", + Type: "TensorListFromTensor", Input: []tf.Input{ - gradients, outputs, + tensor, element_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that contains `count` elements from the `input_dataset`. +// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. +type GenerateVocabRemappingAttr func(optionalAttr) + +// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. // -// Arguments: +// value: Number of entries in the old vocab file to consider. If -1, +// use the entire old vocabulary. +// If not specified, defaults to -1 // -// count: A scalar representing the number of elements from the `input_dataset` -// that should be taken. A value of `-1` indicates that all of `input_dataset` -// is taken. +// REQUIRES: value >= -1 +func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { + return func(m optionalAttr) { + m["old_vocab_size"] = value + } +} + +// Given a path to new and old vocabulary files, returns a remapping Tensor of +// +// length `num_new_vocab`, where `remapping[i]` contains the row number in the old +// vocabulary that corresponds to row `i` in the new vocabulary (starting at line +// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` +// in the new vocabulary is not in the old vocabulary. The old vocabulary is +// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the +// default value of -1. // +// `num_vocab_offset` enables +// use in the partitioned variable case, and should generally be set through +// examining partitioning info. The format of the files should be a text file, +// with each line containing a single entity within the vocabulary. // -func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// For example, with `new_vocab_file` a text file containing each of the following +// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], +// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be +// `[0, -1, 2]`. +// +// The op also returns a count of how many entries in the new vocabulary +// were present in the old vocabulary, which is used to calculate the number of +// values to initialize in a weight matrix remapping +// +// This functionality can be used to remap both row vocabularies (typically, +// features) and column vocabularies (typically, classes) from TensorFlow +// checkpoints. Note that the partitioning logic relies on contiguous vocabularies +// corresponding to div-partitioned variables. Moreover, the underlying remapping +// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should +// use the corresponding index_table_from_file() as the FeatureColumn framework +// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// +// Arguments: +// new_vocab_file: Path to the new vocab file. +// old_vocab_file: Path to the old vocab file. +// new_vocab_offset: How many entries into the new vocab file to start reading. +// num_new_vocab: Number of entries in the new vocab file to remap. +// +// Returns A Tensor of length num_new_vocab where the element at index i +// is equal to the old ID that maps to the new ID i. This element is -1 for any +// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. +func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TakeDataset", + Type: "GenerateVocabRemapping", Input: []tf.Input{ - input_dataset, count, + new_vocab_file, old_vocab_file, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// The gradient operator for the SparseAdd op. +// Assigns sparse updates to the variable referenced by `resource`. // -// The SparseAdd op calculates A + B, where A, B, and the sum are all represented -// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. -// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty -// values of A and B. +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] = updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] // // Arguments: -// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to -// the non-empty values of the sum. -// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. -// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. -// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size -// `[nnz(sum), ndims]`. +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. // -// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the -// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the -// non-empty values of B. -func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { +// Returns the created operation. +func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAddGrad", + Type: "ResourceScatterUpdate", Input: []tf.Input{ - backprop_val_grad, a_indices, b_indices, sum_indices, + resource, indices, updates, }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Computes atan of x element-wise. -func Atan(scope *Scope, x tf.Output) (y tf.Output) { +// Creates and returns an empty tensor list. +// +// All list elements must be tensors of dtype element_dtype and shape compatible +// with element_shape. +// +// handle: an empty tensor list. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func EmptyTensorList(scope *Scope, element_shape tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "Atan", + Type: "EmptyTensorList", Input: []tf.Input{ - x, + element_shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Encode audio data using the WAV file format. -// -// This operation will generate a string suitable to be saved out to create a .wav -// audio file. It will be encoded in the 16-bit PCM format. It takes in float -// values in the range -1.0f to 1.0f, and any outside that value will be clamped to -// that range. +// AvgPoolGradAttr is an optional argument to AvgPoolGrad. +type AvgPoolGradAttr func(optionalAttr) + +// AvgPoolGradDataFormat sets the optional data_format attribute to value. // -// `audio` is a 2-D float Tensor of shape `[length, channels]`. -// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the average pooling function. // // Arguments: -// audio: 2-D with shape `[length, channels]`. -// sample_rate: Scalar containing the sample frequency. +// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. +// the output of `avg_pool`. +// ksize: The size of the sliding window for each dimension of the input. +// strides: The stride of the sliding window for each dimension of the input. +// padding: The type of padding algorithm to use. // -// Returns 0-D. WAV-encoded file contents. -func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { +// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. +func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "EncodeWav", + Type: "AvgPoolGrad", Input: []tf.Input{ - audio, sample_rate, + orig_input_shape, grad, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. The hash function is a keyed hash function, where attribute `key` -// defines the key of the hash function. `key` is an array of 2 elements. +// StageClearAttr is an optional argument to StageClear. +type StageClearAttr func(optionalAttr) + +// StageClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// A strong hash is important when inputs may be malicious, e.g. URLs with -// additional components. Adversaries could try to make their inputs hash to the -// same bucket for a denial-of-service attack or to skew the results. A strong -// hash prevents this by making it difficult, if not infeasible, to compute inputs -// that hash to the same bucket. This comes at a cost of roughly 4x higher compute -// time than `tf.string_to_hash_bucket_fast`. +// REQUIRES: value >= 0 +func StageClearCapacity(value int64) StageClearAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// StageClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// key: The key for the keyed hash function passed as a list of two uint64 -// elements. +// REQUIRES: value >= 0 +func StageClearMemoryLimit(value int64) StageClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageClearContainer(value string) StageClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageClearSharedName(value string) StageClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { +// Returns the created operation. +func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "StringToHashBucketStrong", - Input: []tf.Input{ - input, - }, + Type: "StageClear", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RegexReplaceAttr is an optional argument to RegexReplace. -type RegexReplaceAttr func(optionalAttr) +// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. +type ComputeAccidentalHitsAttr func(optionalAttr) -// RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. +// ComputeAccidentalHitsSeed sets the optional seed attribute to value. // -// value: If True, the replacement is global, otherwise the replacement -// is done only on the first match. -// If not specified, defaults to true -func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { return func(m optionalAttr) { - m["replace_global"] = value + m["seed"] = value } } -// Replaces the match of pattern in input with rewrite. +// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. // -// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Computes the ids of the positions in sampled_candidates that match true_labels. +// +// When doing log-odds NCE, the result of this op should be passed through a +// SparseToDense op, then added to the logits of the sampled candidates. This has +// the effect of 'removing' the sampled labels that match the true labels by +// making the classifier sure that they are sampled labels. // // Arguments: -// input: The text to be processed. -// pattern: The regular expression to match the input. -// rewrite: The rewrite to be applied to the matched expresion. +// true_classes: The true_classes output of UnpackSparseLabels. +// sampled_candidates: The sampled_candidates output of CandidateSampler. +// num_true: Number of true labels per context. // -// Returns The text after applying pattern and rewrite. -func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { +// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label +// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element +// is -FLOAT_MAX. +func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RegexReplace", + Type: "ComputeAccidentalHits", Input: []tf.Input{ - input, pattern, rewrite, + true_classes, sampled_candidates, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes numerical negative value element-wise. -// -// I.e., \\(y = -x\\). -func Neg(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Neg", - Input: []tf.Input{ - x, - }, +// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. +type QuantizedRelu6Attr func(optionalAttr) + +// QuantizedRelu6OutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { + return func(m optionalAttr) { + m["out_type"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Execute a sub graph on a remote processor. -// -// The graph specifications(such as graph itself, input tensors and output names) -// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo -// as serialized_remote_fused_graph_execute_info. -// The specifications will be passed to a dedicated registered -// remote fused graph executor. The executor will send the graph specifications -// to a remote processor and execute that graph. The execution results -// will be passed to consumer nodes as outputs of this node. +// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` // // Arguments: -// inputs: Arbitrary number of tensors with arbitrary data types // -// serialized_remote_fused_graph_execute_info: Serialized protocol buffer -// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. // -// Returns Arbitrary number of tensors with arbitrary data types -func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RemoteFusedGraphExecute", + Type: "QuantizedRelu6", Input: []tf.Input{ - tf.OutputList(inputs), + features, min_features, max_features, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RemoteFusedGraphExecute", err) - return - } - return outputs + return op.Output(0), op.Output(1), op.Output(2) } -// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. -type MaxPool3DGradGradAttr func(optionalAttr) +// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. +type FixedLengthRecordReaderV2Attr func(optionalAttr) -// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { +// value: Number of bytes in the header, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["header_bytes"] = value } } -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) +// value: Number of bytes in the footer, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["footer_bytes"] = value } - opspec := tf.OpSpec{ - Type: "MaxPool3DGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, +} + +// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. +// +// value: Number of bytes to hop before each read. Default of 0 means using +// record_bytes. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["hop_bytes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. -type Conv3DBackpropFilterV2Attr func(optionalAttr) +// FixedLengthRecordReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} -// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["shared_name"] = value } } -// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. // -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { +// value: The type of encoding for the file. Currently ZLIB and GZIP +// are supported. Defaults to none. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["encoding"] = value } } -// Computes the gradients of 3-D convolution with respect to the filter. +// A Reader that outputs fixed-length records from a file. // // Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 5-D -// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` -// tensor. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { +// record_bytes: Number of bytes in the record. +// +// Returns The handle to reference the Reader. +func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"record_bytes": record_bytes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilterV2", - Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, + Type: "FixedLengthRecordReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. -type FakeQuantWithMinMaxVarsAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["num_bits"] = value +// The gradient operator for the SparseAdd op. +// +// The SparseAdd op calculates A + B, where A, B, and the sum are all represented +// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. +// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty +// values of A and B. +// +// Arguments: +// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to +// the non-empty values of the sum. +// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. +// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. +// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size +// `[nnz(sum), ndims]`. +// +// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the +// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the +// non-empty values of B. +func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseAddGrad", + Input: []tf.Input{ + backprop_val_grad, a_indices, b_indices, sum_indices, + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["narrow_range"] = value +// Computes atan of x element-wise. +func Atan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atan", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` +// Encode audio data using the WAV file format. // -// and `max` to 'outputs' tensor of same shape as `inputs`. +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. // -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { +// Arguments: +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. +// +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVars", + Type: "EncodeWav", Input: []tf.Input{ - inputs, min, max, + audio, sample_rate, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Applies softmax to a batched N-D `SparseTensor`. -// -// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` -// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. -// -// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost -// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly -// zero elements do not participate*. Specifically, the algorithm is equivalent -// to the following: +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix -// with shape `[B, C]`, along the size-C dimension; -// (2) Masks out the original implicitly-zero locations; -// (3) Renormalizes the remaining elements. +// The hash function is deterministic on the content of the string within the +// process. The hash function is a keyed hash function, where attribute `key` +// defines the key of the hash function. `key` is an array of 2 elements. // -// Hence, the `SparseTensor` result has exactly the same non-zero indices and -// shape. +// A strong hash is important when inputs may be malicious, e.g. URLs with +// additional components. Adversaries could try to make their inputs hash to the +// same bucket for a denial-of-service attack or to skew the results. A strong +// hash prevents this by making it difficult, if not infeasible, to compute inputs +// that hash to the same bucket. This comes at a cost of roughly 4x higher compute +// time than `tf.string_to_hash_bucket_fast`. // // Arguments: -// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a -// SparseTensor, in canonical ordering. -// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// key: The key for the keyed hash function passed as a list of two uint64 +// elements. // -// Returns 1-D. The `NNZ` values for the result `SparseTensor`. -func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} opspec := tf.OpSpec{ - Type: "SparseSoftmax", + Type: "StringToHashBucketStrong", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Partitions `data` into `num_partitions` tensors using indices from `partitions`. -// -// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` -// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` -// are placed in `outputs[i]` in lexicographic order of `js`, and the first -// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. -// In detail, -// -// ```python -// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] -// -// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) -// ``` +// RegexReplaceAttr is an optional argument to RegexReplace. +type RegexReplaceAttr func(optionalAttr) + +// RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. // -// `data.shape` must start with `partitions.shape`. +// value: If True, the replacement is global, otherwise the replacement +// is done only on the first match. +// If not specified, defaults to true +func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { + return func(m optionalAttr) { + m["replace_global"] = value + } +} + +// Replaces the match of pattern in input with rewrite. // -// For example: +// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) // -// ```python -// # Scalar partitions. -// partitions = 1 -// num_partitions = 2 -// data = [10, 20] -// outputs[0] = [] # Empty with shape [0, 2] -// outputs[1] = [[10, 20]] +// Arguments: +// input: The text to be processed. +// pattern: The regular expression to match the input. +// rewrite: The rewrite to be applied to the matched expresion. // -// # Vector partitions. -// partitions = [0, 0, 1, 1, 0] -// num_partitions = 2 -// data = [10, 20, 30, 40, 50] -// outputs[0] = [10, 20, 50] -// outputs[1] = [30, 40] -// ``` +// Returns The text after applying pattern and rewrite. +func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RegexReplace", + Input: []tf.Input{ + input, pattern, rewrite, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes numerical negative value element-wise. // -// See `dynamic_stitch` for an example on how to merge partitions back. +// I.e., \\(y = -x\\). +func Neg(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Neg", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Execute a sub graph on a remote processor. // -//
-// -//
+// The graph specifications(such as graph itself, input tensors and output names) +// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +// as serialized_remote_fused_graph_execute_info. +// The specifications will be passed to a dedicated registered +// remote fused graph executor. The executor will send the graph specifications +// to a remote processor and execute that graph. The execution results +// will be passed to consumer nodes as outputs of this node. // // Arguments: +// inputs: Arbitrary number of tensors with arbitrary data types // -// partitions: Any shape. Indices in the range `[0, num_partitions)`. -// num_partitions: The number of partitions to output. -func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { +// serialized_remote_fused_graph_execute_info: Serialized protocol buffer +// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// +// Returns Arbitrary number of tensors with arbitrary data types +func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_partitions": num_partitions} + attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} opspec := tf.OpSpec{ - Type: "DynamicPartition", + Type: "RemoteFusedGraphExecute", Input: []tf.Input{ - data, partitions, + tf.OutputList(inputs), }, Attrs: attrs, } @@ -9730,127 +10202,117 @@ func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_pa var idx int var err error if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("DynamicPartition", err) + scope.UpdateErr("RemoteFusedGraphExecute", err) return } return outputs } -// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. -type ResourceApplyAdagradAttr func(optionalAttr) +// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. +type MaxPool3DGradGradAttr func(optionalAttr) -// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. +// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { return func(m optionalAttr) { - m["update_slots"] = value + m["data_format"] = value } } -// Update '*var' according to the adagrad scheme. -// -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// Computes second-order gradients of the maxpooling function. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdagrad", + Type: "MaxPool3DGradGrad", Input: []tf.Input{ - var_, accum, lr, grad, + orig_input, orig_output, grad, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Return the shape of s0 op s1 with broadcast. -// -// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the -// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. -func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BroadcastArgs", - Input: []tf.Input{ - s0, s1, - }, - } op := scope.AddOperation(opspec) return op.Output(0) } -// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. -type DataFormatDimMapAttr func(optionalAttr) +// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. +type Conv3DBackpropFilterV2Attr func(optionalAttr) -// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. +// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. // -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { - m["src_format"] = value + m["data_format"] = value } } -// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. +// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. // -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { - m["dst_format"] = value + m["dilations"] = value } } -// Returns the dimension index in the destination data format given the one in -// -// the source data format. +// Computes the gradients of 3-D convolution with respect to the filter. // // Arguments: -// x: A Tensor with each element as a dimension index in source data format. -// Must be in the range [-4, 4). -// -// Returns A Tensor with each element as a dimension index in destination data format. -func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 5-D +// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` +// tensor. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DataFormatDimMap", + Type: "Conv3DBackpropFilterV2", Input: []tf.Input{ - x, + input, filter_sizes, out_backprop, }, Attrs: attrs, } @@ -9858,38 +10320,38 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt return op.Output(0) } -// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. -type ResourceApplyPowerSignAttr func(optionalAttr) +// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. +type FakeQuantWithMinMaxVarsAttr func(optionalAttr) -// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { +// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["num_bits"] = value } } -// Update '*var' according to the AddSign update. +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g -// variable <- variable - lr_t * update +// and `max` to 'outputs' tensor of same shape as `inputs`. // -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// logbase: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. // -// Returns the created operation. -func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return } @@ -9898,161 +10360,160 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyPowerSign", + Type: "FakeQuantWithMinMaxVars", Input: []tf.Input{ - var_, m, lr, logbase, sign_decay, beta, grad, + inputs, min, max, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Locks a mutex resource. The output is the lock. So long as the lock tensor -// -// is alive, any other request to use `MutexLock` with this mutex will wait. -// -// This is particularly useful for creating a critical section when used in -// conjunction with `MutexLockIdentity`: -// -// ```python -// -// mutex = mutex_v2( -// shared_name=handle_name, container=container, name=name) -// -// def execute_in_critical_section(fn, *args, **kwargs): -// lock = gen_resource_variable_ops.mutex_lock(mutex) -// -// with ops.control_dependencies([lock]): -// r = fn(*args, **kwargs) -// -// with ops.control_dependencies(nest.flatten(r)): -// with ops.colocate_with(mutex): -// ensure_lock_exists = mutex_lock_identity(lock) -// -// # Make sure that if any element of r is accessed, all of -// # them are executed together. -// r = nest.map_structure(tf.identity, r) +// Applies softmax to a batched N-D `SparseTensor`. // -// with ops.control_dependencies([ensure_lock_exists]): -// return nest.map_structure(tf.identity, r) -// ``` +// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` +// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. // -// While `fn` is running in the critical section, no other functions which wish to -// use this critical section may run. +// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost +// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly +// zero elements do not participate*. Specifically, the algorithm is equivalent +// to the following: // -// Often the use case is that two executions of the same graph, in parallel, -// wish to run `fn`; and we wish to ensure that only one of them executes -// at a time. This is especially important if `fn` modifies one or more -// variables at a time. +// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix +// with shape `[B, C]`, along the size-C dimension; +// (2) Masks out the original implicitly-zero locations; +// (3) Renormalizes the remaining elements. // -// It is also useful if two separate functions must share a resource, but we -// wish to ensure the usage is exclusive. +// Hence, the `SparseTensor` result has exactly the same non-zero indices and +// shape. // // Arguments: -// mutex: The mutex resource to lock. +// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a +// SparseTensor, in canonical ordering. +// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. // -// Returns A tensor that keeps a shared pointer to a lock on the mutex; -// when the Tensor is destroyed, the use count on the shared pointer is decreased -// by 1. When it reaches 0, the lock is released. -func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { +// Returns 1-D. The `NNZ` values for the result `SparseTensor`. +func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MutexLock", + Type: "SparseSoftmax", Input: []tf.Input{ - mutex, + sp_indices, sp_values, sp_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the mean along segments of a tensor. +// Partitions `data` into `num_partitions` tensors using indices from `partitions`. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +// are placed in `outputs[i]` in lexicographic order of `js`, and the first +// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +// In detail, // -// Computes a tensor such that -// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is -// over `j` such that `segment_ids[j] == i` and `N` is the total number of -// values summed. +// ```python +// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] // -// If the mean is empty for a given segment ID `i`, `output[i] = 0`. +// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) +// ``` +// +// `data.shape` must start with `partitions.shape`. +// +// For example: +// +// ```python +// # Scalar partitions. +// partitions = 1 +// num_partitions = 2 +// data = [10, 20] +// outputs[0] = [] # Empty with shape [0, 2] +// outputs[1] = [[10, 20]] +// +// # Vector partitions. +// partitions = [0, 0, 1, 1, 0] +// num_partitions = 2 +// data = [10, 20, 30, 40, 50] +// outputs[0] = [10, 20, 50] +// outputs[1] = [30, 40] +// ``` +// +// See `dynamic_stitch` for an example on how to merge partitions back. // //
-// +// //
// // Arguments: // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// partitions: Any shape. Indices in the range `[0, num_partitions)`. +// num_partitions: The number of partitions to output. +func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_partitions": num_partitions} opspec := tf.OpSpec{ - Type: "SegmentMean", + Type: "DynamicPartition", Input: []tf.Input{ - data, segment_ids, + data, partitions, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("DynamicPartition", err) + return + } + return outputs } -// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. -type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) +// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. +type ResourceApplyAdagradAttr func(optionalAttr) -// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { +func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. -// -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update '*var' according to the adagrad scheme. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) // // Arguments: // var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). +// accum: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { +func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10061,137 +10522,66 @@ func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyCenteredRMSProp", + Type: "ResourceApplyAdagrad", Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, + var_, accum, lr, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Creates a dataset that batches `batch_size` elements from `input_dataset`. -// -// Arguments: -// -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// +// Return the shape of s0 op s1 with broadcast. // -func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. +func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "BatchDataset", + Type: "BroadcastArgs", Input: []tf.Input{ - input_dataset, batch_size, + s0, s1, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. -type DecodeAndCropJpegAttr func(optionalAttr) +// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. +type DataFormatDimMapAttr func(optionalAttr) -// DecodeAndCropJpegChannels sets the optional channels attribute to value. +// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeAndCropJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} - -// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value + m["src_format"] = value } } -// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. +// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. // -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { return func(m optionalAttr) { - m["dct_method"] = value + m["dst_format"] = value } } -// Decode and Crop a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// +// Returns the dimension index in the destination data format given the one in // -// It is equivalent to a combination of decode and crop, but much faster by only -// decoding partial jpeg image. +// the source data format. // // Arguments: -// contents: 0-D. The JPEG-encoded image. -// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// x: A Tensor with each element as a dimension index in source data format. +// Must be in the range [-4, 4). // -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { +// Returns A Tensor with each element as a dimension index in destination data format. +func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { if scope.Err() != nil { return } @@ -10200,9 +10590,9 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeAndCropJpeg", + Type: "DataFormatDimMap", Input: []tf.Input{ - contents, crop_window, + x, }, Attrs: attrs, } @@ -10210,313 +10600,182 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, return op.Output(0) } -// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. -type AllCandidateSamplerAttr func(optionalAttr) - -// AllCandidateSamplerSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. +type ResourceApplyPowerSignAttr func(optionalAttr) -// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_locking"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. +// Update '*var' according to the AddSign update. // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g +// variable <- variable - lr_t * update // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to produce. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// logbase: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns the created operation. +func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AllCandidateSampler", + Type: "ResourceApplyPowerSign", Input: []tf.Input{ - true_classes, + var_, m, lr, logbase, sign_decay, beta, grad, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Adds two `SparseTensor` objects to produce another `SparseTensor`. +// Locks a mutex resource. The output is the lock. So long as the lock tensor // -// The input `SparseTensor` objects' indices are assumed ordered in standard -// lexicographic order. If this is not the case, before this step run -// `SparseReorder` to restore index ordering. +// is alive, any other request to use `MutexLock` with this mutex will wait. // -// By default, if two values sum to zero at some index, the output `SparseTensor` -// would still include that particular location in its index, storing a zero in the -// corresponding value slot. To override this, callers can specify `thresh`, -// indicating that if the sum has a magnitude strictly smaller than `thresh`, its -// corresponding value and index would then not be included. In particular, -// `thresh == 0` (default) means everything is kept and actual thresholding happens -// only for a positive value. +// This is particularly useful for creating a critical section when used in +// conjunction with `MutexLockIdentity`: // -// In the following shapes, `nnz` is the count after taking `thresh` into account. +// ```python +// +// mutex = mutex_v2( +// shared_name=handle_name, container=container, name=name) +// +// def execute_in_critical_section(fn, *args, **kwargs): +// lock = gen_resource_variable_ops.mutex_lock(mutex) +// +// with ops.control_dependencies([lock]): +// r = fn(*args, **kwargs) +// +// with ops.control_dependencies(nest.flatten(r)): +// with ops.colocate_with(mutex): +// ensure_lock_exists = mutex_lock_identity(lock) +// +// # Make sure that if any element of r is accessed, all of +// # them are executed together. +// r = nest.map_structure(tf.identity, r) +// +// with ops.control_dependencies([ensure_lock_exists]): +// return nest.map_structure(tf.identity, r) +// ``` +// +// While `fn` is running in the critical section, no other functions which wish to +// use this critical section may run. +// +// Often the use case is that two executions of the same graph, in parallel, +// wish to run `fn`; and we wish to ensure that only one of them executes +// at a time. This is especially important if `fn` modifies one or more +// variables at a time. +// +// It is also useful if two separate functions must share a resource, but we +// wish to ensure the usage is exclusive. // // Arguments: -// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. -// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. -// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. -// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. -// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. -// thresh: 0-D. The magnitude threshold that determines if an output value/index -// pair takes space. -func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { +// mutex: The mutex resource to lock. +// +// Returns A tensor that keeps a shared pointer to a lock on the mutex; +// when the Tensor is destroyed, the use count on the shared pointer is decreased +// by 1. When it reaches 0, the lock is released. +func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAdd", + Type: "MutexLock", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, + mutex, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. -type OrderedMapPeekAttr func(optionalAttr) - -// OrderedMapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Computes the mean along segments of a tensor. // -// REQUIRES: value >= 0 -func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// REQUIRES: value >= 0 -func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the +// Computes a tensor such that +// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +// over `j` such that `segment_ids[j] == i` and `N` is the total number of +// values summed. // -// underlying container does not contain this key -// this op will block until it does. This Op is optimized for -// performance. -func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OrderedMapPeek", - Input: []tf.Input{ - key, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapPeek", err) - return - } - return values -} - -// Inverse fast Fourier transform. +// If the mean is empty for a given segment ID `i`, `output[i] = 0`. // -// Computes the inverse 1-dimensional discrete Fourier transform over the -// inner-most dimension of `input`. +//
+// +//
// // Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft -// @end_compatibility -func IFFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Generates values in an interval. -// -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. -// -// For example: -// -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` // -// Arguments: -// start: First entry in the range. -// stop: Last entry in the range. -// num: Number of values to generate. +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. // -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LinSpace", + Type: "SegmentMean", Input: []tf.Input{ - start, stop, num, + data, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) - -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. -// -// value: whether to ignore the error when the resource -// doesn't exist. -// If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { - return func(m optionalAttr) { - m["ignore_lookup_error"] = value - } -} - -// Deletes the resource specified by the handle. -// -// All subsequent operations using the resource will result in a NotFound -// error status. -// -// Arguments: -// resource: handle to the resource to delete. -// -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DestroyResourceOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. -type ResourceSparseApplyRMSPropAttr func(optionalAttr) +// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. +type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) -// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { +func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the RMSProp algorithm. +// Update '*var' according to the centered RMSProp algorithm. // -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. +// +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, // and mom will not update in iterations during which the grad is zero. // // mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// mean_grad = decay * mean_grad + (1-decay) * gradient +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) // // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -10524,6 +10783,7 @@ func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSProp // // Arguments: // var_: Should be from a Variable(). +// mg: Should be from a Variable(). // ms: Should be from a Variable(). // mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. @@ -10534,7 +10794,7 @@ func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSProp // indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { +func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10543,168 +10803,174 @@ func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyRMSProp", + Type: "ResourceSparseApplyCenteredRMSProp", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns the truth value of (x > y) element-wise. +// Creates a dataset that batches `batch_size` elements from `input_dataset`. // -// *NOTE*: `Greater` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// +// +func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Greater", + Type: "BatchDataset", Input: []tf.Input{ - x, y, + input_dataset, batch_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. -type SampleDistortedBoundingBoxAttr func(optionalAttr) +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. +// +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed precision at `k` as a `bool Tensor`. +func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InTopKV2", + Input: []tf.Input{ + predictions, targets, k, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. +// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. +type DecodeAndCropJpegAttr func(optionalAttr) + +// DecodeAndCropJpegChannels sets the optional channels attribute to value. // -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. +// value: Number of color channels for the decoded image. // If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { +func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["seed"] = value + m["channels"] = value } } -// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. +// DecodeAndCropJpegRatio sets the optional ratio attribute to value. // -// value: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. -// If not specified, defaults to 0.1 -func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["min_object_covered"] = value + m["ratio"] = value } } -// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. +// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. // -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["aspect_ratio_range"] = value + m["fancy_upscaling"] = value } } -// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. +// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. // -// value: The cropped area of the image must contain a fraction of the -// supplied image within in this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["area_range"] = value + m["try_recover_truncated"] = value } } -// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. +// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. // -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["max_attempts"] = value + m["acceptable_fraction"] = value } } -// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. // -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. -// If not specified, defaults to false -func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value + m["dct_method"] = value } } -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. +// Decode and Crop a JPEG-encoded image to a uint8 tensor. // -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. +// The attr `channels` indicates the desired number of color channels for the +// decoded image. // -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. +// Accepted values are: // -// For example, +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. // -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. // -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. // -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` // -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. +// It is equivalent to a combination of decode and crop, but much faster by only +// decoding partial jpeg image. // // Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. +// contents: 0-D. The JPEG-encoded image. +// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. // -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { if scope.Err() != nil { return } @@ -10713,273 +10979,268 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBox", + Type: "DecodeAndCropJpeg", Input: []tf.Input{ - image_size, bounding_boxes, + contents, crop_window, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// LRNAttr is an optional argument to LRN. -type LRNAttr func(optionalAttr) - -// LRNDepthRadius sets the optional depth_radius attribute to value. -// -// value: 0-D. Half-width of the 1-D normalization window. -// If not specified, defaults to 5 -func LRNDepthRadius(value int64) LRNAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } + return op.Output(0) } -// LRNBias sets the optional bias attribute to value. -// -// value: An offset (usually positive to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNBias(value float32) LRNAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} +// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. +type AllCandidateSamplerAttr func(optionalAttr) -// LRNAlpha sets the optional alpha attribute to value. +// AllCandidateSamplerSeed sets the optional seed attribute to value. // -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNAlpha(value float32) LRNAttr { +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { - m["alpha"] = value + m["seed"] = value } } -// LRNBeta sets the optional beta attribute to value. +// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNBeta(value float32) LRNAttr { +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { - m["beta"] = value + m["seed2"] = value } } -// Local Response Normalization. +// Generates labels for candidate sampling with a learned unigram distribution. // -// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last -// dimension), and each vector is normalized independently. Within a given vector, -// each component is divided by the weighted, squared sum of inputs within -// `depth_radius`. In detail, +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // -// sqr_sum[a, b, c, d] = -// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) -// output = input / (bias + alpha * sqr_sum) ** beta +// For each batch, this op picks a single set of sampled candidate labels. // -// For details, see [Krizhevsky et al., ImageNet classification with deep -// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// input: 4-D. -func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to produce. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LRN", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that zips together `input_datasets`. -func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ZipDataset", + Type: "AllCandidateSampler", Input: []tf.Input{ - tf.OutputList(input_datasets), + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. -type ResourceSparseApplyAdagradAttr func(optionalAttr) - -// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. +// Adds two `SparseTensor` objects to produce another `SparseTensor`. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { - return func(m optionalAttr) { - m["update_slots"] = value - } -} - -// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// The input `SparseTensor` objects' indices are assumed ordered in standard +// lexicographic order. If this is not the case, before this step run +// `SparseReorder` to restore index ordering. // -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// By default, if two values sum to zero at some index, the output `SparseTensor` +// would still include that particular location in its index, storing a zero in the +// corresponding value slot. To override this, callers can specify `thresh`, +// indicating that if the sum has a magnitude strictly smaller than `thresh`, its +// corresponding value and index would then not be included. In particular, +// `thresh == 0` (default) means everything is kept and actual thresholding happens +// only for a positive value. // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// In the following shapes, `nnz` is the count after taking `thresh` into account. // -// Returns the created operation. -func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { +// Arguments: +// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. +// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. +// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. +// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. +// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. +// thresh: 0-D. The magnitude threshold that determines if an output value/index +// pair takes space. +func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagrad", + Type: "SparseAdd", Input: []tf.Input{ - var_, accum, lr, grad, indices, + a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. -type StatelessRandomUniformAttr func(optionalAttr) +// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. +type OrderedMapPeekAttr func(optionalAttr) -// StatelessRandomUniformDtype sets the optional dtype attribute to value. +// OrderedMapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { +// REQUIRES: value >= 0 +func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { return func(m optionalAttr) { - m["dtype"] = value + m["capacity"] = value } } -// Outputs deterministic pseudorandom random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. -// -// The outputs are a deterministic function of `shape` and `seed`. +// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// REQUIRES: value >= 0 +func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the // -// Returns Random values with specified shape. -func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { +// underlying container does not contain this key +// this op will block until it does. This Op is optimized for +// performance. +func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomUniform", + Type: "OrderedMapPeek", Input: []tf.Input{ - shape, seed, + key, indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapPeek", err) + return + } + return values } -// Makes its input available to the next iteration. +// Inverse fast Fourier transform. +// +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. // // Arguments: -// data: The tensor to be made available to the next iteration. +// input: A complex64 tensor. // -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its inverse 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft +// @end_compatibility +func IFFT(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NextIteration", + Type: "IFFT", Input: []tf.Input{ - data, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Fact", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AngleAttr is an optional argument to Angle. -type AngleAttr func(optionalAttr) +// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. +type ResourceSparseApplyRMSPropAttr func(optionalAttr) -// AngleTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func AngleTout(value tf.DataType) AngleAttr { +// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { return func(m optionalAttr) { - m["Tout"] = value + m["use_locking"] = value } } -// Returns the argument of a complex number. +// Update '*var' according to the RMSProp algorithm. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the argument of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part. +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. // -// The argument returned by this operation is of the form \\(atan2(b, a)\\). +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) // -// For example: +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.angle(input) ==> [2.0132, 1.056] -// ``` +// Arguments: +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. // -// @compatibility(numpy) -// Equivalent to np.angle. -// @end_compatibility -func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. +// +// Returns the created operation. +func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10988,137 +11249,179 @@ func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "Angle", + Type: "ResourceSparseApplyRMSProp", Input: []tf.Input{ - input, + var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Returns the truth value of (x > y) element-wise. +// +// *NOTE*: `Greater` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Greater", + Input: []tf.Input{ + x, y, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } -// VarHandleOpAttr is an optional argument to VarHandleOp. -type VarHandleOpAttr func(optionalAttr) +// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. +type SampleDistortedBoundingBoxAttr func(optionalAttr) -// VarHandleOpContainer sets the optional container attribute to value. +// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. // -// value: the container this variable is placed in. -// If not specified, defaults to "" -func VarHandleOpContainer(value string) VarHandleOpAttr { +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["container"] = value + m["seed"] = value } } -// VarHandleOpSharedName sets the optional shared_name attribute to value. +// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. // -// value: the name by which this variable is referred to. -// If not specified, defaults to "" -func VarHandleOpSharedName(value string) VarHandleOpAttr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed2"] = value } } -// Creates a handle to a Variable resource. +// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. // -// Arguments: -// dtype: the type of this variable. Must agree with the dtypes -// of all ops using this variable. -// shape: The (possibly partially specified) shape of this variable. -func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - for _, a := range optional { - a(attrs) +// value: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. +// If not specified, defaults to 0.1 +func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["min_object_covered"] = value } - opspec := tf.OpSpec{ - Type: "VarHandleOp", +} - Attrs: attrs, +// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. +// +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["aspect_ratio_range"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Elementwise computes the bitwise XOR of `x` and `y`. +// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. // -// The result will have those bits set, that are different in `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return +// value: The cropped area of the image must contain a fraction of the +// supplied image within in this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["area_range"] = value } - opspec := tf.OpSpec{ - Type: "BitwiseXor", - Input: []tf.Input{ - x, y, - }, +} + +// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. +// +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["max_attempts"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Deserialize `SparseTensor` objects. +// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. // -// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where -// the last dimension stores serialized `SparseTensor` objects and the other N -// dimensions (N >= 0) correspond to a batch. The ranks of the original -// `SparseTensor` objects must all match. When the final `SparseTensor` is -// created, its rank is the rank of the incoming `SparseTensor` objects plus N; -// the sparse tensors have been concatenated along new dimensions, one for each -// batch. +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value + } +} + +// Generate a single randomly distorted bounding box for an image. // -// The output `SparseTensor` object's shape values for the original dimensions -// are the max across the input `SparseTensor` objects' shape values for the -// corresponding dimensions. The new dimensions match the size of the batch. +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. // -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. // -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] +// For example, // -// and +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) // -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) // -// then the final deserialized `SparseTensor` will be: +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` // -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. // // Arguments: -// serialized_sparse: The serialized `SparseTensor` objects. The last dimension -// must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. +// +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DeserializeSparse", + Type: "SampleDistortedBoundingBox", Input: []tf.Input{ - serialized_sparse, + image_size, bounding_boxes, }, Attrs: attrs, } @@ -11126,46 +11429,66 @@ func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataT return op.Output(0), op.Output(1), op.Output(2) } -// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. -type ResourceApplyRMSPropAttr func(optionalAttr) +// LRNAttr is an optional argument to LRN. +type LRNAttr func(optionalAttr) -// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// LRNDepthRadius sets the optional depth_radius attribute to value. // -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { +// value: 0-D. Half-width of the 1-D normalization window. +// If not specified, defaults to 5 +func LRNDepthRadius(value int64) LRNAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["depth_radius"] = value } } -// Update '*var' according to the RMSProp algorithm. +// LRNBias sets the optional bias attribute to value. // -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. +// value: An offset (usually positive to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNBias(value float32) LRNAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNAlpha sets the optional alpha attribute to value. // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNAlpha(value float32) LRNAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNBeta sets the optional beta attribute to value. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNBeta(value float32) LRNAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Local Response Normalization. // -// Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. +// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last +// dimension), and each vector is normalized independently. Within a given vector, +// each component is divided by the weighted, squared sum of inputs within +// `depth_radius`. In detail, // -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// sqr_sum[a, b, c, d] = +// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) +// output = input / (bias + alpha * sqr_sum) ** beta // -// Returns the created operation. -func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { +// For details, see [Krizhevsky et al., ImageNet classification with deep +// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// +// Arguments: +// input: 4-D. +func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -11174,17 +11497,248 @@ func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyRMSProp", + Type: "LRN", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, + input, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. -type ResourceScatterNdUpdateAttr func(optionalAttr) +// Creates a dataset that zips together `input_datasets`. +func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ZipDataset", + Input: []tf.Input{ + tf.OutputList(input_datasets), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. +type ResourceSparseApplyAdagradAttr func(optionalAttr) + +// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdagrad", + Input: []tf.Input{ + var_, accum, lr, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. +type StatelessRandomUniformAttr func(optionalAttr) + +// StatelessRandomUniformDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniform", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Makes its input available to the next iteration. +// +// Arguments: +// data: The tensor to be made available to the next iteration. +// +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NextIteration", + Input: []tf.Input{ + data, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fact", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Elementwise computes the bitwise XOR of `x` and `y`. +// +// The result will have those bits set, that are different in `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseXor", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deserialize `SparseTensor` objects. +// +// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +// the last dimension stores serialized `SparseTensor` objects and the other N +// dimensions (N >= 0) correspond to a batch. The ranks of the original +// `SparseTensor` objects must all match. When the final `SparseTensor` is +// created, its rank is the rank of the incoming `SparseTensor` objects plus N; +// the sparse tensors have been concatenated along new dimensions, one for each +// batch. +// +// The output `SparseTensor` object's shape values for the original dimensions +// are the max across the input `SparseTensor` objects' shape values for the +// corresponding dimensions. The new dimensions match the size of the batch. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: +// +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: The serialized `SparseTensor` objects. The last dimension +// must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "DeserializeSparse", + Input: []tf.Input{ + serialized_sparse, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. +type ResourceScatterNdUpdateAttr func(optionalAttr) // ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. // @@ -11482,69 +12036,15 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output return op.Output(0) } -// ResizeAreaAttr is an optional argument to ResizeArea. -type ResizeAreaAttr func(optionalAttr) - -// ResizeAreaAlignCorners sets the optional align_corners attribute to value. +// 2D real-valued fast Fourier transform. // -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize `images` to `size` using area interpolation. +// Computes the 2-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 2 dimensions of `input`. // -// Input images can be of different types but output images are always float. -// -// The range of pixel values for the output image might be slightly different -// from the range for the input image because of limited numerical precision. -// To guarantee an output range, for example `[0.0, 1.0]`, apply -// `tf.clip_by_value` to the output. -// -// Each output pixel is computed by first transforming the pixel's footprint into -// the input tensor and then averaging the pixels that intersect the footprint. An -// input pixel's contribution to the average is weighted by the fraction of its -// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeArea", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D real-valued fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 2 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. // // Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the // corresponding dimension of `input`, the dimension is cropped. If it is larger, @@ -11711,23 +12211,6 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow return op.Output(0) } -// Mutually reduces multiple tensors of identical type and shape. -func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} - opspec := tf.OpSpec{ - Type: "CollectiveReduce", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // This op consumes a lock created by `MutexLock`. // // This op exists to consume a tensor created by `MutexLock` (other than @@ -11829,81 +12312,6 @@ func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and return tensors } -// Creates a dataset that skips `count` elements from the `input_dataset`. -// -// Arguments: -// -// count: A scalar representing the number of elements from the `input_dataset` -// that should be skipped. If count is -1, skips everything. -// -// -func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SkipDataset", - Input: []tf.Input{ - input_dataset, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the maximum along segments of a tensor. -// -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. -// -// Computes a tensor such that -// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such -// that `segment_ids[j] == i`. -// -// If the max is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// Arguments: -// -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SegmentMax", - Input: []tf.Input{ - data, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes hyperbolic tangent of `x` element-wise. -func Tanh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Tanh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Receives a tensor value broadcast from another device. func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { if scope.Err() != nil { @@ -12361,342 +12769,100 @@ func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []i return op.Output(0), op.Output(1) } -// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. -type ResourceSparseApplyAdagradDAAttr func(optionalAttr) - -// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { - return func(m optionalAttr) { - m["use_locking"] = value +// Returns the truth value of NOT x element-wise. +func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LogicalNot", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. +// 3D real-valued fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 3 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. // -// Returns the created operation. -func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { +// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the their 3D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfftn with 3 dimensions. +// @end_compatibility +func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagradDA", + Type: "RFFT3D", Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, + input, fft_length, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) - -// EncodeJpegFormat sets the optional format attribute to value. -// -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["format"] = value - } -} +// TensorArrayV3Attr is an optional argument to TensorArrayV3. +type TensorArrayV3Attr func(optionalAttr) -// EncodeJpegQuality sets the optional quality attribute to value. +// TensorArrayV3ElementShape sets the optional element_shape attribute to value. // -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { return func(m optionalAttr) { - m["quality"] = value + m["element_shape"] = value } } -// EncodeJpegProgressive sets the optional progressive attribute to value. +// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. // -// value: If True, create a JPEG that loads progressively (coarse to fine). +// value: A boolean that determines whether writes to the TensorArray +// are allowed to grow the size. By default, this is not allowed. // If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { +func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { return func(m optionalAttr) { - m["progressive"] = value + m["dynamic_size"] = value } } -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. // -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { +// value: If true (default), Tensors in the TensorArray are cleared +// after being read. This disables multiple read semantics but allows early +// release of memory. +// If not specified, defaults to true +func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { return func(m optionalAttr) { - m["optimize_size"] = value + m["clear_after_read"] = value } } -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} - -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. -// -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value - } -} - -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["x_density"] = value - } -} - -// EncodeJpegYDensity sets the optional y_density attribute to value. -// -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value - } -} - -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["xmp_metadata"] = value - } -} - -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeJpeg", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) - -// MultinomialSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// MultinomialSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// MultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { - return func(m optionalAttr) { - m["output_dtype"] = value - } -} - -// Draws samples from a multinomial distribution. -// -// Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. -// -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Multinomial", - Input: []tf.Input{ - logits, num_samples, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the truth value of NOT x element-wise. -func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalNot", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 3D real-valued fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 3 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. -// -// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfftn with 3 dimensions. -// @end_compatibility -func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT3D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorArrayV3Attr is an optional argument to TensorArrayV3. -type TensorArrayV3Attr func(optionalAttr) - -// TensorArrayV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. -// -// value: A boolean that determines whether writes to the TensorArray -// are allowed to grow the size. By default, this is not allowed. -// If not specified, defaults to false -func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. -// -// value: If true (default), Tensors in the TensorArray are cleared -// after being read. This disables multiple read semantics but allows early -// release of memory. -// If not specified, defaults to true -func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["clear_after_read"] = value - } -} - -// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. // // value: If true (default is false), then all // elements in the TensorArray will be expected to have have identical shapes. @@ -12965,123 +13131,7 @@ func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_ba return op.Output(0) } -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) - -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. -// -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", - Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) - -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. -// If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Subtracts sparse updates from the variable referenced by `resource`. +// Subtracts sparse updates from the variable referenced by `resource`. // // This operation computes // @@ -13122,62 +13172,6 @@ func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// Inverse 2D fast Fourier transform. -// -// Computes the inverse 2-dimensional discrete Fourier transform over the -// inner-most 2 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft2 -// @end_compatibility -func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. type ResourceApplyProximalGradientDescentAttr func(optionalAttr) @@ -13846,44 +13840,37 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms return scope.AddOperation(opspec) } -// RealAttr is an optional argument to Real. -type RealAttr func(optionalAttr) - -// RealTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func RealTout(value tf.DataType) RealAttr { - return func(m optionalAttr) { - m["Tout"] = value +// Computes the gradient for the inverse of `x` wrt its input. +// +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReciprocalGrad", + Input: []tf.Input{ + y, dy, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the real part of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the real part of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real -// part returned by this operation and *b* is the imaginary part. -// -// For example: +// Returns the min of x and y (i.e. x < y ? x : y) element-wise. // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.real(input) ==> [-2.25, 3.25] -// ``` -func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { +// *NOTE*: `Minimum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Real", + Type: "Minimum", Input: []tf.Input{ - input, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -15296,31 +15283,6 @@ func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTrees return op.Output(0) } -// Concatenates tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Concat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. type ResourceApplyMomentumAttr func(optionalAttr) @@ -16239,23 +16201,136 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D return op.Output(0) } -// Returns element-wise remainder of division. This emulates C semantics in that +// 2D fast Fourier transform. // -// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * -// y + truncate_mod(x, y) = x`. +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. // -// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TruncateMod", - Input: []tf.Input{ - x, y, - }, - } +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse 2D fast Fourier transform. +// +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft2 +// @end_compatibility +func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. +type ResourceApplyRMSPropAttr func(optionalAttr) + +// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the RMSProp algorithm. +// +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyRMSProp", + Input: []tf.Input{ + var_, ms, mom, lr, rho, momentum, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns element-wise remainder of division. This emulates C semantics in that +// +// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +// y + truncate_mod(x, y) = x`. +// +// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TruncateMod", + Input: []tf.Input{ + x, y, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } @@ -17512,69 +17587,6 @@ func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.D return op.Output(0), op.Output(1), op.Output(2) } -// StringJoinAttr is an optional argument to StringJoin. -type StringJoinAttr func(optionalAttr) - -// StringJoinSeparator sets the optional separator attribute to value. -// -// value: string, an optional join separator. -// If not specified, defaults to "" -func StringJoinSeparator(value string) StringJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } -} - -// Joins the strings in the given list of string tensors into one tensor; -// -// with the given separator (default is an empty separator). -// -// Arguments: -// inputs: A list of string tensors. The tensors must all have the same shape, -// or be scalars. Scalars may be mixed in; these will be broadcast to the shape -// of non-scalar inputs. -func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringJoin", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns immutable tensor from memory region. -// -// The current implementation memmaps the tensor from a file. -// -// Arguments: -// dtype: Type of the returned tensor. -// shape: Shape of the returned tensor. -// memory_region_name: Name of readonly memory region used by the tensor, see -// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. -func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} - opspec := tf.OpSpec{ - Type: "ImmutableConst", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Inverse real-valued fast Fourier transform. // // Computes the inverse 1-dimensional discrete Fourier transform of a real-valued @@ -17755,75 +17767,185 @@ func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes [ return op.Output(0), op.Output(1), op.Output(2) } -// Concatenates quantized tensors along one dimension. +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) + +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. +// +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} // // Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// input_mins: The minimum scalar values for each of the input tensors. -// input_maxes: The maximum scalar values for each of the input tensors. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns the created operation. +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QuantizedConcat", + Type: "ResourceApplyProximalAdagrad", Input: []tf.Input{ - concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + var_, accum, lr, l1, l2, grad, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Slice a `SparseTensor` based on the `start` and `size`. +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) + +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. // -// For example, if the input is +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. // -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. // -// Graphically the output tensors are: +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. // -// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] -// [ a ] -// [b c ] +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. // -// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] -// [ d e ] -// [ ] +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient of the sigmoid of `x` wrt its input. +// +// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +// `dy` is the corresponding input gradient. +func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SigmoidGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Convert one or more images from HSV to RGB. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// See `rgb_to_hsv` for a description of the HSV encoding. // // Arguments: -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// start: 1-D. tensor represents the start of the slice. -// size: 1-D. tensor represents the size of the slice. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. // -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSlice", + Type: "HSVToRGB", Input: []tf.Input{ - indices, values, shape, start, size, + images, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) +} + +// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// +// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest +// layer. +func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesGetEnsembleStates", + Input: []tf.Input{ + tree_ensemble_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } // Returns the element-wise min of two SparseTensors. @@ -17956,89 +18078,6 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype return op.Output(0), op.Output(1), op.Output(2) } -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) - -// MaxPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. -// -// Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed precision at `k` as a `bool Tensor`. -func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InTopKV2", - Input: []tf.Input{ - predictions, targets, k, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Assigns a new value to a variable. // // Any ReadVariableOp with a control dependency on this op is guaranteed to return @@ -18620,107 +18659,44 @@ func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feat return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights } -// SparseMatMulAttr is an optional argument to SparseMatMul. -type SparseMatMulAttr func(optionalAttr) +// ShapeAttr is an optional argument to Shape. +type ShapeAttr func(optionalAttr) -// SparseMatMulTransposeA sets the optional transpose_a attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeA(value bool) SparseMatMulAttr { +// ShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeOutType(value tf.DataType) ShapeAttr { return func(m optionalAttr) { - m["transpose_a"] = value + m["out_type"] = value } } -// SparseMatMulTransposeB sets the optional transpose_b attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeB(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value +// Returns the shape of a tensor. +// +// This operation returns a 1-D integer tensor representing the shape of `input`. +// +// For example: +// +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] +// ``` +func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["a_is_sparse"] = value + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Shape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, } -} - -// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["b_is_sparse"] = value - } -} - -// Multiply matrix "a" by matrix "b". -// -// The inputs must be two-dimensional matrices and the inner dimension of "a" must -// match the outer dimension of "b". This op is optimized for the case where at -// least one of "a" or "b" is sparse. The breakeven for using this versus a dense -// matrix multiply on one platform was 30% zero values in the sparse matrix. -// -// The gradient computation of this operation will only take advantage of sparsity -// in the input gradient when that gradient comes from a Relu. -func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseMatMul", - Input: []tf.Input{ - a, b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ShapeAttr is an optional argument to Shape. -type ShapeAttr func(optionalAttr) - -// ShapeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeOutType(value tf.DataType) ShapeAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Returns the shape of a tensor. -// -// This operation returns a 1-D integer tensor representing the shape of `input`. -// -// For example: -// -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] -// ``` -func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Shape", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) + op := scope.AddOperation(opspec) + return op.Output(0) } // Computes the power of one value to another. @@ -19183,88 +19159,58 @@ func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Forwards the input to the output. -// -// This operator represents the loop termination condition used by the -// "pivot" switches of a loop. -// -// Arguments: -// input: A boolean scalar, representing the branch predicate of the Switch op. -// -// Returns The same tensor as `input`. -func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LoopCond", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// RandomGammaAttr is an optional argument to RandomGamma. +type RandomGammaAttr func(optionalAttr) -// Computes the gradient for the inverse of `x` wrt its input. +// RandomGammaSeed sets the optional seed attribute to value. // -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReciprocalGrad", - Input: []tf.Input{ - y, dy, - }, +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomGammaSeed(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// RandomGammaSeed2 sets the optional seed2 attribute to value. // -// *NOTE*: `Minimum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Minimum", - Input: []tf.Input{ - x, y, - }, +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomGammaSeed2(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns the element-wise sum of a list of tensors. -// -// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not -// wait for all of its inputs to be ready before beginning to sum. This can -// save memory if inputs are ready at different times, since minimum temporary -// storage is proportional to the output size rather than the inputs size. -// -// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// Outputs random values from the Gamma distribution(s) described by alpha. // -// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// This op uses the algorithm by Marsaglia et al. to acquire samples via +// transformation-rejection from pairs of uniform and normal random variables. +// See http://dl.acm.org/citation.cfm?id=358414 // // Arguments: -// inputs: A list of `Tensor` objects, each with same shape and type. -// shape: Shape of elements of `inputs`. -func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in alpha. +// alpha: A tensor in which each scalar is a "shape" parameter describing the +// associated gamma distribution. +// +// Returns A tensor with shape `shape + shape(alpha)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. +func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "AccumulateNV2", + Type: "RandomGamma", Input: []tf.Input{ - tf.OutputList(inputs), + shape, alpha, }, Attrs: attrs, } @@ -19320,60 +19266,24 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp return op.Output(0), op.Output(1), op.Output(2) } -// RandomGammaAttr is an optional argument to RandomGamma. -type RandomGammaAttr func(optionalAttr) - -// RandomGammaSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomGammaSeed(value int64) RandomGammaAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomGammaSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomGammaSeed2(value int64) RandomGammaAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from the Gamma distribution(s) described by alpha. +// Forwards the input to the output. // -// This op uses the algorithm by Marsaglia et al. to acquire samples via -// transformation-rejection from pairs of uniform and normal random variables. -// See http://dl.acm.org/citation.cfm?id=358414 +// This operator represents the loop termination condition used by the +// "pivot" switches of a loop. // // Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in alpha. -// alpha: A tensor in which each scalar is a "shape" parameter describing the -// associated gamma distribution. +// input: A boolean scalar, representing the branch predicate of the Switch op. // -// Returns A tensor with shape `shape + shape(alpha)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. -func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { +// Returns The same tensor as `input`. +func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomGamma", + Type: "LoopCond", Input: []tf.Input{ - shape, alpha, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -19476,374 +19386,323 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf return op.Output(0) } -// RandomShuffleAttr is an optional argument to RandomShuffle. -type RandomShuffleAttr func(optionalAttr) - -// RandomShuffleSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomShuffleSeed(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleSeed2 sets the optional seed2 attribute to value. +// Computes gradients for SparseSegmentSqrtN. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleSeed2(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Randomly shuffles a tensor along its first dimension. -// -// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped -// to one and only one `output[i]`. For example, a mapping that might occur for a -// 3x2 tensor is: -// -// ``` -// [[1, 2], [[5, 6], -// [3, 4], ==> [1, 2], -// [5, 6]] [3, 4]] -// ``` +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. // // Arguments: -// value: The tensor to be shuffled. -// -// Returns A tensor of same shape and type as `value`, shuffled along its first -// dimension. -func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { +// grad: gradient propagated to the SparseSegmentSqrtN op. +// indices: indices passed to the corresponding SparseSegmentSqrtN op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. +func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomShuffle", + Type: "SparseSegmentSqrtNGrad", Input: []tf.Input{ - value, + grad, indices, segment_ids, output_dim0, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize. -type OrderedMapIncompleteSizeAttr func(optionalAttr) +// LRNGradAttr is an optional argument to LRNGrad. +type LRNGradAttr func(optionalAttr) -// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// LRNGradDepthRadius sets the optional depth_radius attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr { +// value: A depth radius. +// If not specified, defaults to 5 +func LRNGradDepthRadius(value int64) LRNGradAttr { return func(m optionalAttr) { - m["capacity"] = value + m["depth_radius"] = value } } -// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// LRNGradBias sets the optional bias attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr { +// value: An offset (usually > 0 to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNGradBias(value float32) LRNGradAttr { return func(m optionalAttr) { - m["memory_limit"] = value + m["bias"] = value } } -// OrderedMapIncompleteSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr { +// LRNGradAlpha sets the optional alpha attribute to value. +// +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNGradAlpha(value float32) LRNGradAttr { return func(m optionalAttr) { - m["container"] = value + m["alpha"] = value } } -// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr { +// LRNGradBeta sets the optional beta attribute to value. +// +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNGradBeta(value float32) LRNGradAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["beta"] = value } } -// Op returns the number of incomplete elements in the underlying container. -func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) { +// Gradients for Local Response Normalization. +// +// Arguments: +// input_grads: 4-D with shape `[batch, height, width, channels]`. +// input_image: 4-D with shape `[batch, height, width, channels]`. +// output_image: 4-D with shape `[batch, height, width, channels]`. +// +// Returns The gradients for LRN. +func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapIncompleteSize", - + Type: "LRNGrad", + Input: []tf.Input{ + input_grads, input_image, output_image, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Counts the number of occurrences of each value in an integer array. +// AnyAttr is an optional argument to Any. +type AnyAttr func(optionalAttr) + +// AnyKeepDims sets the optional keep_dims attribute to value. // -// Outputs a vector with length `size` and the same dtype as `weights`. If -// `weights` are empty, then index `i` stores the number of times the value `i` is -// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of -// the value in `weights` at each index where the corresponding value in `arr` is -// `i`. +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AnyKeepDims(value bool) AnyAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical or" of elements across dimensions of a tensor. // -// Values in `arr` outside of the range [0, size) are ignored. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// arr: int32 `Tensor`. -// size: non-negative int32 scalar `Tensor`. -// weights: is an int32, int64, float32, or float64 `Tensor` with the same -// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights -// equal to 1. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for -// each value in the range [0, size). -func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { +// Returns The reduced tensor. +func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Bincount", + Type: "Any", Input: []tf.Input{ - arr, size, weights, + input, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CumsumAttr is an optional argument to Cumsum. -type CumsumAttr func(optionalAttr) - -// CumsumExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumsum. -// If not specified, defaults to false -func CumsumExclusive(value bool) CumsumAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} - -// CumsumReverse sets the optional reverse attribute to value. -// -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumsumReverse(value bool) CumsumAttr { - return func(m optionalAttr) { - m["reverse"] = value - } -} - -// Compute the cumulative sum of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumsum, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is -// performed instead: +// Creates a sequence of numbers. // -// ```python -// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] -// ``` +// This operation creates a sequence of numbers that begins at `start` and +// extends by increments of `delta` up to but not including `limit`. // -// By setting the `reverse` kwarg to `True`, the cumsum is performed in the -// opposite direction: +// For example: // -// ```python -// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] // ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +// # 'start' is 3 +// # 'limit' is 18 +// # 'delta' is 3 +// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] // ``` // // Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { +// start: 0-D (scalar). First entry in the sequence. +// limit: 0-D (scalar). Upper limit of sequence, exclusive. +// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +// +// Returns 1-D. +func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Cumsum", + Type: "Range", Input: []tf.Input{ - x, axis, + start, limit, delta, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CumprodAttr is an optional argument to Cumprod. -type CumprodAttr func(optionalAttr) - -// CumprodExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumprod. -// If not specified, defaults to false -func CumprodExclusive(value bool) CumprodAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) -// CumprodReverse sets the optional reverse attribute to value. +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. // -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumprodReverse(value bool) CumprodAttr { +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { return func(m optionalAttr) { - m["reverse"] = value + m["ignore_lookup_error"] = value } } -// Compute the cumulative product of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumprod, which means that the first -// element of the input is identical to the first element of the output: +// Deletes the resource specified by the handle. // -// ```python -// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] -// ``` +// All subsequent operations using the resource will result in a NotFound +// error status. // -// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is -// performed instead: +// Arguments: +// resource: handle to the resource to delete. // -// ```python -// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumprod is performed in the -// opposite direction: -// -// ```python -// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] -// ``` +// Returns the created operation. +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DestroyResourceOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Generates values in an interval. // -// This is more efficient than using separate `tf.reverse` ops. +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. // -// The `reverse` and `exclusive` kwargs can also be combined: +// For example: // -// ```python -// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] // ``` // // Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { +// start: First entry in the range. +// stop: Last entry in the range. +// num: Number of values to generate. +// +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Cumprod", + Type: "LinSpace", Input: []tf.Input{ - x, axis, + start, stop, num, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. -type QuantizedMatMulAttr func(optionalAttr) +// ComplexAttr is an optional argument to Complex. +type ComplexAttr func(optionalAttr) -// QuantizedMatMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { +// ComplexTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_COMPLEX64 +func ComplexTout(value tf.DataType) ComplexAttr { return func(m optionalAttr) { - m["Toutput"] = value + m["Tout"] = value } } -// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// Converts two real numbers to a complex number. // -// value: If true, `a` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. +// Given a tensor `real` representing the real part of a complex number, and a +// tensor `imag` representing the imaginary part of a complex number, this +// operation returns complex numbers elementwise of the form \\(a + bj\\), where +// *a* represents the `real` part and *b* represents the `imag` part. // -// value: If true, `b` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value +// The input tensors `real` and `imag` must have the same shape. +// +// For example: +// +// ``` +// # tensor 'real' is [2.25, 3.25] +// # tensor `imag` is [4.75, 5.75] +// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] +// ``` +func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Complex", + Input: []tf.Input{ + real, imag, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. -// -// value: The type of output produced by activation function -// following this operation. -// If not specified, defaults to DT_QUINT8 -func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { +// ImagAttr is an optional argument to Imag. +type ImagAttr func(optionalAttr) + +// ImagTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ImagTout(value tf.DataType) ImagAttr { return func(m optionalAttr) { - m["Tactivation"] = value + m["Tout"] = value } } -// Perform a quantized matrix multiplication of `a` by the matrix `b`. +// Returns the imaginary part of a complex number. // -// The inputs must be two-dimensional matrices and the inner dimension of -// `a` (after being transposed if `transpose_a` is non-zero) must match the -// outer dimension of `b` (after being transposed if `transposed_b` is -// non-zero). +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the imaginary part of each element in `input`. All +// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part returned by this operation. // -// Arguments: -// a: Must be a two-dimensional tensor. -// b: Must be a two-dimensional tensor. -// min_a: The float value that the lowest quantized `a` value represents. -// max_a: The float value that the highest quantized `a` value represents. -// min_b: The float value that the lowest quantized `b` value represents. -// max_b: The float value that the highest quantized `b` value represents. +// For example: // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.imag(input) ==> [4.75, 5.75] +// ``` +func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -19852,81 +19711,84 @@ func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, ma a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedMatMul", + Type: "Imag", Input: []tf.Input{ - a, b, min_a, max_a, min_b, max_b, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Does nothing. Serves as a control trigger for scheduling. +// Computes the maximum along segments of a tensor. // -// Only useful as a placeholder for control edges. +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// Returns the created operation. -func ControlTrigger(scope *Scope) (o *tf.Operation) { +// Computes a tensor such that +// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such +// that `segment_ids[j] == i`. +// +// If the max is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// Arguments: +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ControlTrigger", + Type: "SegmentMax", + Input: []tf.Input{ + data, segment_ids, + }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. Prefer `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { +// Computes hyperbolic tangent of `x` element-wise. +func Tanh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalization", + Type: "Tanh", Input: []tf.Input{ - t, m, v, beta, gamma, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Use TensorArrayReadV3 +// Creates a dataset that skips `count` elements from the `input_dataset`. // -// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 -func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be skipped. If count is -1, skips everything. +// +// +func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayReadV2", + Type: "SkipDataset", Input: []tf.Input{ - handle, index, flow_in, + input_dataset, count, }, Attrs: attrs, } @@ -19934,32 +19796,31 @@ func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in return op.Output(0) } -// QuantizedMulAttr is an optional argument to QuantizedMul. -type QuantizedMulAttr func(optionalAttr) +// RealAttr is an optional argument to Real. +type RealAttr func(optionalAttr) -// QuantizedMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { +// RealTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func RealTout(value tf.DataType) RealAttr { return func(m optionalAttr) { - m["Toutput"] = value + m["Tout"] = value } } -// Returns x * y element-wise, working on quantized buffers. +// Returns the real part of a complex number. // -// Arguments: -// -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the real part of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real +// part returned by this operation and *b* is the imaginary part. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// For example: // -// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.real(input) ==> [-2.25, 3.25] +// ``` +func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -19968,42 +19829,52 @@ func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedMul", + Type: "Real", Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// QuantizedAddAttr is an optional argument to QuantizedAdd. -type QuantizedAddAttr func(optionalAttr) +// ResizeAreaAttr is an optional argument to ResizeArea. +type ResizeAreaAttr func(optionalAttr) -// QuantizedAddToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { +// ResizeAreaAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { return func(m optionalAttr) { - m["Toutput"] = value + m["align_corners"] = value } } -// Returns x + y element-wise, working on quantized buffers. +// Resize `images` to `size` using area interpolation. // -// Arguments: +// Input images can be of different types but output images are always float. // +// The range of pixel values for the output image might be slightly different +// from the range for the input image because of limited numerical precision. +// To guarantee an output range, for example `[0.0, 1.0]`, apply +// `tf.clip_by_value` to the output. // -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. +// Each output pixel is computed by first transforming the pixel's footprint into +// the input tensor and then averaging the pixels that intersect the footprint. An +// input pixel's contribution to the average is weighted by the fraction of its +// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { if scope.Err() != nil { return } @@ -20012,75 +19883,93 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedAdd", + Type: "ResizeArea", Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, + images, size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// MfccAttr is an optional argument to Mfcc. -type MfccAttr func(optionalAttr) +// VarHandleOpAttr is an optional argument to VarHandleOp. +type VarHandleOpAttr func(optionalAttr) -// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// VarHandleOpContainer sets the optional container attribute to value. // -// value: The highest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 4000 -func MfccUpperFrequencyLimit(value float32) MfccAttr { +// value: the container this variable is placed in. +// If not specified, defaults to "" +func VarHandleOpContainer(value string) VarHandleOpAttr { return func(m optionalAttr) { - m["upper_frequency_limit"] = value + m["container"] = value } } -// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// VarHandleOpSharedName sets the optional shared_name attribute to value. // -// value: The lowest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 20 -func MfccLowerFrequencyLimit(value float32) MfccAttr { +// value: the name by which this variable is referred to. +// If not specified, defaults to "" +func VarHandleOpSharedName(value string) VarHandleOpAttr { return func(m optionalAttr) { - m["lower_frequency_limit"] = value + m["shared_name"] = value } } -// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// Creates a handle to a Variable resource. // -// value: Resolution of the Mel bank used internally. -// If not specified, defaults to 40 -func MfccFilterbankChannelCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["filterbank_channel_count"] = value +// Arguments: +// dtype: the type of this variable. Must agree with the dtypes +// of all ops using this variable. +// shape: The (possibly partially specified) shape of this variable. +func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "VarHandleOp", + + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. -// -// value: How many output channels to produce per time slice. -// If not specified, defaults to 13 -func MfccDctCoefficientCount(value int64) MfccAttr { +// AngleAttr is an optional argument to Angle. +type AngleAttr func(optionalAttr) + +// AngleTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func AngleTout(value tf.DataType) AngleAttr { return func(m optionalAttr) { - m["dct_coefficient_count"] = value + m["Tout"] = value } } -// Transforms a spectrogram into a form that's useful for speech recognition. +// Returns the argument of a complex number. // -// Mel Frequency Cepstral Coefficients are a way of representing audio data that's -// been effective as an input feature for machine learning. They are created by -// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the -// higher frequencies that are less significant to the human ear. They have a long -// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum -// is a good resource to learn more. +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the argument of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part. // -// Arguments: -// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared -// set to true. -// sample_rate: How many samples per second the source audio used. -func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { +// The argument returned by this operation is of the form \\(atan2(b, a)\\). +// +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.angle(input) ==> [2.0132, 1.056] +// ``` +// +// @compatibility(numpy) +// Equivalent to np.angle. +// @end_compatibility +func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -20089,9 +19978,9 @@ func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional . a(attrs) } opspec := tf.OpSpec{ - Type: "Mfcc", + Type: "Angle", Input: []tf.Input{ - spectrogram, sample_rate, + input, }, Attrs: attrs, } @@ -20099,326 +19988,349 @@ func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional . return op.Output(0) } -// Given a quantized tensor described by (input, input_min, input_max), outputs a +// Clips tensor values to a specified min and max. // -// range that covers the actual values present in that tensor. This op is -// typically used to produce the requested_output_min and requested_output_max for -// Requantize. +// Given a tensor `t`, this operation returns a tensor of the same type and +// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. +// Any values less than `clip_value_min` are set to `clip_value_min`. Any values +// greater than `clip_value_max` are set to `clip_value_max`. // // Arguments: +// t: A `Tensor`. +// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The minimum value to clip by. +// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The maximum value to clip by. // -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// -// Returns The computed min output.the computed max output. -func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { +// Returns A clipped `Tensor` with the same shape as input 't'. +func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RequantizationRange", + Type: "ClipByValue", Input: []tf.Input{ - input, input_min, input_max, + t, clip_value_min, clip_value_max, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Rolls the elements of a tensor along an axis. -// -// The elements are shifted positively (towards larger indices) by the offset of -// `shift` along the dimension of `axis`. Negative `shift` values will shift -// elements in the opposite direction. Elements that roll passed the last position -// will wrap around to the first and vice versa. Multiple shifts along multiple -// axes may be specified. -// -// For example: -// -// ``` -// # 't' is [0, 1, 2, 3, 4] -// roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] +// Counts the number of occurrences of each value in an integer array. // -// # shifting along multiple dimensions -// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] -// roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] +// Outputs a vector with length `size` and the same dtype as `weights`. If +// `weights` are empty, then index `i` stores the number of times the value `i` is +// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of +// the value in `weights` at each index where the corresponding value in `arr` is +// `i`. // -// # shifting along the same axis multiple times -// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] -// roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] -// ``` +// Values in `arr` outside of the range [0, size) are ignored. // // Arguments: +// arr: int32 `Tensor`. +// size: non-negative int32 scalar `Tensor`. +// weights: is an int32, int64, float32, or float64 `Tensor` with the same +// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights +// equal to 1. // -// shift: Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which -// elements are shifted positively (towards larger indices) along the dimension -// specified by `axis[i]`. Negative shifts will roll the elements in the opposite -// direction. -// axis: Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift -// `shift[i]` should occur. If the same axis is referenced more than once, the -// total shift for that axis will be the sum of all the shifts that belong to that -// axis. -// -// Returns Has the same shape and size as the input. The elements are shifted -// positively (towards larger indices) by the offsets of `shift` along the -// dimensions of `axis`. -func Roll(scope *Scope, input tf.Output, shift tf.Output, axis tf.Output) (output tf.Output) { +// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for +// each value in the range [0, size). +func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Roll", + Type: "Bincount", Input: []tf.Input{ - input, shift, axis, + arr, size, weights, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MapPeekAttr is an optional argument to MapPeek. -type MapPeekAttr func(optionalAttr) +// CumsumAttr is an optional argument to Cumsum. +type CumsumAttr func(optionalAttr) -// MapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// CumsumExclusive sets the optional exclusive attribute to value. // -// REQUIRES: value >= 0 -func MapPeekCapacity(value int64) MapPeekAttr { +// value: If `True`, perform exclusive cumsum. +// If not specified, defaults to false +func CumsumExclusive(value bool) CumsumAttr { return func(m optionalAttr) { - m["capacity"] = value + m["exclusive"] = value } } -// MapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// CumsumReverse sets the optional reverse attribute to value. // -// REQUIRES: value >= 0 -func MapPeekMemoryLimit(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapPeekContainer(value string) MapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapPeekSharedName(value string) MapPeekAttr { +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumsumReverse(value bool) CumsumAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["reverse"] = value } } -// Op peeks at the values at the specified key. If the +// Compute the cumulative sum of the tensor `x` along `axis`. // -// underlying container does not contain this key -// this op will block until it does. -func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { +// By default, this op performs an inclusive cumsum, which means that the first +// element of the input is identical to the first element of the output: +// +// ```python +// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] +// ``` +// +// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +// performed instead: +// +// ```python +// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumsum is performed in the +// opposite direction: +// +// ```python +// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapPeek", + Type: "Cumsum", Input: []tf.Input{ - key, indices, + x, axis, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return + return op.Output(0) +} + +// CumprodAttr is an optional argument to Cumprod. +type CumprodAttr func(optionalAttr) + +// CumprodExclusive sets the optional exclusive attribute to value. +// +// value: If `True`, perform exclusive cumprod. +// If not specified, defaults to false +func CumprodExclusive(value bool) CumprodAttr { + return func(m optionalAttr) { + m["exclusive"] = value } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapPeek", err) - return +} + +// CumprodReverse sets the optional reverse attribute to value. +// +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumprodReverse(value bool) CumprodAttr { + return func(m optionalAttr) { + m["reverse"] = value } - return values } -// Looks up keys in a table, outputs the corresponding values. +// Compute the cumulative product of the tensor `x` along `axis`. // -// The tensor `keys` must of the same type as the keys of the table. -// The output `values` is of the type of the table values. +// By default, this op performs an inclusive cumprod, which means that the first +// element of the input is identical to the first element of the output: // -// The scalar `default_value` is the value output for keys not present in the -// table. It must also be of the same type as the table values. +// ```python +// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +// ``` // -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. +// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +// performed instead: // +// ```python +// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +// ``` // -// Returns Same shape as `keys`. Values found in the table, or `default_values` -// for missing keys. -func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { +// By setting the `reverse` kwarg to `True`, the cumprod is performed in the +// opposite direction: +// +// ```python +// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LookupTableFindV2", + Type: "Cumprod", Input: []tf.Input{ - table_handle, keys, default_value, + x, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Bucketizes 'input' based on 'boundaries'. +// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. +type QuantizedMatMulAttr func(optionalAttr) + +// QuantizedMatMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. // -// For example, if the inputs are -// boundaries = [0, 10, 100] -// input = [[-5, 10000] -// [150, 10] -// [5, 100]] +// value: If true, `a` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. // -// then the output will be -// output = [[0, 3] -// [3, 2] -// [1, 3]] +// value: If true, `b` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. // -// Arguments: -// input: Any shape of Tensor contains with int or float type. -// boundaries: A sorted list of floats gives the boundary of the buckets. +// value: The type of output produced by activation function +// following this operation. +// If not specified, defaults to DT_QUINT8 +func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Tactivation"] = value + } +} + +// Perform a quantized matrix multiplication of `a` by the matrix `b`. // -// Returns Same shape with 'input', each value of input replaced with bucket index. +// The inputs must be two-dimensional matrices and the inner dimension of +// `a` (after being transposed if `transpose_a` is non-zero) must match the +// outer dimension of `b` (after being transposed if `transposed_b` is +// non-zero). // -// @compatibility(numpy) -// Equivalent to np.digitize. -// @end_compatibility -func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { +// Arguments: +// a: Must be a two-dimensional tensor. +// b: Must be a two-dimensional tensor. +// min_a: The float value that the lowest quantized `a` value represents. +// max_a: The float value that the highest quantized `a` value represents. +// min_b: The float value that the lowest quantized `b` value represents. +// max_b: The float value that the highest quantized `b` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"boundaries": boundaries} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Bucketize", + Type: "QuantizedMatMul", Input: []tf.Input{ - input, + a, b, min_a, max_a, min_b, max_b, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Calculates gains for each feature and returns the best possible split information for the feature. -// -// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. +// Does nothing. Serves as a control trigger for scheduling. // -// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. +// Only useful as a placeholder for control edges. // -// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). -// -// The length of output lists are all of the same length, `num_features`. -// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. -// -// Arguments: -// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). -// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. -// l1: l1 regularization factor on leaf weights, per instance based. -// l2: l2 regularization factor on leaf weights, per instance based. -// tree_complexity: adjustment to the gain, per leaf based. -// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. -// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. -// -// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. -func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) { +// Returns the created operation. +func ControlTrigger(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"max_splits": max_splits} opspec := tf.OpSpec{ - Type: "BoostedTreesCalculateBestGainsPerFeature", - Input: []tf.Input{ - node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list -} - -// EncodePngAttr is an optional argument to EncodePng. -type EncodePngAttr func(optionalAttr) - -// EncodePngCompression sets the optional compression attribute to value. -// -// value: Compression level. -// If not specified, defaults to -1 -func EncodePngCompression(value int64) EncodePngAttr { - return func(m optionalAttr) { - m["compression"] = value + Type: "ControlTrigger", } + return scope.AddOperation(opspec) } -// PNG-encode an image. -// -// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` -// where `channels` is: +// Batch normalization. // -// * 1: for grayscale. -// * 2: for grayscale + alpha. -// * 3: for RGB. -// * 4: for RGBA. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder -// default or a value from 0 to 9. 9 is the highest compression level, generating -// the smallest output, but is slower. +// This op is deprecated. Prefer `tf.nn.batch_normalization`. // // Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. PNG-encoded image. -func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "EncodePng", + Type: "BatchNormWithGlobalNormalization", Input: []tf.Input{ - image, + t, m, v, beta, gamma, }, Attrs: attrs, } @@ -20426,90 +20338,95 @@ func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (conten return op.Output(0) } -// Updates the table to associates keys with values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. +// Deprecated. Use TensorArrayReadV3 // -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 +func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", + Type: "TensorArrayReadV2", Input: []tf.Input{ - table_handle, keys, values, + handle, index, flow_in, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns element-wise smallest integer in not less than x. -func Ceil(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Ceil", - Input: []tf.Input{ - x, - }, +// QuantizedMulAttr is an optional argument to QuantizedMul. +type QuantizedMulAttr func(optionalAttr) + +// QuantizedMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the number of elements in the given table. +// Returns x * y element-wise, working on quantized buffers. // // Arguments: -// table_handle: Handle to the table. // -// Returns Scalar that contains number of elements in the table. -func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { +// +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// +// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LookupTableSizeV2", + Type: "QuantizedMul", Input: []tf.Input{ - table_handle, + x, y, min_x, max_x, min_y, max_y, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. -type ResizeBilinearGradAttr func(optionalAttr) +// QuantizedAddAttr is an optional argument to QuantizedAdd. +type QuantizedAddAttr func(optionalAttr) -// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { +// QuantizedAddToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["Toutput"] = value } } -// Computes the gradient of bilinear interpolation. +// Returns x + y element-wise, working on quantized buffers. // // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { +// +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// +// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } @@ -20518,108 +20435,214 @@ func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBilinearGrad", + Type: "QuantizedAdd", Input: []tf.Input{ - grads, original_image, + x, y, min_x, max_x, min_y, max_y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Outputs all keys and values in the table. -// -// Arguments: -// table_handle: Handle to the table. -// -// +// MfccAttr is an optional argument to Mfcc. +type MfccAttr func(optionalAttr) + +// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. // -// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. -func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} - opspec := tf.OpSpec{ - Type: "LookupTableExportV2", - Input: []tf.Input{ - table_handle, - }, - Attrs: attrs, +// value: The highest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 4000 +func MfccUpperFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["upper_frequency_limit"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) } -// Replaces the contents of the table with the specified keys and values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. +// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. // -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. +// value: The lowest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 20 +func MfccLowerFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["lower_frequency_limit"] = value + } +} + +// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. // -// Returns the created operation. -func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// value: Resolution of the Mel bank used internally. +// If not specified, defaults to 40 +func MfccFilterbankChannelCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["filterbank_channel_count"] = value + } +} + +// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. +// +// value: How many output channels to produce per time slice. +// If not specified, defaults to 13 +func MfccDctCoefficientCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["dct_coefficient_count"] = value + } +} + +// Transforms a spectrogram into a form that's useful for speech recognition. +// +// Mel Frequency Cepstral Coefficients are a way of representing audio data that's +// been effective as an input feature for machine learning. They are created by +// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +// higher frequencies that are less significant to the human ear. They have a long +// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +// is a good resource to learn more. +// +// Arguments: +// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared +// set to true. +// sample_rate: How many samples per second the source audio used. +func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LookupTableImportV2", + Type: "Mfcc", Input: []tf.Input{ - table_handle, keys, values, + spectrogram, sample_rate, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. -type MapUnstageNoKeyAttr func(optionalAttr) +// Given a quantized tensor described by (input, input_min, input_max), outputs a +// +// range that covers the actual values present in that tensor. This op is +// typically used to produce the requested_output_min and requested_output_max for +// Requantize. +// +// Arguments: +// +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// +// Returns The computed min output.the computed max output. +func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RequantizationRange", + Input: []tf.Input{ + input, input_min, input_max, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} -// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// Rolls the elements of a tensor along an axis. +// +// The elements are shifted positively (towards larger indices) by the offset of +// `shift` along the dimension of `axis`. Negative `shift` values will shift +// elements in the opposite direction. Elements that roll passed the last position +// will wrap around to the first and vice versa. Multiple shifts along multiple +// axes may be specified. +// +// For example: +// +// ``` +// # 't' is [0, 1, 2, 3, 4] +// roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] +// +// # shifting along multiple dimensions +// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +// roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] +// +// # shifting along the same axis multiple times +// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +// roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +// ``` +// +// Arguments: +// +// shift: Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which +// elements are shifted positively (towards larger indices) along the dimension +// specified by `axis[i]`. Negative shifts will roll the elements in the opposite +// direction. +// axis: Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift +// `shift[i]` should occur. If the same axis is referenced more than once, the +// total shift for that axis will be the sum of all the shifts that belong to that +// axis. +// +// Returns Has the same shape and size as the input. The elements are shifted +// positively (towards larger indices) by the offsets of `shift` along the +// dimensions of `axis`. +func Roll(scope *Scope, input tf.Output, shift tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Roll", + Input: []tf.Input{ + input, shift, axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MapPeekAttr is an optional argument to MapPeek. +type MapPeekAttr func(optionalAttr) + +// MapPeekCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { +func MapPeekCapacity(value int64) MapPeekAttr { return func(m optionalAttr) { m["capacity"] = value } } -// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// MapPeekMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { +func MapPeekMemoryLimit(value int64) MapPeekAttr { return func(m optionalAttr) { m["memory_limit"] = value } } -// MapUnstageNoKeyContainer sets the optional container attribute to value. +// MapPeekContainer sets the optional container attribute to value. // If not specified, defaults to "" -func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { +func MapPeekContainer(value string) MapPeekAttr { return func(m optionalAttr) { m["container"] = value } } -// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. +// MapPeekSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { +func MapPeekSharedName(value string) MapPeekAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op removes and returns a random (key, value) +// Op peeks at the values at the specified key. If the // -// from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { +// underlying container does not contain this key +// this op will block until it does. +func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { if scope.Err() != nil { return } @@ -20628,9 +20651,9 @@ func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, opti a(attrs) } opspec := tf.OpSpec{ - Type: "MapUnstageNoKey", + Type: "MapPeek", Input: []tf.Input{ - indices, + key, indices, }, Attrs: attrs, } @@ -20640,234 +20663,174 @@ func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, opti } var idx int var err error - key = op.Output(idx) if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstageNoKey", err) + scope.UpdateErr("MapPeek", err) return } - return key, values -} - -// HashTableV2Attr is an optional argument to HashTableV2. -type HashTableV2Attr func(optionalAttr) - -// HashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func HashTableV2Container(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// HashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func HashTableV2SharedName(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } + return values } -// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// Looks up keys in a table, outputs the corresponding values. // -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates a non-initialized hash table. +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. // -// This op creates a hash table, specifying the type of its keys and values. -// Before using the table you will have to initialize it. After initialization the -// table will be immutable. +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. // -// Returns Handle to a table. -func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { +// +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "HashTableV2", - - Attrs: attrs, + Type: "LookupTableFindV2", + Input: []tf.Input{ + table_handle, keys, default_value, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. -type MutableHashTableV2Attr func(optionalAttr) - -// MutableHashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableV2Container(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// Bucketizes 'input' based on 'boundaries'. // -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// For example, if the inputs are +// boundaries = [0, 10, 100] +// input = [[-5, 10000] +// [150, 10] +// [5, 100]] // -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value +// then the output will be +// output = [[0, 3] +// [3, 2] +// [1, 3]] +// +// Arguments: +// input: Any shape of Tensor contains with int or float type. +// boundaries: A sorted list of floats gives the boundary of the buckets. +// +// Returns Same shape with 'input', each value of input replaced with bucket index. +// +// @compatibility(numpy) +// Equivalent to np.digitize. +// @end_compatibility +func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"boundaries": boundaries} + opspec := tf.OpSpec{ + Type: "Bucketize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Creates an empty hash table. +// Calculates gains for each feature and returns the best possible split information for the feature. // -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. +// +// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. +// +// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). +// +// The length of output lists are all of the same length, `num_features`. +// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). +// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. +// l1: l1 regularization factor on leaf weights, per instance based. +// l2: l2 regularization factor on leaf weights, per instance based. +// tree_complexity: adjustment to the gain, per leaf based. +// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. +// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. // -// Returns Handle to a table. -func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { +// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. +func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"max_splits": max_splits} opspec := tf.OpSpec{ - Type: "MutableHashTableV2", - + Type: "BoostedTreesCalculateBestGainsPerFeature", + Input: []tf.Input{ + node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list } -// DequantizeAttr is an optional argument to Dequantize. -type DequantizeAttr func(optionalAttr) +// EncodePngAttr is an optional argument to EncodePng. +type EncodePngAttr func(optionalAttr) -// DequantizeMode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func DequantizeMode(value string) DequantizeAttr { +// EncodePngCompression sets the optional compression attribute to value. +// +// value: Compression level. +// If not specified, defaults to -1 +func EncodePngCompression(value int64) EncodePngAttr { return func(m optionalAttr) { - m["mode"] = value + m["compression"] = value } } -// Dequantize the 'input' tensor into a float Tensor. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// if T == qint8, in[i] += (range(T) + 1)/ 2.0 -// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// If the input comes from a QuantizedRelu6, the output type is -// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is -// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. -// Dequantize on quint8 will take each value, cast to float, and multiply -// by 6 / 255. -// Note that if quantizedtype is qint8, the operation will additionally add -// each value by 128 prior to casting. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ```c++ -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = range / num_discrete_values -// const double offset_input = static_cast(input) - lowest_quantized; -// result = range_min + ((input - numeric_limits::min()) * range_scale) -// ``` -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` +// PNG-encode an image. // -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` +// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +// where `channels` is: // -// From this we compute our scaling factor, s: -// ```c++ -// s = (2 * m) / (max_fixed - min_fixed) -// ``` +// * 1: for grayscale. +// * 2: for grayscale + alpha. +// * 3: for RGB. +// * 4: for RGBA. // -// Now we can dequantize the elements of our tensor: -// ```c++ -// result = input * s -// ``` +// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +// default or a value from 0 to 9. 9 is the highest compression level, generating +// the smallest output, but is slower. // // Arguments: +// image: 3-D with shape `[height, width, channels]`. // -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. -func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { +// Returns 0-D. PNG-encoded image. +func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { if scope.Err() != nil { return } @@ -20876,9 +20839,9 @@ func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf a(attrs) } opspec := tf.OpSpec{ - Type: "Dequantize", + Type: "EncodePng", Input: []tf.Input{ - input, min_range, max_range, + image, }, Attrs: attrs, } @@ -20886,16 +20849,37 @@ func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf return op.Output(0) } -// Flips all bits elementwise. +// Updates the table to associates keys with values. // -// The result will have exactly those bits set, that are not set in `x`. The -// computation is performed on the underlying representation of x. -func Invert(scope *Scope, x tf.Output) (y tf.Output) { +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Invert", + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// Returns element-wise smallest integer in not less than x. +func Ceil(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Ceil", Input: []tf.Input{ x, }, @@ -20904,65 +20888,110 @@ func Invert(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Inverse 3D fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform over the -// inner-most 3 dimensions of `input`. +// Computes the number of elements in the given table. // // Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their inverse 3D Fourier transform. +// table_handle: Handle to the table. // -// @compatibility(numpy) -// Equivalent to np.fft.ifftn with 3 dimensions. -// @end_compatibility -func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IFFT3D", + Type: "LookupTableSizeV2", Input: []tf.Input{ - input, + table_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Disallowed in GraphDef version >= 2. +// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. +type ResizeBilinearGradAttr func(optionalAttr) + +// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. // -// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead -func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of bilinear interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "AdjustContrast", + Type: "ResizeBilinearGrad", Input: []tf.Input{ - images, contrast_factor, min_value, max_value, + grads, original_image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Table initializer that takes two tensors for keys and values respectively. +// Outputs all keys and values in the table. // // Arguments: -// table_handle: Handle to a table which will be initialized. -// keys: Keys of type Tkey. -// values: Values of type Tval. +// table_handle: Handle to the table. +// +// +// +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Replaces the contents of the table with the specified keys and values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. // // Returns the created operation. -func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InitializeTableV2", + Type: "LookupTableImportV2", Input: []tf.Input{ table_handle, keys, values, }, @@ -20970,370 +20999,443 @@ func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, val return scope.AddOperation(opspec) } -// PrintAttr is an optional argument to Print. -type PrintAttr func(optionalAttr) +// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. +type MapUnstageNoKeyAttr func(optionalAttr) -// PrintMessage sets the optional message attribute to value. +// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: A string, prefix of the error message. -// If not specified, defaults to "" -func PrintMessage(value string) PrintAttr { +// REQUIRES: value >= 0 +func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { return func(m optionalAttr) { - m["message"] = value + m["capacity"] = value } } -// PrintFirstN sets the optional first_n attribute to value. +// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: Only log `first_n` number of times. -1 disables logging. -// If not specified, defaults to -1 -func PrintFirstN(value int64) PrintAttr { +// REQUIRES: value >= 0 +func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { return func(m optionalAttr) { - m["first_n"] = value + m["memory_limit"] = value } } -// PrintSummarize sets the optional summarize attribute to value. -// -// value: Only print this many entries of each tensor. -// If not specified, defaults to 3 -func PrintSummarize(value int64) PrintAttr { +// MapUnstageNoKeyContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { return func(m optionalAttr) { - m["summarize"] = value + m["container"] = value } } -// Prints a list of tensors. -// -// Passes `input` through to `output` and prints `data` when evaluating. -// -// Arguments: -// input: The tensor passed to `output` -// data: A list of tensors to print out when op is evaluated. +// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns a random (key, value) // -// Returns = The unmodified `input` tensor -func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAttr) (output tf.Output) { +// from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Print", + Type: "MapUnstageNoKey", Input: []tf.Input{ - input, tf.OutputList(data), + indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs a `Summary` protocol buffer with a tensor and per-plugin data. -// -// Arguments: -// tag: A string attached to this summary. Used for organization in TensorBoard. -// tensor: A tensor to serialize. -// serialized_summary_metadata: A serialized SummaryMetadata proto. Contains plugin -// data. -func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_summary_metadata tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "TensorSummaryV2", - Input: []tf.Input{ - tag, tensor, serialized_summary_metadata, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that asynchronously prefetches elements from `input_dataset`. -// -// Arguments: -// -// buffer_size: The maximum number of elements to buffer in an iterator over -// this dataset. -// -// -func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { + var idx int + var err error + key = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstageNoKey", err) return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "PrefetchDataset", - Input: []tf.Input{ - input_dataset, buffer_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return key, values } -// TensorSummaryAttr is an optional argument to TensorSummary. -type TensorSummaryAttr func(optionalAttr) +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) -// TensorSummaryDescription sets the optional description attribute to value. +// HashTableV2Container sets the optional container attribute to value. // -// value: A json-encoded SummaryDescription proto. +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func TensorSummaryDescription(value string) TensorSummaryAttr { +func HashTableV2Container(value string) HashTableV2Attr { return func(m optionalAttr) { - m["description"] = value + m["container"] = value } } -// TensorSummaryLabels sets the optional labels attribute to value. +// HashTableV2SharedName sets the optional shared_name attribute to value. // -// value: An unused list of strings. -// If not specified, defaults to <> -func TensorSummaryLabels(value []string) TensorSummaryAttr { +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { return func(m optionalAttr) { - m["labels"] = value + m["shared_name"] = value } } -// TensorSummaryDisplayName sets the optional display_name attribute to value. +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. // -// value: An unused string. -// If not specified, defaults to "" -func TensorSummaryDisplayName(value string) TensorSummaryAttr { +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { return func(m optionalAttr) { - m["display_name"] = value + m["use_node_name_sharing"] = value } } -// Outputs a `Summary` protocol buffer with a tensor. +// Creates a non-initialized hash table. // -// This op is being phased out in favor of TensorSummaryV2, which lets callers pass -// a tag as well as a serialized SummaryMetadata proto string that contains -// plugin-specific data. We will keep this op to maintain backwards compatibility. +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. // // Arguments: -// tensor: A tensor to serialize. -func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr) (summary tf.Output) { +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorSummary", - Input: []tf.Input{ - tensor, - }, + Type: "HashTableV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient for the tanh of `x` wrt its input. +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) + +// MutableHashTableV2Container sets the optional container attribute to value. // -// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` -// is the corresponding input gradient. -func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TanhGrad", - Input: []tf.Input{ - y, dy, - }, + Type: "MutableHashTableV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. +// DequantizeAttr is an optional argument to Dequantize. +type DequantizeAttr func(optionalAttr) + +// DequantizeMode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func DequantizeMode(value string) DequantizeAttr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// Dequantize the 'input' tensor into a float Tensor. // -// This operation computes +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. // -// # Scalar indices -// ref[indices, ...] = max(ref[indices, ...], updates[...]) +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: // -// # Vector indices (for each i) -// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) +// ``` +// if T == qint8, in[i] += (range(T) + 1)/ 2.0 +// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` // -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// *MIN_COMBINED Mode Example* // -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. +// If the input comes from a QuantizedRelu6, the output type is +// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is +// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. +// Dequantize on quint8 will take each value, cast to float, and multiply +// by 6 / 255. +// Note that if quantizedtype is qint8, the operation will additionally add +// each value by 128 prior to casting. // -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// If the mode is 'MIN_FIRST', then this approach is used: // -//
-// -//
+// ```c++ +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = range / num_discrete_values +// const double offset_input = static_cast(input) - lowest_quantized; +// result = range_min + ((input - numeric_limits::min()) * range_scale) +// ``` +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// ```c++ +// s = (2 * m) / (max_fixed - min_fixed) +// ``` +// +// Now we can dequantize the elements of our tensor: +// ```c++ +// result = input * s +// ``` // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. // -// Returns the created operation. -func ResourceScatterMax(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceScatterMax", + Type: "Dequantize", Input: []tf.Input{ - resource, indices, updates, + input, min_range, max_range, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Outputs a `Summary` protocol buffer with scalar values. -// -// The input `tags` and `values` must have the same shape. The generated summary -// has a summary value for each tag-value pair in `tags` and `values`. -// -// Arguments: -// tags: Tags for the summary. -// values: Same shape as `tags. Values for the summary. +// Flips all bits elementwise. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { +// The result will have exactly those bits set, that are not set in `x`. The +// computation is performed on the underlying representation of x. +func Invert(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ScalarSummary", + Type: "Invert", Input: []tf.Input{ - tags, values, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Outputs a `Summary` protocol buffer with a histogram. -// -// The generated -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// has one summary value containing a histogram for `values`. +// Inverse 3D fast Fourier transform. // -// This op reports an `InvalidArgument` error if any value is not finite. +// Computes the inverse 3-dimensional discrete Fourier transform over the +// inner-most 3 dimensions of `input`. // // Arguments: -// tag: Scalar. Tag to use for the `Summary.Value`. -// values: Any shape. Values to use to build the histogram. +// input: A complex64 tensor. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their inverse 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifftn with 3 dimensions. +// @end_compatibility +func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "HistogramSummary", + Type: "IFFT3D", Input: []tf.Input{ - tag, values, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the number of elements in the given queue. -// -// Arguments: -// handle: The handle to a queue. +// Deprecated. Disallowed in GraphDef version >= 2. // -// Returns The number of elements in the given queue. -func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { +// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead +func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QueueSizeV2", + Type: "AdjustContrast", Input: []tf.Input{ - handle, + images, contrast_factor, min_value, max_value, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ImageSummaryAttr is an optional argument to ImageSummary. -type ImageSummaryAttr func(optionalAttr) - -// ImageSummaryMaxImages sets the optional max_images attribute to value. +// Table initializer that takes two tensors for keys and values respectively. // -// value: Max number of batch elements to generate images for. -// If not specified, defaults to 3 +// Arguments: +// table_handle: Handle to a table which will be initialized. +// keys: Keys of type Tkey. +// values: Values of type Tval. // -// REQUIRES: value >= 1 -func ImageSummaryMaxImages(value int64) ImageSummaryAttr { - return func(m optionalAttr) { - m["max_images"] = value +// Returns the created operation. +func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InitializeTableV2", + Input: []tf.Input{ + table_handle, keys, values, + }, } + return scope.AddOperation(opspec) } -// ImageSummaryBadColor sets the optional bad_color attribute to value. -// -// value: Color to use for pixels with non-finite values. -// If not specified, defaults to > int_val:255 int_val:0 int_val:0 int_val:255 > -func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { +// PrintAttr is an optional argument to Print. +type PrintAttr func(optionalAttr) + +// PrintMessage sets the optional message attribute to value. +// +// value: A string, prefix of the error message. +// If not specified, defaults to "" +func PrintMessage(value string) PrintAttr { return func(m optionalAttr) { - m["bad_color"] = value + m["message"] = value } } -// Outputs a `Summary` protocol buffer with images. -// -// The summary has up to `max_images` summary values containing images. The -// images are built from `tensor` which must be 4-D with shape `[batch_size, -// height, width, channels]` and where `channels` can be: -// -// * 1: `tensor` is interpreted as Grayscale. -// * 3: `tensor` is interpreted as RGB. -// * 4: `tensor` is interpreted as RGBA. -// -// The images have the same number of channels as the input tensor. For float -// input, the values are normalized one image at a time to fit in the range -// `[0, 255]`. `uint8` values are unchanged. The op uses two different -// normalization algorithms: -// -// * If the input values are all positive, they are rescaled so the largest one -// is 255. -// -// * If any input value is negative, the values are shifted so input value 0.0 -// is at 127. They are then rescaled so that either the smallest value is 0, -// or the largest one is 255. +// PrintFirstN sets the optional first_n attribute to value. // -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: +// value: Only log `first_n` number of times. -1 disables logging. +// If not specified, defaults to -1 +func PrintFirstN(value int64) PrintAttr { + return func(m optionalAttr) { + m["first_n"] = value + } +} + +// PrintSummarize sets the optional summarize attribute to value. // -// * If `max_images` is 1, the summary value tag is '*tag*/image'. -// * If `max_images` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. +// value: Only print this many entries of each tensor. +// If not specified, defaults to 3 +func PrintSummarize(value int64) PrintAttr { + return func(m optionalAttr) { + m["summarize"] = value + } +} + +// Prints a list of tensors. // -// The `bad_color` argument is the color to use in the generated images for -// non-finite input values. It is a `uint8` 1-D tensor of length `channels`. -// Each element must be in the range `[0, 255]` (It represents the value of a -// pixel in the output image). Non-finite values in the input tensor are -// replaced by this tensor in the output image. The default value is the color -// red. +// Passes `input` through to `output` and prints `data` when evaluating. // // Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 4-D of shape `[batch_size, height, width, channels]` where -// `channels` is 1, 3, or 4. +// input: The tensor passed to `output` +// data: A list of tensors to print out when op is evaluated. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...ImageSummaryAttr) (summary tf.Output) { +// Returns = The unmodified `input` tensor +func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -21342,9 +21444,9 @@ func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...Ima a(attrs) } opspec := tf.OpSpec{ - Type: "ImageSummary", + Type: "Print", Input: []tf.Input{ - tag, tensor, + input, tf.OutputList(data), }, Attrs: attrs, } @@ -21352,53 +21454,44 @@ func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...Ima return op.Output(0) } -// AudioSummaryV2Attr is an optional argument to AudioSummaryV2. -type AudioSummaryV2Attr func(optionalAttr) - -// AudioSummaryV2MaxOutputs sets the optional max_outputs attribute to value. -// -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 +// Outputs a `Summary` protocol buffer with a tensor and per-plugin data. // -// REQUIRES: value >= 1 -func AudioSummaryV2MaxOutputs(value int64) AudioSummaryV2Attr { - return func(m optionalAttr) { - m["max_outputs"] = value +// Arguments: +// tag: A string attached to this summary. Used for organization in TensorBoard. +// tensor: A tensor to serialize. +// serialized_summary_metadata: A serialized SummaryMetadata proto. Contains plugin +// data. +func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_summary_metadata tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorSummaryV2", + Input: []tf.Input{ + tag, tensor, serialized_summary_metadata, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Outputs a `Summary` protocol buffer with audio. -// -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. +// Creates a dataset that asynchronously prefetches elements from `input_dataset`. // -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: +// Arguments: // -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. +// buffer_size: The maximum number of elements to buffer in an iterator over +// this dataset. // -// Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...AudioSummaryV2Attr) (summary tf.Output) { +func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "AudioSummaryV2", + Type: "PrefetchDataset", Input: []tf.Input{ - tag, tensor, sample_rate, + input_dataset, buffer_size, }, Attrs: attrs, } @@ -21406,47 +21499,59 @@ func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate t return op.Output(0) } -// AvgPoolAttr is an optional argument to AvgPool. -type AvgPoolAttr func(optionalAttr) +// TensorSummaryAttr is an optional argument to TensorSummary. +type TensorSummaryAttr func(optionalAttr) -// AvgPoolDataFormat sets the optional data_format attribute to value. +// TensorSummaryDescription sets the optional description attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolDataFormat(value string) AvgPoolAttr { +// value: A json-encoded SummaryDescription proto. +// If not specified, defaults to "" +func TensorSummaryDescription(value string) TensorSummaryAttr { return func(m optionalAttr) { - m["data_format"] = value + m["description"] = value } } -// Performs average pooling on the input. +// TensorSummaryLabels sets the optional labels attribute to value. // -// Each entry in `output` is the mean of the corresponding size `ksize` -// window in `value`. +// value: An unused list of strings. +// If not specified, defaults to <> +func TensorSummaryLabels(value []string) TensorSummaryAttr { + return func(m optionalAttr) { + m["labels"] = value + } +} + +// TensorSummaryDisplayName sets the optional display_name attribute to value. // -// Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// ksize: The size of the sliding window for each dimension of `value`. -// strides: The stride of the sliding window for each dimension of `value`. -// padding: The type of padding algorithm to use. +// value: An unused string. +// If not specified, defaults to "" +func TensorSummaryDisplayName(value string) TensorSummaryAttr { + return func(m optionalAttr) { + m["display_name"] = value + } +} + +// Outputs a `Summary` protocol buffer with a tensor. // -// Returns The average pooled output tensor. -func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { +// This op is being phased out in favor of TensorSummaryV2, which lets callers pass +// a tag as well as a serialized SummaryMetadata proto string that contains +// plugin-specific data. We will keep this op to maintain backwards compatibility. +// +// Arguments: +// tensor: A tensor to serialize. +func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool", + Type: "TensorSummary", Input: []tf.Input{ - value, + tensor, }, Attrs: attrs, } @@ -21454,206 +21559,215 @@ func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padd return op.Output(0) } -// Merges summaries. -// -// This op creates a -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// protocol buffer that contains the union of all the values in the input -// summaries. -// -// When the Op is run, it reports an `InvalidArgument` error if multiple values -// in the summaries to merge use the same tag. -// -// Arguments: -// inputs: Can be of any shape. Each must contain serialized `Summary` protocol -// buffers. +// Computes the gradient for the tanh of `x` wrt its input. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { +// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` +// is the corresponding input gradient. +func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MergeSummary", + Type: "TanhGrad", Input: []tf.Input{ - tf.OutputList(inputs), + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient of morphological 2-D dilation with respect to the filter. +// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = max(ref[indices, ...], updates[...]) +// +// # Vector indices (for each i) +// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
// // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. // -// Returns 3-D with shape `[filter_height, filter_width, depth]`. -func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { +// Returns the created operation. +func ResourceScatterMax(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "Dilation2DBackpropFilter", + Type: "ResourceScatterMax", Input: []tf.Input{ - input, filter, out_backprop, + resource, indices, updates, }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. -type AddSparseToTensorsMapAttr func(optionalAttr) - -// AddSparseToTensorsMapContainer sets the optional container attribute to value. -// -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. -// If not specified, defaults to "" -func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { - return func(m optionalAttr) { - m["shared_name"] = value } + return scope.AddOperation(opspec) } -// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. -// -// A `SparseTensor` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`. -// -// This operator takes the given `SparseTensor` and adds it to a container -// object (a `SparseTensorsMap`). A unique key within this container is generated -// in the form of an `int64`, and this is the value that is returned. +// Outputs a `Summary` protocol buffer with scalar values. // -// The `SparseTensor` can then be read out as part of a minibatch by passing -// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddSparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// The input `tags` and `values` must have the same shape. The generated summary +// has a summary value for each tag-value pair in `tags` and `values`. // // Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +// tags: Tags for the summary. +// values: Same shape as `tags. Values for the summary. // -// Returns 0-D. The handle of the `SparseTensor` now stored in the -// `SparseTensorsMap`. -func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "AddSparseToTensorsMap", + Type: "ScalarSummary", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + tags, values, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. +// Outputs a `Summary` protocol buffer with a histogram. // -// tensor: The tensor to put on the list. -// input_handle: The old list. -// output_handle: A list with the elements of the old list followed by tensor. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func TensorListPushBack(scope *Scope, input_handle tf.Output, tensor tf.Output) (output_handle tf.Output) { +// The generated +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// has one summary value containing a histogram for `values`. +// +// This op reports an `InvalidArgument` error if any value is not finite. +// +// Arguments: +// tag: Scalar. Tag to use for the `Summary.Value`. +// values: Any shape. Values to use to build the histogram. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListPushBack", + Type: "HistogramSummary", Input: []tf.Input{ - input_handle, tensor, + tag, values, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the number of tensors in the input tensor list. +// Computes the number of elements in the given queue. // -// input_handle: the input list -// length: the number of tensors in the list -func TensorListLength(scope *Scope, input_handle tf.Output) (length tf.Output) { +// Arguments: +// handle: The handle to a queue. +// +// Returns The number of elements in the given queue. +func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListLength", + Type: "QueueSizeV2", Input: []tf.Input{ - input_handle, + handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// The shape of the elements of the given list, as a tensor. +// ImageSummaryAttr is an optional argument to ImageSummary. +type ImageSummaryAttr func(optionalAttr) + +// ImageSummaryMaxImages sets the optional max_images attribute to value. // -// input_handle: the list -// element_shape: the shape of elements of the list -func TensorListElementShape(scope *Scope, input_handle tf.Output, shape_type tf.DataType) (element_shape tf.Output) { - if scope.Err() != nil { - return +// value: Max number of batch elements to generate images for. +// If not specified, defaults to 3 +// +// REQUIRES: value >= 1 +func ImageSummaryMaxImages(value int64) ImageSummaryAttr { + return func(m optionalAttr) { + m["max_images"] = value } - attrs := map[string]interface{}{"shape_type": shape_type} - opspec := tf.OpSpec{ - Type: "TensorListElementShape", - Input: []tf.Input{ - input_handle, - }, - Attrs: attrs, +} + +// ImageSummaryBadColor sets the optional bad_color attribute to value. +// +// value: Color to use for pixels with non-finite values. +// If not specified, defaults to > int_val:255 int_val:0 int_val:0 int_val:255 > +func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { + return func(m optionalAttr) { + m["bad_color"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns the item in the list with the given index. +// Outputs a `Summary` protocol buffer with images. // -// input_handle: the list -// index: the position in the list from which an element will be retrieved -// item: the element at that position +// The summary has up to `max_images` summary values containing images. The +// images are built from `tensor` which must be 4-D with shape `[batch_size, +// height, width, channels]` and where `channels` can be: +// +// * 1: `tensor` is interpreted as Grayscale. +// * 3: `tensor` is interpreted as RGB. +// * 4: `tensor` is interpreted as RGBA. // +// The images have the same number of channels as the input tensor. For float +// input, the values are normalized one image at a time to fit in the range +// `[0, 255]`. `uint8` values are unchanged. The op uses two different +// normalization algorithms: // -func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_dtype tf.DataType) (item tf.Output) { +// * If the input values are all positive, they are rescaled so the largest one +// is 255. +// +// * If any input value is negative, the values are shifted so input value 0.0 +// is at 127. They are then rescaled so that either the smallest value is 0, +// or the largest one is 255. +// +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: +// +// * If `max_images` is 1, the summary value tag is '*tag*/image'. +// * If `max_images` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. +// +// The `bad_color` argument is the color to use in the generated images for +// non-finite input values. It is a `unit8` 1-D tensor of length `channels`. +// Each element must be in the range `[0, 255]` (It represents the value of a +// pixel in the output image). Non-finite values in the input tensor are +// replaced by this tensor in the output image. The default value is the color +// red. +// +// Arguments: +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 4-D of shape `[batch_size, height, width, channels]` where +// `channels` is 1, 3, or 4. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...ImageSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorListGetItem", + Type: "ImageSummary", Input: []tf.Input{ - input_handle, index, + tag, tensor, }, Attrs: attrs, } @@ -21661,233 +21775,215 @@ func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, el return op.Output(0) } -// Computes the matrix exponential of one or more square matrices: +// AudioSummaryV2Attr is an optional argument to AudioSummaryV2. +type AudioSummaryV2Attr func(optionalAttr) + +// AudioSummaryV2MaxOutputs sets the optional max_outputs attribute to value. // -// exp(A) = \sum_{n=0}^\infty A^n/n! +// value: Max number of batch elements to generate audio for. +// If not specified, defaults to 3 // -// The exponential is computed using a combination of the scaling and squaring -// method and the Pade approximation. Details can be founds in: -// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential -// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. +// REQUIRES: value >= 1 +func AudioSummaryV2MaxOutputs(value int64) AudioSummaryV2Attr { + return func(m optionalAttr) { + m["max_outputs"] = value + } +} + +// Outputs a `Summary` protocol buffer with audio. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the exponential for all input submatrices `[..., :, :]`. +// The summary has up to `max_outputs` summary values containing audio. The +// audio is built from `tensor` which must be 3-D with shape `[batch_size, +// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. // -// Arguments: -// input: Shape is `[..., M, M]`. +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: // -// Returns Shape is `[..., M, M]`. +// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +// * If `max_outputs` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. // -// @compatibility(scipy) -// Equivalent to scipy.linalg.expm -// @end_compatibility -func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { +// Arguments: +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 2-D of shape `[batch_size, frames]`. +// sample_rate: The sample rate of the signal in hertz. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...AudioSummaryV2Attr) (summary tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "MatrixExponential", + Type: "AudioSummaryV2", Input: []tf.Input{ - input, + tag, tensor, sample_rate, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the matrix logarithm of one or more square matrices: -// +// AvgPoolAttr is an optional argument to AvgPool. +type AvgPoolAttr func(optionalAttr) + +// AvgPoolDataFormat sets the optional data_format attribute to value. // -// log(exp(A)) = A -// -// This op is only defined for complex matrices. If A is positive-definite and -// real, then casting to a complex matrix, taking the logarithm and casting back -// to a real matrix will give the correct result. -// -// This function computes the matrix logarithm using the Schur-Parlett algorithm. -// Details of the algorithm can be found in Section 11.6.2 of: -// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. -// ISBN 978-0-898716-46-7. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the exponential for all input submatrices `[..., :, :]`. -// -// Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. -// -// @compatibility(scipy) -// Equivalent to scipy.linalg.logm -// @end_compatibility -func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixLogarithm", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. -type QueueDequeueUpToV2Attr func(optionalAttr) - -// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolDataFormat(value string) AvgPoolAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["data_format"] = value } } -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// This operation is not supported by all queues. If a queue does not support -// DequeueUpTo, then an Unimplemented error is returned. -// -// If the queue is closed and there are more than 0 but less than `n` -// elements remaining, then instead of returning an OutOfRange error like -// QueueDequeueMany, less than `n` elements are returned immediately. If -// the queue is closed and there are 0 elements left in the queue, then -// an OutOfRange error is returned just like in QueueDequeueMany. -// Otherwise the behavior is identical to QueueDequeueMany: -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size n in the 0th dimension. +// Performs average pooling on the input. // -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. +// Each entry in `output` is the mean of the corresponding size `ksize` +// window in `value`. // // Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. +// value: 4-D with shape `[batch, height, width, channels]`. +// ksize: The size of the sliding window for each dimension of `value`. +// strides: The stride of the sliding window for each dimension of `value`. +// padding: The type of padding algorithm to use. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { +// Returns The average pooled output tensor. +func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueDequeueUpToV2", + Type: "AvgPool", Input: []tf.Input{ - handle, n, + value, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueUpToV2", err) - return - } - return components + return op.Output(0) } -// Computes the Cholesky decomposition of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. -// -// The input has to be symmetric and positive definite. Only the lower-triangular -// part of the input will be used for this operation. The upper-triangular part -// will not be read. +// Merges summaries. // -// The output is a tensor of the same shape as the input -// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// This op creates a +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// protocol buffer that contains the union of all the values in the input +// summaries. // -// **Note**: The gradient computation on GPU is faster for large matrices but -// not for large batch dimensions when the submatrices are small. In this -// case it might be faster to use the CPU. +// When the Op is run, it reports an `InvalidArgument` error if multiple values +// in the summaries to merge use the same tag. // // Arguments: -// input: Shape is `[..., M, M]`. +// inputs: Can be of any shape. Each must contain serialized `Summary` protocol +// buffers. // -// Returns Shape is `[..., M, M]`. -func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cholesky", + Type: "MergeSummary", Input: []tf.Input{ - input, + tf.OutputList(inputs), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Writes contents to the file at input filename. Creates file and recursively -// -// creates directory if not existing. +// Computes the gradient of morphological 2-D dilation with respect to the filter. // // Arguments: -// filename: scalar. The name of the file to which we write the contents. -// contents: scalar. The content to be written to the output file. +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { +// Returns 3-D with shape `[filter_height, filter_width, depth]`. +func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "WriteFile", + Type: "Dilation2DBackpropFilter", Input: []tf.Input{ - filename, contents, + input, filter, out_backprop, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// AllAttr is an optional argument to All. -type AllAttr func(optionalAttr) +// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. +type AddSparseToTensorsMapAttr func(optionalAttr) -// AllKeepDims sets the optional keep_dims attribute to value. +// AddSparseToTensorsMapContainer sets the optional container attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AllKeepDims(value bool) AllAttr { +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["container"] = value } } -// Computes the "logical and" of elements across dimensions of a tensor. +// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. +// +// A `SparseTensor` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`. +// +// This operator takes the given `SparseTensor` and adds it to a container +// object (a `SparseTensorsMap`). A unique key within this container is generated +// in the form of an `int64`, and this is the value that is returned. +// +// The `SparseTensor` can then be read out as part of a minibatch by passing +// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddSparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. // -// Returns The reduced tensor. -func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { +// Returns 0-D. The handle of the `SparseTensor` now stored in the +// `SparseTensorsMap`. +func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { if scope.Err() != nil { return } @@ -21896,9 +21992,9 @@ func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (ou a(attrs) } opspec := tf.OpSpec{ - Type: "All", + Type: "AddSparseToTensorsMap", Input: []tf.Input{ - input, axis, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } @@ -21906,187 +22002,165 @@ func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (ou return op.Output(0) } -// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. -// -// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices, with the same constraints as the single matrix -// SelfAdjointEig. -// -// The result is a [..., M+1, M] matrix with [..., 0,:] containing the -// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues -// are sorted in non-decreasing order. -// -// Arguments: -// input: Shape is `[..., M, M]`. +// Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. // -// Returns Shape is `[..., M+1, M]`. -func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { +// tensor: The tensor to put on the list. +// input_handle: The old list. +// output_handle: A list with the elements of the old list followed by tensor. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func TensorListPushBack(scope *Scope, input_handle tf.Output, tensor tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SelfAdjointEig", + Type: "TensorListPushBack", Input: []tf.Input{ - input, + input_handle, tensor, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softplus gradients for a softplus operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding softplus operation. -// features: The features passed as input to the corresponding softplus operation. +// Returns the number of tensors in the input tensor list. // -// Returns The gradients: `gradients / (1 + exp(-features))`. -func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// input_handle: the input list +// length: the number of tensors in the list +func TensorListLength(scope *Scope, input_handle tf.Output) (length tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SoftplusGrad", + Type: "TensorListLength", Input: []tf.Input{ - gradients, features, + input_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. -type SelfAdjointEigV2Attr func(optionalAttr) - -// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. +// The shape of the elements of the given list, as a tensor. // -// value: If `True` then eigenvectors will be computed and returned in `v`. -// Otherwise, only the eigenvalues will be computed. -// If not specified, defaults to true -func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { - return func(m optionalAttr) { - m["compute_v"] = value +// input_handle: the list +// element_shape: the shape of elements of the list +func TensorListElementShape(scope *Scope, input_handle tf.Output, shape_type tf.DataType) (element_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape_type": shape_type} + opspec := tf.OpSpec{ + Type: "TensorListElementShape", + Input: []tf.Input{ + input_handle, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the eigen decomposition of one or more square self-adjoint matrices. -// -// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in -// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues -// are sorted in non-decreasing order. +// Returns the item in the list with the given index. // -// ```python -// # a is a tensor. -// # e is a tensor of eigenvalues. -// # v is a tensor of eigenvectors. -// e, v = self_adjoint_eig(a) -// e = self_adjoint_eig(a, compute_v=False) -// ``` +// input_handle: the list +// index: the position in the list from which an element will be retrieved +// item: the element at that position // -// Arguments: -// input: `Tensor` input of shape `[N, N]`. // -// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. -func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { +func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_dtype tf.DataType) (item tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "SelfAdjointEigV2", + Type: "TensorListGetItem", Input: []tf.Input{ - input, + input_handle, index, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Adjust the saturation of one or more images. +// Returns a diagonal tensor with a given diagonal values. // -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: // -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A scale is then applied all the saturation -// values, and then remapped back to RGB colorspace. +// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: // -// Arguments: -// images: Images to adjust. At least 3-D. -// scale: A float scale to add to the saturation. +// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. // -// Returns The hue-adjusted image or images. -func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { +// For example: +// +// ``` +// # 'diagonal' is [1, 2, 3, 4] +// tf.diag(diagonal) ==> [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] +// ``` +// +// Arguments: +// diagonal: Rank k tensor where k is at most 1. +func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AdjustSaturation", + Type: "Diag", Input: []tf.Input{ - images, scale, + diagonal, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SvdAttr is an optional argument to Svd. -type SvdAttr func(optionalAttr) +// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. +type ParameterizedTruncatedNormalAttr func(optionalAttr) -// SvdComputeUv sets the optional compute_uv attribute to value. +// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. // -// value: If true, left and right singular vectors will be -// computed and returned in `u` and `v`, respectively. -// If false, `u` and `v` are not set and should never referenced. -// If not specified, defaults to true -func SvdComputeUv(value bool) SvdAttr { +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { - m["compute_uv"] = value + m["seed"] = value } } -// SvdFullMatrices sets the optional full_matrices attribute to value. +// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. // -// value: If true, compute full-sized `u` and `v`. If false -// (the default), compute only the leading `P` singular vectors. -// Ignored if `compute_uv` is `False`. -// If not specified, defaults to false -func SvdFullMatrices(value bool) SvdAttr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { - m["full_matrices"] = value + m["seed2"] = value } } -// Computes the singular value decompositions of one or more matrices. -// -// Computes the SVD of each inner matrix in `input` such that -// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` +// Outputs random values from a normal distribution. The parameters may each be a // -// ```python -// # a is a tensor containing a batch of matrices. -// # s is a tensor of singular values for each matrix. -// # u is the tensor containing of left singular vectors for each matrix. -// # v is the tensor containing of right singular vectors for each matrix. -// s, u, v = svd(a) -// s, _, _ = svd(a, compute_uv=False) -// ``` +// scalar which applies to the entire output, or a vector of length shape[0] which +// stores the parameters for each batch. // // Arguments: -// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. +// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. +// means: The mean parameter of each batch. +// stdevs: The standard deviation parameter of each batch. Must be greater than 0. +// minvals: The minimum cutoff. May be -infinity. +// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval +// for each batch. // -// Returns Singular values. Shape is `[..., P]`.Left singular vectors. If `full_matrices` is `False` then shape is -// `[..., M, P]`; if `full_matrices` is `True` then shape is -// `[..., M, M]`. Undefined if `compute_uv` is `False`.Left singular vectors. If `full_matrices` is `False` then shape is -// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`. -// Undefined if `compute_uv` is false. -func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) { +// Returns A matrix of shape num_batches x samples_per_batch, filled with random +// truncated normal values using the parameters for each row. +func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -22095,602 +22169,470 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf. a(attrs) } opspec := tf.OpSpec{ - Type: "Svd", + Type: "ParameterizedTruncatedNormal", Input: []tf.Input{ - input, + shape, means, stdevs, minvals, maxvals, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. -type QueueEnqueueManyV2Attr func(optionalAttr) - -// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue is too full, this operation will block for up -// to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } + return op.Output(0) } -// Enqueues zero or more tuples of one or more tensors in the given queue. -// -// This operation slices each component tensor along the 0th dimension to -// make multiple queue elements. All of the tuple components must have the -// same size in the 0th dimension. -// -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. -// -// N.B. If the queue is full, this operation will block until the given -// elements have been enqueued (or 'timeout_ms' elapses, if specified). +// Sets the index-th position of the list to contain the given tensor. // -// Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should -// be taken. +// input_handle: the list +// index: the position in the list to which the tensor will be assigned +// item: the element to be assigned to that position +// output_handle: the new list, with the element in the proper position // -// Returns the created operation. -func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { +func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "QueueEnqueueManyV2", + Type: "TensorListSetItem", Input: []tf.Input{ - handle, tf.OutputList(components), + input_handle, index, item, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the product along segments of a tensor. -// -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// Computes the matrix exponential of one or more square matrices: // -// Computes a tensor such that -// \\(output_i = \prod_j data_j\\) where the product is over `j` such -// that `segment_ids[j] == i`. +// exp(A) = \sum_{n=0}^\infty A^n/n! // -// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// The exponential is computed using a combination of the scaling and squaring +// method and the Pade approximation. Details can be founds in: +// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential +// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. // -//
-// -//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the exponential for all input submatrices `[..., :, :]`. // // Arguments: +// input: Shape is `[..., M, M]`. // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. +// Returns Shape is `[..., M, M]`. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// @compatibility(scipy) +// Equivalent to scipy.linalg.expm +// @end_compatibility +func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentProd", + Type: "MatrixExponential", Input: []tf.Input{ - data, segment_ids, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts one or more images from RGB to HSV. -// -// Outputs a tensor of the same shape as the `images` tensor, containing the HSV -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. +type QueueDequeueUpToV2Attr func(optionalAttr) + +// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. // -// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and -// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 -// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Dequeues `n` tuples of one or more tensors from the given queue. +// +// This operation is not supported by all queues. If a queue does not support +// DequeueUpTo, then an Unimplemented error is returned. +// +// If the queue is closed and there are more than 0 but less than `n` +// elements remaining, then instead of returning an OutOfRange error like +// QueueDequeueMany, less than `n` elements are returned immediately. If +// the queue is closed and there are 0 elements left in the queue, then +// an OutOfRange error is returned just like in QueueDequeueMany. +// Otherwise the behavior is identical to QueueDequeueMany: +// +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size n in the 0th dimension. +// +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. // // Arguments: -// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. // -// Returns `images` converted to HSV. -func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RGBToHSV", + Type: "QueueDequeueUpToV2", Input: []tf.Input{ - images, + handle, n, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Does nothing. Only useful as a placeholder for control edges. -// -// Returns the created operation. -func NoOp(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "NoOp", + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueUpToV2", err) + return } - return scope.AddOperation(opspec) + return components } -// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. -type MergeV2CheckpointsAttr func(optionalAttr) - -// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. +// Computes the Cholesky decomposition of one or more square matrices. // -// value: see above. -// If not specified, defaults to true -func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { - return func(m optionalAttr) { - m["delete_old_dirs"] = value - } -} - -// V2 format specific: merges the metadata files of sharded checkpoints. The +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. // -// result is one logical checkpoint, with one physical metadata file and renamed -// data files. +// The input has to be symmetric and positive definite. Only the lower-triangular +// part of the input will be used for this operation. The upper-triangular part +// will not be read. // -// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. +// The output is a tensor of the same shape as the input +// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. // -// If delete_old_dirs is true, attempts to delete recursively the dirname of each -// path in the input checkpoint_prefixes. This is useful when those paths are non -// user-facing temporary locations. +// **Note**: The gradient computation on GPU is faster for large matrices but +// not for large batch dimensions when the submatrices are small. In this +// case it might be faster to use the CPU. // // Arguments: -// checkpoint_prefixes: prefixes of V2 checkpoints to merge. -// destination_prefix: scalar. The desired final prefix. Allowed to be the same -// as one of the checkpoint_prefixes. +// input: Shape is `[..., M, M]`. // -// Returns the created operation. -func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { +// Returns Shape is `[..., M, M]`. +func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MergeV2Checkpoints", + Type: "Cholesky", Input: []tf.Input{ - checkpoint_prefixes, destination_prefix, + input, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Saves input tensors slices to disk. -// -// This is like `Save` except that tensors can be listed in the saved file as being -// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the -// larger tensor and the slice that this tensor covers. `shapes_and_slices` must -// have as many elements as `tensor_names`. -// -// Elements of the `shapes_and_slices` input must either be: -// -// * The empty string, in which case the corresponding tensor is -// saved normally. -// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the -// `dimI` are the dimensions of the larger tensor and `slice-spec` -// specifies what part is covered by the tensor to save. -// -// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` -// where each `sliceI` is either: -// -// * The string `-` meaning that the slice covers all indices of this dimension -// * `start,length` where `start` and `length` are integers. In that -// case the slice covers `length` indices starting at `start`. +// Writes contents to the file at input filename. Creates file and recursively // -// See also `Save`. +// creates directory if not existing. // // Arguments: -// filename: Must have a single element. The name of the file to which we write the -// tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when -// saving the tensors. -// data: `N` tensors to save. +// filename: scalar. The name of the file to which we write the contents. +// contents: scalar. The content to be written to the output file. // // Returns the created operation. -func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { +func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SaveSlices", + Type: "WriteFile", Input: []tf.Input{ - filename, tensor_names, shapes_and_slices, tf.OutputList(data), + filename, contents, }, } return scope.AddOperation(opspec) } -// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. -type DenseToDenseSetOperationAttr func(optionalAttr) +// AllAttr is an optional argument to All. +type AllAttr func(optionalAttr) -// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { +// AllKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AllKeepDims(value bool) AllAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["keep_dims"] = value } } -// Applies set operation along last dimension of 2 `Tensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// Computes the "logical and" of elements across dimensions of a tensor. // -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// Returns The reduced tensor. +func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DenseToDenseSetOperation", + Type: "All", Input: []tf.Input{ - set1, set2, + input, axis, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Generate a sharded filename. The filename is printf formatted as +// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. // -// %s-%05d-of-%05d, basename, shard, num_shards. -func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { +// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices, with the same constraints as the single matrix +// SelfAdjointEig. +// +// The result is a [..., M+1, M] matrix with [..., 0,:] containing the +// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues +// are sorted in non-decreasing order. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M+1, M]`. +func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ShardedFilename", + Type: "SelfAdjointEig", Input: []tf.Input{ - basename, shard, num_shards, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// BatchToSpace for N-D tensors of type T. -// -// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape -// `block_shape + [batch]`, interleaves these blocks back into the grid defined by -// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as -// the input. The spatial dimensions of this intermediate result are then -// optionally cropped according to `crops` to produce the output. This is the -// reverse of SpaceToBatch. See below for a precise description. +// Computes softplus gradients for a softplus operation. // // Arguments: -// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, -// where spatial_shape has M dimensions. -// block_shape: 1-D with shape `[M]`, all values must be >= 1. -// crops: 2-D with shape `[M, 2]`, all values must be >= 0. -// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input -// dimension `i + 1`, which corresponds to spatial dimension `i`. It is -// required that -// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. -// -// This operation is equivalent to the following steps: +// gradients: The backpropagated gradients to the corresponding softplus operation. +// features: The features passed as input to the corresponding softplus operation. // -// 1. Reshape `input` to `reshaped` of shape: -// [block_shape[0], ..., block_shape[M-1], -// batch / prod(block_shape), -// input_shape[1], ..., input_shape[N-1]] -// -// 2. Permute dimensions of `reshaped` to produce `permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1], block_shape[0], -// ..., -// input_shape[M], block_shape[M-1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// 3. Reshape `permuted` to produce `reshaped_permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0], -// ..., -// input_shape[M] * block_shape[M-1], -// -// input_shape[M+1], -// ..., -// input_shape[N-1]] -// -// 4. Crop the start and end of dimensions `[1, ..., M]` of -// `reshaped_permuted` according to `crops` to produce the output of shape: -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], -// ..., -// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 3]` and value: -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` -// -// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// The output tensor has shape `[1, 4, 4, 1]` and value: -// -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` -// -// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [2, 0]]`: -// -// ``` -// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], -// [[[0], [2], [4]]], [[[0], [10], [12]]], -// [[[0], [5], [7]]], [[[0], [13], [15]]], -// [[[0], [6], [8]]], [[[0], [14], [16]]]] -// ``` -// -// The output tensor has shape `[2, 2, 4, 1]` and value: -// -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]]], -// [[[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` -func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { +// Returns The gradients: `gradients / (1 + exp(-features))`. +func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BatchToSpaceND", + Type: "SoftplusGrad", Input: []tf.Input{ - input, block_shape, crops, + gradients, features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnpackAttr is an optional argument to Unpack. -type UnpackAttr func(optionalAttr) +// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. +type SelfAdjointEigV2Attr func(optionalAttr) -// UnpackAxis sets the optional axis attribute to value. +// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. // -// value: Dimension along which to unpack. Negative values wrap around, so the -// valid range is `[-R, R)`. -// If not specified, defaults to 0 -func UnpackAxis(value int64) UnpackAttr { +// value: If `True` then eigenvectors will be computed and returned in `v`. +// Otherwise, only the eigenvalues will be computed. +// If not specified, defaults to true +func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { return func(m optionalAttr) { - m["axis"] = value + m["compute_v"] = value } } -// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. -// -// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. -// For example, given a tensor of shape `(A, B, C, D)`; -// -// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` -// and each tensor in `output` will have shape `(B, C, D)`. (Note that the -// dimension unpacked along is gone, unlike `split`). +// Computes the eigen decomposition of one or more square self-adjoint matrices. // -// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` -// and each tensor in `output` will have shape `(A, C, D)`. -// Etc. +// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in +// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues +// are sorted in non-decreasing order. // -// This is the opposite of `pack`. +// ```python +// # a is a tensor. +// # e is a tensor of eigenvalues. +// # v is a tensor of eigenvectors. +// e, v = self_adjoint_eig(a) +// e = self_adjoint_eig(a, compute_v=False) +// ``` // // Arguments: -// value: 1-D or higher, with `axis` dimension size equal to `num`. -// +// input: `Tensor` input of shape `[N, N]`. // -// Returns The list of tensors unpacked from `value`. -func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { +// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. +func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num": num} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Unpack", + Type: "SelfAdjointEigV2", Input: []tf.Input{ - value, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Unpack", err) - return - } - return output + return op.Output(0), op.Output(1) } -// Increments variable pointed to by 'resource' until it reaches 'limit'. +// Adjust the saturation of one or more images. // -// Arguments: -// resource: Should be from a scalar `Variable` node. -// limit: If incrementing ref would bring it above limit, instead generates an -// 'OutOfRange' error. +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A scale is then applied all the saturation +// values, and then remapped back to RGB colorspace. // +// Arguments: +// images: Images to adjust. At least 3-D. +// scale: A float scale to add to the saturation. // -// Returns A copy of the input before increment. If nothing else modifies the -// input, the values produced will all be distinct. -func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { +// Returns The hue-adjusted image or images. +func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"limit": limit, "T": T} opspec := tf.OpSpec{ - Type: "ResourceCountUpTo", + Type: "AdjustSaturation", Input: []tf.Input{ - resource, + images, scale, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Delete the stack from its resource container. -// -// Arguments: -// handle: The handle to a stack. +// MatrixSolveAttr is an optional argument to MatrixSolve. +type MatrixSolveAttr func(optionalAttr) + +// MatrixSolveAdjoint sets the optional adjoint attribute to value. // -// Returns the created operation. -func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StackCloseV2", - Input: []tf.Input{ - handle, - }, +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. +// If not specified, defaults to false +func MatrixSolveAdjoint(value bool) MatrixSolveAttr { + return func(m optionalAttr) { + m["adjoint"] = value } - return scope.AddOperation(opspec) } -// Generate a glob pattern matching all sharded file names. -func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { +// Solves systems of linear equations. +// +// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `True` then each output matrix satisfies +// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. +// +// Arguments: +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. +// +// Returns Shape is `[..., M, K]`. +func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ShardedFilespec", + Type: "MatrixSolve", Input: []tf.Input{ - basename, num_shards, + matrix, rhs, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. -type TextLineReaderV2Attr func(optionalAttr) +// SvdAttr is an optional argument to Svd. +type SvdAttr func(optionalAttr) -// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. +// SvdComputeUv sets the optional compute_uv attribute to value. // -// value: Number of lines to skip from the beginning of every file. -// If not specified, defaults to 0 -func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["skip_header_lines"] = value - } -} - -// TextLineReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TextLineReaderV2Container(value string) TextLineReaderV2Attr { +// value: If true, left and right singular vectors will be +// computed and returned in `u` and `v`, respectively. +// If false, `u` and `v` are not set and should never referenced. +// If not specified, defaults to true +func SvdComputeUv(value bool) SvdAttr { return func(m optionalAttr) { - m["container"] = value + m["compute_uv"] = value } } -// TextLineReaderV2SharedName sets the optional shared_name attribute to value. +// SvdFullMatrices sets the optional full_matrices attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { +// value: If true, compute full-sized `u` and `v`. If false +// (the default), compute only the leading `P` singular vectors. +// Ignored if `compute_uv` is `False`. +// If not specified, defaults to false +func SvdFullMatrices(value bool) SvdAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["full_matrices"] = value } } -// A Reader that outputs the lines of a file delimited by '\n'. +// Computes the singular value decompositions of one or more matrices. // -// Returns The handle to reference the Reader. -func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { +// Computes the SVD of each inner matrix in `input` such that +// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` +// +// ```python +// # a is a tensor containing a batch of matrices. +// # s is a tensor of singular values for each matrix. +// # u is the tensor containing of left singular vectors for each matrix. +// # v is the tensor containing of right singular vectors for each matrix. +// s, u, v = svd(a) +// s, _, _ = svd(a, compute_uv=False) +// ``` +// +// Arguments: +// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. +// +// Returns Singular values. Shape is `[..., P]`.Left singular vectors. If `full_matrices` is `False` then shape is +// `[..., M, P]`; if `full_matrices` is `True` then shape is +// `[..., M, M]`. Undefined if `compute_uv` is `False`.Left singular vectors. If `full_matrices` is `False` then shape is +// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`. +// Undefined if `compute_uv` is false. +func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) { if scope.Err() != nil { return } @@ -22699,270 +22641,175 @@ func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_ha a(attrs) } opspec := tf.OpSpec{ - Type: "TextLineReaderV2", - + Type: "Svd", + Input: []tf.Input{ + input, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. -type LoadAndRemapMatrixAttr func(optionalAttr) +// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. +type QueueEnqueueManyV2Attr func(optionalAttr) -// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. +// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. // -// value: The maximum number of rows to load from the checkpoint at -// once. If less than or equal to 0, the entire matrix will be loaded into -// memory. Setting this arg trades increased disk reads for lower memory usage. +// value: If the queue is too full, this operation will block for up +// to timeout_ms milliseconds. +// Note: This option is not supported yet. // If not specified, defaults to -1 -func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { +func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { return func(m optionalAttr) { - m["max_rows_in_memory"] = value + m["timeout_ms"] = value } } -// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint -// -// at `ckpt_path` and potentially reorders its rows and columns using the -// specified remappings. -// -// Most users should use one of the wrapper initializers (such as -// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this -// function directly. -// -// The remappings are 1-D tensors with the following properties: -// -// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output -// matrix will be initialized from the row corresponding to index -// `row_remapping[i]` in the old `Tensor` from the checkpoint. -// * `col_remapping` must have either 0 entries (indicating that no column -// reordering is needed) or `num_cols` entries. If specified, column `j` of the -// output matrix will be initialized from the column corresponding to index -// `col_remapping[j]` in the old `Tensor` from the checkpoint. -// * A value of -1 in either of the remappings signifies a "missing" entry. In that -// case, values from the `initializing_values` tensor will be used to fill that -// missing row or column. If `row_remapping` has `r` missing entries and -// `col_remapping` has `c` missing entries, then the following condition must be -// true: -// -// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` +// Enqueues zero or more tuples of one or more tensors in the given queue. // -// The remapping tensors can be generated using the GenerateVocabRemapping op. +// This operation slices each component tensor along the 0th dimension to +// make multiple queue elements. All of the tuple components must have the +// same size in the 0th dimension. // -// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], -// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing -// the value from row i, column j of the old tensor in the checkpoint, the output -// matrix will look like the following: +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. // -// [[w(1, 0), w(1, 2), 0.5], -// [w(0, 0), w(0, 2), -0.5], -// [0.25, -0.25, 42]] +// N.B. If the queue is full, this operation will block until the given +// elements have been enqueued (or 'timeout_ms' elapses, if specified). // // Arguments: -// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from -// which the old matrix `Tensor` will be loaded. -// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. -// row_remapping: An int `Tensor` of row remappings (generally created by -// `generate_vocab_remapping`). Even if no row remapping is needed, this must -// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted -// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). -// col_remapping: An int `Tensor` of column remappings (generally created by -// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping -// is to be done (e.g. column ordering is the same). -// initializing_values: A float `Tensor` containing values to fill in for cells -// in the output matrix that are not loaded from the checkpoint. Length must be -// exactly the same as the number of missing / new cells. -// num_rows: Number of rows (length of the 1st dimension) in the output matrix. -// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should +// be taken. // -// Returns Output matrix containing existing values loaded from the -// checkpoint, and with any missing values filled in from initializing_values. -func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { +// Returns the created operation. +func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LoadAndRemapMatrix", + Type: "QueueEnqueueManyV2", Input: []tf.Input{ - ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, + handle, tf.OutputList(components), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. -type TFRecordReaderV2Attr func(optionalAttr) - -// TFRecordReaderV2Container sets the optional container attribute to value. +// Computes the product along segments of a tensor. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TFRecordReaderV2Container(value string) TFRecordReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TFRecordReaderV2SharedName sets the optional shared_name attribute to value. +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TFRecordReaderV2SharedName(value string) TFRecordReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value +// Computes a tensor such that +// \\(output_i = \prod_j data_j\\) where the product is over `j` such +// that `segment_ids[j] == i`. +// +// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// +//
+// +//
+// +// Arguments: +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// TFRecordReaderV2CompressionType sets the optional compression_type attribute to value. -// If not specified, defaults to "" -func TFRecordReaderV2CompressionType(value string) TFRecordReaderV2Attr { - return func(m optionalAttr) { - m["compression_type"] = value + opspec := tf.OpSpec{ + Type: "SegmentProd", + Input: []tf.Input{ + data, segment_ids, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// A Reader that outputs the records from a TensorFlow Records file. +// Converts one or more images from RGB to HSV. // -// Returns The handle to reference the Reader. -func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_handle tf.Output) { +// Outputs a tensor of the same shape as the `images` tensor, containing the HSV +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// +// Arguments: +// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// +// Returns `images` converted to HSV. +func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TFRecordReaderV2", - - Attrs: attrs, + Type: "RGBToHSV", + Input: []tf.Input{ + images, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. -type QuantizeAndDequantizeV3Attr func(optionalAttr) - -// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["signed_input"] = value +// Does nothing. Only useful as a placeholder for control edges. +// +// Returns the created operation. +func NoOp(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NoOp", } + return scope.AddOperation(opspec) } -// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["range_given"] = value - } -} - -// Quantizes then dequantizes a tensor. -// -// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a -// tensor, so its value can change during training. -func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizeAndDequantizeV3", - Input: []tf.Input{ - input, input_min, input_max, num_bits, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. -type IdentityReaderV2Attr func(optionalAttr) - -// IdentityReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func IdentityReaderV2Container(value string) IdentityReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} +// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. +type MergeV2CheckpointsAttr func(optionalAttr) -// IdentityReaderV2SharedName sets the optional shared_name attribute to value. +// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { +// value: see above. +// If not specified, defaults to true +func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["delete_old_dirs"] = value } } -// A Reader that outputs the queued work as both the key and value. +// V2 format specific: merges the metadata files of sharded checkpoints. The // -// To use, enqueue strings in a Queue. ReaderRead will take the front -// work string and output (work, work). +// result is one logical checkpoint, with one physical metadata file and renamed +// data files. // -// Returns The handle to reference the Reader. -func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "IdentityReaderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. -type ResourceApplyGradientDescentAttr func(optionalAttr) - -// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. +// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. // -// value: If `True`, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' by subtracting 'alpha' * 'delta' from it. +// If delete_old_dirs is true, attempts to delete recursively the dirname of each +// path in the input checkpoint_prefixes. This is useful when those paths are non +// user-facing temporary locations. // // Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// delta: The change. +// checkpoint_prefixes: prefixes of V2 checkpoints to merge. +// destination_prefix: scalar. The desired final prefix. Allowed to be the same +// as one of the checkpoint_prefixes. // // Returns the created operation. -func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { +func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -22971,271 +22818,345 @@ func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyGradientDescent", + Type: "MergeV2Checkpoints", Input: []tf.Input{ - var_, alpha, delta, + checkpoint_prefixes, destination_prefix, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns the next record (key, value pair) produced by a Reader. +// Saves input tensors slices to disk. // -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). +// This is like `Save` except that tensors can be listed in the saved file as being +// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the +// larger tensor and the slice that this tensor covers. `shapes_and_slices` must +// have as many elements as `tensor_names`. // -// Arguments: -// reader_handle: Handle to a Reader. -// queue_handle: Handle to a Queue, with string work items. +// Elements of the `shapes_and_slices` input must either be: // -// Returns A scalar.A scalar. -func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderReadV2", - Input: []tf.Input{ - reader_handle, queue_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Returns up to `num_records` (key, value) pairs produced by a Reader. +// * The empty string, in which case the corresponding tensor is +// saved normally. +// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the +// `dimI` are the dimensions of the larger tensor and `slice-spec` +// specifies what part is covered by the tensor to save. // -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// It may return less than `num_records` even before the last batch. +// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` +// where each `sliceI` is either: +// +// * The string `-` meaning that the slice covers all indices of this dimension +// * `start,length` where `start` and `length` are integers. In that +// case the slice covers `length` indices starting at `start`. +// +// See also `Save`. // // Arguments: -// reader_handle: Handle to a `Reader`. -// queue_handle: Handle to a `Queue`, with string work items. -// num_records: number of records to read from `Reader`. +// filename: Must have a single element. The name of the file to which we write the +// tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when +// saving the tensors. +// data: `N` tensors to save. // -// Returns A 1-D tensor.A 1-D tensor. -func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output, num_records tf.Output) (keys tf.Output, values tf.Output) { +// Returns the created operation. +func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderReadUpToV2", + Type: "SaveSlices", Input: []tf.Input{ - reader_handle, queue_handle, num_records, + filename, tensor_names, shapes_and_slices, tf.OutputList(data), }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. -type ResourceApplyAdamAttr func(optionalAttr) +// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. +type DenseToDenseSetOperationAttr func(optionalAttr) -// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { +// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["validate_indices"] = value } } -// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// Applies set operation along last dimension of 2 `Tensor` inputs. // -// value: If `True`, uses the nesterov update. -// If not specified, defaults to false -func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the Adam algorithm. +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. // -// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t -// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t -// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. // -// Returns the created operation. -func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"set_operation": set_operation} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdam", + Type: "DenseToDenseSetOperation", Input: []tf.Input{ - var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + set1, set2, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// Store the input tensor in the state of the current session. -// -// Arguments: -// value: The tensor to be stored. +// Generate a sharded filename. The filename is printf formatted as // -// Returns The handle for the tensor stored in the session state, represented -// as a ResourceHandle object. -func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { +// %s-%05d-of-%05d, basename, shard, num_shards. +func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GetSessionHandleV2", + Type: "ShardedFilename", Input: []tf.Input{ - value, + basename, shard, num_shards, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. -type ResizeBicubicGradAttr func(optionalAttr) - -// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// BatchToSpace for N-D tensors of type T. // -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of bicubic interpolation. +// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape +// `block_shape + [batch]`, interleaves these blocks back into the grid defined by +// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as +// the input. The spatial dimensions of this intermediate result are then +// optionally cropped according to `crops` to produce the output. This is the +// reverse of SpaceToBatch. See below for a precise description. // // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. +// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, +// where spatial_shape has M dimensions. +// block_shape: 1-D with shape `[M]`, all values must be >= 1. +// crops: 2-D with shape `[M, 2]`, all values must be >= 0. +// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input +// dimension `i + 1`, which corresponds to spatial dimension `i`. It is +// required that +// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { +// This operation is equivalent to the following steps: +// +// 1. Reshape `input` to `reshaped` of shape: +// [block_shape[0], ..., block_shape[M-1], +// batch / prod(block_shape), +// input_shape[1], ..., input_shape[N-1]] +// +// 2. Permute dimensions of `reshaped` to produce `permuted` of shape +// [batch / prod(block_shape), +// +// input_shape[1], block_shape[0], +// ..., +// input_shape[M], block_shape[M-1], +// +// input_shape[M+1], ..., input_shape[N-1]] +// +// 3. Reshape `permuted` to produce `reshaped_permuted` of shape +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0], +// ..., +// input_shape[M] * block_shape[M-1], +// +// input_shape[M+1], +// ..., +// input_shape[N-1]] +// +// 4. Crop the start and end of dimensions `[1, ..., M]` of +// `reshaped_permuted` according to `crops` to produce the output of shape: +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], +// ..., +// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], +// +// input_shape[M+1], ..., input_shape[N-1]] +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]] +// ``` +// +// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [2, 0]]`: +// +// ``` +// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], +// [[[0], [2], [4]]], [[[0], [10], [12]]], +// [[[0], [5], [7]]], [[[0], [13], [15]]], +// [[[0], [6], [8]]], [[[0], [14], [16]]]] +// ``` +// +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]]], +// [[[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResizeBicubicGrad", + Type: "BatchToSpaceND", Input: []tf.Input{ - grads, original_image, + input, block_shape, crops, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. -type ResizeNearestNeighborAttr func(optionalAttr) +// UnpackAttr is an optional argument to Unpack. +type UnpackAttr func(optionalAttr) -// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// UnpackAxis sets the optional axis attribute to value. // -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { +// value: Dimension along which to unpack. Negative values wrap around, so the +// valid range is `[-R, R)`. +// If not specified, defaults to 0 +func UnpackAxis(value int64) UnpackAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["axis"] = value } } -// Resize `images` to `size` using nearest neighbor interpolation. +// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. +// +// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. +// For example, given a tensor of shape `(A, B, C, D)`; +// +// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` +// and each tensor in `output` will have shape `(B, C, D)`. (Note that the +// dimension unpacked along is gone, unlike `split`). +// +// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` +// and each tensor in `output` will have shape `(A, C, D)`. +// Etc. +// +// This is the opposite of `pack`. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// value: 1-D or higher, with `axis` dimension size equal to `num`. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { +// +// Returns The list of tensors unpacked from `value`. +func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num": num} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeNearestNeighbor", + Type: "Unpack", Input: []tf.Input{ - images, size, + value, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. -type ResizeNearestNeighborGradAttr func(optionalAttr) - -// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Unpack", err) + return } + return output } -// Computes the gradient of nearest neighbor interpolation. +// Increments variable pointed to by 'resource' until it reaches 'limit'. // // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The -// original input size. +// resource: Should be from a scalar `Variable` node. +// limit: If incrementing ref would bring it above limit, instead generates an +// 'OutOfRange' error. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients -// with respect to the input image. -func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { +// +// Returns A copy of the input before increment. If nothing else modifies the +// input, the values produced will all be distinct. +func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"limit": limit, "T": T} opspec := tf.OpSpec{ - Type: "ResizeNearestNeighborGrad", + Type: "ResourceCountUpTo", Input: []tf.Input{ - grads, size, + resource, }, Attrs: attrs, } @@ -23243,122 +23164,88 @@ func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, op return op.Output(0) } -// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. -type ExtractJpegShapeAttr func(optionalAttr) - -// ExtractJpegShapeOutputType sets the optional output_type attribute to value. -// -// value: (Optional) The output type of the operation (int32 or int64). -// Defaults to int32. -// If not specified, defaults to DT_INT32 -func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { - return func(m optionalAttr) { - m["output_type"] = value - } -} - -// Extract the shape information of a JPEG-encoded image. -// -// This op only parses the image header, so it is much faster than DecodeJpeg. +// Delete the stack from its resource container. // // Arguments: -// contents: 0-D. The JPEG-encoded image. +// handle: The handle to a stack. // -// Returns 1-D. The image shape with format [height, width, channels]. -func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { +// Returns the created operation. +func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "StackCloseV2", + Input: []tf.Input{ + handle, + }, + } + return scope.AddOperation(opspec) +} + +// Generate a glob pattern matching all sharded file names. +func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { + if scope.Err() != nil { + return } opspec := tf.OpSpec{ - Type: "ExtractJpegShape", + Type: "ShardedFilespec", Input: []tf.Input{ - contents, + basename, num_shards, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. -type PaddingFIFOQueueV2Attr func(optionalAttr) - -// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. -// Shapes of fixed rank but variable size are allowed by setting -// any shape dimension to -1. In this case, the inputs' shape may vary along -// the given dimension, and DequeueMany will pad the given dimension with -// zeros up to the maximum shape of all elements in the given batch. -// If the length of this attr is 0, different queue elements may have -// different ranks and shapes, but only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} +// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. +type TextLineReaderV2Attr func(optionalAttr) -// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. +// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { +// value: Number of lines to skip from the beginning of every file. +// If not specified, defaults to 0 +func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["skip_header_lines"] = value } } -// PaddingFIFOQueueV2Container sets the optional container attribute to value. +// TextLineReaderV2Container sets the optional container attribute to value. // -// value: If non-empty, this queue is placed in the given container. +// value: If non-empty, this reader is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { +func TextLineReaderV2Container(value string) TextLineReaderV2Attr { return func(m optionalAttr) { m["container"] = value } } -// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. +// TextLineReaderV2SharedName sets the optional shared_name attribute to value. // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. // If not specified, defaults to "" -func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { +func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// A queue that produces elements in first-in first-out order. -// -// Variable-size shapes are allowed by setting the corresponding shape dimensions -// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum -// size of any given element in the minibatch. See below for details. -// -// Arguments: -// component_types: The type of each component in a value. +// A Reader that outputs the lines of a file delimited by '\n'. // -// Returns The handle to the queue. -func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { +// Returns The handle to reference the Reader. +func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "PaddingFIFOQueueV2", + Type: "TextLineReaderV2", Attrs: attrs, } @@ -23366,61 +23253,89 @@ func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional .. return op.Output(0) } -// DecodePngAttr is an optional argument to DecodePng. -type DecodePngAttr func(optionalAttr) +// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. +type LoadAndRemapMatrixAttr func(optionalAttr) -// DecodePngChannels sets the optional channels attribute to value. +// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodePngChannels(value int64) DecodePngAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodePngDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_UINT8 -func DecodePngDtype(value tf.DataType) DecodePngAttr { +// value: The maximum number of rows to load from the checkpoint at +// once. If less than or equal to 0, the entire matrix will be loaded into +// memory. Setting this arg trades increased disk reads for lower memory usage. +// If not specified, defaults to -1 +func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { return func(m optionalAttr) { - m["dtype"] = value + m["max_rows_in_memory"] = value } } -// Decode a PNG-encoded image to a uint8 or uint16 tensor. +// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint // -// The attr `channels` indicates the desired number of color channels for the -// decoded image. +// at `ckpt_path` and potentially reorders its rows and columns using the +// specified remappings. // -// Accepted values are: +// Most users should use one of the wrapper initializers (such as +// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this +// function directly. // -// * 0: Use the number of channels in the PNG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// * 4: output an RGBA image. +// The remappings are 1-D tensors with the following properties: // -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. +// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output +// matrix will be initialized from the row corresponding to index +// `row_remapping[i]` in the old `Tensor` from the checkpoint. +// * `col_remapping` must have either 0 entries (indicating that no column +// reordering is needed) or `num_cols` entries. If specified, column `j` of the +// output matrix will be initialized from the column corresponding to index +// `col_remapping[j]` in the old `Tensor` from the checkpoint. +// * A value of -1 in either of the remappings signifies a "missing" entry. In that +// case, values from the `initializing_values` tensor will be used to fill that +// missing row or column. If `row_remapping` has `r` missing entries and +// `col_remapping` has `c` missing entries, then the following condition must be +// true: // -// This op also supports decoding JPEGs and non-animated GIFs since the interface -// is the same, though it is cleaner to use `tf.image.decode_image`. +// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` +// +// The remapping tensors can be generated using the GenerateVocabRemapping op. +// +// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], +// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing +// the value from row i, column j of the old tensor in the checkpoint, the output +// matrix will look like the following: +// +// [[w(1, 0), w(1, 2), 0.5], +// [w(0, 0), w(0, 2), -0.5], +// [0.25, -0.25, 42]] // // Arguments: -// contents: 0-D. The PNG-encoded image. +// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from +// which the old matrix `Tensor` will be loaded. +// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. +// row_remapping: An int `Tensor` of row remappings (generally created by +// `generate_vocab_remapping`). Even if no row remapping is needed, this must +// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted +// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). +// col_remapping: An int `Tensor` of column remappings (generally created by +// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping +// is to be done (e.g. column ordering is the same). +// initializing_values: A float `Tensor` containing values to fill in for cells +// in the output matrix that are not loaded from the checkpoint. Length must be +// exactly the same as the number of missing / new cells. +// num_rows: Number of rows (length of the 1st dimension) in the output matrix. +// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. // -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { +// Returns Output matrix containing existing values loaded from the +// checkpoint, and with any missing values filled in from initializing_values. +func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodePng", + Type: "LoadAndRemapMatrix", Input: []tf.Input{ - contents, + ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, }, Attrs: attrs, } @@ -23428,256 +23343,350 @@ func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (ima return op.Output(0) } -// Decode the first frame of a GIF-encoded image to a uint8 tensor. -// -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: -// -// convert $src.gif -coalesce $dst.gif +// TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. +type TFRecordReaderV2Attr func(optionalAttr) + +// TFRecordReaderV2Container sets the optional container attribute to value. // -// This op also supports decoding JPEGs and PNGs, though it is cleaner to use -// `tf.image.decode_image`. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TFRecordReaderV2Container(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TFRecordReaderV2SharedName sets the optional shared_name attribute to value. // -// Arguments: -// contents: 0-D. The GIF-encoded image. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TFRecordReaderV2SharedName(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// TFRecordReaderV2CompressionType sets the optional compression_type attribute to value. +// If not specified, defaults to "" +func TFRecordReaderV2CompressionType(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// A Reader that outputs the records from a TensorFlow Records file. // -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { +// Returns The handle to reference the Reader. +func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DecodeGif", - Input: []tf.Input{ - contents, - }, + Type: "TFRecordReaderV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient of the sigmoid of `x` wrt its input. -// -// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and -// `dy` is the corresponding input gradient. -func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } +// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. +type QuantizeAndDequantizeV3Attr func(optionalAttr) + +// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// Quantizes then dequantizes a tensor. +// +// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a +// tensor, so its value can change during training. +func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SigmoidGrad", + Type: "QuantizeAndDequantizeV3", Input: []tf.Input{ - y, dy, + input, input_min, input_max, num_bits, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Convert one or more images from HSV to RGB. +// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. +type IdentityReaderV2Attr func(optionalAttr) + +// IdentityReaderV2Container sets the optional container attribute to value. // -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func IdentityReaderV2Container(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// IdentityReaderV2SharedName sets the optional shared_name attribute to value. // -// See `rgb_to_hsv` for a description of the HSV encoding. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the queued work as both the key and value. // -// Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// To use, enqueue strings in a Queue. ReaderRead will take the front +// work string and output (work, work). // -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { +// Returns The handle to reference the Reader. +func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "HSVToRGB", - Input: []tf.Input{ - images, - }, + Type: "IdentityReaderV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. +type ResourceApplyGradientDescentAttr func(optionalAttr) + +// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' by subtracting 'alpha' * 'delta' from it. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// delta: The change. // -// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest -// layer. -func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { +// Returns the created operation. +func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesGetEnsembleStates", + Type: "ResourceApplyGradientDescent", Input: []tf.Input{ - tree_ensemble_handle, + var_, alpha, delta, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return scope.AddOperation(opspec) } -// Gets the next output from the given iterator. +// Returns the next record (key, value pair) produced by a Reader. // -// This operation is a synchronous version IteratorGetNext. It should only be used -// in situations where the iterator does not block the calling thread, or where -// the calling thread is not a member of the thread pool used to execute parallel -// operations (e.g. in eager mode). -func IteratorGetNextSync(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// +// Arguments: +// reader_handle: Handle to a Reader. +// queue_handle: Handle to a Queue, with string work items. +// +// Returns A scalar.A scalar. +func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IteratorGetNextSync", + Type: "ReaderReadV2", Input: []tf.Input{ - iterator, + reader_handle, queue_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Returns up to `num_records` (key, value) pairs produced by a Reader. +// +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// It may return less than `num_records` even before the last batch. +// +// Arguments: +// reader_handle: Handle to a `Reader`. +// queue_handle: Handle to a `Queue`, with string work items. +// num_records: number of records to read from `Reader`. +// +// Returns A 1-D tensor.A 1-D tensor. +func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output, num_records tf.Output) (keys tf.Output, values tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNextSync", err) - return + opspec := tf.OpSpec{ + Type: "ReaderReadUpToV2", + Input: []tf.Input{ + reader_handle, queue_handle, num_records, + }, } - return components + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// SampleDistortedBoundingBoxV2Attr is an optional argument to SampleDistortedBoundingBoxV2. -type SampleDistortedBoundingBoxV2Attr func(optionalAttr) +// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. +type ResourceApplyAdamAttr func(optionalAttr) -// SampleDistortedBoundingBoxV2Seed sets the optional seed attribute to value. +// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. // -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxV2Seed(value int64) SampleDistortedBoundingBoxV2Attr { +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { return func(m optionalAttr) { - m["seed"] = value + m["use_locking"] = value } } -// SampleDistortedBoundingBoxV2Seed2 sets the optional seed2 attribute to value. +// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2Attr { +// value: If `True`, uses the nesterov update. +// If not specified, defaults to false +func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_nesterov"] = value } } -// SampleDistortedBoundingBoxV2AspectRatioRange sets the optional aspect_ratio_range attribute to value. +// Update '*var' according to the Adam algorithm. // -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["aspect_ratio_range"] = value - } -} - -// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. +// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t +// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t +// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) // -// value: The cropped area of the image must contain a fraction of the -// supplied image within in this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["area_range"] = value +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdam", + Input: []tf.Input{ + var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) } -// SampleDistortedBoundingBoxV2MaxAttempts sets the optional max_attempts attribute to value. +// Store the input tensor in the state of the current session. // -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxV2MaxAttempts(value int64) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["max_attempts"] = value +// Arguments: +// value: The tensor to be stored. +// +// Returns The handle for the tensor stored in the session state, represented +// as a ResourceHandle object. +func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GetSessionHandleV2", + Input: []tf.Input{ + value, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. +type ResizeBicubicGradAttr func(optionalAttr) + +// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. // -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. // If not specified, defaults to false -func SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxV2Attr { +func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value + m["align_corners"] = value } } -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, -// -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) -// -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) -// -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` -// -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. +// Computes the gradient of bicubic interpolation. // // Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. -// min_object_covered: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. // -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, min_object_covered tf.Output, optional ...SampleDistortedBoundingBoxV2Attr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -23686,88 +23695,40 @@ func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_b a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBoxV2", + Type: "ResizeBicubicGrad", Input: []tf.Input{ - image_size, bounding_boxes, min_object_covered, + grads, original_image, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. -type ExtractGlimpseAttr func(optionalAttr) - -// ExtractGlimpseCentered sets the optional centered attribute to value. -// -// value: indicates if the offset coordinates are centered relative to -// the image, in which case the (0, 0) offset is relative to the center -// of the input images. If false, the (0,0) offset corresponds to the -// upper left corner of the input images. -// If not specified, defaults to true -func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["centered"] = value - } + return op.Output(0) } -// ExtractGlimpseNormalized sets the optional normalized attribute to value. -// -// value: indicates if the offset coordinates are normalized. -// If not specified, defaults to true -func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["normalized"] = value - } -} +// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. +type ResizeNearestNeighborAttr func(optionalAttr) -// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. +// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. // -// value: indicates if the noise should be generated using a -// uniform distribution or a Gaussian distribution. -// If not specified, defaults to true -func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { return func(m optionalAttr) { - m["uniform_noise"] = value + m["align_corners"] = value } } -// Extracts a glimpse from the input tensor. -// -// Returns a set of windows called glimpses extracted at location -// `offsets` from the input tensor. If the windows only partially -// overlaps the inputs, the non overlapping areas will be filled with -// random noise. -// -// The result is a 4-D tensor of shape `[batch_size, glimpse_height, -// glimpse_width, channels]`. The channels and batch dimensions are the -// same as that of the input tensor. The height and width of the output -// windows are specified in the `size` parameter. -// -// The argument `normalized` and `centered` controls how the windows are built: -// -// * If the coordinates are normalized but not centered, 0.0 and 1.0 -// correspond to the minimum and maximum of each height and width -// dimension. -// * If the coordinates are both normalized and centered, they range from -// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper -// left corner, the lower right corner is located at (1.0, 1.0) and the -// center is at (0, 0). -// * If the coordinates are not normalized they are interpreted as -// numbers of pixels. +// Resize `images` to `size` using nearest neighbor interpolation. // // Arguments: -// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. -// size: A 1-D tensor of 2 elements containing the size of the glimpses -// to extract. The glimpse height must be specified first, following -// by the glimpse width. -// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing -// the y, x locations of the center of each window. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns A tensor representing the glimpses `[batch_size, -// glimpse_height, glimpse_width, channels]`. -func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { if scope.Err() != nil { return } @@ -23776,9 +23737,9 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou a(attrs) } opspec := tf.OpSpec{ - Type: "ExtractGlimpse", + Type: "ResizeNearestNeighbor", Input: []tf.Input{ - input, size, offsets, + images, size, }, Attrs: attrs, } @@ -23786,72 +23747,41 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou return op.Output(0) } -// A container for an iterator resource. -// -// Returns A handle to the iterator that can be passed to a "MakeIterator" -// or "IteratorGetNext" op. -func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "Iterator", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) +// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. +type ResizeNearestNeighborGradAttr func(optionalAttr) -// CropAndResizeGradImageMethod sets the optional method attribute to value. +// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. // -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { return func(m optionalAttr) { - m["method"] = value + m["align_corners"] = value } } -// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// Computes the gradient of nearest neighbor interpolation. // // Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. -// +// grads: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The +// original input size. // -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients +// with respect to the input image. +func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", + Type: "ResizeNearestNeighborGrad", Input: []tf.Input{ - grads, boxes, box_ind, image_size, + grads, size, }, Attrs: attrs, } @@ -23859,48 +23789,40 @@ func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ return op.Output(0) } -// ShuffleDatasetAttr is an optional argument to ShuffleDataset. -type ShuffleDatasetAttr func(optionalAttr) +// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. +type ExtractJpegShapeAttr func(optionalAttr) -// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. +// ExtractJpegShapeOutputType sets the optional output_type attribute to value. // -// value: If true, each iterator over this dataset will be given -// a different pseudorandomly generated seed, based on a sequence seeded by the -// `seed` and `seed2` inputs. If false, each iterator will be given the same -// seed, and repeated iteration over this dataset will yield the exact same -// sequence of results. -// If not specified, defaults to true -func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { +// value: (Optional) The output type of the operation (int32 or int64). +// Defaults to int32. +// If not specified, defaults to DT_INT32 +func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { return func(m optionalAttr) { - m["reshuffle_each_iteration"] = value + m["output_type"] = value } } -// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. -// -// Arguments: +// Extract the shape information of a JPEG-encoded image. // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. +// This op only parses the image header, so it is much faster than DecodeJpeg. // +// Arguments: +// contents: 0-D. The JPEG-encoded image. // -func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { +// Returns 1-D. The image shape with format [height, width, channels]. +func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ShuffleDataset", + Type: "ExtractJpegShape", Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, + contents, }, Attrs: attrs, } @@ -23908,69 +23830,132 @@ func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output return op.Output(0) } -// 3D fast Fourier transform. +// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. +type PaddingFIFOQueueV2Attr func(optionalAttr) + +// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. // -// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 -// dimensions of `input`. +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. +// Shapes of fixed rank but variable size are allowed by setting +// any shape dimension to -1. In this case, the inputs' shape may vary along +// the given dimension, and DequeueMany will pad the given dimension with +// zeros up to the maximum shape of all elements in the given batch. +// If the length of this attr is 0, different queue elements may have +// different ranks and shapes, but only one element may be dequeued at a time. +// If not specified, defaults to <> // -// Arguments: -// input: A complex64 tensor. +// REQUIRES: len(value) >= 0 +func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their 3D Fourier transform. +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// PaddingFIFOQueueV2Container sets the optional container attribute to value. // -// @compatibility(numpy) -// Equivalent to np.fft.fftn with 3 dimensions. -// @end_compatibility -func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that produces elements in first-in first-out order. +// +// Variable-size shapes are allowed by setting the corresponding shape dimensions +// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum +// size of any given element in the minibatch. See below for details. +// +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "FFT3D", - Input: []tf.Input{ - input, - }, + Type: "PaddingFIFOQueueV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. -type CropAndResizeGradBoxesAttr func(optionalAttr) +// DecodePngAttr is an optional argument to DecodePng. +type DecodePngAttr func(optionalAttr) -// CropAndResizeGradBoxesMethod sets the optional method attribute to value. +// DecodePngChannels sets the optional channels attribute to value. // -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodePngChannels(value int64) DecodePngAttr { return func(m optionalAttr) { - m["method"] = value + m["channels"] = value } } -// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. +// DecodePngDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_UINT8 +func DecodePngDtype(value tf.DataType) DecodePngAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Decode a PNG-encoded image to a uint8 or uint16 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the PNG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. +// +// This op also supports decoding JPEGs and non-animated GIFs since the interface +// is the same, though it is cleaner to use `tf.image.decode_image`. // // Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// contents: 0-D. The PNG-encoded image. // -// Returns A 2-D tensor of shape `[num_boxes, 4]`. -func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { if scope.Err() != nil { return } @@ -23979,9 +23964,9 @@ func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxe a(attrs) } opspec := tf.OpSpec{ - Type: "CropAndResizeGradBoxes", + Type: "DecodePng", Input: []tf.Input{ - grads, image, boxes, box_ind, + contents, }, Attrs: attrs, } @@ -23989,559 +23974,484 @@ func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxe return op.Output(0) } -// Saves tensors in V2 checkpoint format. +// Decode the first frame of a GIF-encoded image to a uint8 tensor. // -// By default, saves the named tensors in full. If the caller wishes to save -// specific slices of full tensors, "shape_and_slices" should be non-empty strings -// and correspondingly well-formed. +// GIF with frame or transparency compression are not supported +// convert animated GIF from compressed to uncompressed by: +// +// convert $src.gif -coalesce $dst.gif +// +// This op also supports decoding JPEGs and PNGs, though it is cleaner to use +// `tf.image.decode_image`. // // Arguments: -// prefix: Must have a single element. The prefix of the V2 checkpoint to which we -// write the tensors. -// tensor_names: shape {N}. The names of the tensors to be saved. -// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. -// Empty strings indicate that they are non-partitioned tensors. -// tensors: `N` tensors to save. +// contents: 0-D. The GIF-encoded image. // -// Returns the created operation. -func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SaveV2", + Type: "DecodeGif", Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), + contents, }, } - return scope.AddOperation(opspec) -} - -// StatsAggregatorHandleAttr is an optional argument to StatsAggregatorHandle. -type StatsAggregatorHandleAttr func(optionalAttr) - -// StatsAggregatorHandleContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StatsAggregatorHandleContainer(value string) StatsAggregatorHandleAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StatsAggregatorHandleSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StatsAggregatorHandleSharedName(value string) StatsAggregatorHandleAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Creates a statistics manager resource. -func StatsAggregatorHandle(scope *Scope, optional ...StatsAggregatorHandleAttr) (handle tf.Output) { +// Gets the next output from the given iterator. +// +// This operation is a synchronous version IteratorGetNext. It should only be used +// in situations where the iterator does not block the calling thread, or where +// the calling thread is not a member of the thread pool used to execute parallel +// operations (e.g. in eager mode). +func IteratorGetNextSync(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "StatsAggregatorHandle", - + Type: "IteratorGetNextSync", + Input: []tf.Input{ + iterator, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. -// -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV2", - Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, - }, + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNextSync", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return components } -// EncodeProtoAttr is an optional argument to EncodeProto. -type EncodeProtoAttr func(optionalAttr) +// SampleDistortedBoundingBoxV2Attr is an optional argument to SampleDistortedBoundingBoxV2. +type SampleDistortedBoundingBoxV2Attr func(optionalAttr) -// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value. -// If not specified, defaults to "local://" -func EncodeProtoDescriptorSource(value string) EncodeProtoAttr { +// SampleDistortedBoundingBoxV2Seed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxV2Seed(value int64) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { - m["descriptor_source"] = value + m["seed"] = value } } -// The op serializes protobuf messages provided in the input tensors. +// SampleDistortedBoundingBoxV2Seed2 sets the optional seed2 attribute to value. // -// The types of the tensors in `values` must match the schema for the -// fields specified in `field_names`. All the tensors in `values` must -// have a common shape prefix, *batch_shape*. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// SampleDistortedBoundingBoxV2AspectRatioRange sets the optional aspect_ratio_range attribute to value. // -// The `sizes` tensor specifies repeat counts for each field. The repeat -// count (last dimension) of a each tensor in `values` must be greater -// than or equal to corresponding repeat count in `sizes`. +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["aspect_ratio_range"] = value + } +} + +// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. // -// A `message_type` name must be provided to give context for the field -// names. The actual message descriptor can be looked up either in the -// linked-in descriptor pool or a filename provided by the caller using -// the `descriptor_source` attribute. +// value: The cropped area of the image must contain a fraction of the +// supplied image within in this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["area_range"] = value + } +} + +// SampleDistortedBoundingBoxV2MaxAttempts sets the optional max_attempts attribute to value. // -// The `descriptor_source` attribute selects a source of protocol -// descriptors to consult when looking up `message_type`. This may be a -// filename containing a serialized `FileDescriptorSet` message, -// or the special value `local://`, in which case only descriptors linked -// into the code will be searched; the filename can be on any filesystem -// accessible to TensorFlow. +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxV2MaxAttempts(value int64) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["max_attempts"] = value + } +} + +// SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. // -// You can build a `descriptor_source` file using the `--descriptor_set_out` -// and `--include_imports` options to the protocol compiler `protoc`. +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value + } +} + +// Generate a single randomly distorted bounding box for an image. // -// The `local://` database only covers descriptors linked into the -// code via C++ libraries, not Python imports. You can link in a proto descriptor -// by creating a cc_library target with alwayslink=1. +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. // -// There are a few special cases in the value mapping: +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. // -// Submessage and group fields must be pre-serialized as TensorFlow strings. +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// TensorFlow lacks support for unsigned int64s, so they must be -// represented as `tf.int64` with the same twos-complement bit pattern -// (the obvious way). +// For example, // -// Unsigned int32 values can be represented exactly with `tf.int64`, or -// with sign wrapping if the input is of type `tf.int32`. +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) +// +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) +// +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` +// +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. // // Arguments: -// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`. -// values: List of tensors containing values for the corresponding field. -// field_names: List of strings containing proto field names. -// message_type: Name of the proto message type to decode. +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. +// min_object_covered: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. // -// Returns Tensor of serialized protos with shape `batch_shape`. -func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) { +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, min_object_covered tf.Output, optional ...SampleDistortedBoundingBoxV2Attr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "EncodeProto", + Type: "SampleDistortedBoundingBoxV2", Input: []tf.Input{ - sizes, tf.OutputList(values), + image_size, bounding_boxes, min_object_covered, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Creates a TensorArray for storing the gradients of values in the given handle. +// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. +type ExtractGlimpseAttr func(optionalAttr) + +// ExtractGlimpseCentered sets the optional centered attribute to value. // -// If the given TensorArray gradient already exists, returns a reference to it. +// value: indicates if the offset coordinates are centered relative to +// the image, in which case the (0, 0) offset is relative to the center +// of the input images. If false, the (0,0) offset corresponds to the +// upper left corner of the input images. +// If not specified, defaults to true +func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["centered"] = value + } +} + +// ExtractGlimpseNormalized sets the optional normalized attribute to value. // -// Locks the size of the original TensorArray by disabling its dynamic size flag. +// value: indicates if the offset coordinates are normalized. +// If not specified, defaults to true +func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["normalized"] = value + } +} + +// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. // -// **A note about the input flow_in:** +// value: indicates if the noise should be generated using a +// uniform distribution or a Gaussian distribution. +// If not specified, defaults to true +func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["uniform_noise"] = value + } +} + +// Extracts a glimpse from the input tensor. // -// The handle flow_in forces the execution of the gradient lookup to occur -// only after certain other operations have occurred. For example, when -// the forward TensorArray is dynamically sized, writes to this TensorArray -// may resize the object. The gradient TensorArray is statically sized based -// on the size of the forward TensorArray when this operation executes. -// Furthermore, the size of the forward TensorArray is frozen by this call. -// As a result, the flow is used to ensure that the call to generate the gradient -// TensorArray only happens after all writes are executed. +// Returns a set of windows called glimpses extracted at location +// `offsets` from the input tensor. If the windows only partially +// overlaps the inputs, the non overlapping areas will be filled with +// random noise. // -// In the case of dynamically sized TensorArrays, gradient computation should -// only be performed on read operations that have themselves been chained via -// flow to occur only after all writes have executed. That way the final size -// of the forward TensorArray is known when this operation is called. -// -// **A note about the source attribute:** -// -// TensorArray gradient calls use an accumulator TensorArray object. If -// multiple gradients are calculated and run in the same session, the multiple -// gradient nodes may accidentally flow through the same accumulator TensorArray. -// This double counts and generally breaks the TensorArray gradient flow. +// The result is a 4-D tensor of shape `[batch_size, glimpse_height, +// glimpse_width, channels]`. The channels and batch dimensions are the +// same as that of the input tensor. The height and width of the output +// windows are specified in the `size` parameter. // -// The solution is to identify which gradient call this particular -// TensorArray gradient is being called in. This is performed by identifying -// a unique string (e.g. "gradients", "gradients_1", ...) from the input -// gradient Tensor's name. This string is used as a suffix when creating -// the TensorArray gradient object here (the attribute `source`). +// The argument `normalized` and `centered` controls how the windows are built: // -// The attribute `source` is added as a suffix to the forward TensorArray's -// name when performing the creation / lookup, so that each separate gradient -// calculation gets its own TensorArray accumulator. +// * If the coordinates are normalized but not centered, 0.0 and 1.0 +// correspond to the minimum and maximum of each height and width +// dimension. +// * If the coordinates are both normalized and centered, they range from +// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper +// left corner, the lower right corner is located at (1.0, 1.0) and the +// center is at (0, 0). +// * If the coordinates are not normalized they are interpreted as +// numbers of pixels. // // Arguments: -// handle: The handle to the forward TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// source: The gradient source string, used to decide which gradient TensorArray -// to return. -func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { +// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. +// size: A 1-D tensor of 2 elements containing the size of the glimpses +// to extract. The glimpse height must be specified first, following +// by the glimpse width. +// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing +// the y, x locations of the center of each window. +// +// Returns A tensor representing the glimpses `[batch_size, +// glimpse_height, glimpse_width, channels]`. +func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayGradV3", + Type: "ExtractGlimpse", Input: []tf.Input{ - handle, flow_in, + input, size, offsets, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// DecodeProtoV2Attr is an optional argument to DecodeProtoV2. -type DecodeProtoV2Attr func(optionalAttr) - -// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value. +// A container for an iterator resource. // -// value: Either the special value `local://` or a path to a file containing -// a serialized `FileDescriptorSet`. -// If not specified, defaults to "local://" -func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["descriptor_source"] = value +// Returns A handle to the iterator that can be passed to a "MakeIterator" +// or "IteratorGetNext" op. +func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return } -} + attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "Iterator", -// DecodeProtoV2MessageFormat sets the optional message_format attribute to value. -// -// value: Either `binary` or `text`. -// If not specified, defaults to "binary" -func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["message_format"] = value + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// DecodeProtoV2Sanitize sets the optional sanitize attribute to value. +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) + +// CropAndResizeGradImageMethod sets the optional method attribute to value. // -// value: Whether to sanitize the result or not. -// If not specified, defaults to false -func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr { +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { return func(m optionalAttr) { - m["sanitize"] = value + m["method"] = value } } -// The op extracts fields from a serialized protocol buffers message into tensors. -// -// The `decode_proto` op extracts fields from a serialized protocol buffers -// message into tensors. The fields in `field_names` are decoded and converted -// to the corresponding `output_types` if possible. -// -// A `message_type` name must be provided to give context for the field -// names. The actual message descriptor can be looked up either in the -// linked-in descriptor pool or a filename provided by the caller using -// the `descriptor_source` attribute. -// -// Each output tensor is a dense tensor. This means that it is padded to -// hold the largest number of repeated elements seen in the input -// minibatch. (The shape is also padded by one to prevent zero-sized -// dimensions). The actual repeat counts for each example in the -// minibatch can be found in the `sizes` output. In many cases the output -// of `decode_proto` is fed immediately into tf.squeeze if missing values -// are not a concern. When using tf.squeeze, always pass the squeeze -// dimension explicitly to avoid surprises. -// -// For the most part, the mapping between Proto field types and -// TensorFlow dtypes is straightforward. However, there are a few -// special cases: -// -// - A proto field that contains a submessage or group can only be converted -// to `DT_STRING` (the serialized submessage). This is to reduce the -// complexity of the API. The resulting string can be used as input -// to another instance of the decode_proto op. -// -// - TensorFlow lacks support for unsigned integers. The ops represent uint64 -// types as a `DT_INT64` with the same twos-complement bit pattern -// (the obvious way). Unsigned int32 values can be represented exactly by -// specifying type `DT_INT64`, or using twos-complement if the caller -// specifies `DT_INT32` in the `output_types` attribute. -// -// The `descriptor_source` attribute selects a source of protocol -// descriptors to consult when looking up `message_type`. This may be a -// filename containing a serialized `FileDescriptorSet` message, -// or the special value `local://`, in which case only descriptors linked -// into the code will be searched; the filename can be on any filesystem -// accessible to TensorFlow. -// -// You can build a `descriptor_source` file using the `--descriptor_set_out` -// and `--include_imports` options to the protocol compiler `protoc`. -// -// The `local://` database only covers descriptors linked into the -// code via C++ libraries, not Python imports. You can link in a proto descriptor -// by creating a cc_library target with alwayslink=1. -// -// Both binary and text proto serializations are supported, and can be -// chosen using the `format` attribute. +// Computes the gradient of the crop_and_resize op wrt the input image tensor. // // Arguments: -// bytes: Tensor of serialized protos with shape `batch_shape`. -// message_type: Name of the proto message type to decode. -// field_names: List of strings containing proto field names. -// output_types: List of TF types to use for the respective field in field_names. +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. // -// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`. -// Each entry is the number of values found for the corresponding field. -// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field. -// `values[i]` has datatype `output_types[i]` -// and shape `[batch_shape, max(sizes[...,i])]`. -func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) { +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types} + attrs := map[string]interface{}{"T": T} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeProtoV2", + Type: "CropAndResizeGradImage", Input: []tf.Input{ - bytes, + grads, boxes, box_ind, image_size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - sizes = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("DecodeProtoV2", err) - return - } - return sizes, values + return op.Output(0) } -// Creates a dataset that splits a SparseTensor into elements row-wise. -func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseTensorSliceDataset", - Input: []tf.Input{ - indices, values, dense_shape, - }, +// ShuffleDatasetAttr is an optional argument to ShuffleDataset. +type ShuffleDatasetAttr func(optionalAttr) + +// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. +// +// value: If true, each iterator over this dataset will be given +// a different pseudorandomly generated seed, based on a sequence seeded by the +// `seed` and `seed2` inputs. If false, each iterator will be given the same +// seed, and repeated iteration over this dataset will yield the exact same +// sequence of results. +// If not specified, defaults to true +func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { + return func(m optionalAttr) { + m["reshuffle_each_iteration"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns x / y element-wise for real types. +// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. // -// If `x` and `y` are reals, this will return the floating-point division. +// Arguments: // -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// +// +func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RealDiv", + Type: "ShuffleDataset", Input: []tf.Input{ - x, y, + input_dataset, buffer_size, seed, seed2, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Adds v into specified rows of x. +// 3D fast Fourier transform. // -// Computes y = x; y[i, :] += v; return y. +// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 +// dimensions of `input`. // // Arguments: -// x: A `Tensor` of type T. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// input: A complex64 tensor. // -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceAdd(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fftn with 3 dimensions. +// @end_compatibility +func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InplaceAdd", + Type: "FFT3D", Input: []tf.Input{ - x, i, v, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Restore a Reader to its initial clean state. -// -// Arguments: -// reader_handle: Handle to a Reader. -// -// Returns the created operation. -func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderResetV2", - Input: []tf.Input{ - reader_handle, - }, - } - return scope.AddOperation(opspec) -} - -// RpcAttr is an optional argument to Rpc. -type RpcAttr func(optionalAttr) - -// RpcProtocol sets the optional protocol attribute to value. -// -// value: RPC protocol to use. Empty string means use the default protocol. -// Options include 'grpc'. -// If not specified, defaults to "" -func RpcProtocol(value string) RpcAttr { - return func(m optionalAttr) { - m["protocol"] = value - } -} - -// RpcFailFast sets the optional fail_fast attribute to value. -// -// value: `boolean`. If `true` (default), then failures to connect -// (i.e., the server does not immediately respond) cause an RPC failure. -// If not specified, defaults to true -func RpcFailFast(value bool) RpcAttr { - return func(m optionalAttr) { - m["fail_fast"] = value - } -} +// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. +type CropAndResizeGradBoxesAttr func(optionalAttr) -// RpcTimeoutInMs sets the optional timeout_in_ms attribute to value. +// CropAndResizeGradBoxesMethod sets the optional method attribute to value. // -// value: `int`. If `0` (default), then the kernel will run the RPC -// request and only time out if the RPC deadline passes or the session times out. -// If this value is greater than `0`, then the op will raise an exception if -// the RPC takes longer than `timeout_in_ms`. -// If not specified, defaults to 0 -func RpcTimeoutInMs(value int64) RpcAttr { +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { return func(m optionalAttr) { - m["timeout_in_ms"] = value + m["method"] = value } } -// Perform batches of RPC requests. -// -// This op asynchronously performs either a single RPC request, or a batch -// of requests. RPC requests are defined by three main parameters: -// -// - `address` (the host+port or BNS address of the request) -// - `method` (the RPC method name for the request) -// - `request` (the serialized proto string, or vector of strings, -// of the RPC request argument). -// -// For example, if you have an RPC service running on port localhost:2345, -// and its interface is configured with the following proto declaration: -// -// ``` -// service MyService { -// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { -// } -// }; -// ``` -// -// then call this op with arguments: -// -// ``` -// address = "localhost:2345" -// method = "MyService/MyMethod" -// ``` -// -// The `request` tensor is a string tensor representing serialized `MyRequestProto` -// strings; and the output string tensor `response` will have the same shape -// and contain (upon successful completion) corresponding serialized -// `MyResponseProto` strings. -// -// For example, to send a single, empty, `MyRequestProto`, call -// this op with `request = ""`. To send 5 **parallel** empty requests, -// call this op with `request = ["", "", "", "", ""]`. -// -// More generally, one can create a batch of `MyRequestProto` serialized protos -// from regular batched tensors using the `encode_proto` op, and convert -// the response `MyResponseProto` serialized protos to batched tensors -// using the `decode_proto` op. -// -// **NOTE** Working with serialized proto strings is faster than instantiating -// actual proto objects in memory, so no performance degradation is expected -// compared to writing custom kernels for this workflow. -// -// If the connection fails or the remote worker returns an error -// status, the op reraises this exception locally. -// -// See the `TryRpc` op if you prefer to handle RPC failures manually in the graph. +// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. // // Arguments: -// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `method` and `request`. -// method: `0-D` or `1-D`. The method address on the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `request`. -// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `method`. +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. // -// Returns Same shape as `request`. Serialized proto strings: the rpc responses. -func Rpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...RpcAttr) (response tf.Output) { +// Returns A 2-D tensor of shape `[num_boxes, 4]`. +func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -24550,9 +24460,9 @@ func Rpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, o a(attrs) } opspec := tf.OpSpec{ - Type: "Rpc", + Type: "CropAndResizeGradBoxes", Input: []tf.Input{ - address, method, request, + grads, image, boxes, box_ind, }, Attrs: attrs, } @@ -24560,483 +24470,726 @@ func Rpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, o return op.Output(0) } -// OrderedMapStageAttr is an optional argument to OrderedMapStage. -type OrderedMapStageAttr func(optionalAttr) - -// OrderedMapStageCapacity sets the optional capacity attribute to value. +// Saves tensors in V2 checkpoint format. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// By default, saves the named tensors in full. If the caller wishes to save +// specific slices of full tensors, "shape_and_slices" should be non-empty strings +// and correspondingly well-formed. // -// REQUIRES: value >= 0 -func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Arguments: +// prefix: Must have a single element. The prefix of the V2 checkpoint to which we +// write the tensors. +// tensor_names: shape {N}. The names of the tensors to be saved. +// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. +// Empty strings indicate that they are non-partitioned tensors. +// tensors: `N` tensors to save. // -// REQUIRES: value >= 0 -func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// Returns the created operation. +func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SaveV2", + Input: []tf.Input{ + prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), + }, } + return scope.AddOperation(opspec) } -// OrderedMapStageContainer sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. +// StatsAggregatorHandleAttr is an optional argument to StatsAggregatorHandle. +type StatsAggregatorHandleAttr func(optionalAttr) + +// StatsAggregatorHandleContainer sets the optional container attribute to value. // If not specified, defaults to "" -func OrderedMapStageContainer(value string) OrderedMapStageAttr { +func StatsAggregatorHandleContainer(value string) StatsAggregatorHandleAttr { return func(m optionalAttr) { m["container"] = value } } -// OrderedMapStageSharedName sets the optional shared_name attribute to value. -// -// value: It is necessary to match this name to the matching Unstage Op. +// StatsAggregatorHandleSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func OrderedMapStageSharedName(value string) OrderedMapStageAttr { +func StatsAggregatorHandleSharedName(value string) StatsAggregatorHandleAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// Stage (key, values) in the underlying container which behaves like a ordered -// -// associative container. Elements are ordered by key. -// -// Arguments: -// key: int64 -// -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. -// -// -// Returns the created operation. -func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { +// Creates a statistics manager resource. +func StatsAggregatorHandle(scope *Scope, optional ...StatsAggregatorHandleAttr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapStage", - Input: []tf.Input{ - key, indices, tf.OutputList(values), - }, + Type: "StatsAggregatorHandle", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// StackPushV2Attr is an optional argument to StackPushV2. -type StackPushV2Attr func(optionalAttr) - -// StackPushV2SwapMemory sets the optional swap_memory attribute to value. +// Greedily selects a subset of bounding boxes in descending order of score, // -// value: Swap `elem` to CPU. Default to false. -// If not specified, defaults to false -func StackPushV2SwapMemory(value bool) StackPushV2Attr { - return func(m optionalAttr) { - m["swap_memory"] = value - } -} - -// Push an element onto the stack. +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: -// handle: The handle to a stack. -// elem: The tensor to be pushed onto the stack. +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. // -// Returns The same tensor as the input 'elem'. -func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StackPushV2", + Type: "NonMaxSuppressionV2", Input: []tf.Input{ - handle, elem, + boxes, scores, max_output_size, iou_threshold, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that concatenates `input_dataset` with `another_dataset`. -func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Computes the matrix logarithm of one or more square matrices: +// +// +// log(exp(A)) = A +// +// This op is only defined for complex matrices. If A is positive-definite and +// real, then casting to a complex matrix, taking the logarithm and casting back +// to a real matrix will give the correct result. +// +// This function computes the matrix logarithm using the Schur-Parlett algorithm. +// Details of the algorithm can be found in Section 11.6.2 of: +// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. +// ISBN 978-0-898716-46-7. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the exponential for all input submatrices `[..., :, :]`. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +// +// @compatibility(scipy) +// Equivalent to scipy.linalg.logm +// @end_compatibility +func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ConcatenateDataset", + Type: "MatrixLogarithm", Input: []tf.Input{ - input_dataset, another_dataset, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Adds a value to the current value of a variable. +// EncodeProtoAttr is an optional argument to EncodeProto. +type EncodeProtoAttr func(optionalAttr) + +// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value. +// If not specified, defaults to "local://" +func EncodeProtoDescriptorSource(value string) EncodeProtoAttr { + return func(m optionalAttr) { + m["descriptor_source"] = value + } +} + +// The op serializes protobuf messages provided in the input tensors. // -// Any ReadVariableOp with a control dependency on this op is guaranteed to -// see the incremented value or a subsequent newer one. +// The types of the tensors in `values` must match the schema for the +// fields specified in `field_names`. All the tensors in `values` must +// have a common shape prefix, *batch_shape*. +// +// The `sizes` tensor specifies repeat counts for each field. The repeat +// count (last dimension) of a each tensor in `values` must be greater +// than or equal to corresponding repeat count in `sizes`. +// +// A `message_type` name must be provided to give context for the field +// names. The actual message descriptor can be looked up either in the +// linked-in descriptor pool or a filename provided by the caller using +// the `descriptor_source` attribute. +// +// The `descriptor_source` attribute selects a source of protocol +// descriptors to consult when looking up `message_type`. This may be a +// filename containing a serialized `FileDescriptorSet` message, +// or the special value `local://`, in which case only descriptors linked +// into the code will be searched; the filename can be on any filesystem +// accessible to TensorFlow. +// +// You can build a `descriptor_source` file using the `--descriptor_set_out` +// and `--include_imports` options to the protocol compiler `protoc`. +// +// The `local://` database only covers descriptors linked into the +// code via C++ libraries, not Python imports. You can link in a proto descriptor +// by creating a cc_library target with alwayslink=1. +// +// There are a few special cases in the value mapping: +// +// Submessage and group fields must be pre-serialized as TensorFlow strings. +// +// TensorFlow lacks support for unsigned int64s, so they must be +// represented as `tf.int64` with the same twos-complement bit pattern +// (the obvious way). +// +// Unsigned int32 values can be represented exactly with `tf.int64`, or +// with sign wrapping if the input is of type `tf.int32`. // // Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. +// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`. +// values: List of tensors containing values for the corresponding field. +// field_names: List of strings containing proto field names. +// message_type: Name of the proto message type to decode. // -// Returns the created operation. -func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { +// Returns Tensor of serialized protos with shape `batch_shape`. +func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "AssignAddVariableOp", + Type: "EncodeProto", Input: []tf.Input{ - resource, value, + sizes, tf.OutputList(values), }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Records the latency of producing `input_dataset` elements in a StatsAggregator. -func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Creates a TensorArray for storing the gradients of values in the given handle. +// +// If the given TensorArray gradient already exists, returns a reference to it. +// +// Locks the size of the original TensorArray by disabling its dynamic size flag. +// +// **A note about the input flow_in:** +// +// The handle flow_in forces the execution of the gradient lookup to occur +// only after certain other operations have occurred. For example, when +// the forward TensorArray is dynamically sized, writes to this TensorArray +// may resize the object. The gradient TensorArray is statically sized based +// on the size of the forward TensorArray when this operation executes. +// Furthermore, the size of the forward TensorArray is frozen by this call. +// As a result, the flow is used to ensure that the call to generate the gradient +// TensorArray only happens after all writes are executed. +// +// In the case of dynamically sized TensorArrays, gradient computation should +// only be performed on read operations that have themselves been chained via +// flow to occur only after all writes have executed. That way the final size +// of the forward TensorArray is known when this operation is called. +// +// **A note about the source attribute:** +// +// TensorArray gradient calls use an accumulator TensorArray object. If +// multiple gradients are calculated and run in the same session, the multiple +// gradient nodes may accidentally flow through the same accumulator TensorArray. +// This double counts and generally breaks the TensorArray gradient flow. +// +// The solution is to identify which gradient call this particular +// TensorArray gradient is being called in. This is performed by identifying +// a unique string (e.g. "gradients", "gradients_1", ...) from the input +// gradient Tensor's name. This string is used as a suffix when creating +// the TensorArray gradient object here (the attribute `source`). +// +// The attribute `source` is added as a suffix to the forward TensorArray's +// name when performing the creation / lookup, so that each separate gradient +// calculation gets its own TensorArray accumulator. +// +// Arguments: +// handle: The handle to the forward TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// source: The gradient source string, used to decide which gradient TensorArray +// to return. +func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"source": source} opspec := tf.OpSpec{ - Type: "LatencyStatsDataset", + Type: "TensorArrayGradV3", Input: []tf.Input{ - input_dataset, tag, + handle, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// MapSizeAttr is an optional argument to MapSize. -type MapSizeAttr func(optionalAttr) +// DecodeProtoV2Attr is an optional argument to DecodeProtoV2. +type DecodeProtoV2Attr func(optionalAttr) -// MapSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value. // -// REQUIRES: value >= 0 -func MapSizeCapacity(value int64) MapSizeAttr { +// value: Either the special value `local://` or a path to a file containing +// a serialized `FileDescriptorSet`. +// If not specified, defaults to "local://" +func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["descriptor_source"] = value } } -// MapSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// DecodeProtoV2MessageFormat sets the optional message_format attribute to value. // -// REQUIRES: value >= 0 -func MapSizeMemoryLimit(value int64) MapSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapSizeContainer(value string) MapSizeAttr { +// value: Either `binary` or `text`. +// If not specified, defaults to "binary" +func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr { return func(m optionalAttr) { - m["container"] = value + m["message_format"] = value } } -// MapSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapSizeSharedName(value string) MapSizeAttr { +// DecodeProtoV2Sanitize sets the optional sanitize attribute to value. +// +// value: Whether to sanitize the result or not. +// If not specified, defaults to false +func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["sanitize"] = value } } -// Op returns the number of elements in the underlying container. -func MapSize(scope *Scope, dtypes []tf.DataType, optional ...MapSizeAttr) (size tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} +// The op extracts fields from a serialized protocol buffers message into tensors. +// +// The `decode_proto` op extracts fields from a serialized protocol buffers +// message into tensors. The fields in `field_names` are decoded and converted +// to the corresponding `output_types` if possible. +// +// A `message_type` name must be provided to give context for the field +// names. The actual message descriptor can be looked up either in the +// linked-in descriptor pool or a filename provided by the caller using +// the `descriptor_source` attribute. +// +// Each output tensor is a dense tensor. This means that it is padded to +// hold the largest number of repeated elements seen in the input +// minibatch. (The shape is also padded by one to prevent zero-sized +// dimensions). The actual repeat counts for each example in the +// minibatch can be found in the `sizes` output. In many cases the output +// of `decode_proto` is fed immediately into tf.squeeze if missing values +// are not a concern. When using tf.squeeze, always pass the squeeze +// dimension explicitly to avoid surprises. +// +// For the most part, the mapping between Proto field types and +// TensorFlow dtypes is straightforward. However, there are a few +// special cases: +// +// - A proto field that contains a submessage or group can only be converted +// to `DT_STRING` (the serialized submessage). This is to reduce the +// complexity of the API. The resulting string can be used as input +// to another instance of the decode_proto op. +// +// - TensorFlow lacks support for unsigned integers. The ops represent uint64 +// types as a `DT_INT64` with the same twos-complement bit pattern +// (the obvious way). Unsigned int32 values can be represented exactly by +// specifying type `DT_INT64`, or using twos-complement if the caller +// specifies `DT_INT32` in the `output_types` attribute. +// +// The `descriptor_source` attribute selects a source of protocol +// descriptors to consult when looking up `message_type`. This may be a +// filename containing a serialized `FileDescriptorSet` message, +// or the special value `local://`, in which case only descriptors linked +// into the code will be searched; the filename can be on any filesystem +// accessible to TensorFlow. +// +// You can build a `descriptor_source` file using the `--descriptor_set_out` +// and `--include_imports` options to the protocol compiler `protoc`. +// +// The `local://` database only covers descriptors linked into the +// code via C++ libraries, not Python imports. You can link in a proto descriptor +// by creating a cc_library target with alwayslink=1. +// +// Both binary and text proto serializations are supported, and can be +// chosen using the `format` attribute. +// +// Arguments: +// bytes: Tensor of serialized protos with shape `batch_shape`. +// message_type: Name of the proto message type to decode. +// field_names: List of strings containing proto field names. +// output_types: List of TF types to use for the respective field in field_names. +// +// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`. +// Each entry is the number of values found for the corresponding field. +// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field. +// `values[i]` has datatype `output_types[i]` +// and shape `[batch_shape, max(sizes[...,i])]`. +func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapSize", - + Type: "DecodeProtoV2", + Input: []tf.Input{ + bytes, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + sizes = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("DecodeProtoV2", err) + return + } + return sizes, values } -// Convert JSON-encoded Example records to binary protocol buffer strings. -// -// This op translates a tensor containing Example records, encoded using -// the [standard JSON -// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), -// into a tensor containing the same records encoded as binary protocol -// buffers. The resulting tensor can then be fed to any of the other -// Example-parsing ops. -// -// Arguments: -// json_examples: Each string is a JSON object serialized according to the JSON -// mapping of the Example proto. -// -// Returns Each string is a binary Example protocol buffer corresponding -// to the respective element of `json_examples`. -func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { +// Creates a dataset that splits a SparseTensor into elements row-wise. +func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DecodeJSONExample", + Type: "SparseTensorSliceDataset", Input: []tf.Input{ - json_examples, + indices, values, dense_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseToDenseAttr is an optional argument to SparseToDense. -type SparseToDenseAttr func(optionalAttr) - -// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. -// -// value: If true, indices are checked to make sure they are sorted in -// lexicographic order and that there are no repeats. -// If not specified, defaults to true -func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Converts a sparse representation into a dense tensor. -// -// Builds an array `dense` with shape `output_shape` such that -// -// ``` -// # If sparse_indices is scalar -// dense[i] = (i == sparse_indices ? sparse_values : default_value) -// -// # If sparse_indices is a vector, then for each i -// dense[sparse_indices[i]] = sparse_values[i] -// -// # If sparse_indices is an n by d matrix, then for each i in [0, n) -// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] -// ``` -// -// All other values in `dense` are set to `default_value`. If `sparse_values` is a -// scalar, all sparse indices are set to this single value. -// -// Indices should be sorted in lexicographic order, and indices must not -// contain any repeats. If `validate_indices` is true, these properties -// are checked during execution. +// Returns x / y element-wise for real types. // -// Arguments: -// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete -// index where `sparse_values[i]` will be placed. -// output_shape: 1-D. Shape of the dense output tensor. -// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, -// or a scalar value to be used for all sparse indices. -// default_value: Scalar value to set for indices not specified in -// `sparse_indices`. +// If `x` and `y` are reals, this will return the floating-point division. // -// Returns Dense output tensor of shape `output_shape`. -func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseToDense", + Type: "RealDiv", Input: []tf.Input{ - sparse_indices, output_shape, sparse_values, default_value, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. -// -// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the -// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each -// input channel is processed independently of the others with its own structuring -// function. The `output` tensor has shape -// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output -// tensor depend on the `padding` algorithm. We currently only support the default -// "NHWC" `data_format`. -// -// In detail, the grayscale morphological 2-D dilation is the max-sum correlation -// (for consistency with `conv2d`, we use unmirrored filters): -// -// output[b, y, x, c] = -// max_{dy, dx} input[b, -// strides[1] * y + rates[1] * dy, -// strides[2] * x + rates[2] * dx, -// c] + -// filter[dy, dx, c] -// -// Max-pooling is a special case when the filter has size equal to the pooling -// kernel size and contains all zeros. +// Adds v into specified rows of x. // -// Note on duality: The dilation of `input` by the `filter` is equal to the -// negation of the erosion of `-input` by the reflected `filter`. +// Computes y = x; y[i, :] += v; return y. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// strides: The stride of the sliding window for each dimension of the input -// tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: The input stride for atrous morphological dilation. Must be: -// `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// x: A `Tensor` of type T. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. // -// Returns 4-D with shape `[batch, out_height, out_width, depth]`. -func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceAdd(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "Dilation2D", + Type: "InplaceAdd", Input: []tf.Input{ - input, filter, + x, i, v, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts the given variant tensor to an iterator and stores it in the given resource. +// Restore a Reader to its initial clean state. // // Arguments: -// resource_handle: A handle to an iterator resource. -// serialized: A variant tensor storing the state of the iterator contained in the -// resource. +// reader_handle: Handle to a Reader. // // Returns the created operation. -func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { +func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DeserializeIterator", + Type: "ReaderResetV2", Input: []tf.Input{ - resource_handle, serialized, + reader_handle, }, } return scope.AddOperation(opspec) } -// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2. -type TensorArrayConcatV2Attr func(optionalAttr) +// RpcAttr is an optional argument to Rpc. +type RpcAttr func(optionalAttr) -// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. -// If not specified, defaults to -func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr { +// RpcProtocol sets the optional protocol attribute to value. +// +// value: RPC protocol to use. Empty string means use the default protocol. +// Options include 'grpc'. +// If not specified, defaults to "" +func RpcProtocol(value string) RpcAttr { return func(m optionalAttr) { - m["element_shape_except0"] = value + m["protocol"] = value } } -// Deprecated. Use TensorArrayConcatV3 -func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) +// RpcFailFast sets the optional fail_fast attribute to value. +// +// value: `boolean`. If `true` (default), then failures to connect +// (i.e., the server does not immediately respond) cause an RPC failure. +// If not specified, defaults to true +func RpcFailFast(value bool) RpcAttr { + return func(m optionalAttr) { + m["fail_fast"] = value + } +} + +// RpcTimeoutInMs sets the optional timeout_in_ms attribute to value. +// +// value: `int`. If `0` (default), then the kernel will run the RPC +// request and only time out if the RPC deadline passes or the session times out. +// If this value is greater than `0`, then the op will raise an exception if +// the RPC takes longer than `timeout_in_ms`. +// If not specified, defaults to 0 +func RpcTimeoutInMs(value int64) RpcAttr { + return func(m optionalAttr) { + m["timeout_in_ms"] = value + } +} + +// Perform batches of RPC requests. +// +// This op asynchronously performs either a single RPC request, or a batch +// of requests. RPC requests are defined by three main parameters: +// +// - `address` (the host+port or BNS address of the request) +// - `method` (the RPC method name for the request) +// - `request` (the serialized proto string, or vector of strings, +// of the RPC request argument). +// +// For example, if you have an RPC service running on port localhost:2345, +// and its interface is configured with the following proto declaration: +// +// ``` +// service MyService { +// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { +// } +// }; +// ``` +// +// then call this op with arguments: +// +// ``` +// address = "localhost:2345" +// method = "MyService/MyMethod" +// ``` +// +// The `request` tensor is a string tensor representing serialized `MyRequestProto` +// strings; and the output string tensor `response` will have the same shape +// and contain (upon successful completion) corresponding serialized +// `MyResponseProto` strings. +// +// For example, to send a single, empty, `MyRequestProto`, call +// this op with `request = ""`. To send 5 **parallel** empty requests, +// call this op with `request = ["", "", "", "", ""]`. +// +// More generally, one can create a batch of `MyRequestProto` serialized protos +// from regular batched tensors using the `encode_proto` op, and convert +// the response `MyResponseProto` serialized protos to batched tensors +// using the `decode_proto` op. +// +// **NOTE** Working with serialized proto strings is faster than instantiating +// actual proto objects in memory, so no performance degradation is expected +// compared to writing custom kernels for this workflow. +// +// If the connection fails or the remote worker returns an error +// status, the op reraises this exception locally. +// +// See the `TryRpc` op if you prefer to handle RPC failures manually in the graph. +// +// Arguments: +// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `method` and `request`. +// method: `0-D` or `1-D`. The method address on the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `request`. +// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `method`. +// +// Returns Same shape as `request`. Serialized proto strings: the rpc responses. +func Rpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...RpcAttr) (response tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayConcatV2", + Type: "Rpc", Input: []tf.Input{ - handle, flow_in, + address, method, request, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Creates a dataset that batches and pads `batch_size` elements from the input. +// OrderedMapStageAttr is an optional argument to OrderedMapStage. +type OrderedMapStageAttr func(optionalAttr) + +// OrderedMapStageCapacity sets the optional capacity attribute to value. +// +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapStageContainer sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func OrderedMapStageContainer(value string) OrderedMapStageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapStageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func OrderedMapStageSharedName(value string) OrderedMapStageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Stage (key, values) in the underlying container which behaves like a ordered +// +// associative container. Elements are ordered by key. // // Arguments: +// key: int64 // -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// padded_shapes: A list of int64 tensors representing the desired padded shapes -// of the corresponding output components. These shapes may be partially -// specified, using `-1` to indicate that a particular dimension should be -// padded to the maximum size of all batch elements. -// padding_values: A list of scalars containing the padding value to use for -// each of the outputs. +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. // -func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { +// +// Returns the created operation. +func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_shapes": output_shapes} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "PaddedBatchDataset", + Type: "OrderedMapStage", Input: []tf.Input{ - input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), + key, indices, tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset that batches input elements into a SparseTensor. +// StackPushV2Attr is an optional argument to StackPushV2. +type StackPushV2Attr func(optionalAttr) + +// StackPushV2SwapMemory sets the optional swap_memory attribute to value. // -// Arguments: -// input_dataset: A handle to an input dataset. Must have a single component. -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// row_shape: A vector representing the dense shape of each row in the produced -// SparseTensor. The shape may be partially specified, using `-1` to indicate -// that a particular dimension should use the maximum size of all batch elements. +// value: Swap `elem` to CPU. Default to false. +// If not specified, defaults to false +func StackPushV2SwapMemory(value bool) StackPushV2Attr { + return func(m optionalAttr) { + m["swap_memory"] = value + } +} + +// Push an element onto the stack. // +// Arguments: +// handle: The handle to a stack. +// elem: The tensor to be pushed onto the stack. // -func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns The same tensor as the input 'elem'. +func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DenseToSparseBatchDataset", + Type: "StackPushV2", Input: []tf.Input{ - input_dataset, batch_size, row_shape, + handle, elem, }, Attrs: attrs, } @@ -25044,18 +25197,16 @@ func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size return op.Output(0) } -// Deprecated. Use TensorArrayGradV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 -func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { +// Creates a dataset that concatenates `input_dataset` with `another_dataset`. +func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayGradV2", + Type: "ConcatenateDataset", Input: []tf.Input{ - handle, flow_in, + input_dataset, another_dataset, }, Attrs: attrs, } @@ -25063,120 +25214,39 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source return op.Output(0) } -// Return substrings from `Tensor` of strings. -// -// For each string in the input `Tensor`, creates a substring starting at index -// `pos` with a total length of `len`. -// -// If `len` defines a substring that would extend beyond the length of the input -// string, then as many characters as possible are used. -// -// If `pos` is negative or specifies a character index larger than any of the input -// strings, then an `InvalidArgumentError` is thrown. -// -// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on -// Op creation. -// -// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about -// broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -// -// --- -// -// Examples -// -// Using scalar `pos` and `len`: -// -// ```python -// input = [b'Hello', b'World'] -// position = 1 -// length = 3 +// Adds a value to the current value of a variable. // -// output = [b'ell', b'orl'] -// ``` +// Any ReadVariableOp with a control dependency on this op is guaranteed to +// see the incremented value or a subsequent newer one. // -// Using `pos` and `len` with same shape as `input`: +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. // -// ```python -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen']] -// position = [[1, 2, 3], -// [1, 2, 3], -// [1, 2, 3]] -// length = [[2, 3, 4], -// [4, 3, 2], -// [5, 5, 5]] -// -// output = [[b'en', b'eve', b'lve'], -// [b'hirt', b'urt', b'te'], -// [b'ixtee', b'vente', b'hteen']] -// ``` -// -// Broadcasting `pos` and `len` onto `input`: -// -// ``` -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen'], -// [b'nineteen', b'twenty', b'twentyone']] -// position = [1, 2, 3] -// length = [1, 2, 3] -// -// output = [[b'e', b'ev', b'lve'], -// [b'h', b'ur', b'tee'], -// [b'i', b've', b'hte'], -// [b'i', b'en', b'nty']] -// ``` -// -// Broadcasting `input` onto `pos` and `len`: -// -// ``` -// input = b'thirteen' -// position = [1, 5, 7] -// length = [3, 2, 1] -// -// output = [b'hir', b'ee', b'n'] -// ``` -// -// Arguments: -// input: Tensor of strings -// pos: Scalar defining the position of first character in each substring -// len: Scalar defining the number of characters to include in each substring -// -// Returns Tensor of substrings -func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output) (output tf.Output) { +// Returns the created operation. +func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Substr", + Type: "AssignAddVariableOp", Input: []tf.Input{ - input, pos, len, + resource, value, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a Dataset that returns pseudorandom numbers. -// -// Arguments: -// seed: A scalar seed for the random number generator. If either seed or -// seed2 is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// -// -func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Records the latency of producing `input_dataset` elements in a StatsAggregator. +func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "RandomDataset", + Type: "LatencyStatsDataset", Input: []tf.Input{ - seed, seed2, + input_dataset, tag, }, Attrs: attrs, } @@ -25184,277 +25254,237 @@ func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types [ return op.Output(0) } -// Creates a dataset that shuffles and repeats elements from `input_dataset` +// Runs multiple additive regression ensemble predictors on input instances and // -// pseudorandomly. +// computes the update to cached logits. It is designed to be used during training. +// It traverses the trees starting from cached tree id and cached node id and +// calculates the updates to be pushed to the cache. // // Arguments: // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// count: A scalar representing the number of times the underlying dataset -// should be repeated. The default is `-1`, which results in infinite repetition. -// +// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting +// tree of prediction. +// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting +// node of prediction. +// bucketized_features: A list of rank 1 Tensors containing bucket id for each +// feature. +// logits_dimension: scalar, dimension of the logits, to be used for partial logits +// shape. // -func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns Rank 2 Tensor containing logits update (with respect to cached +// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids. +func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"logits_dimension": logits_dimension} opspec := tf.OpSpec{ - Type: "ShuffleAndRepeatDataset", + Type: "BoostedTreesTrainingPredict", Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, count, + tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features), }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Creates a dataset that caches elements from `input_dataset`. -// -// A CacheDataset will iterate over the input_dataset, and store tensors. If the -// cache already exists, the cache will be used. If the cache is inappropriate -// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error -// will the returned when used. -// -// Arguments: -// -// filename: A path on the filesystem where we should cache the dataset. Note: this -// will be a directory. +// MapSizeAttr is an optional argument to MapSize. +type MapSizeAttr func(optionalAttr) + +// MapSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // +// REQUIRES: value >= 0 +func MapSizeCapacity(value int64) MapSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// REQUIRES: value >= 0 +func MapSizeMemoryLimit(value int64) MapSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapSizeContainer(value string) MapSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapSizeSharedName(value string) MapSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of elements in the underlying container. +func MapSize(scope *Scope, dtypes []tf.DataType, optional ...MapSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "CacheDataset", - Input: []tf.Input{ - input_dataset, filename, - }, + Type: "MapSize", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that executes a SQL query and emits rows of the result set. +// Convert JSON-encoded Example records to binary protocol buffer strings. // -// Arguments: -// driver_name: The database type. Currently, the only supported type is 'sqlite'. -// data_source_name: A connection string to connect to the database. -// query: A SQL query to execute. +// This op translates a tensor containing Example records, encoded using +// the [standard JSON +// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), +// into a tensor containing the same records encoded as binary protocol +// buffers. The resulting tensor can then be fed to any of the other +// Example-parsing ops. // +// Arguments: +// json_examples: Each string is a JSON object serialized according to the JSON +// mapping of the Example proto. // -func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns Each string is a binary Example protocol buffer corresponding +// to the respective element of `json_examples`. +func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SqlDataset", + Type: "DecodeJSONExample", Input: []tf.Input{ - driver_name, data_source_name, query, + json_examples, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that emits the records from one or more binary files. +// SparseToDenseAttr is an optional argument to SparseToDense. +type SparseToDenseAttr func(optionalAttr) + +// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. // -// Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// header_bytes: A scalar representing the number of bytes to skip at the -// beginning of a file. -// record_bytes: A scalar representing the number of bytes in each record. -// footer_bytes: A scalar representing the number of bytes to skip at the end -// of a file. -// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. -func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FixedLengthRecordDataset", - Input: []tf.Input{ - filenames, header_bytes, record_bytes, footer_bytes, buffer_size, - }, +// value: If true, indices are checked to make sure they are sorted in +// lexicographic order and that there are no repeats. +// If not specified, defaults to true +func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { + return func(m optionalAttr) { + m["validate_indices"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Gradients for batch normalization. +// Converts a sparse representation into a dense tensor. // -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// Builds an array `dense` with shape `output_shape` such that // -// This op is deprecated. See `tf.nn.batch_normalization`. +// ``` +// # If sparse_indices is scalar +// dense[i] = (i == sparse_indices ? sparse_values : default_value) // -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. +// # If sparse_indices is a vector, then for each i +// dense[sparse_indices[i]] = sparse_values[i] // -// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", - Input: []tf.Input{ - t, m, v, gamma, backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// Creates a dataset that emits the records from one or more TFRecord files. +// # If sparse_indices is an n by d matrix, then for each i in [0, n) +// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +// ``` +// +// All other values in `dense` are set to `default_value`. If `sparse_values` is a +// scalar, all sparse indices are set to this single value. +// +// Indices should be sorted in lexicographic order, and indices must not +// contain any repeats. If `validate_indices` is true, these properties +// are checked during execution. // // Arguments: -// filenames: A scalar or vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar representing the number of bytes to buffer. A value of -// 0 means no buffering will be performed. -func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { +// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete +// index where `sparse_values[i]` will be placed. +// output_shape: 1-D. Shape of the dense output tensor. +// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, +// or a scalar value to be used for all sparse indices. +// default_value: Scalar value to set for indices not specified in +// `sparse_indices`. +// +// Returns Dense output tensor of shape `output_shape`. +func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TFRecordDataset", + Type: "SparseToDense", Input: []tf.Input{ - filenames, compression_type, buffer_size, + sparse_indices, output_shape, sparse_values, default_value, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// BatchToSpace for 4-D tensors of type T. -// -// This is a legacy version of the more general BatchToSpaceND. -// -// Rearranges (permutes) data from batch into blocks of spatial data, followed by -// cropping. This is the reverse transformation of SpaceToBatch. More specifically, -// this op outputs a copy of the input tensor where values from the `batch` -// dimension are moved in spatial blocks to the `height` and `width` dimensions, -// followed by cropping along the `height` and `width` dimensions. -// -// Arguments: -// input: 4-D tensor with shape -// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, -// depth]`. Note that the batch size of the input tensor must be divisible by -// `block_size * block_size`. -// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies -// how many elements to crop from the intermediate result across the spatial -// dimensions as follows: -// -// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] -// -// -// Returns 4-D with shape `[batch, height, width, depth]`, where: -// -// height = height_pad - crop_top - crop_bottom -// width = width_pad - crop_left - crop_right -// -// The attr `block_size` must be greater than one. It indicates the block size. -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 3]` and value: -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` -// -// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: +// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. // -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` +// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the +// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each +// input channel is processed independently of the others with its own structuring +// function. The `output` tensor has shape +// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output +// tensor depend on the `padding` algorithm. We currently only support the default +// "NHWC" `data_format`. // -// The output tensor has shape `[1, 4, 4, 1]` and value: +// In detail, the grayscale morphological 2-D dilation is the max-sum correlation +// (for consistency with `conv2d`, we use unmirrored filters): // -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` +// output[b, y, x, c] = +// max_{dy, dx} input[b, +// strides[1] * y + rates[1] * dy, +// strides[2] * x + rates[2] * dx, +// c] + +// filter[dy, dx, c] // -// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: +// Max-pooling is a special case when the filter has size equal to the pooling +// kernel size and contains all zeros. // -// ``` -// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], -// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] -// ``` +// Note on duality: The dilation of `input` by the `filter` is equal to the +// negation of the erosion of `-input` by the reflected `filter`. // -// The output tensor has shape `[2, 2, 4, 1]` and value: +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// strides: The stride of the sliding window for each dimension of the input +// tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: The input stride for atrous morphological dilation. Must be: +// `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. // -// ``` -// x = [[[[1], [3]], [[5], [7]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { +// Returns 4-D with shape `[batch, out_height, out_width, depth]`. +func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "BatchToSpace", + Type: "Dilation2D", Input: []tf.Input{ - input, crops, + input, filter, }, Attrs: attrs, } @@ -25462,1928 +25492,871 @@ func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int return op.Output(0) } -// Makes a new iterator from the given `dataset` and stores it in `iterator`. +// Converts the given variant tensor to an iterator and stores it in the given resource. // -// This operation may be executed multiple times. Each execution will reset the -// iterator in `iterator` to the first element of `dataset`. +// Arguments: +// resource_handle: A handle to an iterator resource. +// serialized: A variant tensor storing the state of the iterator contained in the +// resource. // // Returns the created operation. -func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { +func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MakeIterator", + Type: "DeserializeIterator", Input: []tf.Input{ - dataset, iterator, + resource_handle, serialized, }, } return scope.AddOperation(opspec) } -// Makes the summary of accumulated stats for the batch. -// -// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example. -// -// Arguments: -// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer. -// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients. -// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians. -// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column). -// max_splits: int; the maximum number of splits possible in the whole tree. -// num_buckets: int; equals to the maximum possible value of bucketized feature. -// -// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians. -func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) { +// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2. +type TensorArrayConcatV2Attr func(optionalAttr) + +// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. +// If not specified, defaults to +func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr { + return func(m optionalAttr) { + m["element_shape_except0"] = value + } +} + +// Deprecated. Use TensorArrayConcatV3 +func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets} + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesMakeStatsSummary", + Type: "TensorArrayConcatV2", Input: []tf.Input{ - node_ids, gradients, hessians, tf.OutputList(bucketized_features_list), + handle, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Adjust the contrast of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are -// interpreted as `[height, width, channels]`. The other dimensions only -// represent a collection of images, such as `[batch, height, width, channels].` -// -// Contrast is adjusted independently for each channel of each image. -// -// For each channel, the Op first computes the mean of the image pixels in the -// channel and then adjusts each component of each pixel to -// `(x - mean) * contrast_factor + mean`. +// Creates a dataset that batches and pads `batch_size` elements from the input. // // Arguments: -// images: Images to adjust. At least 3-D. -// contrast_factor: A float multiplier for adjusting contrast. // -// Returns The contrast-adjusted image or images. -func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// +func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "AdjustContrastv2", + Type: "PaddedBatchDataset", Input: []tf.Input{ - images, contrast_factor, + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Gets the next output from the given iterator. -func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "IteratorGetNext", - Input: []tf.Input{ - iterator, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNext", err) - return - } - return components -} - -// Outputs the single element from the given dataset. +// Creates a dataset that batches input elements into a SparseTensor. // // Arguments: -// dataset: A handle to a dataset that contains a single element. -// +// input_dataset: A handle to an input dataset. Must have a single component. +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// row_shape: A vector representing the dense shape of each row in the produced +// SparseTensor. The shape may be partially specified, using `-1` to indicate +// that a particular dimension should use the maximum size of all batch elements. // // -// Returns The components of the single element of `input`. -func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { +func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "DatasetToSingleElement", + Type: "DenseToSparseBatchDataset", Input: []tf.Input{ - dataset, + input_dataset, batch_size, row_shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("DatasetToSingleElement", err) - return - } - return components + return op.Output(0) } -// Converts the given `resource_handle` representing an iterator to a string. -// -// Arguments: -// resource_handle: A handle to an iterator resource. +// Deprecated. Use TensorArrayGradV3 // -// Returns A string representation of the given handle. -func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 +func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"source": source} opspec := tf.OpSpec{ - Type: "IteratorToStringHandle", + Type: "TensorArrayGradV2", Input: []tf.Input{ - resource_handle, + handle, flow_in, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. -type IteratorFromStringHandleAttr func(optionalAttr) - -// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. +// Return substrings from `Tensor` of strings. // -// value: If specified, defines the type of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> +// For each string in the input `Tensor`, creates a substring starting at index +// `pos` with a total length of `len`. // -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_types"] = value - } -} - -// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. +// If `len` defines a substring that would extend beyond the length of the input +// string, then as many characters as possible are used. // -// value: If specified, defines the shape of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> +// If `pos` is negative or specifies a character index larger than any of the input +// strings, then an `InvalidArgumentError` is thrown. // -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_shapes"] = value - } -} - -// Converts the given string representing a handle to an iterator to a resource. +// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on +// Op creation. +// +// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about +// broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +// +// --- +// +// Examples +// +// Using scalar `pos` and `len`: +// +// ```python +// input = [b'Hello', b'World'] +// position = 1 +// length = 3 +// +// output = [b'ell', b'orl'] +// ``` +// +// Using `pos` and `len` with same shape as `input`: +// +// ```python +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen']] +// position = [[1, 2, 3], +// [1, 2, 3], +// [1, 2, 3]] +// length = [[2, 3, 4], +// [4, 3, 2], +// [5, 5, 5]] +// +// output = [[b'en', b'eve', b'lve'], +// [b'hirt', b'urt', b'te'], +// [b'ixtee', b'vente', b'hteen']] +// ``` +// +// Broadcasting `pos` and `len` onto `input`: +// +// ``` +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen'], +// [b'nineteen', b'twenty', b'twentyone']] +// position = [1, 2, 3] +// length = [1, 2, 3] +// +// output = [[b'e', b'ev', b'lve'], +// [b'h', b'ur', b'tee'], +// [b'i', b've', b'hte'], +// [b'i', b'en', b'nty']] +// ``` +// +// Broadcasting `input` onto `pos` and `len`: +// +// ``` +// input = b'thirteen' +// position = [1, 5, 7] +// length = [3, 2, 1] +// +// output = [b'hir', b'ee', b'n'] +// ``` // // Arguments: -// string_handle: A string representation of the given handle. +// input: Tensor of strings +// pos: Scalar defining the position of first character in each substring +// len: Scalar defining the number of characters to include in each substring // -// Returns A handle to an iterator resource. -func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { +// Returns Tensor of substrings +func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "IteratorFromStringHandle", + Type: "Substr", Input: []tf.Input{ - string_handle, + input, pos, len, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Gather slices from `params` axis `axis` according to `indices`. -// -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `params.shape[:axis] + indices.shape + -// params.shape[axis + 1:]` where: -// -// ```python -// # Scalar indices (output is rank(params) - 1). -// output[a_0, ..., a_n, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices, b_0, ..., b_n] -// -// # Vector indices (output is rank(params)). -// output[a_0, ..., a_n, i, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] -// -// # Higher rank indices (output is rank(params) + rank(indices) - 1). -// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = -// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] -// ``` -// -//
-// -//
-// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, a 0 is stored in the -// corresponding output value. +// Creates a Dataset that returns pseudorandom numbers. // // Arguments: -// params: The tensor from which to gather values. Must be at least rank -// `axis + 1`. -// indices: Index tensor. Must be in range `[0, params.shape[axis])`. -// axis: The axis in `params` to gather `indices` from. Defaults to the first -// dimension. Supports negative indexes. +// seed: A scalar seed for the random number generator. If either seed or +// seed2 is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. // -// Returns Values from `params` gathered from indices given by `indices`, with -// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. -func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { +// +func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "GatherV2", + Type: "RandomDataset", Input: []tf.Input{ - params, indices, axis, + seed, seed2, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts the given `resource_handle` representing an iterator to a variant tensor. +// Creates a dataset that shuffles and repeats elements from `input_dataset` +// +// pseudorandomly. // // Arguments: -// resource_handle: A handle to an iterator resource. // -// Returns A variant tensor storing the state of the iterator contained in the -// resource. -func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// count: A scalar representing the number of times the underlying dataset +// should be repeated. The default is `-1`, which results in infinite repetition. +// +// +func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SerializeIterator", + Type: "ShuffleAndRepeatDataset", Input: []tf.Input{ - resource_handle, + input_dataset, buffer_size, seed, seed2, count, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. -type FIFOQueueV2Attr func(optionalAttr) - -// FIFOQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} - -// FIFOQueueV2Capacity sets the optional capacity attribute to value. +// Creates a dataset that caches elements from `input_dataset`. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// FIFOQueueV2Container sets the optional container attribute to value. +// A CacheDataset will iterate over the input_dataset, and store tensors. If the +// cache already exists, the cache will be used. If the cache is inappropriate +// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error +// will the returned when used. // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FIFOQueueV2Container(value string) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// FIFOQueueV2SharedName sets the optional shared_name attribute to value. +// Arguments: // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements in first-in first-out order. +// filename: A path on the filesystem where we should cache the dataset. Note: this +// will be a directory. // -// Arguments: -// component_types: The type of each component in a value. // -// Returns The handle to the queue. -func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { +func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "FIFOQueueV2", - + Type: "CacheDataset", + Input: []tf.Input{ + input_dataset, filename, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Produces a summary of any statistics recorded by the given statistics manager. -func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { +// Computes the sum along sparse segments of a tensor. +// +// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// For example: +// +// ```python +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// +// tf.sparse_segment_sum_with_num_segments( +// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +// # => [[0 0 0 0] +// # [0 0 0 0] +// # [0 0 0 0]] +// +// tf.sparse_segment_sum_with_num_segments(c, +// tf.constant([0, 1]), +// tf.constant([0, 2], +// num_segments=4)) +// # => [[ 1 2 3 4] +// # [ 0 0 0 0] +// # [-1 -2 -3 -4] +// # [ 0 0 0 0]] +// ``` +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `num_segments`. +func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "StatsAggregatorSummary", + Type: "SparseSegmentSumWithNumSegments", Input: []tf.Input{ - iterator, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Compute the pairwise cross product. -// -// `a` and `b` must be the same shape; they can either be simple 3-element vectors, -// or any shape where the innermost dimension is 3. In the latter case, each pair -// of corresponding 3-element vectors is cross-multiplied independently. +// Creates a dataset that executes a SQL query and emits rows of the result set. // // Arguments: -// a: A tensor containing 3-element vectors. -// b: Another tensor, of same type and shape as `a`. +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. // -// Returns Pairwise cross product of the vectors in `a` and `b`. -func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { +// +func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Cross", + Type: "SqlDataset", Input: []tf.Input{ - a, b, + driver_name, data_source_name, query, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Performs a padding as a preprocess during a convolution. -// -// Similar to FusedResizeAndPadConv2d, this op allows for an optimized -// implementation where the spatial padding transformation stage is fused with the -// im2col lookup, but in this case without the bilinear filtering required for -// resizing. Fusing the padding prevents the need to write out the intermediate -// results as whole tensors, reducing memory pressure, and we can get some latency -// gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' -// order is used instead. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. +// Creates a dataset that emits the records from one or more binary files. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// header_bytes: A scalar representing the number of bytes to skip at the +// beginning of a file. +// record_bytes: A scalar representing the number of bytes in each record. +// footer_bytes: A scalar representing the number of bytes to skip at the end +// of a file. +// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. +func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "FusedPadConv2D", + Type: "FixedLengthRecordDataset", Input: []tf.Input{ - input, paddings, filter, + filenames, header_bytes, record_bytes, footer_bytes, buffer_size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. -type Conv2DBackpropInputAttr func(optionalAttr) - -// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} - -// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. +// Gradients for batch normalization. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv2DBackpropInputDilations sets the optional dilations attribute to value. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of convolution with respect to the input. +// This op is deprecated. See `tf.nn.batch_normalization`. // // Arguments: -// input_sizes: An integer vector representing the shape of `input`, -// where `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. -// padding: The type of padding algorithm to use. +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. // -// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient -// w.r.t. the input of the convolution. -func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { +// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "Conv2DBackpropInput", + Type: "BatchNormWithGlobalNormalizationGrad", Input: []tf.Input{ - input_sizes, filter, out_backprop, + t, m, v, gamma, backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` -// -// For example, if each `indices[m]` is scalar or vector, we have -// -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] -// -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` -// -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is -// -// merged.shape = [max(indices)] + constant -// -// Values are merged in order, so if an index appears in both `indices[m][i]` and -// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the -// merged result. If you do not need this guarantee, ParallelDynamicStitch might -// perform better on some devices. -// -// For example: -// -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` -// -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: -// -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` +// Creates a dataset that emits the records from one or more TFRecord files. // -//
-// -//
-func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Arguments: +// filenames: A scalar or vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar representing the number of bytes to buffer. A value of +// 0 means no buffering will be performed. +func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DynamicStitch", + Type: "TFRecordDataset", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + filenames, compression_type, buffer_size, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of (x == y) element-wise. +// BatchToSpace for 4-D tensors of type T. // -// *NOTE*: `Equal` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Equal", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. -type TensorArrayGatherV2Attr func(optionalAttr) - -// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Deprecated. Use TensorArrayGatherV3 +// This is a legacy version of the more general BatchToSpaceND. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 -func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayGatherV2", - Input: []tf.Input{ - handle, indices, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Interleave the values from the `data` tensors into a single tensor. +// Rearranges (permutes) data from batch into blocks of spatial data, followed by +// cropping. This is the reverse transformation of SpaceToBatch. More specifically, +// this op outputs a copy of the input tensor where values from the `batch` +// dimension are moved in spatial blocks to the `height` and `width` dimensions, +// followed by cropping along the `height` and `width` dimensions. // -// Builds a merged tensor such that +// Arguments: +// input: 4-D tensor with shape +// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, +// depth]`. Note that the batch size of the input tensor must be divisible by +// `block_size * block_size`. +// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies +// how many elements to crop from the intermediate result across the spatial +// dimensions as follows: // -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] +// +// +// Returns 4-D with shape `[batch, height, width, depth]`, where: +// +// height = height_pad - crop_top - crop_bottom +// width = width_pad - crop_left - crop_right +// +// The attr `block_size` must be greater than one. It indicates the block size. +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] // ``` // -// For example, if each `indices[m]` is scalar or vector, we have +// The output tensor has shape `[1, 2, 2, 1]` and value: // -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: // -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] // ``` // -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is +// The output tensor has shape `[1, 2, 2, 3]` and value: // -// merged.shape = [max(indices)] + constant +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` // -// Values may be merged in parallel, so if an index appears in both `indices[m][i]` -// and `indices[n][j]`, the result may be invalid. This differs from the normal -// DynamicStitch operator that defines the behavior in that case. +// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: // -// For example: +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` // -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]] // ``` // -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: +// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: // -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. +// ``` +// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], +// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] // ``` // -//
-// -//
-func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[5], [7]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"block_size": block_size} opspec := tf.OpSpec{ - Type: "ParallelDynamicStitch", + Type: "BatchToSpace", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient for the inverse of `x` wrt its input. -// -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InvGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// List of the given size with empty elements. -// -// element_shape: the shape of the future elements of the list -// num_elements: the number of elements to reserve -// handle: the output list -// element_dtype: the desired type of elements in the list. -func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListReserve", - Input: []tf.Input{ - element_shape, num_elements, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. -type PriorityQueueV2Attr func(optionalAttr) - -// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. -// -// value: The type of each component in a value. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["component_types"] = value - } -} - -// PriorityQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// PriorityQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func PriorityQueueV2Container(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// PriorityQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements sorted by the first component value. -// -// Note that the PriorityQueue requires the first component of any element -// to be a scalar int64, in addition to the other elements declared by -// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue -// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra -// entry in their input (resp. output) lists. -// -// Arguments: -// shapes: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// -// Returns The handle to the queue. -func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shapes": shapes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "PriorityQueueV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnstageAttr is an optional argument to Unstage. -type UnstageAttr func(optionalAttr) - -// UnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageCapacity(value int64) UnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// UnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageMemoryLimit(value int64) UnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// UnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnstageContainer(value string) UnstageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// UnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnstageSharedName(value string) UnstageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op is similar to a lightweight Dequeue. -// -// The basic functionality is similar to dequeue with many fewer -// capabilities and options. This Op is optimized for performance. -func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Unstage", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("Unstage", err) - return - } - return values -} - -// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. -type QueueEnqueueV2Attr func(optionalAttr) - -// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue is full, this operation will block for up to -// timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } -} - -// Enqueues a tuple of one or more tensors in the given queue. -// -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. -// -// N.B. If the queue is full, this operation will block until the given -// element has been enqueued (or 'timeout_ms' elapses, if specified). -// -// Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should be taken. -// -// Returns the created operation. -func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueEnqueueV2", - Input: []tf.Input{ - handle, tf.OutputList(components), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. -type QueueDequeueManyV2Attr func(optionalAttr) - -// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } -} - -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// If the queue is closed and there are fewer than `n` elements, then an -// OutOfRange error is returned. -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size `n` in the 0th dimension. -// -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. -// -// N.B. If the queue is empty, this operation will block until `n` elements -// have been dequeued (or 'timeout_ms' elapses, if specified). -// -// Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. -// -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueDequeueManyV2", - Input: []tf.Input{ - handle, n, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueManyV2", err) - return - } - return components -} - -// EncodeBase64Attr is an optional argument to EncodeBase64. -type EncodeBase64Attr func(optionalAttr) - -// EncodeBase64Pad sets the optional pad attribute to value. -// -// value: Bool whether padding is applied at the ends. -// If not specified, defaults to false -func EncodeBase64Pad(value bool) EncodeBase64Attr { - return func(m optionalAttr) { - m["pad"] = value - } -} - -// Encode strings into web-safe base64 format. -// -// Refer to the following article for more information on base64 format: -// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the -// end so that the encoded has length multiple of 4. See Padding section of the -// link above. -// -// Web-safe means that the encoder uses - and _ instead of + and /. -// -// Arguments: -// input: Strings to be encoded. -// -// Returns Input strings encoded in base64. -func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeBase64", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deprecated. Use TensorArrayCloseV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 -// -// Returns the created operation. -func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayCloseV2", - Input: []tf.Input{ - handle, - }, - } - return scope.AddOperation(opspec) -} - -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. -// -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. -// -// Arguments: -// inputs: The input tensors, exactly one of which will become available. -// -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Merge", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) - -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. -// -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. -// If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { - return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value - } -} - -// Closes the given queue. -// -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. -// -// Arguments: -// handle: The handle to a queue. -// -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueCloseV2", - Input: []tf.Input{ - handle, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes inverse hyperbolic tangent of x element-wise. -func Atanh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atanh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns true if queue is closed. -// -// This operation returns true if the queue is closed and false if the queue -// is open. -// -// Arguments: -// handle: The handle to a queue. -func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "QueueIsClosedV2", - Input: []tf.Input{ - handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the batched diagonal part of a batched tensor. -// -// This operation returns a tensor with the `diagonal` part -// of the batched `input`. The `diagonal` part is computed as follows: -// -// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: -// -// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. -// -// The input must be at least a matrix. -// -// For example: -// -// ``` -// # 'input' is [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// and input.shape = (2, 4, 4) -// -// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// which has shape (2, 4) -// ``` -// -// Arguments: -// input: Rank `k` tensor where `k >= 2`. -// -// Returns The extracted diagonal(s) having shape -// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`. -func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiagPart", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the absolute value of a tensor. -// -// Given a tensor `x`, this operation returns a tensor containing the absolute -// value of each element in `x`. For example, if x is an input element and y is -// an output element, this operation computes \\(y = |x|\\). -func Abs(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Abs", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StackV2Attr is an optional argument to StackV2. -type StackV2Attr func(optionalAttr) - -// StackV2StackName sets the optional stack_name attribute to value. -// -// value: Overrides the name used for the temporary stack resource. Default -// value is the name of the 'Stack' op (which is guaranteed unique). -// If not specified, defaults to "" -func StackV2StackName(value string) StackV2Attr { - return func(m optionalAttr) { - m["stack_name"] = value - } -} - -// A stack that produces elements in first-in last-out order. -// -// Arguments: -// max_size: The maximum size of the stack if non-negative. If negative, the stack -// size is unlimited. -// elem_type: The type of the elements on the stack. -// -// Returns The handle to the stack. -func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"elem_type": elem_type} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StackV2", - Input: []tf.Input{ - max_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. -type FusedBatchNormGradV2Attr func(optionalAttr) - -// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. -// -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Gradient for batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. -// -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNormGradV2", - Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// DecodeCompressedAttr is an optional argument to DecodeCompressed. -type DecodeCompressedAttr func(optionalAttr) - -// DecodeCompressedCompressionType sets the optional compression_type attribute to value. -// -// value: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// If not specified, defaults to "" -func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { - return func(m optionalAttr) { - m["compression_type"] = value - } -} - -// Decompress strings. -// -// This op decompresses each element of the `bytes` input `Tensor`, which -// is assumed to be compressed using the given `compression_type`. -// -// The `output` is a string `Tensor` of the same shape as `bytes`, -// each element containing the decompressed data from the corresponding -// element in `bytes`. -// -// Arguments: -// bytes: A Tensor of string which is compressed. -// -// Returns A Tensor with the same shape as input `bytes`, uncompressed -// from bytes. -func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeCompressed", - Input: []tf.Input{ - bytes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CudnnRNNAttr is an optional argument to CudnnRNN. -type CudnnRNNAttr func(optionalAttr) - -// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNRnnMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNInputMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNDirection(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} - -// CudnnRNNDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNDropout(value float32) CudnnRNNAttr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNSeed(value int64) CudnnRNNAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// CudnnRNNSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNSeed2(value int64) CudnnRNNAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// CudnnRNNIsTraining sets the optional is_training attribute to value. -// If not specified, defaults to true -func CudnnRNNIsTraining(value bool) CudnnRNNAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// A RNN backed by cuDNN. -// -// Computes the RNN from the input and initial states, with respect to the params -// buffer. -// -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -// input: a 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: a 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// output: a 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: the same shape has input_h. -// output_c: the same shape as input_c for LSTM. An empty tensor for other models. -// is_training: Indicates whether this operation is used for inferenece or -// training. -// reserve_space: an opaque tensor that can be used in backprop calculation. It -// is only produced if is_training is false. -func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CudnnRNN", - Input: []tf.Input{ - input, input_h, input_c, params, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. -// -// Each comparison returns a boolean `true` (if `input_value > threshold`) -// or and `false` otherwise. -// -// This operation is useful for Locality-Sensitive-Hashing (LSH) and other -// algorithms that use hashing approximations of cosine and `L2` distances; -// codes can be generated from an input via: -// -// ```python -// codebook_size = 50 -// codebook_bits = codebook_size * 32 -// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], -// dtype=x.dtype, -// initializer=tf.orthogonal_initializer()) -// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) -// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 -// # now codes has shape x.shape[:-1] + [codebook_size] -// ``` -// -// **NOTE**: Currently, the innermost dimension of the tensor must be divisible -// by 8. -// -// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is -// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. -// -// Arguments: -// input: Values to compare against `threshold` and bitpack. -// threshold: Threshold to compare against. -// -// Returns The bitpacked comparisons. -func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CompareAndBitpack", - Input: []tf.Input{ - input, threshold, + input, crops, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Push an element onto the tensor_array. +// Makes a new iterator from the given `dataset` and stores it in `iterator`. // -// Arguments: -// handle: The handle to a TensorArray. -// index: The position to write to inside the TensorArray. -// value: The tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// This operation may be executed multiple times. Each execution will reset the +// iterator in `iterator` to the first element of `dataset`. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns the created operation. +func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayWriteV3", + Type: "MakeIterator", Input: []tf.Input{ - handle, index, value, flow_in, + dataset, iterator, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Scatter the data from the input value into specific TensorArray elements. +// Makes the summary of accumulated stats for the batch. // -// `indices` must be a vector, its length must match the first dim of `value`. +// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example. // // Arguments: -// handle: The handle to a TensorArray. -// indices: The locations at which to write the tensor elements. -// value: The concatenated tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer. +// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients. +// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians. +// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column). +// max_splits: int; the maximum number of splits possible in the whole tree. +// num_buckets: int; equals to the maximum possible value of bucketized feature. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians. +func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "TensorArrayScatterV3", + Type: "BoostedTreesMakeStatsSummary", Input: []tf.Input{ - handle, indices, value, flow_in, + node_ids, gradients, hessians, tf.OutputList(bucketized_features_list), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// EmptyAttr is an optional argument to Empty. -type EmptyAttr func(optionalAttr) - -// EmptyInit sets the optional init attribute to value. +// Adjust the contrast of one or more images. // -// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. -// If not specified, defaults to false -func EmptyInit(value bool) EmptyAttr { - return func(m optionalAttr) { - m["init"] = value - } -} - -// Creates a tensor with the given shape. +// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are +// interpreted as `[height, width, channels]`. The other dimensions only +// represent a collection of images, such as `[batch, height, width, channels].` // -// This operation creates a tensor of `shape` and `dtype`. +// Contrast is adjusted independently for each channel of each image. // -// Arguments: -// shape: 1-D. Represents the shape of the output tensor. +// For each channel, the Op first computes the mean of the image pixels in the +// channel and then adjusts each component of each pixel to +// `(x - mean) * contrast_factor + mean`. // +// Arguments: +// images: Images to adjust. At least 3-D. +// contrast_factor: A float multiplier for adjusting contrast. // -// Returns A `Tensor` of type `T`. -func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { +// Returns The contrast-adjusted image or images. +func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Empty", + Type: "AdjustContrastv2", Input: []tf.Input{ - shape, + images, contrast_factor, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. -type TensorArrayConcatV3Attr func(optionalAttr) - -// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. -// -// value: The expected shape of an element, if known, -// excluding the first dimension. Used to validate the shapes of -// TensorArray elements. If this shape is not fully specified, concatenating -// zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { - return func(m optionalAttr) { - m["element_shape_except0"] = value - } -} - -// Concat the elements from the TensorArray into value `value`. -// -// Takes `T` elements of shapes -// -// ``` -// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) -// ``` -// -// and concatenates them into a Tensor of shape: -// -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` -// -// All elements must have the same shape (excepting the first dimension). -// -// Arguments: -// handle: The handle to a TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. -// -// Returns All of the elements in the TensorArray, concatenated along the first -// axis.A vector of the row sizes of the original T elements in the -// value output. In the example above, this would be the values: -// `(n1, n2, ..., n(T-1))`. -func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { +// Gets the next output from the given iterator. +func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayConcatV3", + Type: "IteratorGetNext", Input: []tf.Input{ - handle, flow_in, + iterator, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. -type ParameterizedTruncatedNormalAttr func(optionalAttr) - -// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { - return func(m optionalAttr) { - m["seed"] = value + if scope.Err() != nil { + return } -} - -// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNext", err) + return } + return components } -// Outputs random values from a normal distribution. The parameters may each be a -// -// scalar which applies to the entire output, or a vector of length shape[0] which -// stores the parameters for each batch. +// Outputs the single element from the given dataset. // // Arguments: -// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. -// means: The mean parameter of each batch. -// stdevs: The standard deviation parameter of each batch. Must be greater than 0. -// minvals: The minimum cutoff. May be -infinity. -// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval -// for each batch. +// dataset: A handle to a dataset that contains a single element. // -// Returns A matrix of shape num_batches x samples_per_batch, filled with random -// truncated normal values using the parameters for each row. -func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { +// +// +// Returns The components of the single element of `input`. +func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ParameterizedTruncatedNormal", + Type: "DatasetToSingleElement", Input: []tf.Input{ - shape, means, stdevs, minvals, maxvals, + dataset, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("DatasetToSingleElement", err) + return + } + return components } -// Sets the index-th position of the list to contain the given tensor. +// Converts the given `resource_handle` representing an iterator to a string. // -// input_handle: the list -// index: the position in the list to which the tensor will be assigned -// item: the element to be assigned to that position -// output_handle: the new list, with the element in the proper position +// Arguments: +// resource_handle: A handle to an iterator resource. // -func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) { +// Returns A string representation of the given handle. +func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListSetItem", + Type: "IteratorToStringHandle", Input: []tf.Input{ - input_handle, index, item, + resource_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a diagonal tensor with a given diagonal values. -// -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: +// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. +type IteratorFromStringHandleAttr func(optionalAttr) + +// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. // -// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of -// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: +// value: If specified, defines the type of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> // -// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { + return func(m optionalAttr) { + m["output_types"] = value + } +} + +// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. // -// For example: +// value: If specified, defines the shape of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> // -// ``` -// # 'diagonal' is [1, 2, 3, 4] -// tf.diag(diagonal) ==> [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] -// ``` +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { + return func(m optionalAttr) { + m["output_shapes"] = value + } +} + +// Converts the given string representing a handle to an iterator to a resource. // // Arguments: -// diagonal: Rank k tensor where k is at most 1. -func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { +// string_handle: A string representation of the given handle. +// +// Returns A handle to an iterator resource. +func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Diag", + Type: "IteratorFromStringHandle", Input: []tf.Input{ - diagonal, + string_handle, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Split the data from the input value into TensorArray elements. -// -// Assuming that `lengths` takes on values -// -// ```(n0, n1, ..., n(T-1))``` -// -// and that `value` has shape +// Gather slices from `params` axis `axis` according to `indices`. // -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `params.shape[:axis] + indices.shape + +// params.shape[axis + 1:]` where: // -// this splits values into a TensorArray with T tensors. +// ```python +// # Scalar indices (output is rank(params) - 1). +// output[a_0, ..., a_n, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices, b_0, ..., b_n] // -// TensorArray index t will be the subtensor of values with starting position +// # Vector indices (output is rank(params)). +// output[a_0, ..., a_n, i, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] // -// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` +// # Higher rank indices (output is rank(params) + rank(indices) - 1). +// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = +// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] +// ``` // -// and having size +//
+// +//
// -// ```nt x d0 x d1 x ...``` +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, a 0 is stored in the +// corresponding output value. // // Arguments: -// handle: The handle to a TensorArray. -// value: The concatenated tensor to write to the TensorArray. -// lengths: The vector of lengths, how to split the rows of value into the -// TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// params: The tensor from which to gather values. Must be at least rank +// `axis + 1`. +// indices: Index tensor. Must be in range `[0, params.shape[axis])`. +// axis: The axis in `params` to gather `indices` from. Defaults to the first +// dimension. Supports negative indexes. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns Values from `params` gathered from indices given by `indices`, with +// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. +func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySplitV3", + Type: "GatherV2", Input: []tf.Input{ - handle, value, lengths, flow_in, + params, indices, axis, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SerializeSparseAttr is an optional argument to SerializeSparse. -type SerializeSparseAttr func(optionalAttr) - -// SerializeSparseOutType sets the optional out_type attribute to value. -// -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize a `SparseTensor` into a `[3]` `Tensor` object. +// Converts the given `resource_handle` representing an iterator to a variant tensor. // // Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. -func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { +// resource_handle: A handle to an iterator resource. +// +// Returns A variant tensor storing the state of the iterator contained in the +// resource. +func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SerializeSparse", + Type: "SerializeIterator", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + resource_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. -type RandomShuffleQueueV2Attr func(optionalAttr) +// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. +type FIFOQueueV2Attr func(optionalAttr) -// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// FIFOQueueV2Shapes sets the optional shapes attribute to value. // // value: The shape of each component in a value. The length of this attr must // be either 0 or the same as the length of component_types. If the length of @@ -27392,386 +26365,646 @@ type RandomShuffleQueueV2Attr func(optionalAttr) // If not specified, defaults to <> // // REQUIRES: len(value) >= 0 -func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { +func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { return func(m optionalAttr) { m["shapes"] = value } } -// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// FIFOQueueV2Capacity sets the optional capacity attribute to value. // // value: The upper bound on the number of elements in this queue. // Negative numbers mean no limit. // If not specified, defaults to -1 -func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { +func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { return func(m optionalAttr) { m["capacity"] = value } } -// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// FIFOQueueV2Container sets the optional container attribute to value. // -// value: Dequeue will block unless there would be this -// many elements after the dequeue or the queue is closed. This -// ensures a minimum level of mixing of elements. -// If not specified, defaults to 0 -func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FIFOQueueV2Container(value string) FIFOQueueV2Attr { return func(m optionalAttr) { - m["min_after_dequeue"] = value + m["container"] = value } } -// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// FIFOQueueV2SharedName sets the optional shared_name attribute to value. // -// value: If either seed or seed2 is set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { return func(m optionalAttr) { - m["seed"] = value + m["shared_name"] = value } } -// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. +// A queue that produces elements in first-in first-out order. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FIFOQueueV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Produces a summary of any statistics recorded by the given statistics manager. +func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatsAggregatorSummary", + Input: []tf.Input{ + iterator, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the pairwise cross product. +// +// `a` and `b` must be the same shape; they can either be simple 3-element vectors, +// or any shape where the innermost dimension is 3. In the latter case, each pair +// of corresponding 3-element vectors is cross-multiplied independently. +// +// Arguments: +// a: A tensor containing 3-element vectors. +// b: Another tensor, of same type and shape as `a`. +// +// Returns Pairwise cross product of the vectors in `a` and `b`. +func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cross", + Input: []tf.Input{ + a, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Performs a padding as a preprocess during a convolution. +// +// Similar to FusedResizeAndPadConv2d, this op allows for an optimized +// implementation where the spatial padding transformation stage is fused with the +// im2col lookup, but in this case without the bilinear filtering required for +// resizing. Fusing the padding prevents the need to write out the intermediate +// results as whole tensors, reducing memory pressure, and we can get some latency +// gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' +// order is used instead. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} + opspec := tf.OpSpec{ + Type: "FusedPadConv2D", + Input: []tf.Input{ + input, paddings, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. +type Conv2DBackpropInputAttr func(optionalAttr) + +// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_cudnn_on_gpu"] = value } } -// RandomShuffleQueueV2Container sets the optional container attribute to value. +// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { return func(m optionalAttr) { - m["container"] = value + m["data_format"] = value } } -// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// Conv2DBackpropInputDilations sets the optional dilations attribute to value. // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["dilations"] = value } } -// A queue that randomizes the order of elements. +// Computes the gradients of convolution with respect to the input. // // Arguments: -// component_types: The type of each component in a value. +// input_sizes: An integer vector representing the shape of `input`, +// where `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. // -// Returns The handle to the queue. -func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { +// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient +// w.r.t. the input of the convolution. +func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomShuffleQueueV2", - + Type: "Conv2DBackpropInput", + Input: []tf.Input{ + input_sizes, filter, out_backprop, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Draw bounding boxes on a batch of images. +// Interleave the values from the `data` tensors into a single tensor. // -// Outputs a copy of `images` but draws on top of the pixels zero or more bounding -// boxes specified by the locations in `boxes`. The coordinates of the each -// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. +// Builds a merged tensor such that // -// For example, if an image is 100 x 200 pixels (height x width) and the bounding -// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of -// the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates). +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` // -// Parts of the bounding box may fall outside the image. +// For example, if each `indices[m]` is scalar or vector, we have // -// Arguments: -// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. -// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding -// boxes. +// ```python +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] // -// Returns 4-D with the same shape as `images`. The batch of input images with -// bounding boxes drawn on the images. -func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` +// +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is +// +// merged.shape = [max(indices)] + constant +// +// Values are merged in order, so if an index appears in both `indices[m][i]` and +// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the +// merged result. If you do not need this guarantee, ParallelDynamicStitch might +// perform better on some devices. +// +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` +// +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// +//
+// +//
+func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DrawBoundingBoxes", + Type: "DynamicStitch", Input: []tf.Input{ - images, boxes, + tf.OutputList(indices), tf.OutputList(data), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. -type LearnedUnigramCandidateSamplerAttr func(optionalAttr) - -// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// Returns the truth value of (x == y) element-wise. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value +// *NOTE*: `Equal` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Equal", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { +// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. +type TensorArrayGatherV2Attr func(optionalAttr) + +// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { return func(m optionalAttr) { - m["seed2"] = value + m["element_shape"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. -// -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// Deprecated. Use TensorArrayGatherV3 // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 +func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LearnedUnigramCandidateSampler", + Type: "TensorArrayGatherV2", Input: []tf.Input{ - true_classes, + handle, indices, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Computes gradients for the scaled exponential linear (Selu) operation. +// Interleave the values from the `data` tensors into a single tensor. +// +// Builds a merged tensor such that +// +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` +// +// For example, if each `indices[m]` is scalar or vector, we have +// +// ```python +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] +// +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` +// +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is +// +// merged.shape = [max(indices)] + constant +// +// Values may be merged in parallel, so if an index appears in both `indices[m][i]` +// and `indices[n][j]`, the result may be invalid. This differs from the normal +// DynamicStitch operator that defines the behavior in that case. +// +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` // -// Arguments: -// gradients: The backpropagated gradients to the corresponding Selu operation. -// outputs: The outputs of the corresponding Selu operation. +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: // -// Returns The gradients: `gradients * (outputs + scale * alpha)` -// if outputs < 0, `scale * gradients` otherwise. -func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// +//
+// +//
+func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SeluGrad", + Type: "ParallelDynamicStitch", Input: []tf.Input{ - gradients, outputs, + tf.OutputList(indices), tf.OutputList(data), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Get the current size of the TensorArray. -// -// Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). -// flow_in: A float scalar that enforces proper chaining of operations. +// Computes the gradient for the inverse of `x` wrt its input. // -// Returns The current size of the TensorArray. -func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySizeV3", + Type: "InvGrad", Input: []tf.Input{ - handle, flow_in, + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Use TensorArrayGradV3 +// List of the given size with empty elements. // -// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 -func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// element_shape: the shape of the future elements of the list +// num_elements: the number of elements to reserve +// handle: the output list +// element_dtype: the desired type of elements in the list. +func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "TensorArrayWriteV2", + Type: "TensorListReserve", Input: []tf.Input{ - handle, index, value, flow_in, + element_shape, num_elements, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseReduceMaxAttr is an optional argument to SparseReduceMax. -type SparseReduceMaxAttr func(optionalAttr) +// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. +type PriorityQueueV2Attr func(optionalAttr) -// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. +// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { +// value: The type of each component in a value. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { return func(m optionalAttr) { - m["keep_dims"] = value + m["component_types"] = value } } -// Computes the max of elements across dimensions of a SparseTensor. +// PriorityQueueV2Capacity sets the optional capacity attribute to value. // -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// PriorityQueueV2Container sets the optional container attribute to value. // -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PriorityQueueV2Container(value string) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// PriorityQueueV2SharedName sets the optional shared_name attribute to value. // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that produces elements sorted by the first component value. +// +// Note that the PriorityQueue requires the first component of any element +// to be a scalar int64, in addition to the other elements declared by +// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue +// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra +// entry in their input (resp. output) lists. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// shapes: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { +// Returns The handle to the queue. +func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"shapes": shapes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SparseReduceMax", - Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, - }, + Type: "PriorityQueueV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AsStringAttr is an optional argument to AsString. -type AsStringAttr func(optionalAttr) +// UnstageAttr is an optional argument to Unstage. +type UnstageAttr func(optionalAttr) -// AsStringPrecision sets the optional precision attribute to value. +// UnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: The post-decimal precision to use for floating point numbers. -// Only used if precision > -1. -// If not specified, defaults to -1 -func AsStringPrecision(value int64) AsStringAttr { +// REQUIRES: value >= 0 +func UnstageCapacity(value int64) UnstageAttr { return func(m optionalAttr) { - m["precision"] = value + m["capacity"] = value } } -// AsStringScientific sets the optional scientific attribute to value. +// UnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: Use scientific notation for floating point numbers. -// If not specified, defaults to false -func AsStringScientific(value bool) AsStringAttr { +// REQUIRES: value >= 0 +func UnstageMemoryLimit(value int64) UnstageAttr { return func(m optionalAttr) { - m["scientific"] = value + m["memory_limit"] = value } } -// AsStringShortest sets the optional shortest attribute to value. -// -// value: Use shortest representation (either scientific or standard) for -// floating point numbers. -// If not specified, defaults to false -func AsStringShortest(value bool) AsStringAttr { +// UnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnstageContainer(value string) UnstageAttr { return func(m optionalAttr) { - m["shortest"] = value + m["container"] = value } } -// AsStringWidth sets the optional width attribute to value. -// -// value: Pad pre-decimal numbers to this width. -// Applies to both floating point and integer numbers. -// Only used if width > -1. -// If not specified, defaults to -1 -func AsStringWidth(value int64) AsStringAttr { +// UnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnstageSharedName(value string) UnstageAttr { return func(m optionalAttr) { - m["width"] = value + m["shared_name"] = value } } -// AsStringFill sets the optional fill attribute to value. +// Op is similar to a lightweight Dequeue. // -// value: The value to pad if width > -1. If empty, pads with spaces. -// Another typical value is '0'. String cannot be longer than 1 character. -// If not specified, defaults to "" -func AsStringFill(value string) AsStringAttr { +// The basic functionality is similar to dequeue with many fewer +// capabilities and options. This Op is optimized for performance. +func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Unstage", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("Unstage", err) + return + } + return values +} + +// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. +type QueueEnqueueV2Attr func(optionalAttr) + +// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue is full, this operation will block for up to +// timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { return func(m optionalAttr) { - m["fill"] = value + m["timeout_ms"] = value } } -// Converts each entry in the given tensor to strings. Supports many numeric +// Enqueues a tuple of one or more tensors in the given queue. +// +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. +// +// N.B. If the queue is full, this operation will block until the given +// element has been enqueued (or 'timeout_ms' elapses, if specified). +// +// Arguments: +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should be taken. // -// types and boolean. -func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { +// Returns the created operation. +func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -27780,347 +27013,302 @@ func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output t a(attrs) } opspec := tf.OpSpec{ - Type: "AsString", + Type: "QueueEnqueueV2", Input: []tf.Input{ - input, + handle, tf.OutputList(components), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Deprecated. Use TensorArrayScatterV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 -func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayScatterV2", - Input: []tf.Input{ - handle, indices, value, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. +type QueueDequeueManyV2Attr func(optionalAttr) -// Creates a tree ensemble model and returns a handle to it. -// -// Arguments: -// tree_ensemble_handle: Handle to the tree ensemble resource to be created. -// stamp_token: Token to use as the initial value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the tree ensemble. +// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. // -// Returns the created operation. -func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BoostedTreesCreateEnsemble", - Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, - }, +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value } - return scope.AddOperation(opspec) } -// Applies sparse addition to `input` using individual values or slices -// -// from `updates` according to indices `indices`. The updates are non-aliasing: -// `input` is only modified in-place if no other operations will use it. -// Otherwise, a copy of `input` is made. This operation has a gradient with -// respect to both `input` and `updates`. -// -// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `input`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or `(P-K)`-dimensional slices -// (if `K < P`) along the `K`th dimension of `input`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// ``` -// [d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]]. -// ``` -// -// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 -// elements. In Python, that addition would look like this: +// Dequeues `n` tuples of one or more tensors from the given queue. // -// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) -// with tf.Session() as sess: -// print(sess.run(output)) +// If the queue is closed and there are fewer than `n` elements, then an +// OutOfRange error is returned. // -// The resulting value `output` would look like this: +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size `n` in the 0th dimension. // -// [1, 13, 3, 14, 14, 6, 7, 20] +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. // -// See @{tf.scatter_nd} for more details about how to make updates to slices. +// N.B. If the queue is empty, this operation will block until `n` elements +// have been dequeued (or 'timeout_ms' elapses, if specified). // // Arguments: -// input: A Tensor. -// indices: A Tensor. Must be one of the following types: `int32`, `int64`. -// A tensor of indices into `input`. -// updates: A Tensor. Must have the same type as ref. A tensor of updated values -// to add to `input`. +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. // -// Returns A `Tensor` with the same shape as `input`, containing values of `input` -// updated with `updates`. -func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ScatterNdNonAliasingAdd", + Type: "QueueDequeueManyV2", Input: []tf.Input{ - input, indices, updates, + handle, n, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueManyV2", err) + return + } + return components } -// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. -type FractionalMaxPoolAttr func(optionalAttr) +// EncodeBase64Attr is an optional argument to EncodeBase64. +type EncodeBase64Attr func(optionalAttr) -// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. +// EncodeBase64Pad sets the optional pad attribute to value. // -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. +// value: Bool whether padding is applied at the ends. // If not specified, defaults to false -func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { +func EncodeBase64Pad(value bool) EncodeBase64Attr { return func(m optionalAttr) { - m["pseudo_random"] = value + m["pad"] = value } } -// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. +// Encode strings into web-safe base64 format. // -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: +// Refer to the following article for more information on base64 format: +// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the +// end so that the encoded has length multiple of 4. See Padding section of the +// link above. // -// `index 0 1 2 3 4` +// Web-safe means that the encoder uses - and _ instead of + and /. // -// `value 20 5 16 3 7` +// Arguments: +// input: Strings to be encoded. // -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. -// If not specified, defaults to false -func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["overlapping"] = value +// Returns Input strings encoded in base64. +func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. -// -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalMaxPool node in the computation graph. Mainly used -// in unit test to make FractionalMaxPool deterministic. -// If not specified, defaults to false -func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["deterministic"] = value + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } -} - -// FractionalMaxPoolSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed"] = value + opspec := tf.OpSpec{ + Type: "EncodeBase64", + Input: []tf.Input{ + input, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// Deprecated. Use TensorArrayCloseV3 // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed2"] = value +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// +// Returns the created operation. +func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return } + opspec := tf.OpSpec{ + Type: "TensorArrayCloseV2", + Input: []tf.Input{ + handle, + }, + } + return scope.AddOperation(opspec) } -// Performs fractional max pooling on the input. -// -// Fractional max pooling is slightly different than regular max pooling. In -// regular max pooling, you downsize an input set by taking the maximum value of -// smaller N x N subsections of the set (often 2x2), and try to reduce the set by -// a factor of N, where N is an integer. Fractional max pooling, as you might -// expect from the word "fractional", means that the overall reduction ratio N -// does not have to be an integer. +// Forwards the value of an available tensor from `inputs` to `output`. // -// The sizes of the pooling regions are generated randomly but are fairly uniform. -// For example, let's look at the height dimension, and the constraints on the -// list of rows that will be pool boundaries. +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. // -// First we define the following: +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. // -// 1. input_row_length : the number of rows from the input set -// 2. output_row_length : which will be smaller than the input -// 3. alpha = input_row_length / output_row_length : our reduction ratio -// 4. K = floor(alpha) -// 5. row_pooling_sequence : this is the result list of pool boundary rows +// Arguments: +// inputs: The input tensors, exactly one of which will become available. // -// Then, row_pooling_sequence should satisfy: +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Merge", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) + +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. // -// 1. a[0] = 0 : the first value of the sequence is 0 -// 2. a[end] = input_row_length : the last value of the sequence is the size -// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size -// 4. length(row_pooling_sequence) = output_row_length+1 +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. +// If not specified, defaults to false +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { + return func(m optionalAttr) { + m["cancel_pending_enqueues"] = value + } +} + +// Closes the given queue. // -// For more details on fractional max pooling, see this paper: -// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. +// handle: The handle to a queue. // -// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { +// Returns the created operation. +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FractionalMaxPool", + Type: "QueueCloseV2", Input: []tf.Input{ - value, + handle, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Deprecated. Use TensorArraySizeV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 -func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// Computes inverse hyperbolic tangent of x element-wise. +func Atanh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySizeV2", + Type: "Atanh", Input: []tf.Input{ - handle, flow_in, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv2DAttr is an optional argument to Conv2D. -type Conv2DAttr func(optionalAttr) - -// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value +// Returns true if queue is closed. +// +// This operation returns true if the queue is closed and false if the queue +// is open. +// +// Arguments: +// handle: The handle to a queue. +func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QueueIsClosedV2", + Input: []tf.Input{ + handle, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Conv2DDataFormat sets the optional data_format attribute to value. +// Computes the absolute value of a tensor. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func Conv2DDataFormat(value string) Conv2DAttr { - return func(m optionalAttr) { - m["data_format"] = value +// Given a tensor `x`, this operation returns a tensor containing the absolute +// value of each element in `x`. For example, if x is an input element and y is +// an output element, this operation computes \\(y = |x|\\). +func Abs(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Abs", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Conv2DDilations sets the optional dilations attribute to value. +// StackV2Attr is an optional argument to StackV2. +type StackV2Attr func(optionalAttr) + +// StackV2StackName sets the optional stack_name attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv2DDilations(value []int64) Conv2DAttr { +// value: Overrides the name used for the temporary stack resource. Default +// value is the name of the 'Stack' op (which is guaranteed unique). +// If not specified, defaults to "" +func StackV2StackName(value string) StackV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["stack_name"] = value } } -// Computes a 2-D convolution given 4-D `input` and `filter` tensors. -// -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]`, this op -// performs the following: -// -// 1. Flattens the filter to a 2-D matrix with shape -// `[filter_height * filter_width * in_channels, output_channels]`. -// 2. Extracts image patches from the input tensor to form a *virtual* -// tensor of shape `[batch, out_height, out_width, -// filter_height * filter_width * in_channels]`. -// 3. For each patch, right-multiplies the filter matrix and the image patch -// vector. -// -// In detail, with the default NHWC format, -// -// output[b, i, j, k] = -// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * -// filter[di, dj, q, k] -// -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// A stack that produces elements in first-in last-out order. // // Arguments: -// input: A 4-D tensor. The dimension order is interpreted according to the value -// of `data_format`, see below for details. -// filter: A 4-D tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]` -// strides: 1-D tensor of length 4. The stride of the sliding window for each -// dimension of `input`. The dimension order is determined by the value of -// `data_format`, see below for details. -// padding: The type of padding algorithm to use. +// max_size: The maximum size of the stack if non-negative. If negative, the stack +// size is unlimited. +// elem_type: The type of the elements on the stack. // -// Returns A 4-D tensor. The dimension order is determined by the value of -// `data_format`, see below for details. -func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { +// Returns The handle to the stack. +func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"elem_type": elem_type} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2D", + Type: "StackV2", Input: []tf.Input{ - input, filter, + max_size, }, Attrs: attrs, } @@ -28128,67 +27316,110 @@ func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, pa return op.Output(0) } -// StageAttr is an optional argument to Stage. -type StageAttr func(optionalAttr) +// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. +type FusedBatchNormGradV2Attr func(optionalAttr) -// StageCapacity sets the optional capacity attribute to value. -// -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. // -// REQUIRES: value >= 0 -func StageCapacity(value int64) StageAttr { +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["epsilon"] = value } } -// StageMemoryLimit sets the optional memory_limit attribute to value. +// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. // -// value: The maximum number of bytes allowed for Tensors in the Staging Area. -// If > 0, inserts will block until sufficient space is available. -// If not specified, defaults to 0 +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. // -// REQUIRES: value >= 0 -func StageMemoryLimit(value int64) StageAttr { +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { return func(m optionalAttr) { - m["memory_limit"] = value + m["is_training"] = value + } +} + +// Gradient for batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. +// +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNormGradV2", + Input: []tf.Input{ + y_backprop, x, scale, reserve_space_1, reserve_space_2, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// StageContainer sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func StageContainer(value string) StageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) -// StageSharedName sets the optional shared_name attribute to value. +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. // -// value: It is necessary to match this name to the matching Unstage Op. +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". // If not specified, defaults to "" -func StageSharedName(value string) StageAttr { +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["compression_type"] = value } } -// Stage values similar to a lightweight Enqueue. +// Decompress strings. // -// The basic functionality of this Op is similar to a queue with many -// fewer capabilities and options. This Op is optimized for performance. +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. +// +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. // // Arguments: -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. +// bytes: A Tensor of string which is compressed. // -// Returns the created operation. -func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Operation) { +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -28197,283 +27428,499 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera a(attrs) } opspec := tf.OpSpec{ - Type: "Stage", + Type: "DecodeCompressed", Input: []tf.Input{ - tf.OutputList(values), + bytes, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// StagePeekAttr is an optional argument to StagePeek. -type StagePeekAttr func(optionalAttr) +// CudnnRNNAttr is an optional argument to CudnnRNN. +type CudnnRNNAttr func(optionalAttr) -// StagePeekCapacity sets the optional capacity attribute to value. +// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNRnnMode(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNInputMode(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNDirection(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNDropout sets the optional dropout attribute to value. // If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StagePeekCapacity(value int64) StagePeekAttr { +func CudnnRNNDropout(value float32) CudnnRNNAttr { return func(m optionalAttr) { - m["capacity"] = value + m["dropout"] = value } } -// StagePeekMemoryLimit sets the optional memory_limit attribute to value. +// CudnnRNNSeed sets the optional seed attribute to value. // If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StagePeekMemoryLimit(value int64) StagePeekAttr { +func CudnnRNNSeed(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["memory_limit"] = value + m["seed"] = value } } -// StagePeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StagePeekContainer(value string) StagePeekAttr { +// CudnnRNNSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed2(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["container"] = value + m["seed2"] = value } } -// StagePeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StagePeekSharedName(value string) StagePeekAttr { +// CudnnRNNIsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNIsTraining(value bool) CudnnRNNAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["is_training"] = value } } -// Op peeks at the values at the specified index. If the +// A RNN backed by cuDNN. // -// underlying container does not contain sufficient elements -// this op will block until it does. This Op is optimized for -// performance. -func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ...StagePeekAttr) (values []tf.Output) { +// Computes the RNN from the input and initial states, with respect to the params +// buffer. +// +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +// input: a 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: a 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: a 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: the same shape has input_h. +// output_c: the same shape as input_c for LSTM. An empty tensor for other models. +// is_training: Indicates whether this operation is used for inferenece or +// training. +// reserve_space: an opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is false. +func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StagePeek", + Type: "CudnnRNN", Input: []tf.Input{ - index, + input, input_h, input_c, params, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. +// +// Each comparison returns a boolean `true` (if `input_value > threshold`) +// or and `false` otherwise. +// +// This operation is useful for Locality-Sensitive-Hashing (LSH) and other +// algorithms that use hashing approximations of cosine and `L2` distances; +// codes can be generated from an input via: +// +// ```python +// codebook_size = 50 +// codebook_bits = codebook_size * 32 +// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], +// dtype=x.dtype, +// initializer=tf.orthogonal_initializer()) +// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +// # now codes has shape x.shape[:-1] + [codebook_size] +// ``` +// +// **NOTE**: Currently, the innermost dimension of the tensor must be divisible +// by 8. +// +// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +// +// Arguments: +// input: Values to compare against `threshold` and bitpack. +// threshold: Threshold to compare against. +// +// Returns The bitpacked comparisons. +func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("StagePeek", err) - return + opspec := tf.OpSpec{ + Type: "CompareAndBitpack", + Input: []tf.Input{ + input, threshold, + }, } - return values + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapStageAttr is an optional argument to MapStage. -type MapStageAttr func(optionalAttr) - -// MapStageCapacity sets the optional capacity attribute to value. +// Push an element onto the tensor_array. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// Arguments: +// handle: The handle to a TensorArray. +// index: The position to write to inside the TensorArray. +// value: The tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. // -// REQUIRES: value >= 0 -func MapStageCapacity(value int64) MapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return } -} - -// MapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapStageMemoryLimit(value int64) MapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value + opspec := tf.OpSpec{ + Type: "TensorArrayWriteV3", + Input: []tf.Input{ + handle, index, value, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapStageContainer sets the optional container attribute to value. +// Scatter the data from the input value into specific TensorArray elements. // -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func MapStageContainer(value string) MapStageAttr { - return func(m optionalAttr) { - m["container"] = value +// `indices` must be a vector, its length must match the first dim of `value`. +// +// Arguments: +// handle: The handle to a TensorArray. +// indices: The locations at which to write the tensor elements. +// value: The concatenated tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return } + opspec := tf.OpSpec{ + Type: "TensorArrayScatterV3", + Input: []tf.Input{ + handle, indices, value, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapStageSharedName sets the optional shared_name attribute to value. +// EmptyAttr is an optional argument to Empty. +type EmptyAttr func(optionalAttr) + +// EmptyInit sets the optional init attribute to value. // -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func MapStageSharedName(value string) MapStageAttr { +// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. +// If not specified, defaults to false +func EmptyInit(value bool) EmptyAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["init"] = value } } -// Stage (key, values) in the underlying container which behaves like a hashtable. +// Creates a tensor with the given shape. // -// Arguments: -// key: int64 +// This operation creates a tensor of `shape` and `dtype`. // -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. +// Arguments: +// shape: 1-D. Represents the shape of the output tensor. // // -// Returns the created operation. -func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { +// Returns A `Tensor` of type `T`. +func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapStage", + Type: "Empty", Input: []tf.Input{ - key, indices, tf.OutputList(values), + shape, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapUnstageAttr is an optional argument to MapUnstage. -type MapUnstageAttr func(optionalAttr) +// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. +type TensorArrayConcatV3Attr func(optionalAttr) -// MapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. // -// REQUIRES: value >= 0 -func MapUnstageCapacity(value int64) MapUnstageAttr { +// value: The expected shape of an element, if known, +// excluding the first dimension. Used to validate the shapes of +// TensorArray elements. If this shape is not fully specified, concatenating +// zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { return func(m optionalAttr) { - m["capacity"] = value + m["element_shape_except0"] = value } } -// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Concat the elements from the TensorArray into value `value`. // -// REQUIRES: value >= 0 -func MapUnstageMemoryLimit(value int64) MapUnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// Takes `T` elements of shapes +// +// ``` +// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) +// ``` +// +// and concatenates them into a Tensor of shape: +// +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` +// +// All elements must have the same shape (excepting the first dimension). +// +// Arguments: +// handle: The handle to a TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns All of the elements in the TensorArray, concatenated along the first +// axis.A vector of the row sizes of the original T elements in the +// value output. In the example above, this would be the values: +// `(n1, n2, ..., n(T-1))`. +func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayConcatV3", + Input: []tf.Input{ + handle, flow_in, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// MapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageContainer(value string) MapUnstageAttr { - return func(m optionalAttr) { - m["container"] = value +// Split the data from the input value into TensorArray elements. +// +// Assuming that `lengths` takes on values +// +// ```(n0, n1, ..., n(T-1))``` +// +// and that `value` has shape +// +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// +// this splits values into a TensorArray with T tensors. +// +// TensorArray index t will be the subtensor of values with starting position +// +// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` +// +// and having size +// +// ```nt x d0 x d1 x ...``` +// +// Arguments: +// handle: The handle to a TensorArray. +// value: The concatenated tensor to write to the TensorArray. +// lengths: The vector of lengths, how to split the rows of value into the +// TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySplitV3", + Input: []tf.Input{ + handle, value, lengths, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapUnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapUnstageSharedName(value string) MapUnstageAttr { +// SerializeSparseAttr is an optional argument to SerializeSparse. +type SerializeSparseAttr func(optionalAttr) + +// SerializeSparseOutType sets the optional out_type attribute to value. +// +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["out_type"] = value } } -// Op removes and returns the values associated with the key +// Serialize a `SparseTensor` into a `[3]` `Tensor` object. // -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { +// Arguments: +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapUnstage", + Type: "SerializeSparse", Input: []tf.Input{ - key, indices, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstage", err) - return - } - return values + return op.Output(0) } -// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. -type MapIncompleteSizeAttr func(optionalAttr) +// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. +type RandomShuffleQueueV2Attr func(optionalAttr) -// MapIncompleteSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. // -// REQUIRES: value >= 0 -func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["capacity"] = value } } -// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// +// value: Dequeue will block unless there would be this +// many elements after the dequeue or the queue is closed. This +// ensures a minimum level of mixing of elements. +// If not specified, defaults to 0 +func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["min_after_dequeue"] = value + } +} + +// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, a random seed is used. // If not specified, defaults to 0 +func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. // -// REQUIRES: value >= 0 -func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["memory_limit"] = value + m["seed2"] = value } } -// MapIncompleteSizeContainer sets the optional container attribute to value. +// RandomShuffleQueueV2Container sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { +func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. // If not specified, defaults to "" -func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { +func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op returns the number of incomplete elements in the underlying container. -func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { +// A queue that randomizes the order of elements. +// +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapIncompleteSize", + Type: "RandomShuffleQueueV2", Attrs: attrs, } @@ -28481,149 +27928,211 @@ func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncomp return op.Output(0) } -// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. -type OrderedMapUnstageAttr func(optionalAttr) - -// OrderedMapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Draw bounding boxes on a batch of images. // -// REQUIRES: value >= 0 -func OrderedMapUnstageCapacity(value int64) OrderedMapUnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// REQUIRES: value >= 0 -func OrderedMapUnstageMemoryLimit(value int64) OrderedMapUnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// For example, if an image is 100 x 200 pixels (height x width) and the bounding +// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of +// the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates). +// +// Parts of the bounding box may fall outside the image. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. +// +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DrawBoundingBoxes", + Input: []tf.Input{ + images, boxes, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// OrderedMapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageContainer(value string) OrderedMapUnstageAttr { +// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. +type LearnedUnigramCandidateSamplerAttr func(optionalAttr) + +// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["container"] = value + m["seed"] = value } } -// OrderedMapUnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageSharedName(value string) OrderedMapUnstageAttr { +// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed2"] = value } } -// Op removes and returns the values associated with the key +// Generates labels for candidate sampling with a learned unigram distribution. // -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageAttr) (values []tf.Output) { +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapUnstage", + Type: "LearnedUnigramCandidateSampler", Input: []tf.Input{ - key, indices, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes gradients for the scaled exponential linear (Selu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Selu operation. +// outputs: The outputs of the corresponding Selu operation. +// +// Returns The gradients: `gradients * (outputs + scale * alpha)` +// if outputs < 0, `scale * gradients` otherwise. +func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapUnstage", err) - return + opspec := tf.OpSpec{ + Type: "SeluGrad", + Input: []tf.Input{ + gradients, outputs, + }, } - return values + op := scope.AddOperation(opspec) + return op.Output(0) } -// OrderedMapSizeAttr is an optional argument to OrderedMapSize. -type OrderedMapSizeAttr func(optionalAttr) - -// OrderedMapSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Get the current size of the TensorArray. // -// REQUIRES: value >= 0 -func OrderedMapSizeCapacity(value int64) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Arguments: +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// flow_in: A float scalar that enforces proper chaining of operations. // -// REQUIRES: value >= 0 -func OrderedMapSizeMemoryLimit(value int64) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapSizeContainer(value string) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["container"] = value +// Returns The current size of the TensorArray. +func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { + if scope.Err() != nil { + return } -} - -// OrderedMapSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapSizeSharedName(value string) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "TensorArraySizeV3", + Input: []tf.Input{ + handle, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Op returns the number of elements in the underlying container. -func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSizeAttr) (size tf.Output) { +// Deprecated. Use TensorArrayGradV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 +func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "OrderedMapSize", - - Attrs: attrs, + Type: "TensorArrayWriteV2", + Input: []tf.Input{ + handle, index, value, flow_in, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ShapeNAttr is an optional argument to ShapeN. -type ShapeNAttr func(optionalAttr) +// SparseReduceMaxAttr is an optional argument to SparseReduceMax. +type SparseReduceMaxAttr func(optionalAttr) -// ShapeNOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeNOutType(value tf.DataType) ShapeNAttr { +// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { return func(m optionalAttr) { - m["out_type"] = value + m["keep_dims"] = value } } -// Returns shape of tensors. +// Computes the max of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. // -// This operation returns N 1-D integer tensors representing shape of `input[i]s`. -func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -28632,205 +28141,316 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t a(attrs) } opspec := tf.OpSpec{ - Type: "ShapeN", + Type: "SparseReduceMax", Input: []tf.Input{ - tf.OutputList(input), + input_indices, input_values, input_shape, reduction_axes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("ShapeN", err) - return - } - return output + return op.Output(0) } -// CudnnRNNParamsToCanonicalAttr is an optional argument to CudnnRNNParamsToCanonical. -type CudnnRNNParamsToCanonicalAttr func(optionalAttr) - -// CudnnRNNParamsToCanonicalRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNParamsToCanonicalRnnMode(value string) CudnnRNNParamsToCanonicalAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} +// AsStringAttr is an optional argument to AsString. +type AsStringAttr func(optionalAttr) -// CudnnRNNParamsToCanonicalInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNParamsToCanonicalInputMode(value string) CudnnRNNParamsToCanonicalAttr { +// AsStringPrecision sets the optional precision attribute to value. +// +// value: The post-decimal precision to use for floating point numbers. +// Only used if precision > -1. +// If not specified, defaults to -1 +func AsStringPrecision(value int64) AsStringAttr { return func(m optionalAttr) { - m["input_mode"] = value + m["precision"] = value } } -// CudnnRNNParamsToCanonicalDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNParamsToCanonicalDirection(value string) CudnnRNNParamsToCanonicalAttr { +// AsStringScientific sets the optional scientific attribute to value. +// +// value: Use scientific notation for floating point numbers. +// If not specified, defaults to false +func AsStringScientific(value bool) AsStringAttr { return func(m optionalAttr) { - m["direction"] = value + m["scientific"] = value } } -// CudnnRNNParamsToCanonicalDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalDropout(value float32) CudnnRNNParamsToCanonicalAttr { +// AsStringShortest sets the optional shortest attribute to value. +// +// value: Use shortest representation (either scientific or standard) for +// floating point numbers. +// If not specified, defaults to false +func AsStringShortest(value bool) AsStringAttr { return func(m optionalAttr) { - m["dropout"] = value + m["shortest"] = value } } -// CudnnRNNParamsToCanonicalSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalSeed(value int64) CudnnRNNParamsToCanonicalAttr { +// AsStringWidth sets the optional width attribute to value. +// +// value: Pad pre-decimal numbers to this width. +// Applies to both floating point and integer numbers. +// Only used if width > -1. +// If not specified, defaults to -1 +func AsStringWidth(value int64) AsStringAttr { return func(m optionalAttr) { - m["seed"] = value + m["width"] = value } } -// CudnnRNNParamsToCanonicalSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { +// AsStringFill sets the optional fill attribute to value. +// +// value: The value to pad if width > -1. If empty, pads with spaces. +// Another typical value is '0'. String cannot be longer than 1 character. +// If not specified, defaults to "" +func AsStringFill(value string) AsStringAttr { return func(m optionalAttr) { - m["seed2"] = value + m["fill"] = value } } -// Retrieves CudnnRNN params in canonical form. -// -// Retrieves a set of weights from the opaque params buffer that can be saved and -// restored in a way compatible with future runs. -// -// Note that the params buffer may not be compatible across different GPUs. So any -// save and restoration should be converted to and from the canonical weights and -// biases. +// Converts each entry in the given tensor to strings. Supports many numeric // -// num_layers: Specifies the number of layers in the RNN model. -// num_units: Specifies the size of the hidden state. -// input_size: Specifies the size of the input state. -// num_params: number of parameter sets for all layers. -// Each layer may contain multiple parameter sets, with each set consisting of -// a weight matrix and a bias vector. -// weights: the canonical form of weights that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// biases: the canonical form of biases that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, params tf.Output, num_params int64, optional ...CudnnRNNParamsToCanonicalAttr) (weights []tf.Output, biases []tf.Output) { +// types and boolean. +func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_params": num_params} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNParamsToCanonical", + Type: "AsString", Input: []tf.Input{ - num_layers, num_units, input_size, params, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayScatterV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 +func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if weights, idx, err = makeOutputList(op, idx, "weights"); err != nil { - scope.UpdateErr("CudnnRNNParamsToCanonical", err) + opspec := tf.OpSpec{ + Type: "TensorArrayScatterV2", + Input: []tf.Input{ + handle, indices, value, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a tree ensemble model and returns a handle to it. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble resource to be created. +// stamp_token: Token to use as the initial value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the tree ensemble. +// +// Returns the created operation. +func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { + if scope.Err() != nil { return } - if biases, idx, err = makeOutputList(op, idx, "biases"); err != nil { - scope.UpdateErr("CudnnRNNParamsToCanonical", err) + opspec := tf.OpSpec{ + Type: "BoostedTreesCreateEnsemble", + Input: []tf.Input{ + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + }, + } + return scope.AddOperation(opspec) +} + +// Applies sparse addition to `input` using individual values or slices +// +// from `updates` according to indices `indices`. The updates are non-aliasing: +// `input` is only modified in-place if no other operations will use it. +// Otherwise, a copy of `input` is made. This operation has a gradient with +// respect to both `input` and `updates`. +// +// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `input`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or `(P-K)`-dimensional slices +// (if `K < P`) along the `K`th dimension of `input`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]]. +// ``` +// +// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +// elements. In Python, that addition would look like this: +// +// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) +// with tf.Session() as sess: +// print(sess.run(output)) +// +// The resulting value `output` would look like this: +// +// [1, 13, 3, 14, 14, 6, 7, 20] +// +// See @{tf.scatter_nd} for more details about how to make updates to slices. +// +// Arguments: +// input: A Tensor. +// indices: A Tensor. Must be one of the following types: `int32`, `int64`. +// A tensor of indices into `input`. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to add to `input`. +// +// Returns A `Tensor` with the same shape as `input`, containing values of `input` +// updated with `updates`. +func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { + if scope.Err() != nil { return } - return weights, biases + opspec := tf.OpSpec{ + Type: "ScatterNdNonAliasingAdd", + Input: []tf.Input{ + input, indices, updates, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. +type FractionalMaxPoolAttr func(optionalAttr) + +// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. +// +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. +// If not specified, defaults to false +func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["pseudo_random"] = value + } +} + +// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. +// +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [20, 16] for fractional max pooling. +// If not specified, defaults to false +func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } } -// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler. -type UniformCandidateSamplerAttr func(optionalAttr) +// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. +// +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalMaxPool node in the computation graph. Mainly used +// in unit test to make FractionalMaxPool deterministic. +// If not specified, defaults to false +func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["deterministic"] = value + } +} -// UniformCandidateSamplerSeed sets the optional seed attribute to value. +// FractionalMaxPoolSeed sets the optional seed attribute to value. // // value: If either seed or seed2 are set to be non-zero, the random number // generator is seeded by the given seed. Otherwise, it is seeded by a // random seed. // If not specified, defaults to 0 -func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr { +func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { return func(m optionalAttr) { m["seed"] = value } } -// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. // // value: An second seed to avoid seed collision. // If not specified, defaults to 0 -func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr { +func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { return func(m optionalAttr) { m["seed2"] = value } } -// Generates labels for candidate sampling with a uniform distribution. +// Performs fractional max pooling on the input. // -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. +// Fractional max pooling is slightly different than regular max pooling. In +// regular max pooling, you downsize an input set by taking the maximum value of +// smaller N x N subsections of the set (often 2x2), and try to reduce the set by +// a factor of N, where N is an integer. Fractional max pooling, as you might +// expect from the word "fractional", means that the overall reduction ratio N +// does not have to be an integer. // -// For each batch, this op picks a single set of sampled candidate labels. +// The sizes of the pooling regions are generated randomly but are fairly uniform. +// For example, let's look at the height dimension, and the constraints on the +// list of rows that will be pool boundaries. // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// First we define the following: +// +// 1. input_row_length : the number of rows from the input set +// 2. output_row_length : which will be smaller than the input +// 3. alpha = input_row_length / output_row_length : our reduction ratio +// 4. K = floor(alpha) +// 5. row_pooling_sequence : this is the result list of pool boundary rows +// +// Then, row_pooling_sequence should satisfy: +// +// 1. a[0] = 0 : the first value of the sequence is 0 +// 2. a[end] = input_row_length : the last value of the sequence is the size +// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +// 4. length(row_pooling_sequence) = output_row_length+1 +// +// For more details on fractional max pooling, see this paper: +// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "UniformCandidateSampler", + Type: "FractionalMaxPool", Input: []tf.Input{ - true_classes, + value, }, Attrs: attrs, } @@ -28838,113 +28458,178 @@ func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int6 return op.Output(0), op.Output(1), op.Output(2) } -// CTCLossAttr is an optional argument to CTCLoss. -type CTCLossAttr func(optionalAttr) - -// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. +// Deprecated. Use TensorArraySizeV3 // -// value: Scalar, if true then repeated labels are -// collapsed prior to the CTC calculation. -// If not specified, defaults to false -func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { +// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 +func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySizeV2", + Input: []tf.Input{ + handle, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv2DAttr is an optional argument to Conv2D. +type Conv2DAttr func(optionalAttr) + +// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { return func(m optionalAttr) { - m["preprocess_collapse_repeated"] = value + m["use_cudnn_on_gpu"] = value } } -// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. +// Conv2DDataFormat sets the optional data_format attribute to value. // -// value: Scalar. If set to false, *during* CTC calculation -// repeated non-blank labels will not be merged and are interpreted as -// individual labels. This is a simplified version of CTC. -// If not specified, defaults to true -func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func Conv2DDataFormat(value string) Conv2DAttr { return func(m optionalAttr) { - m["ctc_merge_repeated"] = value + m["data_format"] = value } } -// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. +// Conv2DDilations sets the optional dilations attribute to value. // -// value: Scalar. If set to true, during CTC -// calculation, items that have longer output sequences than input sequences -// are skipped: they don't contribute to the loss term and have zero-gradient. -// If not specified, defaults to false -func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { - m["ignore_longer_outputs_than_inputs"] = value + m["dilations"] = value } } -// Calculates the CTC Loss (log probability) for each batch entry. Also calculates +// Computes a 2-D convolution given 4-D `input` and `filter` tensors. // -// the gradient. This class performs the softmax operation for you, so inputs -// should be e.g. linear projections of outputs by an LSTM. +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]`, this op +// performs the following: +// +// 1. Flattens the filter to a 2-D matrix with shape +// `[filter_height * filter_width * in_channels, output_channels]`. +// 2. Extracts image patches from the input tensor to form a *virtual* +// tensor of shape `[batch, out_height, out_width, +// filter_height * filter_width * in_channels]`. +// 3. For each patch, right-multiplies the filter matrix and the image patch +// vector. +// +// In detail, with the default NHWC format, +// +// output[b, i, j, k] = +// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * +// filter[di, dj, q, k] +// +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. // // Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// labels_indices: The indices of a `SparseTensor`. -// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for -// `(batch b, time t)`. -// labels_values: The values (labels) associated with the given batch and time. -// sequence_length: A vector containing sequence lengths (batch). +// input: A 4-D tensor. The dimension order is interpreted according to the value +// of `data_format`, see below for details. +// filter: A 4-D tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]` +// strides: 1-D tensor of length 4. The stride of the sliding window for each +// dimension of `input`. The dimension order is determined by the value of +// `data_format`, see below for details. +// padding: The type of padding algorithm to use. // -// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: -// `(max_time x batch_size x num_classes)`. -func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { +// Returns A 4-D tensor. The dimension order is determined by the value of +// `data_format`, see below for details. +func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CTCLoss", + Type: "Conv2D", Input: []tf.Input{ - inputs, labels_indices, labels_values, sequence_length, + input, filter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. -type CTCGreedyDecoderAttr func(optionalAttr) +// StageAttr is an optional argument to Stage. +type StageAttr func(optionalAttr) -// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// StageCapacity sets the optional capacity attribute to value. // -// value: If True, merge repeated classes in output. -// If not specified, defaults to false -func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageCapacity(value int64) StageAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// StageMemoryLimit sets the optional memory_limit attribute to value. +// +// value: The maximum number of bytes allowed for Tensors in the Staging Area. +// If > 0, inserts will block until sufficient space is available. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageMemoryLimit(value int64) StageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageContainer sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func StageContainer(value string) StageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func StageSharedName(value string) StageAttr { return func(m optionalAttr) { - m["merge_repeated"] = value + m["shared_name"] = value } } -// Performs greedy decoding on the logits given in inputs. -// -// A note about the attribute merge_repeated: if enabled, when -// consecutive logits' maximum indices are the same, only the first of -// these is emitted. Labeling the blank '*', the sequence "A B B * B B" -// becomes "A B B" if merge_repeated = True and "A B B B B" if -// merge_repeated = False. +// Stage values similar to a lightweight Enqueue. // -// Regardless of the value of merge_repeated, if the maximum index of a given -// time and batch corresponds to the blank, index `(num_classes - 1)`, no new -// element is emitted. +// The basic functionality of this Op is similar to a queue with many +// fewer capabilities and options. This Op is optimized for performance. // // Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. // -// Returns Indices matrix, size `(total_decoded_outputs x 2)`, -// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, -// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. -// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence -// log-probabilities. -func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { +// Returns the created operation. +func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -28953,589 +28638,554 @@ func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "CTCGreedyDecoder", + Type: "Stage", Input: []tf.Input{ - inputs, sequence_length, + tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) + return scope.AddOperation(opspec) } -// Forwards `data` to the output port determined by `pred`. -// -// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, -// the data goes to `output_false`. -// -// See also `RefSwitch` and `Merge`. -// -// Arguments: -// data: The tensor to be forwarded to the appropriate output. -// pred: A scalar that specifies which output port will receive data. -// -// Returns If `pred` is false, data will be forwarded to this output.If `pred` is true, data will be forwarded to this output. -func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Switch", - Input: []tf.Input{ - data, pred, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} +// StagePeekAttr is an optional argument to StagePeek. +type StagePeekAttr func(optionalAttr) -// Add all input tensors element wise. +// StagePeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// inputs: Must all be the same size and shape. -func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AddN", - Input: []tf.Input{ - tf.OutputList(inputs), - }, +// REQUIRES: value >= 0 +func StagePeekCapacity(value int64) StagePeekAttr { + return func(m optionalAttr) { + m["capacity"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// TryRpcAttr is an optional argument to TryRpc. -type TryRpcAttr func(optionalAttr) - -// TryRpcProtocol sets the optional protocol attribute to value. +// StagePeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: RPC protocol to use. Empty string means use the default protocol. -// Options include 'grpc'. -// If not specified, defaults to "" -func TryRpcProtocol(value string) TryRpcAttr { +// REQUIRES: value >= 0 +func StagePeekMemoryLimit(value int64) StagePeekAttr { return func(m optionalAttr) { - m["protocol"] = value + m["memory_limit"] = value } } -// TryRpcFailFast sets the optional fail_fast attribute to value. -// -// value: `boolean`. If `true` (default), then failures to connect -// (i.e., the server does not immediately respond) cause an RPC failure. -// If not specified, defaults to true -func TryRpcFailFast(value bool) TryRpcAttr { +// StagePeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StagePeekContainer(value string) StagePeekAttr { return func(m optionalAttr) { - m["fail_fast"] = value + m["container"] = value } } -// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value. -// -// value: `int`. If `0` (default), then the kernel will run the RPC -// request and only time out if the RPC deadline passes or the session times out. -// If this value is greater than `0`, then the op will raise an exception if -// the RPC takes longer than `timeout_in_ms`. -// If not specified, defaults to 0 -func TryRpcTimeoutInMs(value int64) TryRpcAttr { +// StagePeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StagePeekSharedName(value string) StagePeekAttr { return func(m optionalAttr) { - m["timeout_in_ms"] = value + m["shared_name"] = value } } -// Perform batches of RPC requests. -// -// This op asynchronously performs either a single RPC request, or a batch -// of requests. RPC requests are defined by three main parameters: -// -// - `address` (the host+port or BNS address of the request) -// - `method` (the method name for the request) -// - `request` (the serialized proto string, or vector of strings, -// of the RPC request argument). -// -// For example, if you have an RPC service running on port localhost:2345, -// and its interface is configured with the following proto declaration: -// -// ``` -// service MyService { -// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { -// } -// }; -// ``` -// -// then call this op with arguments: -// -// ``` -// address = "localhost:2345" -// method = "MyService/MyMethod" -// ``` -// -// The `request` tensor is a string tensor representing serialized `MyRequestProto` -// strings; and the output string tensor `response` will have the same shape -// and contain (upon successful completion) corresponding serialized -// `MyResponseProto` strings. -// -// For example, to send a single, empty, `MyRequestProto`, call -// this op with `request = ""`. To send 5 **parallel** empty requests, -// call this op with `request = ["", "", "", "", ""]`. -// -// More generally, one can create a batch of `MyRequestProto` serialized protos -// from regular batched tensors using the `encode_proto` op, and convert -// the response `MyResponseProto` serialized protos to batched tensors -// using the `decode_proto` op. -// -// **NOTE** Working with serialized proto strings is faster than instantiating -// actual proto objects in memory, so no performance degradation is expected -// compared to writing custom kernels for this workflow. -// -// Unlike the standard `Rpc` op, if the connection fails or the remote worker -// returns an error status, this op does **not** reraise the exception. -// Instead, the `status_code` and `status_message` entry for the corresponding RPC -// call is set with the error returned from the RPC call. The `response` tensor -// will contain valid response values for those minibatch entries whose RPCs did -// not fail; the rest of the entries will have empty strings. -// -// Arguments: -// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `method` and `request`. -// method: `0-D` or `1-D`. The method address on the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `request`. -// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `method`. +// Op peeks at the values at the specified index. If the // -// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages -// returned from the RPC calls. -func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) { +// underlying container does not contain sufficient elements +// this op will block until it does. This Op is optimized for +// performance. +func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ...StagePeekAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TryRpc", + Type: "StagePeek", Input: []tf.Input{ - address, method, request, + index, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("StagePeek", err) + return + } + return values } -// EnterAttr is an optional argument to Enter. -type EnterAttr func(optionalAttr) +// MapStageAttr is an optional argument to MapStage. +type MapStageAttr func(optionalAttr) -// EnterIsConstant sets the optional is_constant attribute to value. +// MapStageCapacity sets the optional capacity attribute to value. // -// value: If true, the output is constant within the child frame. -// If not specified, defaults to false -func EnterIsConstant(value bool) EnterAttr { +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapStageCapacity(value int64) MapStageAttr { return func(m optionalAttr) { - m["is_constant"] = value + m["capacity"] = value } } -// EnterParallelIterations sets the optional parallel_iterations attribute to value. +// MapStageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: The number of iterations allowed to run in parallel. -// If not specified, defaults to 10 -func EnterParallelIterations(value int64) EnterAttr { +// REQUIRES: value >= 0 +func MapStageMemoryLimit(value int64) MapStageAttr { return func(m optionalAttr) { - m["parallel_iterations"] = value + m["memory_limit"] = value } } -// Creates or finds a child frame, and makes `data` available to the child frame. +// MapStageContainer sets the optional container attribute to value. // -// This op is used together with `Exit` to create loops in the graph. -// The unique `frame_name` is used by the `Executor` to identify frames. If -// `is_constant` is true, `output` is a constant in the child frame; otherwise -// it may be changed in the child frame. At most `parallel_iterations` iterations -// are run in parallel in the child frame. +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func MapStageContainer(value string) MapStageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapStageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func MapStageSharedName(value string) MapStageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Stage (key, values) in the underlying container which behaves like a hashtable. // // Arguments: -// data: The tensor to be made available to the child frame. -// frame_name: The name of the child frame. +// key: int64 // -// Returns The same tensor as `data`. -func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. +// +// +// Returns the created operation. +func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"frame_name": frame_name} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Enter", + Type: "MapStage", Input: []tf.Input{ - data, + key, indices, tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Produce a string tensor that encodes the state of a Reader. +// MapUnstageAttr is an optional argument to MapUnstage. +type MapUnstageAttr func(optionalAttr) + +// MapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Not all Readers support being serialized, so this can produce an -// Unimplemented error. +// REQUIRES: value >= 0 +func MapUnstageCapacity(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func MapUnstageMemoryLimit(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - opspec := tf.OpSpec{ - Type: "ReaderSerializeStateV2", - Input: []tf.Input{ - reader_handle, - }, +} + +// MapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageContainer(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["container"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Exits the current frame to its parent frame. -// -// Exit makes its input `data` available to the parent frame. -// -// Arguments: -// data: The tensor to be made available to the parent frame. +// MapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageSharedName(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the values associated with the key // -// Returns The same tensor as `data`. -func Exit(scope *Scope, data tf.Output) (output tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Exit", + Type: "MapUnstage", Input: []tf.Input{ - data, + key, indices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a copy of the input tensor. -func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Snapshot", - Input: []tf.Input{ - input, - }, + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstage", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return values } -// AbortAttr is an optional argument to Abort. -type AbortAttr func(optionalAttr) +// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. +type MapIncompleteSizeAttr func(optionalAttr) -// AbortErrorMsg sets the optional error_msg attribute to value. +// MapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: A string which is the message associated with the exception. +// REQUIRES: value >= 0 +func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapIncompleteSizeContainer sets the optional container attribute to value. // If not specified, defaults to "" -func AbortErrorMsg(value string) AbortAttr { +func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { return func(m optionalAttr) { - m["error_msg"] = value + m["container"] = value } } -// AbortExitWithoutError sets the optional exit_without_error attribute to value. -// If not specified, defaults to false -func AbortExitWithoutError(value bool) AbortAttr { +// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { return func(m optionalAttr) { - m["exit_without_error"] = value + m["shared_name"] = value } } -// Raise a exception to abort the process when called. -// -// If exit_without_error is true, the process will exit normally, -// otherwise it will exit with a SIGABORT signal. -// -// Returns nothing but an exception. -// -// Returns the created operation. -func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { +// Op returns the number of incomplete elements in the underlying container. +func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Abort", + Type: "MapIncompleteSize", Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. -type FixedUnigramCandidateSamplerAttr func(optionalAttr) +// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. +type OrderedMapUnstageAttr func(optionalAttr) -// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value. +// OrderedMapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: Each valid line in this file (which should have a CSV-like format) -// corresponds to a valid word ID. IDs are in sequential order, starting from -// num_reserved_ids. The last entry in each line is expected to be a value -// corresponding to the count or relative probability. Exactly one of vocab_file -// and unigrams needs to be passed to this op. -// If not specified, defaults to "" -func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageCapacity(value int64) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["vocab_file"] = value + m["capacity"] = value } } -// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value. +// OrderedMapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: The distortion is used to skew the unigram probability distribution. -// Each weight is first raised to the distortion's power before adding to the -// internal unigram distribution. As a result, distortion = 1.0 gives regular -// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives -// a uniform distribution. -// If not specified, defaults to 1 -func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageMemoryLimit(value int64) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["distortion"] = value + m["memory_limit"] = value } } -// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value. -// -// value: Optionally some reserved IDs can be added in the range [0, -// ..., num_reserved_ids) by the users. One use case is that a special unknown -// word token is used as ID 0. These IDs will have a sampling probability of 0. -// If not specified, defaults to 0 -func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr { +// OrderedMapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageContainer(value string) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["num_reserved_ids"] = value + m["container"] = value + } +} + +// OrderedMapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageSharedName(value string) OrderedMapUnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the values associated with the key +// +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OrderedMapUnstage", + Input: []tf.Input{ + key, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapUnstage", err) + return } + return values } -// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value. -// -// value: A sampler can be used to sample from a subset of the original range -// in order to speed up the whole computation through parallelism. This parameter -// (together with 'shard') indicates the number of partitions that are being -// used in the overall computation. -// If not specified, defaults to 1 +// OrderedMapSizeAttr is an optional argument to OrderedMapSize. +type OrderedMapSizeAttr func(optionalAttr) + +// OrderedMapSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// REQUIRES: value >= 1 -func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr { +// REQUIRES: value >= 0 +func OrderedMapSizeCapacity(value int64) OrderedMapSizeAttr { return func(m optionalAttr) { - m["num_shards"] = value + m["capacity"] = value } } -// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value. -// -// value: A sampler can be used to sample from a subset of the original range -// in order to speed up the whole computation through parallelism. This parameter -// (together with 'num_shards') indicates the particular partition number of a -// sampler op, when partitioning is being used. +// OrderedMapSizeMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr { +func OrderedMapSizeMemoryLimit(value int64) OrderedMapSizeAttr { return func(m optionalAttr) { - m["shard"] = value + m["memory_limit"] = value } } -// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value. -// -// value: A list of unigram counts or probabilities, one per ID in sequential -// order. Exactly one of vocab_file and unigrams should be passed to this op. -// If not specified, defaults to <> -func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr { +// OrderedMapSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapSizeContainer(value string) OrderedMapSizeAttr { return func(m optionalAttr) { - m["unigrams"] = value + m["container"] = value } } -// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr { +// OrderedMapSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapSizeSharedName(value string) OrderedMapSizeAttr { return func(m optionalAttr) { - m["seed"] = value + m["shared_name"] = value } } -// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr { +// Op returns the number of elements in the underlying container. +func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSizeAttr) (size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OrderedMapSize", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ShapeNAttr is an optional argument to ShapeN. +type ShapeNAttr func(optionalAttr) + +// ShapeNOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeNOutType(value tf.DataType) ShapeNAttr { return func(m optionalAttr) { - m["seed2"] = value + m["out_type"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// A unigram sampler could use a fixed unigram distribution read from a -// file or passed in as an in-memory array instead of building up the distribution -// from data on the fly. There is also an option to skew the distribution by -// applying a distortion power to the weights. -// -// The vocabulary file should be in CSV-like format, with the last field -// being the weight associated with the word. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. -// -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// Returns shape of tensors. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// This operation returns N 1-D integer tensors representing shape of `input[i]s`. +func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FixedUnigramCandidateSampler", + Type: "ShapeN", Input: []tf.Input{ - true_classes, + tf.OutputList(input), }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("ShapeN", err) + return + } + return output } -// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. -type WholeFileReaderV2Attr func(optionalAttr) +// CudnnRNNParamsToCanonicalAttr is an optional argument to CudnnRNNParamsToCanonical. +type CudnnRNNParamsToCanonicalAttr func(optionalAttr) -// WholeFileReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { +// CudnnRNNParamsToCanonicalRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNParamsToCanonicalRnnMode(value string) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["container"] = value + m["rnn_mode"] = value } } -// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { +// CudnnRNNParamsToCanonicalInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNParamsToCanonicalInputMode(value string) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["input_mode"] = value } } -// A Reader that outputs the entire contents of a file as a value. -// -// To use, enqueue filenames in a Queue. The output of ReaderRead will -// be a filename (key) and the contents of that file (value). -// -// Returns The handle to reference the Reader. -func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return +// CudnnRNNParamsToCanonicalDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNParamsToCanonicalDirection(value string) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["direction"] = value } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) +} + +// CudnnRNNParamsToCanonicalDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalDropout(value float32) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["dropout"] = value } - opspec := tf.OpSpec{ - Type: "WholeFileReaderV2", +} - Attrs: attrs, +// CudnnRNNParamsToCanonicalSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalSeed(value int64) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["seed"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Transforms a tf.Example proto (as a string) into typed tensors. +// CudnnRNNParamsToCanonicalSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Retrieves CudnnRNN params in canonical form. // -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// dense_defaults: A list of Tensors (some may be empty), whose length matches -// the length of `dense_keys`. dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// num_sparse: The number of sparse features to be parsed from the example. This -// must match the lengths of `sparse_keys` and `sparse_types`. -// sparse_keys: A list of `num_sparse` strings. -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: The keys expected in the Examples' features associated with dense -// values. -// sparse_types: A list of `num_sparse` types; the data types of data in each -// Feature given in sparse_keys. -// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: The shapes of data in each Feature given in dense_keys. -// The length of this list must match the length of `dense_keys`. The -// number of elements in the Feature corresponding to dense_key[j] must -// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == -// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] -// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, -// ..., DN), the shape of the output Tensor dense_values[j] will be (M, -// D1, .., DN), where M is the number of blocks of elements of length -// D1 * .... * DN, in the input. -func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { +// Retrieves a set of weights from the opaque params buffer that can be saved and +// restored in a way compatible with future runs. +// +// Note that the params buffer may not be compatible across different GPUs. So any +// save and restoration should be converted to and from the canonical weights and +// biases. +// +// num_layers: Specifies the number of layers in the RNN model. +// num_units: Specifies the size of the hidden state. +// input_size: Specifies the size of the input state. +// num_params: number of parameter sets for all layers. +// Each layer may contain multiple parameter sets, with each set consisting of +// a weight matrix and a bias vector. +// weights: the canonical form of weights that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// biases: the canonical form of biases that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, params tf.Output, num_params int64, optional ...CudnnRNNParamsToCanonicalAttr) (weights []tf.Output, biases []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} + attrs := map[string]interface{}{"num_params": num_params} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ParseSingleExample", + Type: "CudnnRNNParamsToCanonical", Input: []tf.Input{ - serialized, tf.OutputList(dense_defaults), + num_layers, num_units, input_size, params, }, Attrs: attrs, } @@ -29545,76 +29195,83 @@ func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf. } var idx int var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleExample", err) + if weights, idx, err = makeOutputList(op, idx, "weights"); err != nil { + scope.UpdateErr("CudnnRNNParamsToCanonical", err) return } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) + if biases, idx, err = makeOutputList(op, idx, "biases"); err != nil { + scope.UpdateErr("CudnnRNNParamsToCanonical", err) return } - return sparse_indices, sparse_values, sparse_shapes, dense_values + return weights, biases } -// Deserializes a serialized tree ensemble config and replaces current tree -// -// ensemble. -// -// Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. -// stamp_token: Token to use as the new value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the ensemble. +// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler. +type UniformCandidateSamplerAttr func(optionalAttr) + +// UniformCandidateSamplerSeed sets the optional seed attribute to value. // -// Returns the created operation. -func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value } - opspec := tf.OpSpec{ - Type: "BoostedTreesDeserializeEnsemble", - Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, - }, +} + +// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value } - return scope.AddOperation(opspec) } -// Runs multiple additive regression ensemble predictors on input instances and +// Generates labels for candidate sampling with a uniform distribution. // -// computes the update to cached logits. It is designed to be used during training. -// It traverses the trees starting from cached tree id and cached node id and -// calculates the updates to be pushed to the cache. +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // -// Arguments: +// For each batch, this op picks a single set of sampled candidate labels. // -// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting -// tree of prediction. -// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting -// node of prediction. -// bucketized_features: A list of rank 1 Tensors containing bucket id for each -// feature. -// logits_dimension: scalar, dimension of the logits, to be used for partial logits -// shape. +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // -// Returns Rank 2 Tensor containing logits update (with respect to cached -// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids. -func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) { +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"logits_dimension": logits_dimension} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesTrainingPredict", + Type: "UniformCandidateSampler", Input: []tf.Input{ - tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features), + true_classes, }, Attrs: attrs, } @@ -29622,546 +29279,666 @@ func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, c return op.Output(0), op.Output(1), op.Output(2) } -// Elementwise computes the bitwise AND of `x` and `y`. +// CTCLossAttr is an optional argument to CTCLoss. +type CTCLossAttr func(optionalAttr) + +// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. // -// The result will have those bits set, that are set in both `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return +// value: Scalar, if true then repeated labels are +// collapsed prior to the CTC calculation. +// If not specified, defaults to false +func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["preprocess_collapse_repeated"] = value } - opspec := tf.OpSpec{ - Type: "BitwiseAnd", - Input: []tf.Input{ - x, y, - }, +} + +// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. +// +// value: Scalar. If set to false, *during* CTC calculation +// repeated non-blank labels will not be merged and are interpreted as +// individual labels. This is a simplified version of CTC. +// If not specified, defaults to true +func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["ctc_merge_repeated"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Elementwise computes the bitwise left-shift of `x` and `y`. +// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. // -// If `y` is negative, or greater than or equal to the width of `x` in bits the -// result is implementation defined. -func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// value: Scalar. If set to true, during CTC +// calculation, items that have longer output sequences than input sequences +// are skipped: they don't contribute to the loss term and have zero-gradient. +// If not specified, defaults to false +func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["ignore_longer_outputs_than_inputs"] = value + } +} + +// Calculates the CTC Loss (log probability) for each batch entry. Also calculates +// +// the gradient. This class performs the softmax operation for you, so inputs +// should be e.g. linear projections of outputs by an LSTM. +// +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// labels_indices: The indices of a `SparseTensor`. +// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for +// `(batch b, time t)`. +// labels_values: The values (labels) associated with the given batch and time. +// sequence_length: A vector containing sequence lengths (batch). +// +// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: +// `(max_time x batch_size x num_classes)`. +func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LeftShift", + Type: "CTCLoss", Input: []tf.Input{ - x, y, + inputs, labels_indices, labels_values, sequence_length, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// TensorListStackAttr is an optional argument to TensorListStack. -type TensorListStackAttr func(optionalAttr) +// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. +type CTCGreedyDecoderAttr func(optionalAttr) -// TensorListStackNumElements sets the optional num_elements attribute to value. -// If not specified, defaults to -1 -func TensorListStackNumElements(value int64) TensorListStackAttr { +// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// +// value: If True, merge repeated classes in output. +// If not specified, defaults to false +func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { return func(m optionalAttr) { - m["num_elements"] = value + m["merge_repeated"] = value } } -// Stacks all tensors in the list. +// Performs greedy decoding on the logits given in inputs. // -// Requires that all tensors have the same shape. +// A note about the attribute merge_repeated: if enabled, when +// consecutive logits' maximum indices are the same, only the first of +// these is emitted. Labeling the blank '*', the sequence "A B B * B B" +// becomes "A B B" if merge_repeated = True and "A B B B B" if +// merge_repeated = False. // -// input_handle: the input list -// tensor: the gathered result -// num_elements: optional. If not -1, the number of elements in the list. +// Regardless of the value of merge_repeated, if the maximum index of a given +// time and batch corresponds to the blank, index `(num_classes - 1)`, no new +// element is emitted. // -func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// +// Returns Indices matrix, size `(total_decoded_outputs x 2)`, +// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, +// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. +// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence +// log-probabilities. +func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorListStack", + Type: "CTCGreedyDecoder", Input: []tf.Input{ - input_handle, + inputs, sequence_length, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Elementwise computes the bitwise right-shift of `x` and `y`. +// Forwards `data` to the output port determined by `pred`. +// +// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, +// the data goes to `output_false`. +// +// See also `RefSwitch` and `Merge`. // -// Performs a logical shift for unsigned integer types, and an arithmetic shift -// for signed integer types. +// Arguments: +// data: The tensor to be forwarded to the appropriate output. +// pred: A scalar that specifies which output port will receive data. // -// If `y` is negative, or greater than or equal to than the width of `x` in bits -// the result is implementation defined. -func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns If `pred` is false, data will be forwarded to this output.If `pred` is true, data will be forwarded to this output. +func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RightShift", + Type: "Switch", Input: []tf.Input{ - x, y, + data, pred, }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Adjust the hue of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A delta is then applied all the hue values, -// and then remapped back to RGB colorspace. +// Add all input tensors element wise. // // Arguments: -// images: Images to adjust. At least 3-D. -// delta: A float delta to add to the hue. -// -// Returns The hue-adjusted image or images. -func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { +// inputs: Must all be the same size and shape. +func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AdjustHue", + Type: "AddN", Input: []tf.Input{ - images, delta, + tf.OutputList(inputs), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// BatchAttr is an optional argument to Batch. -type BatchAttr func(optionalAttr) +// TryRpcAttr is an optional argument to TryRpc. +type TryRpcAttr func(optionalAttr) -// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. -// If not specified, defaults to 10 -func BatchMaxEnqueuedBatches(value int64) BatchAttr { +// TryRpcProtocol sets the optional protocol attribute to value. +// +// value: RPC protocol to use. Empty string means use the default protocol. +// Options include 'grpc'. +// If not specified, defaults to "" +func TryRpcProtocol(value string) TryRpcAttr { return func(m optionalAttr) { - m["max_enqueued_batches"] = value + m["protocol"] = value } } -// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. -// If not specified, defaults to <> -func BatchAllowedBatchSizes(value []int64) BatchAttr { +// TryRpcFailFast sets the optional fail_fast attribute to value. +// +// value: `boolean`. If `true` (default), then failures to connect +// (i.e., the server does not immediately respond) cause an RPC failure. +// If not specified, defaults to true +func TryRpcFailFast(value bool) TryRpcAttr { return func(m optionalAttr) { - m["allowed_batch_sizes"] = value + m["fail_fast"] = value } } -// BatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func BatchContainer(value string) BatchAttr { +// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value. +// +// value: `int`. If `0` (default), then the kernel will run the RPC +// request and only time out if the RPC deadline passes or the session times out. +// If this value is greater than `0`, then the op will raise an exception if +// the RPC takes longer than `timeout_in_ms`. +// If not specified, defaults to 0 +func TryRpcTimeoutInMs(value int64) TryRpcAttr { return func(m optionalAttr) { - m["container"] = value + m["timeout_in_ms"] = value } } -// BatchSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func BatchSharedName(value string) BatchAttr { - return func(m optionalAttr) { - m["shared_name"] = value +// Perform batches of RPC requests. +// +// This op asynchronously performs either a single RPC request, or a batch +// of requests. RPC requests are defined by three main parameters: +// +// - `address` (the host+port or BNS address of the request) +// - `method` (the method name for the request) +// - `request` (the serialized proto string, or vector of strings, +// of the RPC request argument). +// +// For example, if you have an RPC service running on port localhost:2345, +// and its interface is configured with the following proto declaration: +// +// ``` +// service MyService { +// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { +// } +// }; +// ``` +// +// then call this op with arguments: +// +// ``` +// address = "localhost:2345" +// method = "MyService/MyMethod" +// ``` +// +// The `request` tensor is a string tensor representing serialized `MyRequestProto` +// strings; and the output string tensor `response` will have the same shape +// and contain (upon successful completion) corresponding serialized +// `MyResponseProto` strings. +// +// For example, to send a single, empty, `MyRequestProto`, call +// this op with `request = ""`. To send 5 **parallel** empty requests, +// call this op with `request = ["", "", "", "", ""]`. +// +// More generally, one can create a batch of `MyRequestProto` serialized protos +// from regular batched tensors using the `encode_proto` op, and convert +// the response `MyResponseProto` serialized protos to batched tensors +// using the `decode_proto` op. +// +// **NOTE** Working with serialized proto strings is faster than instantiating +// actual proto objects in memory, so no performance degradation is expected +// compared to writing custom kernels for this workflow. +// +// Unlike the standard `Rpc` op, if the connection fails or the remote worker +// returns an error status, this op does **not** reraise the exception. +// Instead, the `status_code` and `status_message` entry for the corresponding RPC +// call is set with the error returned from the RPC call. The `response` tensor +// will contain valid response values for those minibatch entries whose RPCs did +// not fail; the rest of the entries will have empty strings. +// +// Arguments: +// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `method` and `request`. +// method: `0-D` or `1-D`. The method address on the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `request`. +// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `method`. +// +// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages +// returned from the RPC calls. +func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TryRpc", + Input: []tf.Input{ + address, method, request, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// BatchBatchingQueue sets the optional batching_queue attribute to value. -// If not specified, defaults to "" -func BatchBatchingQueue(value string) BatchAttr { +// EnterAttr is an optional argument to Enter. +type EnterAttr func(optionalAttr) + +// EnterIsConstant sets the optional is_constant attribute to value. +// +// value: If true, the output is constant within the child frame. +// If not specified, defaults to false +func EnterIsConstant(value bool) EnterAttr { return func(m optionalAttr) { - m["batching_queue"] = value + m["is_constant"] = value } } -// Batches all input tensors nondeterministically. -// -// When many instances of this Op are being run concurrently with the same -// container/shared_name in the same device, some will output zero-shaped Tensors -// and others will output Tensors of size up to max_batch_size. -// -// All Tensors in in_tensors are batched together (so, for example, labels and -// features should be batched with a single instance of this operation. +// EnterParallelIterations sets the optional parallel_iterations attribute to value. // -// Each invocation of batch emits an `id` scalar which will be used to identify -// this particular invocation when doing unbatch or its gradient. +// value: The number of iterations allowed to run in parallel. +// If not specified, defaults to 10 +func EnterParallelIterations(value int64) EnterAttr { + return func(m optionalAttr) { + m["parallel_iterations"] = value + } +} + +// Creates or finds a child frame, and makes `data` available to the child frame. // -// Each op which emits a non-empty batch will also emit a non-empty batch_index -// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, -// start, and length of elements of each set of Tensors present in batched_tensors. +// This op is used together with `Exit` to create loops in the graph. +// The unique `frame_name` is used by the `Executor` to identify frames. If +// `is_constant` is true, `output` is a constant in the child frame; otherwise +// it may be changed in the child frame. At most `parallel_iterations` iterations +// are run in parallel in the child frame. // -// Batched tensors are concatenated along the first dimension, and all tensors in -// in_tensors must have the first dimension of the same size. +// Arguments: +// data: The tensor to be made available to the child frame. +// frame_name: The name of the child frame. // -// in_tensors: The tensors to be batched. -// num_batch_threads: Number of scheduling threads for processing batches of work. -// Determines the number of batches processed in parallel. -// max_batch_size: Batch sizes will never be bigger than this. -// batch_timeout_micros: Maximum number of microseconds to wait before outputting -// an incomplete batch. -// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does -// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad -// batches up to one of those sizes. The entries must increase monotonically, and -// the final entry must equal max_batch_size. -// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. -// batched_tensors: Either empty tensors or a batch of concatenated Tensors. -// batch_index: If out_tensors is non-empty, has information to invert it. -// container: Controls the scope of sharing of this batch. -// id: always contains a scalar with a unique ID for this invocation of Batch. -// shared_name: Concurrently running instances of batch in the same device with the -// same container and shared_name will batch their elements together. If left -// empty, the op name will be used as the shared name. -// T: the types of tensors to be batched. -func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { +// Returns The same tensor as `data`. +func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} + attrs := map[string]interface{}{"frame_name": frame_name} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Batch", + Type: "Enter", + Input: []tf.Input{ + data, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Produce a string tensor that encodes the state of a Reader. +// +// Not all Readers support being serialized, so this can produce an +// Unimplemented error. +// +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderSerializeStateV2", Input: []tf.Input{ - tf.OutputList(in_tensors), + reader_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Exits the current frame to its parent frame. +// +// Exit makes its input `data` available to the parent frame. +// +// Arguments: +// data: The tensor to be made available to the parent frame. +// +// Returns The same tensor as `data`. +func Exit(scope *Scope, data tf.Output) (output tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { - scope.UpdateErr("Batch", err) - return + opspec := tf.OpSpec{ + Type: "Exit", + Input: []tf.Input{ + data, + }, } - batch_index = op.Output(idx) - id = op.Output(idx) - return batched_tensors, batch_index, id + op := scope.AddOperation(opspec) + return op.Output(0) } -// UnbatchAttr is an optional argument to Unbatch. -type UnbatchAttr func(optionalAttr) - -// UnbatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnbatchContainer(value string) UnbatchAttr { - return func(m optionalAttr) { - m["container"] = value +// Returns a copy of the input tensor. +func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// UnbatchSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnbatchSharedName(value string) UnbatchAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "Snapshot", + Input: []tf.Input{ + input, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Reverses the operation of Batch for a single output Tensor. +// Returns a tensor of zeros with the same shape and type as x. // -// An instance of Unbatch either receives an empty batched_tensor, in which case it -// asynchronously waits until the values become available from a concurrently -// running instance of Unbatch with the same container and shared_name, or receives -// a non-empty batched_tensor in which case it finalizes all other concurrently -// running instances and outputs its own element from the batch. +// Arguments: +// x: a tensor of type T. // -// batched_tensor: The possibly transformed output of Batch. The size of the first -// dimension should remain unchanged by the transformations for the operation to -// work. -// batch_index: The matching batch_index obtained from Batch. -// id: The id scalar emitted by Batch. -// unbatched_tensor: The Tensor corresponding to this execution. -// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the -// batched input tensor associated with a given invocation of the op. -// container: Container to control resource sharing. -// shared_name: Instances of Unbatch with the same container and shared_name are -// assumed to possibly belong to the same batch. If left empty, the op name will -// be used as the shared name. -func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) { +// Returns a tensor of the same shape and type as x but filled with zeros. +func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"timeout_micros": timeout_micros} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Unbatch", + Type: "ZerosLike", Input: []tf.Input{ - batched_tensor, batch_index, id, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) +// AbortAttr is an optional argument to Abort. +type AbortAttr func(optionalAttr) -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// AbortErrorMsg sets the optional error_msg attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { +// value: A string which is the message associated with the exception. +// If not specified, defaults to "" +func AbortErrorMsg(value string) AbortAttr { return func(m optionalAttr) { - m["data_format"] = value + m["error_msg"] = value } } -// Computes gradients of average pooling function. +// AbortExitWithoutError sets the optional exit_without_error attribute to value. +// If not specified, defaults to false +func AbortExitWithoutError(value bool) AbortAttr { + return func(m optionalAttr) { + m["exit_without_error"] = value + } +} + +// Raise a exception to abort the process when called. // -// Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// If exit_without_error is true, the process will exit normally, +// otherwise it will exit with a SIGABORT signal. // -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { +// Returns nothing but an exception. +// +// Returns the created operation. +func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, + Type: "Abort", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. -type ParseSingleSequenceExampleAttr func(optionalAttr) +// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. +type FixedUnigramCandidateSamplerAttr func(optionalAttr) -// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. +// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value. // -// value: A list of Ncontext_sparse types; the data types of data in -// each context Feature given in context_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> +// value: Each valid line in this file (which should have a CSV-like format) +// corresponds to a valid word ID. IDs are in sequential order, starting from +// num_reserved_ids. The last entry in each line is expected to be a value +// corresponding to the count or relative probability. Exactly one of vocab_file +// and unigrams needs to be passed to this op. +// If not specified, defaults to "" +func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["vocab_file"] = value + } +} + +// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// value: The distortion is used to skew the unigram probability distribution. +// Each weight is first raised to the distortion's power before adding to the +// internal unigram distribution. As a result, distortion = 1.0 gives regular +// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives +// a uniform distribution. +// If not specified, defaults to 1 +func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["context_sparse_types"] = value + m["distortion"] = value } } -// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. -// If not specified, defaults to <> +// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// value: Optionally some reserved IDs can be added in the range [0, +// ..., num_reserved_ids) by the users. One use case is that a special unknown +// word token is used as ID 0. These IDs will have a sampling probability of 0. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["feature_list_dense_types"] = value + m["num_reserved_ids"] = value } } -// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. +// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value. // -// value: A list of Ncontext_dense shapes; the shapes of data in -// each context Feature given in context_dense_keys. -// The number of elements in the Feature corresponding to context_dense_key[j] -// must always equal context_dense_shapes[j].NumEntries(). -// The shape of context_dense_values[j] will match context_dense_shapes[j]. -// If not specified, defaults to <> +// value: A sampler can be used to sample from a subset of the original range +// in order to speed up the whole computation through parallelism. This parameter +// (together with 'shard') indicates the number of partitions that are being +// used in the overall computation. +// If not specified, defaults to 1 // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { +// REQUIRES: value >= 1 +func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["context_dense_shapes"] = value + m["num_shards"] = value } } -// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. +// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value. // -// value: A list of Nfeature_list_sparse types; the data types -// of data in each FeatureList given in feature_list_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> +// value: A sampler can be used to sample from a subset of the original range +// in order to speed up the whole computation through parallelism. This parameter +// (together with 'num_shards') indicates the particular partition number of a +// sampler op, when partitioning is being used. +// If not specified, defaults to 0 // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// REQUIRES: value >= 0 +func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["feature_list_sparse_types"] = value + m["shard"] = value } } -// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. +// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value. // -// value: A list of Nfeature_list_dense shapes; the shapes of -// data in each FeatureList given in feature_list_dense_keys. -// The shape of each Feature in the FeatureList corresponding to -// feature_list_dense_key[j] must always equal -// feature_list_dense_shapes[j].NumEntries(). +// value: A list of unigram counts or probabilities, one per ID in sequential +// order. Exactly one of vocab_file and unigrams should be passed to this op. // If not specified, defaults to <> +func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["unigrams"] = value + } +} + +// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["feature_list_dense_shapes"] = value + m["seed2"] = value } } -// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. +// Generates labels for candidate sampling with a learned unigram distribution. +// +// A unigram sampler could use a fixed unigram distribution read from a +// file or passed in as an in-memory array instead of building up the distribution +// from data on the fly. There is also an option to skew the distribution by +// applying a distortion power to the weights. +// +// The vocabulary file should be in CSV-like format, with the last field +// being the weight associated with the word. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// serialized: A scalar containing a binary serialized SequenceExample proto. -// feature_list_dense_missing_assumed_empty: A vector listing the -// FeatureList keys which may be missing from the SequenceExample. If the -// associated FeatureList is missing, it is treated as empty. By default, -// any FeatureList not listed in this vector must exist in the SequenceExample. -// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). -// The keys expected in the Examples' features associated with context_sparse -// values. -// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' context features associated with -// dense values. -// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors -// (scalars). The keys expected in the FeatureLists associated with sparse -// values. -// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' feature_lists associated -// with lists of dense values. -// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). -// context_dense_defaults[j] provides default values -// when the SequenceExample's context map lacks context_dense_key[j]. -// If an empty Tensor is provided for context_dense_defaults[j], -// then the Feature context_dense_keys[j] is required. -// The input type is inferred from context_dense_defaults[j], even when it's -// empty. If context_dense_defaults[j] is not empty, its shape must match -// context_dense_shapes[j]. -// debug_name: A scalar containing the name of the serialized proto. -// May contain, for example, table key (descriptive) name for the -// corresponding serialized proto. This is purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty scalar if no name is available. -func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ParseSingleSequenceExample", + Type: "FixedUnigramCandidateSampler", Input: []tf.Input{ - serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values + return op.Output(0), op.Output(1), op.Output(2) } -// UnbatchGradAttr is an optional argument to UnbatchGrad. -type UnbatchGradAttr func(optionalAttr) +// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. +type WholeFileReaderV2Attr func(optionalAttr) -// UnbatchGradContainer sets the optional container attribute to value. +// WholeFileReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func UnbatchGradContainer(value string) UnbatchGradAttr { +func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { m["container"] = value } } -// UnbatchGradSharedName sets the optional shared_name attribute to value. +// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. // If not specified, defaults to "" -func UnbatchGradSharedName(value string) UnbatchGradAttr { +func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// Gradient of Unbatch. +// A Reader that outputs the entire contents of a file as a value. // -// Acts like Batch but using the given batch_index index of batching things as they -// become available. This ensures that the gradients are propagated back in the -// same session which did the forward pass. +// To use, enqueue filenames in a Queue. The output of ReaderRead will +// be a filename (key) and the contents of that file (value). // -// original_input: The input to the Unbatch operation this is the gradient of. -// batch_index: The batch_index given to the Unbatch operation this is the gradient -// of. -// grad: The downstream gradient. -// id: The id scalar emitted by Batch. -// batched_grad: The return value, either an empty tensor or the batched gradient. -// container: Container to control resource sharing. -// shared_name: Instances of UnbatchGrad with the same container and shared_name -// are assumed to possibly belong to the same batch. If left empty, the op name -// will be used as the shared name. -func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) { +// Returns The handle to reference the Reader. +func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } @@ -30170,232 +29947,325 @@ func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "UnbatchGrad", - Input: []tf.Input{ - original_input, batch_index, grad, id, - }, + Type: "WholeFileReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeWavAttr is an optional argument to DecodeWav. -type DecodeWavAttr func(optionalAttr) - -// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// Transforms a tf.Example proto (as a string) into typed tensors. // -// value: Number of sample channels wanted. -// If not specified, defaults to -1 -func DecodeWavDesiredChannels(value int64) DecodeWavAttr { - return func(m optionalAttr) { - m["desired_channels"] = value +// Arguments: +// serialized: A vector containing a batch of binary serialized Example protos. +// dense_defaults: A list of Tensors (some may be empty), whose length matches +// the length of `dense_keys`. dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// num_sparse: The number of sparse features to be parsed from the example. This +// must match the lengths of `sparse_keys` and `sparse_types`. +// sparse_keys: A list of `num_sparse` strings. +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: The keys expected in the Examples' features associated with dense +// values. +// sparse_types: A list of `num_sparse` types; the data types of data in each +// Feature given in sparse_keys. +// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: The shapes of data in each Feature given in dense_keys. +// The length of this list must match the length of `dense_keys`. The +// number of elements in the Feature corresponding to dense_key[j] must +// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == +// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] +// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, +// ..., DN), the shape of the output Tensor dense_values[j] will be (M, +// D1, .., DN), where M is the number of blocks of elements of length +// D1 * .... * DN, in the input. +func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { + if scope.Err() != nil { + return } -} - -// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. -// -// value: Length of audio requested. -// If not specified, defaults to -1 -func DecodeWavDesiredSamples(value int64) DecodeWavAttr { - return func(m optionalAttr) { - m["desired_samples"] = value + attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} + opspec := tf.OpSpec{ + Type: "ParseSingleExample", + Input: []tf.Input{ + serialized, tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return } + return sparse_indices, sparse_values, sparse_shapes, dense_values } -// Decode a 16-bit PCM WAV file to a float tensor. -// -// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. -// -// When desired_channels is set, if the input contains fewer channels than this -// then the last channel will be duplicated to give the requested number, else if -// the input has more channels than requested then the additional channels will be -// ignored. -// -// If desired_samples is set, then the audio will be cropped or padded with zeroes -// to the requested length. +// Deserializes a serialized tree ensemble config and replaces current tree // -// The first output contains a Tensor with the content of the audio samples. The -// lowest dimension will be the number of channels, and the second will be the -// number of samples. For example, a ten-sample-long stereo WAV file should give an -// output shape of [10, 2]. +// ensemble. // // Arguments: -// contents: The WAV-encoded audio, usually from a file. +// tree_ensemble_handle: Handle to the tree ensemble. +// stamp_token: Token to use as the new value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the ensemble. // -// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. -func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { +// Returns the created operation. +func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "BoostedTreesDeserializeEnsemble", + Input: []tf.Input{ + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + }, + } + return scope.AddOperation(opspec) +} + +// Elementwise computes the bitwise AND of `x` and `y`. +// +// The result will have those bits set, that are set in both `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return } opspec := tf.OpSpec{ - Type: "DecodeWav", + Type: "BitwiseAnd", Input: []tf.Input{ - contents, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Concatenates a list of `N` tensors along the first dimension. -// -// The input tensors are all required to have size 1 in the first dimension. -// -// For example: -// -// ``` -// # 'x' is [[1, 4]] -// # 'y' is [[2, 5]] -// # 'z' is [[3, 6]] -// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// ``` -// -// The difference between concat and parallel_concat is that concat requires all -// of the inputs be computed before the operation will begin but doesn't require -// that the input shapes be known during graph construction. Parallel concat -// will copy pieces of the input into the output as they become available, in -// some situations this can provide a performance benefit. -// -// Arguments: -// values: Tensors to be concatenated. All must have size 1 in the first dimension -// and same shape. -// shape: the final shape of the result; should be equal to the shapes of any input -// but with the number of input values in the first dimension. +// Elementwise computes the bitwise left-shift of `x` and `y`. // -// Returns The concatenated tensor. -func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { +// If `y` is negative, or greater than or equal to the width of `x` in bits the +// result is implementation defined. +func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "ParallelConcat", + Type: "LeftShift", Input: []tf.Input{ - tf.OutputList(values), + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Subtracts `v` into specified rows of `x`. +// TensorListStackAttr is an optional argument to TensorListStack. +type TensorListStackAttr func(optionalAttr) + +// TensorListStackNumElements sets the optional num_elements attribute to value. +// If not specified, defaults to -1 +func TensorListStackNumElements(value int64) TensorListStackAttr { + return func(m optionalAttr) { + m["num_elements"] = value + } +} + +// Stacks all tensors in the list. // -// Computes y = x; y[i, :] -= v; return y. +// Requires that all tensors have the same shape. // -// Arguments: -// x: A `Tensor` of type T. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// input_handle: the input list +// tensor: the gathered result +// num_elements: optional. If not -1, the number of elements in the list. // -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { +func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "InplaceSub", + Type: "TensorListStack", Input: []tf.Input{ - x, i, v, + input_handle, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts a flat index or array of flat indices into a tuple of -// -// coordinate arrays. -// -// @compatibility(numpy) -// Equivalent to np.unravel_index -// @end_compatibility +// Elementwise computes the bitwise right-shift of `x` and `y`. // -// Arguments: -// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the -// flattened version of an array of dimensions dims. -// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling -// indices. +// Performs a logical shift for unsigned integer types, and an arithmetic shift +// for signed integer types. // -// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the -// same shape as the indices array. -func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) { +// If `y` is negative, or greater than or equal to than the width of `x` in bits +// the result is implementation defined. +func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "UnravelIndex", + Type: "RightShift", Input: []tf.Input{ - indices, dims, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Compute the lower regularized incomplete Gamma function `Q(a, x)`. -// -// The lower regularized incomplete Gamma function is defined as: -// -// -// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// Adjust the hue of one or more images. // -// where +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. // -// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A delta is then applied all the hue values, +// and then remapped back to RGB colorspace. // -// is the lower incomplete Gamma function. +// Arguments: +// images: Images to adjust. At least 3-D. +// delta: A float delta to add to the hue. // -// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete -// Gamma function. -func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { +// Returns The hue-adjusted image or images. +func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Igamma", + Type: "AdjustHue", Input: []tf.Input{ - a, x, + images, delta, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes offsets of concat inputs within its output. +// BatchAttr is an optional argument to Batch. +type BatchAttr func(optionalAttr) + +// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. +// If not specified, defaults to 10 +func BatchMaxEnqueuedBatches(value int64) BatchAttr { + return func(m optionalAttr) { + m["max_enqueued_batches"] = value + } +} + +// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. +// If not specified, defaults to <> +func BatchAllowedBatchSizes(value []int64) BatchAttr { + return func(m optionalAttr) { + m["allowed_batch_sizes"] = value + } +} + +// BatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func BatchContainer(value string) BatchAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// BatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func BatchSharedName(value string) BatchAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// BatchBatchingQueue sets the optional batching_queue attribute to value. +// If not specified, defaults to "" +func BatchBatchingQueue(value string) BatchAttr { + return func(m optionalAttr) { + m["batching_queue"] = value + } +} + +// Batches all input tensors nondeterministically. // -// For example: +// When many instances of this Op are being run concurrently with the same +// container/shared_name in the same device, some will output zero-shaped Tensors +// and others will output Tensors of size up to max_batch_size. // -// ``` -// # 'x' is [2, 2, 7] -// # 'y' is [2, 3, 7] -// # 'z' is [2, 5, 7] -// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] -// ``` +// All Tensors in in_tensors are batched together (so, for example, labels and +// features should be batched with a single instance of this operation. // -// This is typically used by gradient computations for a concat operation. +// Each invocation of batch emits an `id` scalar which will be used to identify +// this particular invocation when doing unbatch or its gradient. // -// Arguments: -// concat_dim: The dimension along which to concatenate. -// shape: The `N` int32 vectors representing shape of tensors being concatenated. +// Each op which emits a non-empty batch will also emit a non-empty batch_index +// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, +// start, and length of elements of each set of Tensors present in batched_tensors. // -// Returns The `N` int32 vectors representing the starting offset -// of input tensors within the concatenated output. -func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { +// Batched tensors are concatenated along the first dimension, and all tensors in +// in_tensors must have the first dimension of the same size. +// +// in_tensors: The tensors to be batched. +// num_batch_threads: Number of scheduling threads for processing batches of work. +// Determines the number of batches processed in parallel. +// max_batch_size: Batch sizes will never be bigger than this. +// batch_timeout_micros: Maximum number of microseconds to wait before outputting +// an incomplete batch. +// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does +// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad +// batches up to one of those sizes. The entries must increase monotonically, and +// the final entry must equal max_batch_size. +// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. +// batched_tensors: Either empty tensors or a batch of concatenated Tensors. +// batch_index: If out_tensors is non-empty, has information to invert it. +// container: Controls the scope of sharing of this batch. +// id: always contains a scalar with a unique ID for this invocation of Batch. +// shared_name: Concurrently running instances of batch in the same device with the +// same container and shared_name will batch their elements together. If left +// empty, the op name will be used as the shared name. +// T: the types of tensors to be batched. +func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ConcatOffset", + Type: "Batch", Input: []tf.Input{ - concat_dim, tf.OutputList(shape), + tf.OutputList(in_tensors), }, + Attrs: attrs, } op := scope.AddOperation(opspec) if scope.Err() != nil { @@ -30403,195 +30273,321 @@ func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset } var idx int var err error - if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { - scope.UpdateErr("ConcatOffset", err) + if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { + scope.UpdateErr("Batch", err) return } - return offset + batch_index = op.Output(idx) + id = op.Output(idx) + return batched_tensors, batch_index, id } -// Splits a tensor into `num_split` tensors along one dimension. +// UnbatchAttr is an optional argument to Unbatch. +type UnbatchAttr func(optionalAttr) + +// UnbatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnbatchContainer(value string) UnbatchAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// UnbatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnbatchSharedName(value string) UnbatchAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Reverses the operation of Batch for a single output Tensor. // -// Arguments: -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. -// value: The tensor to split. -// num_split: The number of ways to split. Must evenly divide -// `value.shape[split_dim]`. +// An instance of Unbatch either receives an empty batched_tensor, in which case it +// asynchronously waits until the values become available from a concurrently +// running instance of Unbatch with the same container and shared_name, or receives +// a non-empty batched_tensor in which case it finalizes all other concurrently +// running instances and outputs its own element from the batch. // -// Returns They are identically shaped tensors, whose shape matches that of `value` -// except along `axis`, where their sizes are -// `values.shape[split_dim] / num_split`. -func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) { +// batched_tensor: The possibly transformed output of Batch. The size of the first +// dimension should remain unchanged by the transformations for the operation to +// work. +// batch_index: The matching batch_index obtained from Batch. +// id: The id scalar emitted by Batch. +// unbatched_tensor: The Tensor corresponding to this execution. +// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the +// batched input tensor associated with a given invocation of the op. +// container: Container to control resource sharing. +// shared_name: Instances of Unbatch with the same container and shared_name are +// assumed to possibly belong to the same batch. If left empty, the op name will +// be used as the shared name. +func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_split": num_split} + attrs := map[string]interface{}{"timeout_micros": timeout_micros} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Split", + Type: "Unbatch", Input: []tf.Input{ - axis, value, + batched_tensor, batch_index, id, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Split", err) - return + return op.Output(0) +} + +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) + +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { + return func(m optionalAttr) { + m["data_format"] = value } - return output } -// Splits a tensor into `num_split` tensors along one dimension. +// Computes gradients of average pooling function. // // Arguments: -// value: The tensor to split. -// size_splits: list containing the sizes of each output tensor along the split -// dimension. Must sum to the dimension of value along split_dim. -// Can contain one -1 indicating that dimension is to be inferred. -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. -// +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns Tensors whose shape matches that of `value` -// except along `axis`, where their sizes are -// `size_splits[i]`. -func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_split": num_split} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SplitV", + Type: "AvgPool3DGrad", Input: []tf.Input{ - value, size_splits, axis, + orig_input_shape, grad, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return + return op.Output(0) +} + +// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. +type ParseSingleSequenceExampleAttr func(optionalAttr) + +// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. +// +// value: A list of Ncontext_sparse types; the data types of data in +// each context Feature given in context_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_sparse_types"] = value } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("SplitV", err) - return +} + +// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_types"] = value } - return output } -// Gives a guarantee to the TF runtime that the input tensor is a constant. +// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. // -// The runtime is then free to make optimizations based on this. +// value: A list of Ncontext_dense shapes; the shapes of data in +// each context Feature given in context_dense_keys. +// The number of elements in the Feature corresponding to context_dense_key[j] +// must always equal context_dense_shapes[j].NumEntries(). +// The shape of context_dense_values[j] will match context_dense_shapes[j]. +// If not specified, defaults to <> // -// Only accepts value typed tensors as inputs and rejects resource variable handles -// as input. +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_dense_shapes"] = value + } +} + +// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. // -// Returns the input tensor without modification. -func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// value: A list of Nfeature_list_sparse types; the data types +// of data in each FeatureList given in feature_list_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_sparse_types"] = value } - opspec := tf.OpSpec{ - Type: "GuaranteeConst", - Input: []tf.Input{ - input, - }, +} + +// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. +// +// value: A list of Nfeature_list_dense shapes; the shapes of +// data in each FeatureList given in feature_list_dense_keys. +// The shape of each Feature in the FeatureList corresponding to +// feature_list_dense_key[j] must always equal +// feature_list_dense_shapes[j].NumEntries(). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_shapes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns a tensor of zeros with the same shape and type as x. +// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. // // Arguments: -// x: a tensor of type T. -// -// Returns a tensor of the same shape and type as x but filled with zeros. -func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { +// serialized: A scalar containing a binary serialized SequenceExample proto. +// feature_list_dense_missing_assumed_empty: A vector listing the +// FeatureList keys which may be missing from the SequenceExample. If the +// associated FeatureList is missing, it is treated as empty. By default, +// any FeatureList not listed in this vector must exist in the SequenceExample. +// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). +// The keys expected in the Examples' features associated with context_sparse +// values. +// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' context features associated with +// dense values. +// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors +// (scalars). The keys expected in the FeatureLists associated with sparse +// values. +// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' feature_lists associated +// with lists of dense values. +// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). +// context_dense_defaults[j] provides default values +// when the SequenceExample's context map lacks context_dense_key[j]. +// If an empty Tensor is provided for context_dense_defaults[j], +// then the Feature context_dense_keys[j] is required. +// The input type is inferred from context_dense_defaults[j], even when it's +// empty. If context_dense_defaults[j] is not empty, its shape must match +// context_dense_shapes[j]. +// debug_name: A scalar containing the name of the serialized proto. +// May contain, for example, table key (descriptive) name for the +// corresponding serialized proto. This is purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty scalar if no name is available. +func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ZerosLike", + Type: "ParseSingleSequenceExample", Input: []tf.Input{ - x, + serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. -type QuantizedInstanceNormAttr func(optionalAttr) - -// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. -// -// value: If True, `given_y_min` and `given_y_min` -// and `given_y_max` are used as the output range. Otherwise, -// the implementation computes the output range. -// If not specified, defaults to false -func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["output_range_given"] = value + if scope.Err() != nil { + return } -} - -// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. -// -// value: Output in `y_min` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_min"] = value + var idx int + var err error + if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } -} - -// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. -// -// value: Output in `y_max` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_max"] = value + if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } + if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values } -// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. -// -// value: A small float number to avoid dividing by 0. -// If not specified, defaults to 1e-05 -func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { +// UnbatchGradAttr is an optional argument to UnbatchGrad. +type UnbatchGradAttr func(optionalAttr) + +// UnbatchGradContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnbatchGradContainer(value string) UnbatchGradAttr { return func(m optionalAttr) { - m["variance_epsilon"] = value + m["container"] = value } } -// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. -// -// value: Minimum value of `y_max - y_min` -// If not specified, defaults to 0.001 -func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { +// UnbatchGradSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnbatchGradSharedName(value string) UnbatchGradAttr { return func(m optionalAttr) { - m["min_separation"] = value + m["shared_name"] = value } } -// Quantized Instance normalization. +// Gradient of Unbatch. // -// Arguments: -// x: A 4D input Tensor. -// x_min: The value represented by the lowest quantized input. -// x_max: The value represented by the highest quantized input. +// Acts like Batch but using the given batch_index index of batching things as they +// become available. This ensures that the gradients are propagated back in the +// same session which did the forward pass. // -// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. -func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { +// original_input: The input to the Unbatch operation this is the gradient of. +// batch_index: The batch_index given to the Unbatch operation this is the gradient +// of. +// grad: The downstream gradient. +// id: The id scalar emitted by Batch. +// batched_grad: The return value, either an empty tensor or the batched gradient. +// container: Container to control resource sharing. +// shared_name: Instances of UnbatchGrad with the same container and shared_name +// are assumed to possibly belong to the same batch. If left empty, the op name +// will be used as the shared name. +func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) { if scope.Err() != nil { return } @@ -30600,112 +30596,116 @@ func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf. a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedInstanceNorm", + Type: "UnbatchGrad", Input: []tf.Input{ - x, x_min, x_max, + original_input, batch_index, grad, id, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns the diagonal part of the tensor. +// DecodeWavAttr is an optional argument to DecodeWav. +type DecodeWavAttr func(optionalAttr) + +// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. // -// This operation returns a tensor with the `diagonal` part -// of the `input`. The `diagonal` part is computed as follows: +// value: Number of sample channels wanted. +// If not specified, defaults to -1 +func DecodeWavDesiredChannels(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_channels"] = value + } +} + +// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. // -// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a -// tensor of rank `k` with dimensions `[D1,..., Dk]` where: +// value: Length of audio requested. +// If not specified, defaults to -1 +func DecodeWavDesiredSamples(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_samples"] = value + } +} + +// Decode a 16-bit PCM WAV file to a float tensor. // -// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. +// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. // -// For example: +// When desired_channels is set, if the input contains fewer channels than this +// then the last channel will be duplicated to give the requested number, else if +// the input has more channels than requested then the additional channels will be +// ignored. // -// ``` -// # 'input' is [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] +// If desired_samples is set, then the audio will be cropped or padded with zeroes +// to the requested length. // -// tf.diag_part(input) ==> [1, 2, 3, 4] -// ``` +// The first output contains a Tensor with the content of the audio samples. The +// lowest dimension will be the number of channels, and the second will be the +// number of samples. For example, a ten-sample-long stereo WAV file should give an +// output shape of [10, 2]. // // Arguments: -// input: Rank k tensor where k is even and not zero. +// contents: The WAV-encoded audio, usually from a file. // -// Returns The extracted diagonal. -func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { +// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. +func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "DiagPart", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the element-wise max of two SparseTensors. -// -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. -// -// Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. -// -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "SparseSparseMaximum", + Type: "DecodeWav", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, + contents, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0), op.Output(1) } -// Returns a batched matrix tensor with new batched diagonal values. +// Concatenates a list of `N` tensors along the first dimension. // -// Given `input` and `diagonal`, this operation returns a tensor with the -// same shape and values as `input`, except for the main diagonal of the -// innermost matrices. These will be overwritten by the values in `diagonal`. +// The input tensors are all required to have size 1 in the first dimension. // -// The output is computed as follows: +// For example: // -// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has -// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a -// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: +// ``` +// # 'x' is [[1, 4]] +// # 'y' is [[2, 5]] +// # 'z' is [[3, 6]] +// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// ``` // -// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. -// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. +// The difference between concat and parallel_concat is that concat requires all +// of the inputs be computed before the operation will begin but doesn't require +// that the input shapes be known during graph construction. Parallel concat +// will copy pieces of the input into the output as they become available, in +// some situations this can provide a performance benefit. // // Arguments: -// input: Rank `k+1`, where `k >= 1`. -// diagonal: Rank `k`, where `k >= 1`. +// values: Tensors to be concatenated. All must have size 1 in the first dimension +// and same shape. +// shape: the final shape of the result; should be equal to the shapes of any input +// but with the number of input values in the first dimension. // -// Returns Rank `k+1`, with `output.shape = input.shape`. -func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { +// Returns The concatenated tensor. +func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "MatrixSetDiag", + Type: "ParallelConcat", Input: []tf.Input{ - input, diagonal, + tf.OutputList(values), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index 8fcad61f4c6eec597d2b14fb8c9b4fa59987a829..25ec71870315917351d68db6a16d25fe037d543b 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output { return Output{op, i} } +// NumInputs returns the number of inputs of op. +func (op *Operation) NumInputs() int { + return int(C.TF_OperationNumInputs(op.c)) +} + // Output represents one of the outputs of an operation in the graph. Has a // DataType (and eventually a Shape). May be passed as an input argument to a // function for adding operations to a graph, or to a Session's Run() method to @@ -123,6 +128,67 @@ func (p Output) c() C.TF_Output { func (p Output) canBeAnInput() {} +// Consumers returns the inputs that consume this output. +func (p Output) Consumers() []Consumer { + max := int(C.TF_OperationOutputNumConsumers(p.c())) + if max == 0 { + return nil + } + inputs := make([]C.TF_Input, max) + n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max)) + inputs = inputs[:int(n)] + + var consumers []Consumer + for _, consumer := range inputs { + consumers = append(consumers, Consumer{ + Index: int(consumer.index), + Op: &Operation{ + c: consumer.oper, + g: p.Op.g, + }, + }) + } + + return consumers +} + +// Consumer identifies a specific input of an operation that consumes the output +// of another operation. +type Consumer struct { + // Op is the Operation that is consuming the output of another operation. + Op *Operation + + // Index is the index of the input within Op that the output of another + // operation is connected to. + Index int +} + +func (p Consumer) c() C.TF_Input { + if p.Op == nil { + // Attempt to provide a more useful panic message than "nil + // pointer dereference". + panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers") + } + return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)} +} + +// DataType returns the type of the input. +func (p Consumer) DataType() DataType { + return DataType(C.TF_OperationInputType(p.c())) +} + +// Producer returns the Output that is connected to this Consumer. +func (p Consumer) Producer() Output { + output := C.TF_OperationInput(p.c()) + return Output{ + Op: &Operation{ + c: output.oper, + g: p.Op.g, + }, + Index: int(output.index), + } +} + // Input is the interface for specifying inputs to an operation being added to // a Graph. // diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 40c951ab8c13f43e2063b9f9cfadcd44a6da72fe..06b65bdfb7eb814a2bead191374029cc0fdf025e 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -166,6 +166,68 @@ func TestOutputDataTypeAndShape(t *testing.T) { } } +func TestOperationInputs(t *testing.T) { + g := NewGraph() + x, err := Placeholder(g, "x", Float) + if err != nil { + t.Fatal(err) + } + y, err := Placeholder(g, "y", Float) + if err != nil { + t.Fatal(err) + } + add, err := Add(g, "add", x, y) + if err != nil { + t.Fatal(err) + } + addOp := add.Op + + if out := addOp.NumInputs(); out != 2 { + t.Fatalf("Got %d inputs, wanted 2", out) + } +} + +func TestOperationConsumers(t *testing.T) { + g := NewGraph() + x, err := Placeholder(g, "x", Float) + if err != nil { + t.Fatal(err) + } + a, err := Neg(g, "a", x) + if err != nil { + t.Fatal(err) + } + b, err := Neg(g, "b", x) + if err != nil { + t.Fatal(err) + } + + consumers := []*Operation{a.Op, b.Op} + + xConsumers := x.Consumers() + if out := len(xConsumers); out != 2 { + t.Fatalf("Got %d consumers, wanted 2", out) + } + + for i, consumer := range xConsumers { + got := consumer.Op.Name() + want := consumers[i].Name() + if got != want { + t.Fatalf("%d. Got op name %q, wanted %q", i, got, want) + } + + got = consumer.Producer().Op.Name() + want = x.Op.Name() + if got != want { + t.Fatalf("%d. Got op name %q, wanted %q", i, got, want) + } + } + + if len(b.Consumers()) != 0 { + t.Fatalf("expected %+v to have no consumers", b) + } +} + func forceGC() { var mem runtime.MemStats runtime.ReadMemStats(&mem) diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 2d25c04dc9b1d0bc2ae831f98c0879e73a6bfafa..f3338f6595793df82380f4ce63058ba4285c91dd 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -131,13 +131,9 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) } runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) - n, err := r.Read(raw) - if err != nil { + if _, err := io.ReadFull(r, raw); err != nil { return nil, err } - if uintptr(n) != nbytes { - return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) - } return t, nil } diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 793c36dd4db28fc5fdb713095c6d1d6713367a7a..dc533cd3e1c7198f902b2db850e8daff50f4cdeb 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -18,6 +18,7 @@ package tensorflow import ( "bytes" + "io" "reflect" "testing" ) @@ -226,6 +227,54 @@ func TestTensorSerializationErrors(t *testing.T) { } } +func TestReadTensorReadAll(t *testing.T) { + // Get the bytes of a tensor. + a := []float32{1.1, 1.2, 1.3} + ats, err := NewTensor(a) + if err != nil { + t.Fatal(err) + } + abuf := new(bytes.Buffer) + if _, err := ats.WriteContentsTo(abuf); err != nil { + t.Fatal(err) + } + + // Get the bytes of another tensor. + b := []float32{1.1, 1.2, 1.3} + bts, err := NewTensor(b) + if err != nil { + t.Fatal(err) + } + bbuf := new(bytes.Buffer) + if _, err := bts.WriteContentsTo(bbuf); err != nil { + t.Fatal(err) + } + + // Check that ReadTensor reads all bytes of both tensors, when the situation + // requires one than reads. + abbuf := io.MultiReader(abuf, bbuf) + abts, err := ReadTensor(Float, []int64{2, 3}, abbuf) + if err != nil { + t.Fatal(err) + } + abtsf32 := abts.Value().([][]float32) + expected := [][]float32{a, b} + + if len(abtsf32) != 2 { + t.Fatalf("first dimension %d is not 2", len(abtsf32)) + } + for i := 0; i < 2; i++ { + if len(abtsf32[i]) != 3 { + t.Fatalf("second dimension %d is not 3", len(abtsf32[i])) + } + for j := 0; j < 3; j++ { + if abtsf32[i][j] != expected[i][j] { + t.Errorf("value at %d %d not equal %f %f", i, j, abtsf32[i][j], expected[i][j]) + } + } + } +} + func benchmarkNewTensor(b *testing.B, v interface{}) { for i := 0; i < b.N; i++ { if t, err := NewTensor(v); err != nil || t == nil { diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 7c3a39cd3084ea46f6803c312c2cf4a5cb797fe4..73e210fae07d603feffefb6948b82910cf683043 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -56,13 +56,15 @@ java_library( srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]), javacopts = JAVACOPTS, resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]), + deps = [ + "@com_google_guava", + "@com_squareup_javapoet", + ], ) filegroup( name = "java_op_sources", - srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [ - ":java_op_gen_sources", - ], + srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"], visibility = [ "//tensorflow/java:__pkg__", ], @@ -72,6 +74,7 @@ tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", + "//tensorflow/core/api_def:java_api_def", ], base_package = "org.tensorflow.op", gen_tool = ":java_op_gen_tool", @@ -87,6 +90,9 @@ tf_cc_binary( linkstatic = 1, deps = [ ":java_op_gen_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:ops", ], ) @@ -111,6 +117,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", + "//tensorflow/core:protos_all_cc", + "@com_googlesource_code_re2//:re2", ], ) @@ -303,6 +311,7 @@ tf_cc_test( ], deps = [ ":java_op_gen_lib", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -330,7 +339,7 @@ tf_cc_binary( "//tensorflow:debug": [], # Disable all custom linker options in debug mode "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by LINKER_EXPORTED_SYMBOLS - LINKER_EXPORTED_SYMBOLS, + "$(location {})".format(LINKER_EXPORTED_SYMBOLS), ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index 2f1ce253b2facb6d86d5c44b60668823f660ae7e..c7382ff23138cd8121718d0b7552da0f0a2d78af 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -1,7 +1,7 @@ # TensorFlow for Java > *WARNING*: The TensorFlow Java API is not currently covered by the TensorFlow -> [API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +> [API stability guarantees](https://www.tensorflow.org/guide/version_semantics). > > For using TensorFlow on Android refer instead to > [contrib/android](https://www.tensorflow.org/code/tensorflow/contrib/android), @@ -23,8 +23,7 @@ native libraries will need to be built from source. 2. Setup the environment to build TensorFlow from source code ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux) - or [Mac OS - X](https://www.tensorflow.org/install/install_sources#PrepareMac)). + or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)). If you'd like to skip reading those details and do not care about GPU support, try the following: diff --git a/tensorflow/java/maven/.gitignore b/tensorflow/java/maven/.gitignore index ff080515d5e730b308bf78f7e28244c6c799cdc3..657e2a60bc57c0cf259c000476c75ae58d75fff2 100644 --- a/tensorflow/java/maven/.gitignore +++ b/tensorflow/java/maven/.gitignore @@ -11,4 +11,10 @@ tensorflow/src tensorflow/target proto/src proto/target +hadoop/src +hadoop/target +spark-connector/src +spark-connector/target +spark-connector/dependency-reduced-pom.xml +spark-connector/spark-warehouse pom.xml.versionsBackup diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md index c7e8f0380629f492ade9ba47cdcb4bc286ac82bc..3e030dcd09c886983540b95640230eae3a6f2c0f 100644 --- a/tensorflow/java/maven/README.md +++ b/tensorflow/java/maven/README.md @@ -53,6 +53,12 @@ There are seven artifacts and thus `pom.xml`s involved in this release: 7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings shared by all of the above. +8. `hadoop`: The TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop. + The source code for this package is available in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/hadoop) + +9. `spark-connector`: A Scala library for loading and storing TensorFlow TFRecord + using Apache Spark DataFrames. The source code for this package is available + in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector) ## Updating the release diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..0642be06fa148933902ab450c5cf2f771e268828 --- /dev/null +++ b/tensorflow/java/maven/hadoop/pom.xml @@ -0,0 +1,24 @@ + + + 4.0.0 + TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop + hadoop + jar + + + https://github.com/tensorflow/ecosystem.git + git@github.com:tensorflow/ecosystem.git + scm:git:https://github.com/tensorflow/ecosystem.git + + + https://github.com/tensorflow/ecosystem/ + + org.tensorflow + parentpom + 1.9.0-rc0 + ../ + + \ No newline at end of file diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 08cc860f5795a4cf20f4ab2d09d2c2d37a52faf6..a7fa9ea5cc78f9d83cfb105f09837e958c60d5b4 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.8.0 + 1.9.0-rc1 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index fcc7eacc33b7bab366159425405b4bf5b0216cf1..83aae29f1ea0f893c40597a1be6f77668d8206e9 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.8.0 + 1.9.0-rc1 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 3d22d86a4970def52bf9a4a452a8131e1357341a..50bd8ee5f9e6d268976540ca8180380447bc8f18 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.8.0 + 1.9.0-rc1 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 0a09a5ea7cb96776b8296f68f599c333559a0729..b4746794ea9e417bb0bb9253ca356976a48eb1e8 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.8.0 + 1.9.0-rc1 pom https://www.tensorflow.org @@ -32,6 +32,8 @@ libtensorflow_jni_gpu tensorflow proto + hadoop + spark-connector + 4.0.0 + TensorFlow TFRecord connector for Apache Spark DataFrames + spark-connector + jar + + + https://github.com/tensorflow/ecosystem.git + git@github.com:tensorflow/ecosystem.git + scm:git:https://github.com/tensorflow/ecosystem.git + + + https://github.com/tensorflow/ecosystem/ + + org.tensorflow + parentpom + 1.9.0-rc0 + ../ + + \ No newline at end of file diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 0df1f2814906e548855522335f710e9702f8bb2a..157c4b8e82d6b8062ce8c9c98432cfe97a20d190 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.8.0 + 1.9.0-rc1 ../ tensorflow diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 62575f6683089b1e4a62cedb3639570e1711b8a6..f5f54bf4d31af159624c668f1abb106f68944737 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -26,12 +26,12 @@ namespace java { // An enumeration of different modifiers commonly used in Java enum Modifier { - PACKAGE = 0, - PUBLIC = (1 << 0), + PACKAGE = 0, + PUBLIC = (1 << 0), PROTECTED = (1 << 1), - PRIVATE = (1 << 2), - STATIC = (1 << 3), - FINAL = (1 << 4), + PRIVATE = (1 << 2), + STATIC = (1 << 3), + FINAL = (1 << 4), }; class Annotation; @@ -75,12 +75,8 @@ class Type { // Reflection API does return Type(Type::PRIMITIVE, "void"); } - static Type Generic(const string& name) { - return Type(Type::GENERIC, name); - } - static Type Wildcard() { - return Type(Type::GENERIC, ""); - } + static Type Generic(const string& name) { return Type(Type::GENERIC, name); } + static Type Wildcard() { return Type(Type::GENERIC, ""); } static Type Class(const string& name, const string& package = "") { return Type(Type::CLASS, name, package); } @@ -226,9 +222,7 @@ class Method { // A definition of a documentation bloc for a Java element (JavaDoc) class Javadoc { public: - static Javadoc Create(const string& brief = "") { - return Javadoc(brief); - } + static Javadoc Create(const string& brief = "") { return Javadoc(brief); } const string& brief() const { return brief_; } const string& details() const { return details_; } Javadoc& details(const string& details) { diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 6c35cd9595a0aafd2f6db35dadf899c19047e1f1..0d9e0883af262ee1f262a5e1308cb9df8763488d 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -41,7 +41,7 @@ const char kUsageHeader[] = "using an appropriate annotation processor.\n\n" "The '--base_package' overrides the default parent package under which " "the generated subpackage and classes are to be located.\n\n" - "Finally, the `--api_dirs` argument takes a list of comma-seperated " + "Finally, the `--api_dirs` argument takes a list of comma-separated " "directories of API definitions can be provided to override default\n" "values found in the ops definitions. Directories are ordered by priority " "(the last having precedence over the first).\n\n"; @@ -55,10 +55,12 @@ int main(int argc, char* argv[]) { tensorflow::string api_dirs_str; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, - "Root directory into which output files are generated"), - tensorflow::Flag("base_package", &base_package, + "Root directory into which output files are generated"), + tensorflow::Flag( + "base_package", &base_package, "Package parent to the generated subpackage and classes"), - tensorflow::Flag("api_dirs", &api_dirs_str, + tensorflow::Flag( + "api_dirs", &api_dirs_str, "List of directories that contains the ops api definitions")}; tensorflow::string usage = tensorflow::java::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 284f675c9466a9b571bfb75a151795d465845fb6..d5bd99bdd9d71f73288661380ec45e76c797fa75 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -13,43 +13,44 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include #include +#include #include #include +#include +#include +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/java/src/gen/cc/java_defs.h" -#include "tensorflow/java/src/gen/cc/source_writer.h" #include "tensorflow/java/src/gen/cc/op_generator.h" #include "tensorflow/java/src/gen/cc/op_specs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" namespace tensorflow { namespace java { namespace { -const char* kLicense = - "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" - "\n" - "Licensed under the Apache License, Version 2.0 (the \"License\");\n" - "you may not use this file except in compliance with the License.\n" - "You may obtain a copy of the License at\n" - "\n" - " http://www.apache.org/licenses/LICENSE-2.0\n" - "\n" - "Unless required by applicable law or agreed to in writing, software\n" - "distributed under the License is distributed on an \"AS IS\" BASIS,\n" - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "See the License for the specific language governing permissions and\n" - "limitations under the License.\n" - "=======================================================================*/\n"; +constexpr const char kLicense[] = + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + "\n" + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + "you may not use this file except in compliance with the License.\n" + "You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + "Unless required by applicable law or agreed to in writing, software\n" + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "See the License for the specific language governing permissions and\n" + "limitations under the License.\n" + "=======================================================================*/" + "\n"; // There is three different modes to render an op class, depending on the // number and type of outputs it has: @@ -64,20 +65,16 @@ const char* kLicense = // allowing an instance to be passed directly as a list input to // another operation // -enum RenderMode { - DEFAULT, - OPERAND, - LIST_OPERAND -}; +enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND }; void AddArgument(const Variable& var, const string& description, - Method* method_out, Javadoc* javadoc_out) { + Method* method_out, Javadoc* javadoc_out) { method_out->add_argument(var); javadoc_out->add_param_tag(var.name(), description); } void CollectOpDependencies(const OpSpec& op, RenderMode mode, - std::list* out) { + std::list* out) { out->push_back(Type::Class("Operation", "org.tensorflow")); out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); out->push_back(Type::Class("Scope", "org.tensorflow.op")); @@ -110,7 +107,7 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, - SourceWriter* writer) { + SourceWriter* writer) { string var_name = optional ? "opts." + attr.var().name() : attr.var().name(); if (attr.iterable()) { string array_name = attr.var().name() + "Array"; @@ -143,11 +140,11 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, } void RenderFactoryMethods(const OpSpec& op, const Type& op_class, - SourceWriter* writer) { + SourceWriter* writer) { Method factory = Method::Create("create", op_class); - Javadoc factory_doc = Javadoc::Create( - "Factory method to create a class to wrap a new " + op_class.name() - + " operation to the graph."); + Javadoc factory_doc = + Javadoc::Create("Factory method to create a class to wrap a new " + + op_class.name() + " operation to the graph."); Variable scope = Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); AddArgument(scope, "current graph scope", &factory, &factory_doc); @@ -159,23 +156,23 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, } if (!op.optional_attributes().empty()) { AddArgument(Variable::Varargs("options", Type::Class("Options")), - "carries optional attributes values", &factory, &factory_doc); + "carries optional attributes values", &factory, &factory_doc); } factory_doc.add_tag("return", "a new instance of " + op_class.name()); - writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); - writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" - + op.graph_op_name() + "\", scope.makeOpName(\"" - + op_class.name() + "\"));"); + writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc); + writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + op.graph_op_name() + "\", scope.makeOpName(\"" + + op_class.name() + "\"));"); writer->EndLine(); for (const ArgumentSpec& input : op.inputs()) { if (input.iterable()) { - writer->Append("opBuilder.addInputList(Operands.asOutputs(" - + input.var().name() + "));"); + writer->Append("opBuilder.addInputList(Operands.asOutputs(" + + input.var().name() + "));"); writer->EndLine(); } else { - writer->Append("opBuilder.addInput(" + input.var().name() - + ".asOutput());"); + writer->Append("opBuilder.addInput(" + input.var().name() + + ".asOutput());"); writer->EndLine(); } } @@ -200,7 +197,7 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, } void RenderConstructor(const OpSpec& op, const Type& op_class, - SourceWriter* writer) { + SourceWriter* writer) { Variable operation = Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); Method constructor = Method::ConstructorFor(op_class).add_argument(operation); @@ -214,15 +211,14 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, writer->BeginMethod(constructor, PRIVATE) .Append("super(operation);") .EndLine(); - if (op.outputs().size() > 0) { - writer->Append("int outputIdx = 0;") - .EndLine(); + if (!op.outputs().empty()) { + writer->Append("int outputIdx = 0;").EndLine(); for (const ArgumentSpec& output : op.outputs()) { if (output.iterable()) { string var_length = output.var().name() + "Length"; writer->Append("int " + var_length) - .Append(" = operation.outputListLength(\"" + output.op_def_name() - + "\");") + .Append(" = operation.outputListLength(\"" + output.op_def_name() + + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); if (!output.type().wildcard()) { @@ -235,8 +231,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, .Append("outputIdx += " + var_length + ";") .EndLine(); } else { - writer->Append(output.var().name() - + " = operation.output(outputIdx++);") + writer + ->Append(output.var().name() + " = operation.output(outputIdx++);") .EndLine(); } } @@ -246,13 +242,12 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { for (const AttributeSpec& attr : op.optional_attributes()) { - Method setter = - Method::Create(attr.var().name(), Type::Class("Options")); + Method setter = Method::Create(attr.var().name(), Type::Class("Options")); Javadoc setter_doc = Javadoc::Create(); AddArgument(attr.var(), attr.description(), &setter, &setter_doc); - writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) - .Append("return new Options()." + attr.var().name() + "(" - + attr.var().name() + ");") + writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc) + .Append("return new Options()." + attr.var().name() + "(" + + attr.var().name() + ");") .EndLine() .EndMethod(); } @@ -267,15 +262,16 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { } void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, - SourceWriter* writer) { + SourceWriter* writer) { ArgumentSpec output = op.outputs().front(); if (mode == OPERAND) { bool cast2obj = output.type().wildcard(); - Type return_type = Type::Class("Output", "org.tensorflow") - .add_parameter(cast2obj ? Type::Class("Object") : output.type()); + Type return_type = + Type::Class("Output", "org.tensorflow") + .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) - .add_annotation(Annotation::Create("Override")); + .add_annotation(Annotation::Create("Override")); if (cast2obj) { as_output.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); @@ -286,9 +282,7 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } else { writer->Append("return "); } - writer->Append(output.var().name() + ";") - .EndLine() - .EndMethod(); + writer->Append(output.var().name() + ";").EndLine().EndMethod(); } else if (mode == LIST_OPERAND) { Type operand = Type::Interface("Operand", "org.tensorflow"); @@ -297,12 +291,13 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } else { operand.add_parameter(output.type()); } - Type return_type = Type::Interface("Iterator", "java.util") - .add_parameter(operand); - Method iterator = Method::Create("iterator", return_type) - .add_annotation(Annotation::Create("Override")) - .add_annotation(Annotation::Create("SuppressWarnings") - .attributes("{\"rawtypes\", \"unchecked\"}")); + Type return_type = + Type::Interface("Iterator", "java.util").add_parameter(operand); + Method iterator = + Method::Create("iterator", return_type) + .add_annotation(Annotation::Create("Override")) + .add_annotation(Annotation::Create("SuppressWarnings") + .attributes("{\"rawtypes\", \"unchecked\"}")); // cast the output list using a raw List writer->BeginMethod(iterator, PUBLIC) .Append("return (" + return_type.name() + ") ") @@ -313,10 +308,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } void RenderOptionsClass(const OpSpec& op, const Type& op_class, - SourceWriter* writer) { + SourceWriter* writer) { Type options_class = Type::Class("Options"); - Javadoc options_doc = Javadoc::Create( - "Optional attributes for {@link " + op_class.canonical_name() + "}"); + Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " + + op_class.canonical_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); for (const AttributeSpec& attr : op.optional_attributes()) { Method setter = Method::Create(attr.var().name(), options_class); @@ -339,24 +334,27 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class, } inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { - return Type::Class(endpoint.name(), + return Type::Class( + endpoint.name(), base_package + "." + str_util::Lowercase(endpoint.package())); } void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, - const string& base_package, const string& output_dir, Env* env) { - Type op_class(ClassOf(endpoint, base_package) - .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); + const string& base_package, const string& output_dir, + Env* env) { + Type op_class( + ClassOf(endpoint, base_package) + .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); Javadoc op_javadoc(endpoint.javadoc()); // op interfaces RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { const ArgumentSpec& output = op.outputs().front(); - Type operand_type(output.type().wildcard() ? - Type::Class("Object") : output.type()); + Type operand_type(output.type().wildcard() ? Type::Class("Object") + : output.type()); Type operand_inf(Type::Interface("Operand", "org.tensorflow") - .add_parameter(operand_type)); + .add_parameter(operand_type)); if (output.iterable()) { mode = LIST_OPERAND; op_class.add_supertype(Type::IterableOf(operand_inf)); @@ -368,25 +366,24 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, // op generic parameters std::set generics; for (const ArgumentSpec& output : op.outputs()) { - if (output.type().kind() == Type::GENERIC && !output.type().wildcard() - && generics.find(output.type().name()) == generics.end()) { + if (output.type().kind() == Type::GENERIC && !output.type().wildcard() && + generics.find(output.type().name()) == generics.end()) { op_class.add_parameter(output.type()); - op_javadoc.add_param_tag("<" + output.type().name() + ">", + op_javadoc.add_param_tag( + "<" + output.type().name() + ">", "data type for {@code " + output.var().name() + "()} output"); generics.insert(output.type().name()); } } // op annotations - op_class.add_annotation( - Annotation::Create("Generated", "javax.annotation") - .attributes("value = \"TensorFlow Java Op Generator\"")); if (endpoint.deprecated()) { op_class.add_annotation(Annotation::Create("Deprecated")); string explanation; if (!op.endpoints().front().deprecated()) { - explanation = "use {@link " + - ClassOf(op.endpoints().front(), base_package).canonical_name() - + "} instead"; + explanation = + "use {@link " + + ClassOf(op.endpoints().front(), base_package).canonical_name() + + "} instead"; } else { explanation = op.deprecation_explanation(); } @@ -394,21 +391,25 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } if (!op.hidden()) { // expose the op in the Ops Graph API only if it is visible - op_class.add_annotation( - Annotation::Create("Operator", "org.tensorflow.op.annotation") - .attributes("group = \"" + endpoint.package() + "\"")); + Annotation oper_annot = + Annotation::Create("Operator", "org.tensorflow.op.annotation"); + if (endpoint.package() != kDefaultEndpointPackage) { + oper_annot.attributes("group = \"" + endpoint.package() + "\""); + } + op_class.add_annotation(oper_annot); } // create op class file - const string op_dir_name = io::JoinPath(output_dir, - str_util::StringReplace(op_class.package(), ".", "/", true)); + const string op_dir_name = io::JoinPath( + output_dir, str_util::StringReplace(op_class.package(), ".", "/", true)); if (!env->FileExists(op_dir_name).ok()) { TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name)) << op_dir_name; } const string op_file_name = op_class.name() + ".java"; std::unique_ptr op_file; - TF_CHECK_OK(env->NewWritableFile( - io::JoinPath(op_dir_name, op_file_name), &op_file)) << op_file_name; + TF_CHECK_OK( + env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file)) + << op_file_name; // render endpoint source code SourceFileWriter writer(op_file.get()); @@ -416,7 +417,10 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, CollectOpDependencies(op, mode, &dependencies); writer.Write(kLicense) .EndLine() - .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); + .Write("// This class has been generated, DO NOT EDIT!") + .EndLine() + .EndLine() + .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc); if (!op.optional_attributes().empty()) { RenderOptionsClass(op, op_class, &writer); } @@ -448,7 +452,7 @@ bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) { } // namespace Status OpGenerator::Run(const OpList& op_list, const string& base_package, - const string& output_dir) { + const string& output_dir) { ApiDefMap api_map(op_list); if (!api_dirs_.empty()) { // Only load api files that correspond to the requested "op_list" diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index cfe842070a77947003e8fedc1897a418f1403241..759d800ecfb5bec10b7bf8454baf5fc4c389e990 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" -#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/java/src/gen/cc/op_specs.h" @@ -37,14 +37,15 @@ namespace java { class OpGenerator { public: explicit OpGenerator(const std::vector& api_dirs, - Env* env = Env::Default()) : api_dirs_(api_dirs), env_(env) {} + Env* env = Env::Default()) + : api_dirs_(api_dirs), env_(env) {} // Generates wrappers for the given list of 'ops'. // // Output files are generated in //, // where 'op_package' is derived from ops endpoints. Status Run(const OpList& op_list, const string& base_package, - const string& output_dir); + const string& output_dir); private: const std::vector api_dirs_; diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 4bcfc7fe011423df71a899d18815d3558e01b35f..63e99fbb04fd6ba34f2bbd2bc3fe7644a31ddf7f 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -97,6 +97,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, *iterable_out = true; visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int())); } + Type type = Type::Wildcard(); if (arg_def.type() != DataType::DT_INVALID) { // resolve type from DataType @@ -376,7 +377,7 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, package = name_tokens.at(0); name = name_tokens.at(1); } else { - package = "core"; // generate unclassified ops in the 'core' package + package = kDefaultEndpointPackage; name = name_tokens.at(0); } return EndpointSpec(package, diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 034cf636ed071a9dccac643d0f89988b070a1efc..3b53c730df23c6f81f968f09b9d145a8efa1030a 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -27,6 +27,8 @@ limitations under the License. namespace tensorflow { namespace java { +constexpr const char kDefaultEndpointPackage[] = "core"; + class EndpointSpec { public: // A specification for an operation endpoint diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 56806cbb6dc5da94dc672828c2ff2dfe001b71c9..8e5fba7e32f096504f2aace6e9943b6f7281be31 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/java/src/gen/cc/source_writer.h" @@ -123,7 +124,7 @@ SourceWriter& SourceWriter::EndBlock() { } SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers, - const Javadoc* javadoc) { + const Javadoc* javadoc) { GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); if (!method.constructor()) { generic_namespace->Visit(method.return_type()); @@ -165,7 +166,8 @@ SourceWriter& SourceWriter::EndMethod() { } SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers, - const std::list* extra_dependencies, const Javadoc* javadoc) { + const std::list* extra_dependencies, + const Javadoc* javadoc) { if (!type.package().empty()) { Append("package ").Append(type.package()).Append(";").EndLine(); } @@ -186,7 +188,7 @@ SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers, } SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers, - const Javadoc* javadoc) { + const Javadoc* javadoc) { GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); generic_namespace->Visit(type); EndLine(); @@ -226,7 +228,7 @@ SourceWriter& SourceWriter::EndType() { } SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers, - const Javadoc* javadoc) { + const Javadoc* javadoc) { // If present, write field javadoc only as one brief line if (javadoc != nullptr && !javadoc->brief().empty()) { Append("/** ").Append(javadoc->brief()).Append(" */").EndLine(); @@ -345,8 +347,8 @@ void SourceWriter::TypeVisitor::Visit(const Type& type) { void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC && !type.wildcard() - && generic_names_.find(type.name()) == generic_names_.end()) { + if (type.kind() == Type::GENERIC && !type.wildcard() && + generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index 1f0febe9a3135a0b9cc07a406db97b3ccc0182ee..de0113bd5b7092bfae0e2dd15fa1c7a26c491c9e 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -93,7 +93,7 @@ class SourceWriter { // This method appends a new opening brace to the current data and indent the // next lines according to Google Java Style Guide. The block can optionally // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) - SourceWriter& BeginBlock(const string& expr = ""); + SourceWriter& BeginBlock(const string& expression = ""); // Ends the current block of source code. // @@ -108,7 +108,7 @@ class SourceWriter { // in parameter to define the access scope of this method and, optionally, // a Javadoc. SourceWriter& BeginMethod(const Method& method, int modifiers, - const Javadoc* javadoc = nullptr); + const Javadoc* javadoc = nullptr); // Ends the current method. // @@ -125,9 +125,9 @@ class SourceWriter { // // If not null, all types found in the 'extra_dependencies' list will be // imported before declaring the new type. - SourceWriter& BeginType(const Type& clazz, int modifiers, - const std::list* extra_dependencies = nullptr, - const Javadoc* javadoc = nullptr); + SourceWriter& BeginType(const Type& type, int modifiers, + const std::list* extra_dependencies = nullptr, + const Javadoc* javadoc = nullptr); // Begins to write a new inner type. // @@ -136,7 +136,7 @@ class SourceWriter { // in parameter to define the accesses and the scope of this type and, // optionally, a Javadoc. SourceWriter& BeginInnerType(const Type& type, int modifiers, - const Javadoc* javadoc = nullptr); + const Javadoc* javadoc = nullptr); // Ends the current type. // @@ -150,7 +150,7 @@ class SourceWriter { // or BeginInnerType()). Modifiers are also be passed in parameter to define // the accesses and the scope of this field and, optionally, a Javadoc. SourceWriter& WriteField(const Variable& field, int modifiers, - const Javadoc* javadoc = nullptr); + const Javadoc* javadoc = nullptr); protected: virtual void DoAppend(const StringPiece& str) = 0; diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index b9a5fee9bea1660dbfdf1a7b6920133f0bb85577..fb8fc64dffa309e4df5ef0e9a0c631fab73c397b 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -245,12 +245,17 @@ TEST(StreamTest, Types) { SourceBufferWriter writer; Type generic = Type::Generic("T").add_supertype(Type::Class("Number")); - writer.AppendType(Type::Int()).Append(", ") - .AppendType(Type::Class("String")).Append(", ") - .AppendType(generic).Append(", ") - .AppendType(Type::ListOf(generic)).Append(", ") - .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") - .AppendType(Type::ListOf(Type::Wildcard())); + writer.AppendType(Type::Int()) + .Append(", ") + .AppendType(Type::Class("String")) + .Append(", ") + .AppendType(generic) + .Append(", ") + .AppendType(Type::ListOf(generic)) + .Append(", ") + .AppendType(Type::ListOf(Type::IterableOf(generic))) + .Append(", ") + .AppendType(Type::ListOf(Type::Wildcard())); const char* expected = "int, String, T, List, List>, List"; @@ -314,7 +319,7 @@ TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Javadoc clazz_doc = Javadoc::Create("Javadoc test") - .details("This is a\nmultiline description."); + .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); @@ -380,10 +385,10 @@ TEST(WriteType, ParameterizedClassFields) { Javadoc field3_doc = Javadoc::Create("This variable is documented"); writer.BeginType(clazz, PUBLIC) - .WriteField(field1, STATIC | PUBLIC | FINAL) - .WriteField(field2, PRIVATE) - .WriteField(field3, PRIVATE, &field3_doc) - .EndType(); + .WriteField(field1, STATIC | PUBLIC | FINAL) + .WriteField(field2, PRIVATE) + .WriteField(field3, PRIVATE, &field3_doc) + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -402,9 +407,9 @@ TEST(WriteType, SimpleInnerClass) { Type inner_class = Type::Class("InnerTest"); writer.BeginType(clazz, PUBLIC) - .BeginInnerType(inner_class, PUBLIC) - .EndType() - .EndType(); + .BeginInnerType(inner_class, PUBLIC) + .EndType() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -425,9 +430,9 @@ TEST(WriteType, StaticParameterizedInnerClass) { inner_class.add_parameter(type_t); writer.BeginType(clazz, PUBLIC) - .BeginInnerType(inner_class, PUBLIC | STATIC) - .EndType() - .EndType(); + .BeginInnerType(inner_class, PUBLIC | STATIC) + .EndType() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -445,8 +450,9 @@ TEST(WriteMethod, SimpleMethod) { Method method = Method::Create("doNothing", Type::Void()); writer.BeginType(clazz, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() - .EndType(); + .BeginMethod(method, PUBLIC) + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -462,15 +468,17 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - Javadoc method_doc = Javadoc::Create("Javadoc test") - .details("This method has a\nmultiline description."); + Javadoc method_doc = + Javadoc::Create("Javadoc test") + .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); writer.BeginType(clazz, PUBLIC) - .BeginMethod(method, PUBLIC, &method_doc).EndMethod() - .EndType(); + .BeginMethod(method, PUBLIC, &method_doc) + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -497,20 +505,23 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { Method method = Method::Create("boolToInt", Type::Int()); method.add_argument(Variable::Create("b", Type::Boolean())); method.add_argument(reverse); - Javadoc method_doc = Javadoc::Create("Converts a boolean to an int") - .details("This method will convert\na boolean to an int") - .add_param_tag(reverse.name(), "if true, value is reversed") - .add_tag("return", "int value for this boolean"); + Javadoc method_doc = + Javadoc::Create("Converts a boolean to an int") + .details("This method will convert\na boolean to an int") + .add_param_tag(reverse.name(), "if true, value is reversed") + .add_tag("return", "int value for this boolean"); writer.BeginType(clazz, PUBLIC) - .BeginMethod(method, PUBLIC, &method_doc) - .Append("if (b && !reverse)") - .BeginBlock() - .Append("return 1;").EndLine() - .EndBlock() - .Append("return 0;").EndLine() - .EndMethod() - .EndType(); + .BeginMethod(method, PUBLIC, &method_doc) + .Append("if (b && !reverse)") + .BeginBlock() + .Append("return 1;") + .EndLine() + .EndBlock() + .Append("return 0;") + .EndLine() + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -543,10 +554,11 @@ TEST(WriteMethod, ParameterizedMethod) { Method method = Method::Create("doNothing", type_t); writer.BeginType(clazz, PUBLIC) - .BeginMethod(method, PUBLIC) - .Append("return null;").EndLine() - .EndMethod() - .EndType(); + .BeginMethod(method, PUBLIC) + .Append("return null;") + .EndLine() + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -567,10 +579,11 @@ TEST(WriteMethod, StaticParameterizedMethod) { Method method = Method::Create("doNothing", type_t); writer.BeginType(clazz, PUBLIC) - .BeginMethod(method, PUBLIC | STATIC) - .Append("return null;").EndLine() - .EndMethod() - .EndType(); + .BeginMethod(method, PUBLIC | STATIC) + .Append("return null;") + .EndLine() + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 11fda4fc22aeec9c2d94b5e884c11ceb2a66d29e..796d6a62dcf8551d8d68d9ff62077e7f09db4401 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,19 +15,44 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; -import java.io.PrintWriter; +import java.util.Collection; import java.util.Collections; -import java.util.HashSet; +import java.util.HashMap; +import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; import javax.annotation.processing.ProcessingEnvironment; import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; +import javax.lang.model.element.AnnotationMirror; +import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; +import javax.lang.model.element.TypeParameterElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeVariable; +import javax.lang.model.util.ElementFilter; +import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; /** @@ -55,6 +80,7 @@ public final class OperatorProcessor extends AbstractProcessor { super.init(processingEnv); messager = processingEnv.getMessager(); filer = processingEnv.getFiler(); + elements = processingEnv.getElementUtils(); } @Override @@ -98,42 +124,77 @@ public final class OperatorProcessor extends AbstractProcessor { } // Collect all classes tagged with our annotation. - Set opClasses = new HashSet(); - if (!collectOpClasses(roundEnv, opClasses, annotation)) { + Multimap groupedMethods = HashMultimap.create(); + if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { return true; } // Nothing to do when there are no tagged classes. - if (opClasses.isEmpty()) { + if (groupedMethods.isEmpty()) { return true; } - // TODO:(kbsriram) validate operator classes and generate Op API. - writeApi(); + // Validate operator classes and generate Op API. + writeApi(groupedMethods); + hasRun = true; return true; } @Override public Set getSupportedAnnotationTypes() { - return Collections.singleton(String.format("%s.annotation.Operator", OP_PACKAGE)); + return Collections.singleton("org.tensorflow.op.annotation.Operator"); + } + + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); + private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); + private static final TypeName T_STRING = ClassName.get(String.class); + + private Filer filer; + private Messager messager; + private Elements elements; + private boolean hasRun = false; + + private void error(Element e, String message, Object... args) { + if (args != null && args.length > 0) { + message = String.format(message, args); + } + messager.printMessage(Kind.ERROR, message, e); } - private void writeApi() { - // Generate an empty class for now and get the build working correctly. This will be changed to - // generate the actual API once we've done with build-related changes. - // TODO:(kbsriram) - try (PrintWriter writer = - new PrintWriter(filer.createSourceFile(String.format("%s.Ops", OP_PACKAGE)).openWriter())) { - writer.println(String.format("package %s;", OP_PACKAGE)); - writer.println("public class Ops{}"); + private void write(TypeSpec spec) { + try { + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { - error(null, "Unexpected failure generating API: %s", e.getMessage()); + throw new AssertionError(e); + } + } + + private void writeApi(Multimap groupedMethods) { + Map groups = new HashMap<>(); + + // Generate a API class for each group collected other than the default one (= empty string) + for (Map.Entry> entry : groupedMethods.asMap().entrySet()) { + if (!entry.getKey().isEmpty()) { + TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); + write(groupClass); + groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name)); + } } + // Generate the top API class, adding any methods added to the default group + TypeSpec topClass = buildTopClass(groups, groupedMethods.get("")); + write(topClass); } - private boolean collectOpClasses( - RoundEnvironment roundEnv, Set opClasses, TypeElement annotation) { + private boolean collectOpsMethods( + RoundEnvironment roundEnv, + Multimap groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. @@ -145,20 +206,251 @@ public final class OperatorProcessor extends AbstractProcessor { result = false; continue; } - opClasses.add((TypeElement) e); + TypeElement opClass = (TypeElement) e; + // Skip deprecated operations for now, as we do not guarantee API stability yet + if (opClass.getAnnotation(Deprecated.class) == null) { + collectOpMethods(groupedMethods, opClass, annotation); + } } return result; } - private void error(Element e, String message, Object... args) { - if (args != null && args.length > 0) { - message = String.format(message, args); + private void collectOpMethods( + Multimap groupedMethods, TypeElement opClass, TypeElement annotation) { + AnnotationMirror am = getAnnotationMirror(opClass, annotation); + String groupName = getAnnotationElementValueAsString("group", am); + String methodName = getAnnotationElementValueAsString("name", am); + ClassName opClassName = ClassName.get(opClass); + if (Strings.isNullOrEmpty(methodName)) { + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + } + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called + // "create", which takes in parameter a scope and, optionally, a list of arguments + for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { + MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); + groupedMethods.put(groupName, method); + } } - messager.printMessage(Kind.ERROR, message, e); } - private Filer filer; - private Messager messager; - private boolean hasRun = false; - private static final String OP_PACKAGE = "org.tensorflow.op"; + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + MethodSpec.Builder builder = + MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { + TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); + builder.addTypeVariable(tvn); + } + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { + builder.addException(TypeName.get(thrownType)); + } + StringBuilder call = new StringBuilder("return $T.create(scope"); + boolean first = true; + for (VariableElement param : factoryMethod.getParameters()) { + ParameterSpec p = ParameterSpec.get(param); + if (first) { + first = false; + continue; + } + call.append(", "); + call.append(p.name); + builder.addParameter(p); + } + call.append(")"); + builder.addStatement(call.toString(), opClassName); + return builder.build(); + } + + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { + StringBuilder javadoc = new StringBuilder(); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); + + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the + // 'scope' parameter that is implicitly passed by this API + Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); + boolean firstParam = true; + + while (tagMatcher.find()) { + String tag = tagMatcher.group(); + if (tag.startsWith("@param") && firstParam) { + firstParam = false; + } else { + javadoc.append(tag).append('\n'); + } + } + javadoc.append("@see {@link ").append(opClassName).append("}\n"); + + return javadoc.toString(); + } + + private static TypeSpec buildGroupClass(String group, Collection methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + + TypeSpec.Builder builder = + TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + builder.addField( + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + + return builder.build(); + } + + private static TypeSpec buildTopClass( + Map groupToClass, Collection methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); + + for (Map.Entry entry : groupToClass.entrySet()) { + ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); + } + + TypeSpec.Builder opsBuilder = + TypeSpec.classBuilder("Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n

\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n

Example usage:\n

{@code\n"
+                    + "try (Graph g = new Graph()) {\n"
+                    + "  Ops ops = new Ops(g);\n"
+                    + "  // Operations are typed classes with convenience\n"
+                    + "  // builders in Ops.\n"
+                    + "  Constant three = ops.constant(3);\n"
+                    + "  // Single-result operations implement the Operand\n"
+                    + "  // interface, so this works too.\n"
+                    + "  Operand four = ops.constant(4);\n"
+                    + "  // Most builders are found within a group, and accept\n"
+                    + "  // Operand types as operands\n"
+                    + "  Operand nine = ops.math().add(four, ops.constant(5));\n"
+                    + "  // Multi-result operations however offer methods to\n"
+                    + "  // select a particular result for use.\n"
+                    + "  Operand result = \n"
+                    + "      ops.math().add(ops.array().unique(s, a).y(), b);\n"
+                    + "  // Optional attributes\n"
+                    + "  ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+                    + "  // Naming operators\n"
+                    + "  ops.withName(“foo”).constant(5); // name “foo”\n"
+                    + "  // Names can exist in a hierarchy\n"
+                    + "  Ops sub = ops.withSubScope(“sub”);\n"
+                    + "  sub.withName(“bar”).constant(4); // “sub/bar”\n"
+                    + "}\n"
+                    + "}
\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withSubScope") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withName") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); + + opsBuilder.addField( + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("scope") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); + + for (Map.Entry entry : groupToClass.entrySet()) { + opsBuilder.addField( + FieldSpec.builder(entry.getValue(), entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder(entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); + } + + opsBuilder.addMethod( + MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); + + return opsBuilder.build(); + } + + private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) { + for (AnnotationMirror am : element.getAnnotationMirrors()) { + if (am.getAnnotationType().asElement().equals(annotation)) { + return am; + } + } + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); + } + + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { + for (Map.Entry entry : + am.getElementValues().entrySet()) { + if (entry.getKey().getSimpleName().contentEquals(elementName)) { + return entry.getValue().getValue().toString(); + } + } + return ""; + } } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index d4fd3db5f7325ae891832ff7b658f5d3ea0789a6..7d19696749bbbb944e591daf596562f13f6dc103 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable { } } + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function + * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. + *

+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all + * shapes in {@code y}. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output[] y, Output[] x, Output[] dx) { + Output[] dy = new Output[x.length]; + final long[] yHandles = new long[y.length]; + final int[] yIndices = new int[y.length]; + final long[] xHandles = new long[x.length]; + final int[] xIndices = new int[x.length]; + long[] dxHandles = null; + int[] dxIndices = null; + + try (Reference ref = ref()) { + for (int i = 0; i < y.length; ++i) { + yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yIndices[i] = y[i].index(); + } + for (int i = 0; i < x.length; ++i) { + xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xIndices[i] = x[i].index(); + } + if (dx != null && dx.length > 0) { + dxHandles = new long[dx.length]; + dxIndices = new int[dx.length]; + + for (int i = 0; i < dx.length; ++i) { + dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxIndices[i] = dx[i].index(); + } + } + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles + // of the gradient operations while the second holds the index of their output + // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] + long[] dyHandlesAndIndices = + addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + int ndy = dyHandlesAndIndices.length >> 1; + if (ndy != dy.length) { + throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length + + " were expected"); + } + for (int i = 0, j = ndy; i < ndy; ++i, ++j) { + Operation op = new Operation(this, dyHandlesAndIndices[i]); + dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); + } + } + return dy; + } + + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code dy/dx_1, dy/dx_2...} + *

+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is + * a single output and {@code dx} is null. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output y, Output[] x) { + return addGradients(new Output[]{y}, x, null); + } + private final Object nativeHandleLock = new Object(); private long nativeHandle; private int refcount = 0; @@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); + private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, + long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java new file mode 100644 index 0000000000000000000000000000000000000000..f4671c8af941dd732859080238fa48e0a22672b6 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss + * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}. + *

+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all + * shapes in {@code y}. + *

+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}. + *

+ * Example of usage: + *

{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ * 
+ * Constant alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.dy(1));
+ * }
+ */ +@Operator +public class Gradients implements Op, Iterable> { + + /** + * Optional attributes for {@link Gradients} + */ + public static class Options { + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return this option builder + */ + public Options dx(Iterable> dx) { + this.dx = dx; + return this; + } + + private Iterable> dx; + + private Options() { + } + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * @param scope current graph scope + * @param y outputs of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + public static Gradients create(Scope scope, Iterable> y, Iterable> x, Options... options) { + Output[] dx = null; + if (options != null) { + for (Options opts : options) { + if (opts.dx != null) { + dx = Operands.asOutputs(opts.dx); + } + } + } + Output[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(gradOutputs)); + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is + * a single output. + * + * @param scope current graph scope + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Gradients create(Scope scope, Operand y, Iterable> x, Options... options) { + return create(scope, (Iterable) Arrays.asList(y), x, options); + } + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return builder to add more options to this operation + */ + public Options dx(Iterable> dx) { + return new Options().dx(dx); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator> iterator() { + return (Iterator) dy.iterator(); + } + + /** + * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x} + */ + public List> dy() { + return dy; + } + + /** + * Returns a symbolic handle to one of the gradient operation output + *

+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code + * gradients.dy(0)} + * + * @param The expected element type of the tensors produced by this output. + * @param index The index of the output among the gradients added by this operation + */ + @SuppressWarnings("unchecked") + public Output dy(int index) { + return (Output) dy.get(index); + } + + private List> dy; + + private Gradients(List> dy) { + this.dy = dy; + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/package-info.java index 521c5c610c1f775cf9174664f5b786786ce1181d..f353ee31459806eb2db98d23ac030c15258a77fb 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/package-info.java +++ b/tensorflow/java/src/main/java/org/tensorflow/package-info.java @@ -17,7 +17,7 @@ limitations under the License. * Defines classes to build, save, load and execute TensorFlow models. * *

WARNING: The API is currently experimental and is not covered by TensorFlow API stability + * href="https://www.tensorflow.org/guide/version_semantics">API stability * guarantees. See README.md for installation * instructions. diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index 0fef15527586555e7d3fc2c76403c6e5888fb236..dac6a345e917b618f7f1234c27959069650b51b7 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/java/src/main/native/graph_jni.h" #include +#include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" namespace { @@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { TF_DeleteBuffer(buf); return ret; } + +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, + jlongArray y_handles, jintArray y_indices, + jlongArray x_handles, jintArray x_indices, + jlongArray dx_handles, jintArray dx_indices) { + + TF_Graph* g = requireHandle(env, handle); + if (g == nullptr) return nullptr; + + const jint ny = env->GetArrayLength(y_handles); + const jint nx = env->GetArrayLength(x_handles); + + std::unique_ptr y(new TF_Output[ny]); + std::unique_ptr x(new TF_Output[nx]); + std::unique_ptr dx(nullptr); + std::unique_ptr dy(new TF_Output[nx]); + + resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny); + resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx); + if (dx_handles != nullptr) { + if (env->GetArrayLength(dx_handles) != ny) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d dx handles", ny, + env->GetArrayLength(dx_handles)); + } + dx.reset(new TF_Output[ny]); + resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny); + } + if (env->ExceptionCheck()) return nullptr; + + TF_Status* status = TF_NewStatus(); + TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); + + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + + // returned array contains both op handles and output indices, in pair + jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); + jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); + for (int i = 0, j = nx; i < nx; ++i, ++j) { + TF_Output dy_output = dy.get()[i]; + dy_elems[i] = reinterpret_cast(dy_output.oper); + dy_elems[j] = static_cast(dy_output.index); + } + env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0); + + return dy_handles_and_indices; +} diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index dd2e038332f7d39e6460d6cfef40a9df7e348758..4f87e8d5a79d3ac46f7813ba4344bbfda069b557 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Graph + * Method: name + * Signature: (J[J[I[J[I[J[I)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, + jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, + jintArray); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc index 2cd542d3c9be536a42037e9ef533ed629dd3ac9f..cb54daf13795c24e11566845892da6b5c4896cf5 100644 --- a/tensorflow/java/src/main/native/session_jni.cc +++ b/tensorflow/java/src/main/native/session_jni.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" #include "tensorflow/java/src/main/native/session_jni.h" @@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array, env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT); } -void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, - jintArray src_index, TF_Output* dst, jint n) { - if (env->ExceptionCheck()) return; - jint len = env->GetArrayLength(src_op); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operations", n, len, type); - return; - } - len = env->GetArrayLength(src_index); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operation output indices", n, len, - type); - return; - } - jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); - jint* indices = env->GetIntArrayElements(src_index, nullptr); - for (int i = 0; i < n; ++i) { - if (op_handles[i] == 0) { - throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, - i, n); - break; - } - dst[i] = TF_Output{reinterpret_cast(op_handles[i]), - static_cast(indices[i])}; - } - env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); - env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); -} - void TF_MaybeDeleteBuffer(TF_Buffer* buf) { if (buf == nullptr) return; TF_DeleteBuffer(buf); diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..069ac05a1c39408dc02f5bbf9a7fc50fd095cc96 --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/utils_jni.h" + +#include "tensorflow/java/src/main/native/exception_jni.h" + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n) { + if (env->ExceptionCheck()) return; + jint len = env->GetArrayLength(src_op); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operations", n, len, type); + return; + } + len = env->GetArrayLength(src_index); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operation output indices", n, len, + type); + return; + } + jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); + jint* indices = env->GetIntArrayElements(src_index, nullptr); + for (int i = 0; i < n; ++i) { + if (op_handles[i] == 0) { + throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, + i, n); + break; + } + dst[i] = TF_Output{reinterpret_cast(op_handles[i]), + static_cast(indices[i])}; + } + env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); + env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); +} + + + + diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..352298e7de1d07cebc1a287774c9bef85c9a6ae4 --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_ +#define TENSORFLOW_JAVA_UTILS_JNI_H_ + +#include + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */ diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c540299bdcfcd7bc5969caf82b29144bad24201f..c2e52c22c6dc58a3002b536e64c4607b675804f7 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -129,4 +130,106 @@ public class GraphTest { // expected exception. } } + + @Test + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x1 = TestUtil.placeholder(g, "x1", Float.class); + Output x2 = TestUtil.placeholder(g, "x2", Float.class); + Output y0 = TestUtil.square(g, "y0", x1); + Output y1 = TestUtil.square(g, "y1", y0); + Output y2 = TestUtil.addN(g, y0, x2); + + Output[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); + + try (Tensor c1 = Tensors.create(3.0f); + Tensor c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad[0]) + .run() + .get(0)) { + + assertEquals(114.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run() + .get(0)) { + + assertEquals(108.0f, output.floatValue(), 0.0f); + } + } + } + + private static Output[] toArray(Output... outputs) { + return outputs; + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6458193161a98e17483fe73de107b77..7d5980bcdedebedcd2fa4722e85abc1d598fb4fd 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList> outputs = new AutoCloseableList>(result.outputs); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList extends ArrayList - implements AutoCloseable { - AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c973b5a3d8b2be8ee21710d65732bc1e5c3b520a..4e848864167982c750b390a77a1ab7f5d0d40fe9 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -16,9 +16,34 @@ limitations under the License. package org.tensorflow; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; /** Static utility functions. */ public class TestUtil { + + public static final class AutoCloseableList extends ArrayList + implements AutoCloseable { + AutoCloseableList(Collection c) { + super(c); + } + + @Override + public void close() { + Exception toThrow = null; + for (AutoCloseable c : this) { + try { + c.close(); + } catch (Exception e) { + toThrow = e; + } + } + if (toThrow != null) { + throw new RuntimeException(toThrow); + } + } + } + public static Output constant(Graph g, String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name) @@ -36,7 +61,7 @@ public class TestUtil { .output(0); } - public static Output addN(Graph g, Output... inputs) { + public static Output addN(Graph g, Output... inputs) { return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } @@ -58,6 +83,13 @@ public class TestUtil { .setAttr("num_split", numSplit) .build(); } + + public static Output square(Graph g, String name, Output value) { + return g.opBuilder("Square", name) + .addInput(value) + .build() + .output(0); + } public static void transpose_A_times_X(Graph g, int[][] a) { Output aa = constant(g, "A", a); diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f714d1fb21c753a03ad760522f1fa43a44048550..ebfcfff4a5263ec8af31b461d274a8a6f9b6ec34 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4,14 +4,16 @@ # Public targets: # ":platform" - Low-level and platform-specific Python code. -package(default_visibility = [ +visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", "//tensorflow/contrib/lite/toco/python:__pkg__", "//tensorflow_models:__subpackages__", # TODO(aselle): to pass open source test. "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__", -]) +] + +package(default_visibility = visibility) licenses(["notice"]) # Apache 2.0 @@ -55,12 +57,12 @@ py_library( "//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed - "//tensorflow/tools/api/generator:__pkg__", "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed ], deps = [ ":no_contrib", "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -71,6 +73,7 @@ py_library( visibility = [ "//tensorflow:__pkg__", "//tensorflow/python/tools:__pkg__", + "//tensorflow/tools/api/generator:__pkg__", ], deps = [ ":array_ops", @@ -79,6 +82,7 @@ py_library( ":check_ops", ":client", ":client_testlib", + ":collective_ops", ":confusion_matrix", ":control_flow_ops", ":cudnn_rnn_ops_gen", @@ -123,13 +127,14 @@ py_library( ":util", ":weights_broadcast_ops", "//tensorflow/core:protos_all_py", + "//tensorflow/python/compat", "//tensorflow/python/data", - "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", "//tensorflow/python/ops/losses", + "//tensorflow/python/ops/parallel_for", "//tensorflow/python/profiler", "//tensorflow/python/saved_model", "//third_party/py/numpy", @@ -255,7 +260,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/py/numpy:headers", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -268,7 +273,7 @@ cc_library( ":safe_ptr", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -276,6 +281,9 @@ cc_library( name = "ndarray_tensor_bridge", srcs = ["lib/core/ndarray_tensor_bridge.cc"], hdrs = ["lib/core/ndarray_tensor_bridge.h"], + visibility = visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], deps = [ ":bfloat16_lib", ":numpy_lib", @@ -292,7 +300,7 @@ cc_library( deps = [ "//tensorflow/c:c_api", "//tensorflow/core:lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -313,9 +321,9 @@ cc_library( hdrs = ["util/util.h"], deps = [ ":safe_ptr", - "//tensorflow/core:framework", "//tensorflow/core:lib", - "//util/python:python_headers", + "//tensorflow/core:lib_internal", + "//third_party/python_runtime:headers", ], ) @@ -337,7 +345,7 @@ cc_library( "//tensorflow/core:script_ops_op_lib", "//tensorflow/python/eager:pywrap_tfe_lib", "//third_party/py/numpy:headers", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -348,7 +356,7 @@ cc_library( deps = [ "//tensorflow/c:c_api", "//tensorflow/c/eager:c_api", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -356,6 +364,9 @@ cc_library( name = "ndarray_tensor", srcs = ["lib/core/ndarray_tensor.cc"], hdrs = ["lib/core/ndarray_tensor.h"], + visibility = visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], deps = [ ":bfloat16_lib", ":ndarray_tensor_bridge", @@ -378,7 +389,7 @@ cc_library( ":safe_ptr", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -389,7 +400,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:script_ops_op_lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -688,12 +699,22 @@ py_library( ], ) +py_library( + name = "error_interpolation", + srcs = [ + "framework/error_interpolation.py", + ], + srcs_version = "PY2AND3", + deps = [], +) + py_library( name = "function", srcs = ["framework/function.py"], srcs_version = "PY2AND3", deps = [ ":array_ops", + ":cond_v2_impl", ":dtypes", ":framework_ops", ":graph_to_function_def", @@ -710,11 +731,44 @@ py_library( srcs = ["framework/graph_to_function_def.py"], srcs_version = "PY2AND3", deps = [ + ":cond_v2_impl", + ":op_def_registry", + "//tensorflow/core:protos_all_py", + ], +) + +py_library( + name = "function_def_to_graph", + srcs = ["framework/function_def_to_graph.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework", + ":function", ":op_def_registry", + ":tensor_shape", + ":versions", "//tensorflow/core:protos_all_py", ], ) +py_test( + name = "function_def_to_graph_test", + size = "small", + srcs = ["framework/function_def_to_graph_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":array_ops", + ":client_testlib", + ":dtypes", + ":framework_ops", + ":function_def_to_graph", + ":graph_to_function_def", + ":math_ops", + ":test_ops", + ], +) + py_library( name = "graph_util", srcs = [ @@ -956,6 +1010,18 @@ py_test( ], ) +py_test( + name = "framework_error_interpolation_test", + size = "small", + srcs = ["framework/error_interpolation_test.py"], + main = "framework/error_interpolation_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":error_interpolation", + ], +) + py_test( name = "framework_subscribe_test", size = "small", @@ -1016,7 +1082,9 @@ py_test( tf_gen_op_wrapper_private_py( name = "functional_ops_gen", - visibility = ["//learning/brain/python/ops:__pkg__"], + visibility = [ + "//learning/brain/python/ops:__pkg__", + ], ) py_library( @@ -1435,6 +1503,14 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "collective_ops_gen", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:collective_ops_op_lib", + ], +) + tf_gen_op_wrapper_private_py( name = "control_flow_ops_gen", visibility = [ @@ -1555,6 +1631,9 @@ tf_gen_op_wrapper_private_py( tf_gen_op_wrapper_private_py( name = "resource_variable_ops_gen", + visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], ) tf_gen_op_wrapper_private_py( @@ -1736,9 +1815,33 @@ py_test( ], ) +py_library( + name = "collective_ops", + srcs = ["ops/collective_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":collective_ops_gen", + ":framework_for_generated_wrappers", + ], +) + +py_test( + name = "collective_ops_test", + size = "small", + srcs = ["ops/collective_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":collective_ops", + ":framework_for_generated_wrappers", + "//third_party/py/numpy", + ], +) + py_library( name = "control_flow_grad", - srcs = ["ops/control_flow_grad.py"], + srcs = + ["ops/control_flow_grad.py"], srcs_version = "PY2AND3", deps = [ ":control_flow_ops", @@ -1758,6 +1861,7 @@ py_library( "tensor_shape", ":array_ops", ":array_ops_gen", + ":cond_v2_impl", ":constant_op", ":control_flow_ops_gen", ":control_flow_util", @@ -1786,6 +1890,37 @@ py_library( ], ) +py_library( + name = "cond_v2", + srcs = [ + "ops/cond_v2.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2_impl", + ":function", + ":function_def_to_graph", + ":gradients", + ], +) + +py_library( + name = "cond_v2_impl", + srcs = [ + "ops/cond_v2_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":c_api_util", + ":framework_ops", + ":functional_ops_gen", + ":pywrap_tensorflow", + ":util", + "//tensorflow/core:protos_all_py", + ], +) + py_library( name = "ctc_ops", srcs = ["ops/ctc_ops.py"], @@ -1852,6 +1987,8 @@ py_library( ":math_ops", ":platform", ":resource_variable_ops", + ":sparse_ops", + ":tensor_shape", ":variables", ], ) @@ -1868,6 +2005,7 @@ py_library( ":array_grad", ":array_ops", ":bitwise_ops", + ":cond_v2_impl", ":control_flow_grad", ":control_flow_ops", ":control_flow_util", @@ -1884,6 +2022,7 @@ py_library( ":math_grad", ":math_ops", ":platform", + ":random_grad", ":resource_variable_ops", ":spectral_grad", ":util", @@ -2262,6 +2401,19 @@ py_library( ], ) +py_library( + name = "random_grad", + srcs = ["ops/random_grad.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":dtypes", + ":framework_ops", + ":math_ops", + ":random_ops_gen", + ], +) + py_library( name = "random_ops", srcs = ["ops/random_ops.py"], @@ -2322,6 +2474,7 @@ py_library( srcs = ["ops/script_ops.py"], srcs_version = "PY2AND3", deps = [ + ":array_ops", ":framework_for_generated_wrappers", ":script_ops_gen", "//third_party/py/numpy", @@ -2461,6 +2614,7 @@ py_library( ":check_ops", ":confusion_matrix", ":control_flow_ops", + ":distribute", ":framework", ":framework_for_generated_wrappers", ":math_ops", @@ -2666,7 +2820,6 @@ py_library( ":util", ":variables", "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:util", "@six_archive//:six", ], ) @@ -3269,6 +3422,19 @@ py_library( ], ) +py_test( + name = "lock_util_test", + size = "small", + srcs = ["util/lock_util_test.py"], + main = "util/lock_util_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + "@absl_py//absl/testing:parameterized", + ], +) + tf_proto_library( name = "protos_all", srcs = glob( @@ -3404,7 +3570,7 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//third_party/py/numpy:headers", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -3456,6 +3622,7 @@ tf_py_wrap_cc( "util/transform_graph.i", "util/util.i", ], + # add win_def_file win_def_file = select({ "//tensorflow:windows": ":pywrap_tensorflow_filtered_def_file", "//conditions:default": None, @@ -3475,6 +3642,7 @@ tf_py_wrap_cc( ":py_record_writer_lib", ":python_op_gen", ":tf_session_helper", + "//third_party/python_runtime:headers", "//tensorflow/c:c_api", "//tensorflow/c:checkpoint_reader", "//tensorflow/c:python_api", @@ -3497,7 +3665,6 @@ tf_py_wrap_cc( "//tensorflow/core/profiler/internal:print_model_analysis", "//tensorflow/tools/graph_transforms:transform_graph_lib", "//tensorflow/python/eager:pywrap_tfe_lib", - "//util/python:python_headers", ] + (tf_additional_lib_deps() + tf_additional_plugin_deps() + tf_additional_verbs_deps() + @@ -3586,6 +3753,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":c_api_util", + ":error_interpolation", ":errors", ":framework", ":framework_for_generated_wrappers", @@ -3786,7 +3954,7 @@ tf_cuda_library( tf_py_test( name = "session_test", - size = "small", + size = "medium", srcs = ["client/session_test.py"], additional_deps = [ ":array_ops", @@ -3968,6 +4136,19 @@ py_test( ], ) +py_test( + name = "tf_record_test", + size = "small", + srcs = ["lib/io/tf_record_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":errors", + ":lib", + ":util", + ], +) + cuda_py_test( name = "adam_test", size = "small", @@ -4275,7 +4456,7 @@ py_test( py_test( name = "warm_starting_util_test", - size = "small", + size = "medium", srcs = ["training/warm_starting_util_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index cf707fb2c731c0db57c2335d3ffd49b292c811cc..a2ab63bb48799d5b93882bb87ab40b02dbb96621 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -79,7 +79,6 @@ from tensorflow.python.ops import initializers_ns as initializers # Bring in subpackages. from tensorflow.python import data from tensorflow.python import keras -from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.layers import layers from tensorflow.python.ops import bitwise_ops as bitwise diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 5507d011bb0746c84b868ca7efcc3e4f8d2e146a..e037925961f2bfc8b8906fa81c2d7908ea590a62 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -361,7 +361,7 @@ class _ListFetchMapper(_FetchMapper): for m, vi in zip(self._mappers, self._value_indices): results.append(m.build_results([values[j] for j in vi])) # Return a value of the original type of the fetches. - if self._fetch_type == list: + if issubclass(self._fetch_type, list): return results elif self._fetch_type == tuple: return tuple(results) @@ -619,21 +619,12 @@ class BaseSession(SessionInterface): self._config = None self._add_shapes = False - # pylint: disable=protected-access - # We cache _USE_C_API's value because some test cases will create a session - # with _USE_C_API = False but set it back to True before calling close(). - self._created_with_new_api = ops._USE_C_API - # pylint: enable=protected-access - self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: - if self._created_with_new_api: - # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) - # pylint: enable=protected-access - else: - self._session = tf_session.TF_NewDeprecatedSession(opts) + # pylint: disable=protected-access + self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts) @@ -660,11 +651,7 @@ class BaseSession(SessionInterface): Returns: A list of devices in the session. """ - if self._created_with_new_api: - raw_device_list = tf_session.TF_SessionListDevices(self._session) - else: - raw_device_list = tf_session.TF_DeprecatedSessionListDevices( - self._session) + raw_device_list = tf_session.TF_SessionListDevices(self._session) device_list = [] size = tf_session.TF_DeviceListCount(raw_device_list) for i in range(size): @@ -684,16 +671,9 @@ class BaseSession(SessionInterface): tf.errors.OpError: Or one of its subclasses if an error occurs while closing the TensorFlow session. """ - if self._created_with_new_api: - if self._session and not self._closed: - self._closed = True - tf_session.TF_CloseSession(self._session) - - else: - with self._extend_lock: - if self._opened and not self._closed: - self._closed = True - tf_session.TF_CloseDeprecatedSession(self._session) + if self._session and not self._closed: + self._closed = True + tf_session.TF_CloseSession(self._session) def __del__(self): # cleanly ignore all exceptions @@ -703,10 +683,7 @@ class BaseSession(SessionInterface): pass if self._session is not None: try: - if self._created_with_new_api: - tf_session.TF_DeleteSession(self._session) - else: - tf_session.TF_DeleteDeprecatedSession(self._session) + tf_session.TF_DeleteSession(self._session) except AttributeError: # At shutdown, `c_api_util` or `tf_session` may have been garbage # collected, causing the above method calls to fail. In this case, @@ -1005,12 +982,9 @@ class BaseSession(SessionInterface): try: subfeed_t = self.graph.as_graph_element( subfeed, allow_tensor=True, allow_operation=False) - if self._created_with_new_api: - # pylint: disable=protected-access - feed_list.append(subfeed_t._as_tf_output()) - # pylint: enable=protected-access - else: - feed_list.append(compat.as_bytes(subfeed_t.name)) + # pylint: disable=protected-access + feed_list.append(subfeed_t._as_tf_output()) + # pylint: enable=protected-access except Exception as e: e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) e.args = (e.message,) @@ -1023,22 +997,13 @@ class BaseSession(SessionInterface): # Set up a graph with feeds and fetches for partial run. def _setup_fn(session, feed_list, fetch_list, target_list): self._extend_graph() - if self._created_with_new_api: - return tf_session.TF_SessionPRunSetup_wrapper( - session, feed_list, fetch_list, target_list) - else: - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_PRunSetup(session, feed_list, fetch_list, - target_list, status) + return tf_session.TF_SessionPRunSetup_wrapper( + session, feed_list, fetch_list, target_list) - if self._created_with_new_api: - # pylint: disable=protected-access - final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] - final_targets = [op._c_op for op in fetch_handler.targets()] - # pylint: enable=protected-access - else: - final_fetches = _name_list(fetch_handler.fetches()) - final_targets = _name_list(fetch_handler.targets()) + # pylint: disable=protected-access + final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] + final_targets = [op._c_op for op in fetch_handler.targets()] + # pylint: enable=protected-access return self._do_call(_setup_fn, self._session, feed_list, final_fetches, final_targets) @@ -1196,14 +1161,10 @@ class BaseSession(SessionInterface): # Create a fetch handler to take care of the structure of fetches. fetch_handler = _FetchHandler(self._graph, fetches, {}) - if self._created_with_new_api: - # pylint: disable=protected-access - fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] - target_list = [op._c_op for op in fetch_handler.targets()] - # pylint: enable=protected-access - else: - fetch_list = _name_list(fetch_handler.fetches()) - target_list = _name_list(fetch_handler.targets()) + # pylint: disable=protected-access + fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] + target_list = [op._c_op for op in fetch_handler.targets()] + # pylint: enable=protected-access def _callable_template_with_options_and_metadata(fetch_list, target_list, @@ -1289,16 +1250,11 @@ class BaseSession(SessionInterface): Raises: tf.errors.OpError: Or one of its subclasses on error. """ - if self._created_with_new_api: - # pylint: disable=protected-access - feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items()) - fetches = [t._as_tf_output() for t in fetch_list] - targets = [op._c_op for op in target_list] - # pylint: enable=protected-access - else: - feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items()) - fetches = _name_list(fetch_list) - targets = _name_list(target_list) + # pylint: disable=protected-access + feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items()) + fetches = [t._as_tf_output() for t in fetch_list] + targets = [op._c_op for op in target_list] + # pylint: enable=protected-access def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): # Ensure any changes to the graph are reflected in the runtime. @@ -1335,22 +1291,8 @@ class BaseSession(SessionInterface): raise type(e)(node_def, op, message) def _extend_graph(self): - if self._created_with_new_api: - with self._graph._lock: # pylint: disable=protected-access - tf_session.ExtendSession(self._session) - else: - # Ensure any changes to the graph are reflected in the runtime. - with self._extend_lock: - if self._graph.version > self._current_version: - # pylint: disable=protected-access - graph_def, self._current_version = self._graph._as_graph_def( - from_version=self._current_version, add_shapes=self._add_shapes) - # pylint: enable=protected-access - - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_ExtendGraph(self._session, - graph_def.SerializeToString(), status) - self._opened = True + with self._graph._session_run_lock(): # pylint: disable=protected-access + tf_session.ExtendSession(self._session) # The threshold to run garbage collection to delete dead tensors. _DEAD_HANDLES_THRESHOLD = 10 @@ -1403,24 +1345,13 @@ class BaseSession(SessionInterface): def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata): - if self._created_with_new_api: - return tf_session.TF_SessionRun_wrapper( - self._session, options, feed_dict, fetch_list, target_list, - run_metadata) - else: - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_Run( - self._session, options, feed_dict, fetch_list, target_list, - status, run_metadata) + return tf_session.TF_SessionRun_wrapper( + self._session, options, feed_dict, fetch_list, target_list, + run_metadata) def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): - if self._created_with_new_api: - return tf_session.TF_SessionPRun_wrapper( - self._session, handle, feed_dict, fetch_list) - else: - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_PRun( - self._session, handle, feed_dict, fetch_list, status) + return tf_session.TF_SessionPRun_wrapper( + self._session, handle, feed_dict, fetch_list) # pylint: disable=protected-access class _Callable(object): @@ -1433,25 +1364,29 @@ class BaseSession(SessionInterface): compat.as_bytes(callable_options.SerializeToString())) try: with errors.raise_exception_on_not_ok_status() as status: - if session._created_with_new_api: - self._handle = tf_session.TF_SessionMakeCallable( - session._session, options_ptr, status) - else: - self._handle = tf_session.TF_DeprecatedSessionMakeCallable( - session._session, options_ptr, status) + self._handle = tf_session.TF_SessionMakeCallable( + session._session, options_ptr, status) finally: tf_session.TF_DeleteBuffer(options_ptr) - def __call__(self, *args): + def __call__(self, *args, **kwargs): # TODO(b/74355905): Support argument and return value nested structures, # and tensor-like objects such as SparseTensors. - with errors.raise_exception_on_not_ok_status() as status: - if self._session._created_with_new_api: - return tf_session.TF_SessionRunCallable( - self._session._session, self._handle, args, status, None) - else: - return tf_session.TF_DeprecatedSessionRunCallable( - self._session._session, self._handle, args, status, None) + run_metadata = kwargs.get('run_metadata', None) + try: + run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None + # TODO(mrry): Switch to raising an exception from the SWIG wrapper. + with errors.raise_exception_on_not_ok_status() as status: + ret = tf_session.TF_SessionRunCallable( + self._session._session, self._handle, args, status, + run_metadata_ptr) + if run_metadata: + proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) + run_metadata.ParseFromString(compat.as_bytes(proto_data)) + finally: + if run_metadata_ptr: + tf_session.TF_DeleteBuffer(run_metadata_ptr) + return ret def __del__(self): # NOTE(mrry): It is possible that `self._session.__del__()` could be @@ -1459,12 +1394,8 @@ class BaseSession(SessionInterface): # will be `None`. if self._handle is not None and self._session._session is not None: with errors.raise_exception_on_not_ok_status() as status: - if self._session._created_with_new_api: - tf_session.TF_SessionReleaseCallable( - self._session._session, self._handle, status) - else: - tf_session.TF_DeprecatedSessionReleaseCallable( - self._session._session, self._handle, status) + tf_session.TF_SessionReleaseCallable( + self._session._session, self._handle, status) # pylint: enable=protected-access # TODO(b/74355905): Reimplement `Session.make_callable()` using this method diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e9a7d9ac1dc146d4c73f3e22bb6c3a9168776a10..b72e029d1ccb688f5992f6cc8695969be5e5e2e3 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import random import os import sys import threading @@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase): for t in threads: t.join() - def testParallelRunAndBuild(self): + @staticmethod + def _build_graph(): + time.sleep(random.random() * 0.1) + # Do some graph construction. Try to exercise non-trivial paths. + graph = ops.get_default_graph() + gdef = None + for _ in range(10): + x = array_ops.placeholder(dtype=dtypes.float32) + with ops.colocate_with(x): + y = array_ops.placeholder(dtype=dtypes.float32) + with ops.device('/cpu:0'): + z = control_flow_ops.while_loop( + lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) + with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): + gradients_impl.gradients(z, [x, y]) + if gdef is None: + gdef = graph.as_graph_def() + else: + importer.import_graph_def(gdef, name='import') + + def testParallelRunAndSingleBuild(self): with session.Session() as sess: c = constant_op.constant(5.0) stop = threading.Event() def run_loop(): while not stop.is_set(): + time.sleep(random.random() * 0.1) self.assertEqual(sess.run(c), 5.0) - threads = [self.checkedThread(target=run_loop) for _ in range(100)] + threads = [self.checkedThread(target=run_loop) for _ in range(10)] for t in threads: t.start() - # Do some graph construction. Try to exercise non-trivial paths. - graph = ops.get_default_graph() - gdef = None - for _ in range(10): - x = array_ops.placeholder(dtype=dtypes.float32) - with ops.colocate_with(x): - y = array_ops.placeholder(dtype=dtypes.float32) - with ops.device('/cpu:0'): - z = control_flow_ops.while_loop( - lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) - with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): - gradients_impl.gradients(z, [x, y]) - if gdef is None: - gdef = graph.as_graph_def() - else: - importer.import_graph_def(gdef, name='import') + SessionTest._build_graph() stop.set() for t in threads: t.join() + def testParallelRunAndParallelBuild(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + stop = threading.Event() + + def run_loop(): + while not stop.is_set(): + time.sleep(random.random() * 0.1) + self.assertEqual(sess.run(c), 5.0) + + run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] + for t in run_threads: + t.start() + + build_threads = [self.checkedThread(target=SessionTest._build_graph) + for _ in range(10)] + for t in build_threads: + t.start() + for t in build_threads: + t.join() + + # Let the run_threads run until the build threads are finished. + stop.set() + for t in run_threads: + t.join() + def testRunFeedDict(self): with session.Session() as s: x = array_ops.zeros([2]) @@ -1364,6 +1397,20 @@ class SessionTest(test_util.TensorFlowTestCase): for _ in range(5): self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) + def testOptimizedMakeCallableWithRunMetadata(self): + with session.Session() as sess: + ph = array_ops.placeholder(dtypes.float32) + a = math_ops.add(ph, 1.0) + callable_opts = config_pb2.CallableOptions() + callable_opts.feed.append(ph.name) + callable_opts.fetch.append(a.name) + callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + callable_fn = sess._make_callable_from_options(callable_opts) + run_metadata = config_pb2.RunMetadata() + self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), + run_metadata=run_metadata)) + self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) + def testFeedError(self): with session.Session() as sess: feed_t = array_ops.placeholder(dtype=dtypes.float32) @@ -1565,10 +1612,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) def testFeedShapeCompatibility(self): - # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. - if ops._USE_C_API: - return - with session.Session() as sess: some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) new_shape = constant_op.constant([2, 2]) @@ -1577,7 +1620,10 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) - with self.assertRaisesRegexp(ValueError, 'may not be fed'): + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + 'Input to reshape is a tensor with 4 values, ' + 'but the requested shape has 21'): sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) def testInferShapesFalse(self): diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 1db1432d6521bb5f48558081916158792010b1c5..985cb904360ac293461936bf67fb1b1de2c77b4a 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -135,7 +135,7 @@ tensorflow::ImportNumpy(); // Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers %typemap(out) int64_t { - $result = PyInt_FromLong($1); + $result = PyLong_FromLongLong($1); } // We use TF_OperationGetControlInputs_wrapper instead of @@ -610,7 +610,7 @@ def TF_Reset(target, containers=None, config=None): } for (size_t i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i])); } } @@ -673,7 +673,7 @@ def TF_Reset(target, containers=None, config=None): } for (size_t i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i])); } } diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..58ceafca0638a90c2e66ddea0e4bbb1547455f48 --- /dev/null +++ b/tensorflow/python/compat/BUILD @@ -0,0 +1,22 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "compat", + srcs = ["compat.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], +) + +tf_py_test( + name = "compat_test", + size = "small", + srcs = ["compat_test.py"], + additional_deps = [ + ":compat", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..68a6421c2c56c9f007cbd8aee3111c4abfde691c --- /dev/null +++ b/tensorflow/python/compat/compat.py @@ -0,0 +1,125 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for API compatibility between TensorFlow release versions. + +See +@{$guide/version_compat#backward_and_partial_forward_compatibility} +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +from tensorflow.python.util import tf_contextlib + +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 1) + + +def forward_compatible(year, month, day): + """Return true if the forward compatibility window has expired. + + Forward-compatibility refers to scenarios where the producer of a TensorFlow + model (a GraphDef or SavedModel) is compiled against a version of the + TensorFlow library newer than what the consumer was compiled against. The + "producer" is typically a Python program that constructs and trains a model + while the "consumer" is typically another program that loads and serves the + model. + + TensorFlow has been supporting a 3 week forward-compatibility window for + programs compiled from source at HEAD. + + For example, consider the case where a new operation `MyNewAwesomeAdd` is + created with the intent of replacing the implementation of an existing Python + wrapper - `tf.add`. The Python wrapper implementation should change from + something like: + + ```python + def add(inputs, name=None): + return gen_math_ops.add(inputs, name) + ``` + + to: + + ```python + from tensorflow.python.compat import compat + + def add(inputs, name=None): + if compat.forward_compatible(year, month, day): + # Can use the awesome new implementation. + return gen_math_ops.my_new_awesome_add(inputs, name) + # To maintain forward compatibiltiy, use the old implementation. + return gen_math_ops.add(inputs, name) + ``` + + Where `year`, `month`, and `day` specify the date beyond which binaries + that consume a model are expected to have been updated to include the + new operations. This date is typically at least 3 weeks beyond the date + the code that adds the new operation is committed. + + Args: + year: A year (e.g., 2018). + month: A month (1 <= month <= 12) in year. + day: A day (1 <= day <= 31, or 30, or 29, or 28) in month. + + Returns: + True if the caller can expect that serialized TensorFlow graphs produced + can be consumed by programs that are compiled with the TensorFlow library + source code after (year, month, day). + """ + return _FORWARD_COMPATIBILITY_HORIZON > datetime.date(year, month, day) + + +@tf_contextlib.contextmanager +def forward_compatibility_horizon(year, month, day): + """Context manager for testing forward compatibility of generated graphs. + + To ensure forward compatibility of generated graphs (see `forward_compatible`) + with older binaries, new features can be gated with: + + ```python + if compat.forward_compatible(year=2018, month=08, date=01): + generate_graph_with_new_features() + else: + generate_graph_so_older_binaries_can_consume_it() + ``` + + However, when adding new features, one may want to unittest it before + the forward compatibility window expires. This context manager enables + such tests. For example: + + ```python + from tensorflow.python.compat import compat + + def testMyNewFeature(self): + with compat.forward_compatibility_horizon(2018, 08, 02): + # Test that generate_graph_with_new_features() has an effect + ``` + + Args : + year: A year (e.g. 2018). + month: A month (1 <= month <= 12) in year. + day: A day (1 <= day <= 31, or 30, or 29, or 28) in month. + + Yields: + Nothing. + """ + global _FORWARD_COMPATIBILITY_HORIZON + try: + old_compat_date = _FORWARD_COMPATIBILITY_HORIZON + _FORWARD_COMPATIBILITY_HORIZON = datetime.date(year, month, day) + yield + finally: + _FORWARD_COMPATIBILITY_HORIZON = old_compat_date diff --git a/tensorflow/python/compat/compat_test.py b/tensorflow/python/compat/compat_test.py new file mode 100644 index 0000000000000000000000000000000000000000..946abbb300d66e7be5ea317e365bc75cbcf6941c --- /dev/null +++ b/tensorflow/python/compat/compat_test.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for forward and backwards compatibility utilties.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +from tensorflow.python.compat import compat +from tensorflow.python.platform import test + + +class CompatTest(test.TestCase): + + def _compatibility_date(self): + date = compat._FORWARD_COMPATIBILITY_HORIZON # pylint: disable=protected-access + return (date.year, date.month, date.day) + + def _n_days_after(self, n): + date = compat._FORWARD_COMPATIBILITY_HORIZON + datetime.timedelta(days=n) # pylint: disable=protected-access + return (date.year, date.month, date.day) + + def test_basic(self): + compatibility_date = self._compatibility_date() + one_day_before = self._n_days_after(-1) + self.assertTrue(compat.forward_compatible(*one_day_before)) + self.assertFalse(compat.forward_compatible(*compatibility_date)) + + def test_decorator(self): + compatibility_date = self._compatibility_date() + one_day_after = self._n_days_after(1) + with compat.forward_compatibility_horizon(*one_day_after): + self.assertTrue(compat.forward_compatible(*compatibility_date)) + self.assertFalse(compat.forward_compatible(*one_day_after)) + + # After exiting context manager, value should be reset. + self.assertFalse(compat.forward_compatible(*compatibility_date)) + + def test_decorator_with_failure(self): + compatibility_date = self._compatibility_date() + one_day_after = self._n_days_after(1) + + class DummyError(Exception): + pass + + try: + with compat.forward_compatibility_horizon(*one_day_after): + raise DummyError() + except DummyError: + pass # silence DummyError + + # After exiting context manager, value should be reset. + self.assertFalse(compat.forward_compatible(*compatibility_date)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py index 7efe0948e7729c398f972977b51426d80b8cd83e..3b9bf2469e6d41fd0e8c5199af677e60bedf93f9 100644 --- a/tensorflow/python/data/__init__.py +++ b/tensorflow/python/data/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """`tf.data.Dataset` API for input pipelines. -See the @{$datasets$Importing Data} Programmer's Guide for an overview. +See @{$guide/datasets$Importing Data} for an overview. """ from __future__ import absolute_import diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index ed0c11e6c117dcbb810fd3acfc484128ed3519fa..3bde62fa1d8a71c0d6f2bbfbff29bb842a9248f0 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -15,6 +15,7 @@ tf_py_test( size = "small", srcs = ["batch_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -72,6 +73,17 @@ tf_py_test( ], ) +tf_py_test( + name = "dataset_ops_test", + size = "small", + srcs = ["dataset_ops_test.py"], + additional_deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + tf_py_test( name = "filter_dataset_op_test", size = "small", @@ -167,6 +179,7 @@ tf_py_test( size = "small", srcs = ["prefetch_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dataset_ops_gen", diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index bd80b9dbf561de16168b05facf0086dadcda6444..89de55dd4f9fdc612663c839b926684d27d48c54 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math +import time +from absl.testing import parameterized import numpy as np +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,73 +37,83 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase): +class BatchDatasetTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('even', 28, 14, False), + ('uneven_with_remainder', 28, 15, False), + ('uneven_without_remainder', 28, 15, True), + ('empty', 0, 14, False), + ) + def testBatchDataset(self, count, batch_size, drop_remainder): + """Tests the batch dataset logic for various input configurations. + + Args: + count: the number of input elements + batch_size: the batch size + drop_remainder: whether a smaller batch size should be produced if batch + size does not divide number of inputs evenly + """ - def testBatchDataset(self): - """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> BatchDataset(batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = ( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) + .repeat(count).batch(batch_size, + drop_remainder).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - self.assertEqual([[None] + list(c.shape[1:]) for c in components], + if drop_remainder: + dim0 = batch_size + else: + dim0 = None + self.assertEqual([[dim0] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) with self.test_session() as sess: - # Batch of a finite input, where the batch_size divides the - # total number of elements. - sess.run(init_op, feed_dict={count: 28, batch_size: 14}) - num_batches = (28 * 7) // 14 - for i in range(num_batches): + sess.run( + init_op, + feed_dict={ + count_t: count, + batch_size_t: batch_size, + drop_remainder_t: drop_remainder + }) + num_full_batches = (count * 7) // batch_size + for i in range(num_full_batches): result = sess.run(get_next) for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i * 14 + j) % 7]**2, + for j in range(batch_size): + self.assertAllEqual(component[(i * batch_size + j) % 7]**2, result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of a finite input, where the batch_size does not - # divide the total number of elements. - sess.run(init_op, feed_dict={count: 14, batch_size: 8}) - - # We expect (num_batches - 1) full-sized batches. - num_batches = int(math.ceil((14 * 7) / 8)) - for i in range(num_batches - 1): + if not drop_remainder and (count * 7) % batch_size > 0: result = sess.run(get_next) for component, result_component in zip(components, result): - for j in range(8): - self.assertAllEqual(component[(i * 8 + j) % 7]**2, - result_component[j]) - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, - result_component[j]) + for j in range((count * 7) % batch_size): + self.assertAllEqual( + component[(num_full_batches * batch_size + j) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Batch of an empty input should fail straight away. - sess.run(init_op, feed_dict={count: 0, batch_size: 8}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def testBatchDatasetInvalidBatchSize(self): + iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator()) + get_next = iterator.get_next() - # Empty batch should be an initialization time error. + with self.test_session() as sess: with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + sess.run(get_next) def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -210,66 +222,108 @@ class BatchDatasetTest(test.TestCase): r'First element had shape \[3\] and element 2 had shape \[4\].'): sess.run(next_element) - def testPaddedBatchDataset(self): - seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) - padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) + +def _random_seq_lens(count): + return np.random.randint(20, size=(count,)).astype(np.int32) + + +class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('default_padding', _random_seq_lens(32), 4, [-1], False), + ('constant_padding', _random_seq_lens(32), 4, [25], False), + ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False), + ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True), + ) + def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes, + drop_remainder): + """Tests the padded batch dataset logic for various input configurations. + + Args: + seq_lens: the input sequence lengths + batch_size: the batch size + padded_shapes: the padded shapes to use + drop_remainder: whether a smaller batch size should be produced if batch + size does not divide number of inputs evenly + """ + + seq_lens_t = array_ops.placeholder(dtypes.int32, shape=[None]) + batch_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + padded_shapes_t = array_ops.placeholder(dtypes.int64, shape=[1]) + drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[]) iterator = ( - dataset_ops.Dataset.from_tensor_slices(seq_lens) + dataset_ops.Dataset.from_tensor_slices(seq_lens_t) .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=padded_shape).make_initializable_iterator()) + batch_size=batch_size_t, + drop_remainder=drop_remainder_t, + padded_shapes=padded_shapes_t).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: - # Test with random sequence lengths, and max padding. - random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) sess.run( - init_op, feed_dict={ - padded_shape: [-1], - seq_lens: random_seq_lens + init_op, + feed_dict={ + seq_lens_t: seq_lens, + batch_size_t: batch_size, + padded_shapes_t: padded_shapes, + drop_remainder_t: drop_remainder, }) - for i in range(8): + + num_full_batches = len(seq_lens) // batch_size + + for i in range(num_full_batches): result = sess.run(get_next) - padded_len = np.max(result) - self.assertEqual((4, padded_len), result.shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] + padded_len = padded_shapes[0] + if padded_len is None or padded_len == -1: + padded_len = np.max(result) if result.size > 0 else 0 + self.assertEqual((batch_size, padded_len), result.shape) + for j in range(batch_size): + seq_len = seq_lens[(i * batch_size) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + self.assertAllEqual(result[j, seq_len:], + [0] * (padded_len - seq_len)) - # Test with random sequence lengths, and constant padding. - sess.run( - init_op, feed_dict={ - padded_shape: [25], - seq_lens: random_seq_lens - }) - for i in range(8): + if not drop_remainder and len(seq_lens) % batch_size > 0: result = sess.run(get_next) - self.assertEqual((4, 25), result.shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] + padded_len = np.max(result) if result.size > 0 else 0 + self.assertEqual((len(seq_lens) % batch_size, padded_len), + result.shape) + for j in range(len(seq_lens) % batch_size): + seq_len = seq_lens[num_full_batches * batch_size + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) + self.assertAllEqual(result[j, seq_len:], + [0] * (padded_len - seq_len)) + with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) + def testPaddedBatchShortPadding(self): + iterator = ( + dataset_ops.Dataset.from_tensor_slices([6, 5, 5, 5, 5]) + .map(lambda x: array_ops.fill([x], x)).padded_batch( + batch_size=4, padded_shapes=[5]).make_one_shot_iterator()) + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaises(errors.DataLossError): + sess.run(get_next) + + def testPaddedBatchEmptyTensors(self): + iterator = ( + dataset_ops.Dataset.from_tensor_slices([0, 0, 0, 0]) + .map(lambda x: array_ops.fill([x], x)).padded_batch( + batch_size=4, padded_shapes=[-1]).make_one_shot_iterator()) + get_next = iterator.get_next() + + with self.test_session() as sess: result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Test error handling with constant sequence lengths, and - # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) - with self.assertRaises(errors.DataLossError): - result = sess.run(get_next) - def testPaddedBatchDatasetNonDefaultPadding(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) @@ -371,6 +425,94 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(TypeError): _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) + def testPaddedBatchShapeError(self): + with self.assertRaisesRegexp( + ValueError, r'The padded shape \(1,\) is not compatible with the ' + r'corresponding input component shape \(\).'): + _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1]) + + with self.assertRaisesRegexp( + ValueError, r'The padded shape \(1,\) is not compatible with the ' + r'corresponding input component shape \(3,\).'): + _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( + 5, padded_shapes=[1]) + + with self.assertRaisesRegexp( + ValueError, r'Padded shape .* must be a 1-D tensor ' + r'of tf.int64 values, but its shape was \(2, 2\).'): + _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( + 5, padded_shapes=[[1, 1], [1, 1]]) + + with self.assertRaisesRegexp( + TypeError, r'Padded shape .* must be a 1-D tensor ' + r'of tf.int64 values, but its element type was float32.'): + _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( + 5, padded_shapes=constant_op.constant([1., 2., 3.])) + + with self.assertRaisesRegexp( + ValueError, r'The padded shape \(1,\) is not compatible with the ' + r'corresponding input component shape \(\).'): + shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64) + _ = dataset_ops.Dataset.range(10).padded_batch( + 5, padded_shapes=shape_as_tensor) + + with self.assertRaisesRegexp( + ValueError, r'The padded shape \(\?, \?\) is not compatible with the ' + r'corresponding input component shape \(\).'): + shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2]) + _ = dataset_ops.Dataset.range(10).padded_batch( + 5, padded_shapes=shape_as_tensor) + + +class BatchDatasetBenchmark(test.Benchmark): + + def benchmarkBatchSparse(self): + non_zeros_per_row_values = [0, 1, 5, 10, 100] + batch_size_values = [1, 32, 64, 128, 1024] + + sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64) + batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[]) + + dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat( + ).batch(batch_size_placeholder) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + for non_zeros_per_row in non_zeros_per_row_values: + + sparse_value = sparse_tensor.SparseTensorValue( + indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis], + values=np.arange(non_zeros_per_row, dtype=np.int64), + dense_shape=[1000]) + + for batch_size in batch_size_values: + + with session.Session() as sess: + sess.run(iterator.initializer, feed_dict={ + sparse_placeholder: sparse_value, + batch_size_placeholder: batch_size}) + # Run five steps to warm up the session caches before taking the + # first measurement. + for _ in range(5): + sess.run(next_element.indices.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.indices.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100.0 + + print('Batch sparse dataset non-zeros per row: %d batch_size: %d ' + 'wall time: %f' + % (non_zeros_per_row, batch_size, median_wall_time)) + self.report_benchmark( + iters=10000, wall_time=median_wall_time, + name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % ( + non_zeros_per_row, batch_size)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py index 296a76ec887ae7c31cb9d0bd2afd6d1fe827d95c..fb55ae140058349753731b0c257acb3cf3def0a3 100644 --- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py @@ -259,9 +259,7 @@ class DatasetConstructorTest(test.TestCase): sess.run(init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) self.assertAllEqual([4, 5, 6], sess.run(get_next)) - # NOTE(mrry): Type name in message differs between Python 2 (`long`) and - # 3 (`int`). - with self.assertRaisesOpError(r"invalid literal for"): + with self.assertRaisesOpError("The expected type was int64"): sess.run(get_next) self.assertAllEqual([7, 8, 9], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): @@ -290,6 +288,34 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testFromGeneratorStructureError(self): + def generator(): + yield 1, 2 + yield 3, 4 + yield 5 + yield 6, 7, 8 + yield 9, 10 + + iterator = (dataset_ops.Dataset.from_generator( + generator, output_types=(dtypes.int64, dtypes.int64)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + self.assertEqual((1, 2), sess.run(get_next)) + self.assertEqual((3, 4), sess.run(get_next)) + with self.assertRaisesOpError( + r"The expected structure was \(tf\.int64, tf\.int64\)"): + sess.run(get_next) + with self.assertRaisesOpError( + r"The expected structure was \(tf\.int64, tf\.int64\)"): + sess.run(get_next) + self.assertEqual((9, 10), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testFromGeneratorHeterogeneous(self): def generator(): yield 1 diff --git a/tensorflow/python/keras/_impl/keras/datasets/__init__.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py similarity index 58% rename from tensorflow/python/keras/_impl/keras/datasets/__init__.py rename to tensorflow/python/data/kernel_tests/dataset_ops_test.py index 60db3766fbce859269cecb92a537084ef18c0da5..2c4c11e132d1fc9b8969540994a097098279dd9e 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/__init__.py +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -12,17 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras datasets: utilities for downloading and pre-processing common datasets. +"""Tests for the input pipeline ops.""" -""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets import boston_housing -from tensorflow.python.keras._impl.keras.datasets import cifar10 -from tensorflow.python.keras._impl.keras.datasets import cifar100 -from tensorflow.python.keras._impl.keras.datasets import fashion_mnist -from tensorflow.python.keras._impl.keras.datasets import imdb -from tensorflow.python.keras._impl.keras.datasets import mnist -from tensorflow.python.keras._impl.keras.datasets import reuters +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class DatasetOpsTest(test.TestCase): + + def testAsSerializedGraph(self): + dataset = dataset_ops.Dataset.range(10) + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 1ad0b9de5e76e3edd66303ab4666108f43a27428..0ecd821e9e473522b0cf4bd7bbceb071ecf5bb9e 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from collections import namedtuple import threading import time +import warnings import numpy as np @@ -638,6 +639,33 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testWarnOnLookupTable(self): + def collecting_function(x): + _ = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer([], []), 0.0, name="t1") + return x + + warnings.simplefilter("always") + with warnings.catch_warnings(record=True) as w: + _ = dataset_ops.Dataset.range(10).map(collecting_function) + # NOTE(mrry): Python 3 prints other warnings in addition to the one we are + # testing, so we search for the expected warning. + self.assertGreaterEqual(len(w), 1) + found_warning = False + for warning in w: + if ("Creating lookup tables inside a function passed to Dataset.map() is " + "not supported." in str(warning)): + found_warning = True + break + self.assertTrue(found_warning) + + def testNestedDatasetError(self): + dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]) + with self.assertRaisesRegexp( + NotImplementedError, r"The Dataset.map\(\) transformation does not " + "currently support nested datasets as outputs."): + _ = dataset.map(dataset_ops.Dataset.from_tensor_slices) + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py index 646324cb95df6fc1fa0a901ebdccc8d4ef74a66c..63a0830272dca254866c1609fec3677ab28749d5 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -24,35 +26,33 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PrefetchDatasetTest(test.TestCase): +class PrefetchDatasetTest(test.TestCase, parameterized.TestCase): - def testBufferSize(self): - buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) + @parameterized.parameters((-1), (0), (5)) + def testBufferSize(self, buffer_size): + buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(10).prefetch( - buffer_size=buffer_size).make_initializable_iterator() + buffer_size=buffer_size_t).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: - sess.run(init_op, feed_dict={buffer_size: 5}) + sess.run(init_op, feed_dict={buffer_size_t: buffer_size}) for m in range(10): self.assertEqual(m, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testInvalidBufferSize(self): - buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) + @parameterized.parameters((-2), (-42)) + def testInvalidBufferSize(self, buffer_size): + buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(10).prefetch( - buffer_size=buffer_size).make_initializable_iterator() + buffer_size=buffer_size_t).make_initializable_iterator() init_op = iterator.initializer with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): with self.test_session() as sess: - sess.run(init_op, feed_dict={buffer_size: 0}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): - with self.test_session() as sess: - sess.run(init_op, feed_dict={buffer_size: -5}) + sess.run(init_op, feed_dict={buffer_size_t: buffer_size}) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index 1ddedfda4e1c9d6b6949f796be1870f167435763..e99f0a203b4d8b83fc6a95163e23b74300f6f6b8 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -24,6 +24,7 @@ import zlib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -38,6 +39,13 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat +try: + import psutil # pylint: disable=g-import-not-at-top + psutil_import_succeeded = True +except ImportError: + psutil_import_succeeded = False + + class TextLineDatasetTest(test.TestCase): def _lineText(self, f, l): @@ -162,6 +170,34 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def testIteratorResourceCleanup(self): + filename = os.path.join(self.get_temp_dir(), "text.txt") + with open(filename, "wt") as f: + for i in range(3): + f.write("%d\n" % (i,)) + with context.eager_mode(): + first_iterator = iter(readers.TextLineDataset(filename)) + self.assertEqual(b"0", next(first_iterator).numpy()) + second_iterator = iter(readers.TextLineDataset(filename)) + self.assertEqual(b"0", next(second_iterator).numpy()) + # Eager kernel caching is based on op attributes, which includes the + # Dataset's output shape. Create a different kernel to test that they + # don't create resources with the same names. + different_kernel_iterator = iter( + readers.TextLineDataset(filename).repeat().batch(16)) + self.assertEqual([16], next(different_kernel_iterator).shape) + # Remove our references to the Python Iterator objects, which (assuming no + # reference cycles) is enough to trigger DestroyResourceOp and close the + # partially-read files. + del first_iterator + del second_iterator + del different_kernel_iterator + if not psutil_import_succeeded: + self.skipTest( + "psutil is required to check that we've closed our files.") + open_files = psutil.Process().open_files() + self.assertNotIn(filename, [open_file.path for open_file in open_files]) + class FixedLengthRecordReaderTest(test.TestCase): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6a3f6bf40c22bf59715bfbd12e0080b704eb526f..d2a8c0f3137aa25d2e5327cd4e61c04298656e4d 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -19,6 +19,7 @@ from __future__ import print_function import abc import threading +import warnings import numpy as np import six @@ -32,6 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -57,6 +59,15 @@ class Dataset(object): def __init__(self): pass + def _as_serialized_graph(self): + """Produces serialized graph representation of the dataset. + + Returns: + A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a + serialized graph. + """ + return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor()) + @abc.abstractmethod def _as_variant_tensor(self): """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. @@ -97,12 +108,7 @@ class Dataset(object): if shared_name is None: shared_name = "" iterator_resource = gen_dataset_ops.iterator( - container="", - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + container="", shared_name=shared_name, **flat_structure(self)) with ops.colocate_with(iterator_resource): initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), iterator_resource) @@ -160,13 +166,8 @@ class Dataset(object): return iterator_ops.Iterator( gen_dataset_ops.one_shot_iterator( - dataset_factory=_make_dataset, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, - self.output_classes))), None, - self.output_types, self.output_shapes, self.output_classes) + dataset_factory=_make_dataset, **flat_structure(self)), + None, self.output_types, self.output_shapes, self.output_classes) @abc.abstractproperty def output_classes(self): @@ -212,6 +213,13 @@ class Dataset(object): def from_tensors(tensors): """Creates a `Dataset` with a single element, comprising the given tensors. + Note that if `tensors` contains a NumPy array, and eager execution is not + enabled, the values will be embedded in the graph as one or more + @{tf.constant} operations. For large datasets (> 1 GB), this can waste + memory and run into byte limits of graph serialization. If tensors contains + one or more large NumPy arrays, consider the alternative described in + @{$guide/datasets#consuming_numpy_arrays$this guide}. + Args: tensors: A nested structure of tensors. @@ -224,6 +232,13 @@ class Dataset(object): def from_tensor_slices(tensors): """Creates a `Dataset` whose elements are slices of the given tensors. + Note that if `tensors` contains a NumPy array, and eager execution is not + enabled, the values will be embedded in the graph as one or more + @{tf.constant} operations. For large datasets (> 1 GB), this can waste + memory and run into byte limits of graph serialization. If tensors contains + one or more large NumPy arrays, consider the alternative described in + @{$guide/datasets#consuming_numpy_arrays$this guide}. + Args: tensors: A nested structure of tensors, each having the same size in the 0th dimension. @@ -398,13 +413,23 @@ class Dataset(object): # Use the same _convert function from the py_func() implementation to # convert the returned values to arrays early, so that we can inspect # their values. - # pylint: disable=protected-access - ret_arrays = [ - script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype) - for ret, dtype in zip( - nest.flatten_up_to(output_types, values), flattened_types) - ] - # pylint: enable=protected-access + try: + flattened_values = nest.flatten_up_to(output_types, values) + except (TypeError, ValueError): + raise TypeError( + "`generator` yielded an element that did not match the expected " + "structure. The expected structure was %s, but the yielded " + "element was %s." % (output_types, values)) + ret_arrays = [] + for ret, dtype in zip(flattened_values, flattened_types): + try: + ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access + ret, dtype=dtype.as_numpy_dtype)) + except (TypeError, ValueError): + raise TypeError( + "`generator` yielded an element that could not be converted to " + "the expected type. The expected type was %s, but the yielded " + "element was %s." % (dtype.name, ret)) # Additional type and shape checking to ensure that the components # of the generated element match the `output_types` and `output_shapes` @@ -740,7 +765,6 @@ class Dataset(object): d = d.shard(FLAGS.num_workers, FLAGS.worker_index) d = d.repeat(FLAGS.num_epochs) d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.repeat() d = d.interleave(tf.data.TFRecordDataset, cycle_length=FLAGS.num_readers, block_length=1) d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) @@ -782,35 +806,50 @@ class Dataset(object): return self._enumerate().filter(filter_fn).map(lambda _, elem: elem) - def batch(self, batch_size): + def batch(self, batch_size, drop_remainder=False): """Combines consecutive elements of this dataset into batches. - NOTE: If the number of elements (`N`) in this dataset is not an exact - multiple of `batch_size`, the final batch contain smaller tensors with - shape `N % batch_size` in the batch dimension. If your program depends on - the batches having the same shape, consider using the - @{tf.contrib.data.batch_and_drop_remainder} transformation instead. + The tensors in the resulting element will have an additional outer + dimension, which will be `batch_size` (or `N % batch_size` for the last + element if `batch_size` does not divide the number of input elements `N` + evenly and `drop_remainder` is `False`). If your program depends on the + batches having the same outer dimension, you should set the `drop_remainder` + argument to `True` to prevent the smaller batch from being produced. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in the case its has fewer than + `batch_size` elements; the default behavior is not to drop the smaller + batch. Returns: Dataset: A `Dataset`. """ - return BatchDataset(self, batch_size) + return BatchDataset(self, batch_size, drop_remainder) - def padded_batch(self, batch_size, padded_shapes, padding_values=None): + def padded_batch(self, + batch_size, + padded_shapes, + padding_values=None, + drop_remainder=False): """Combines consecutive elements of this dataset into padded batches. This transformation combines multiple consecutive elements of the input - dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors - in the resulting element have an additional outer dimension, which will be - `batch_size` for all but the last element, and `N % batch_size` for the - last element (where `N` is the number of elements in this dataset). Unlike - @{tf.data.Dataset.batch}, the elements may have different shapes for some - of their components, and this transformation will pad each component to - the respective shape in `padding_shapes`. The `padding_shapes` argument + dataset into a single element. + + Like @{tf.data.Dataset.batch}, the tensors in the resulting element will + have an additional outer dimension, which will be `batch_size` (or + `N % batch_size` for the last element if `batch_size` does not divide the + number of input elements `N` evenly and `drop_remainder` is `False`). If + your program depends on the batches having the same outer dimension, you + should set the `drop_remainder` argument to `True` to prevent the smaller + batch from being produced. + + Unlike @{tf.data.Dataset.batch}, the input elements to be batched may have + different shapes, and this transformation will pad each component to the + respective shape in `padding_shapes`. The `padding_shapes` argument determines the resulting shape for each dimension of each component in an output element: @@ -820,12 +859,6 @@ class Dataset(object): will be padded out to the maximum length of all elements in that dimension. - NOTE: If the number of elements (`N`) in this dataset is not an exact - multiple of `batch_size`, the final batch contain smaller tensors with - shape `N % batch_size` in the batch dimension. If your program depends on - the batches having the same shape, consider using the - @{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead. - See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements that may have different shapes into a @{tf.SparseTensor}. @@ -843,14 +876,95 @@ class Dataset(object): `tf.Tensor`, representing the padding values to use for the respective components. Defaults are `0` for numeric types and the empty string for string types. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in the case its has fewer than + `batch_size` elements; the default behavior is not to drop the smaller + batch. Returns: Dataset: A `Dataset`. """ - return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values) + return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values, + drop_remainder) def map(self, map_func, num_parallel_calls=None): - """Maps `map_func` across this dataset. + """Maps `map_func` across the elements of this dataset. + + This transformation applies `map_func` to each element of this dataset, and + returns a new dataset containing the transformed elements, in the same + order as they appeared in the input. + + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3, 4, 5 } + + a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 } + ``` + + The input signature of `map_func` is determined by the structure of each + element in this dataset. For example: + + ```python + # Each element is a `tf.Tensor` object. + a = { 1, 2, 3, 4, 5 } + # `map_func` takes a single argument of type `tf.Tensor` with the same + # shape and dtype. + result = a.map(lambda x: ...) + + # Each element is a tuple containing two `tf.Tensor` objects. + b = { (1, "foo"), (2, "bar"), (3, "baz") } + # `map_func` takes two arguments of type `tf.Tensor`. + result = b.map(lambda x_int, y_str: ...) + + # Each element is a dictionary mapping strings to `tf.Tensor` objects. + c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} } + # `map_func` takes a single argument of type `dict` with the same keys as + # the elements. + result = c.map(lambda d: ...) + ``` + + The value or values returned by `map_func` determine the structure of each + element in the returned dataset. + + ```python + # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`. + def f(...): + return tf.constant(37.0) + result = dataset.map(f) + result.output_classes == tf.Tensor + result.output_types == tf.float32 + result.output_shapes == [] # scalar + + # `map_func` returns two `tf.Tensor` objects. + def g(...): + return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) + result = dataset.map(g) + result.output_classes == (tf.Tensor, tf.Tensor) + result.output_types == (tf.float32, tf.string) + result.output_shapes == ([], [3]) + + # Python primitives, lists, and NumPy arrays are implicitly converted to + # `tf.Tensor`. + def h(...): + return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64) + result = dataset.map(h) + result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor) + result.output_types == (tf.float32, tf.string, tf.float64) + result.output_shapes == ([], [3], [2]) + + # `map_func` can return nested structures. + def i(...): + return {"a": 37.0, "b": [42, 16]}, "foo" + result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor) + result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string) + result.output_shapes == ({"a": [], "b": [2]}, []) + ``` + + In addition to `tf.Tensor` objects, `map_func` can accept as arguments and + return `tf.SparseTensor` objects. Args: map_func: A function mapping a nested structure of tensors (having @@ -959,7 +1073,8 @@ class Dataset(object): scalar `tf.bool` tensor. Returns: - Dataset: A `Dataset`. + Dataset: The `Dataset` containing the elements of this dataset for which + `predicate` is `True`. """ return FilterDataset(self, predicate) @@ -1110,6 +1225,309 @@ class SparseTensorSliceDataset(Dataset): return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64) +class _NestedDatasetComponent(object): + """The structure of a `Dataset` nested in a component of another `Dataset`. + + A `StructuredFunctionWrapper` around a function that returns a `Dataset` as + one of its components will have a `NestedDatasetComponent` in the + corresponding position in the `output_classes`, `output_shapes`, and + `output_types` properties. + + NOTE(mrry): This class is not currently exposed via the public API. Support + for nested datasets can be enabled on a function-by-function basis by setting + `experimental_nested_dataset_support=True` in the `StructuredFunctionWrapper` + initializer. + + TODO(b/110122868): Add this class, or something equivalent, to the public API. + We are considering revising the public API for accessing Dataset structure + (`output_classes` etc.) based on experience with nested datasets and other + custom component types. + """ + + def __init__(self, + dataset=None, + output_shapes=None, + output_types=None, + output_classes=None): + if dataset is None: + if (output_classes is None or output_shapes is None or + output_types is None): + raise ValueError( + "Either `dataset`, or all of `output_classes`, " + "`output_shapes`, and `output_types` must be specified.") + self._output_classes = output_classes + self._output_shapes = output_shapes + self._output_types = output_types + else: + if not (output_classes is None and output_shapes is None and + output_types is None): + raise ValueError( + "Either `dataset`, or all of `output_classes`, " + "`output_shapes`, and `output_types` must be specified.") + self._output_classes = dataset.output_classes + self._output_shapes = dataset.output_shapes + self._output_types = dataset.output_types + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class _VariantDataset(Dataset): + """A Dataset wrapper around a @{tf.variant}-typed function argument.""" + + def __init__(self, dataset_variant, structure): + super(_VariantDataset, self).__init__() + self._dataset_variant = dataset_variant + self._structure = structure + + def _as_variant_tensor(self): + return self._dataset_variant + + @property + def output_classes(self): + return self._structure.output_classes + + @property + def output_shapes(self): + return self._structure.output_shapes + + @property + def output_types(self): + return self._structure.output_types + + +class StructuredFunctionWrapper(object): + """A wrapper for `Defun` that supports structured arguments and return values. + """ + + def __init__(self, func, transformation_name, dataset=None, + input_classes=None, input_shapes=None, input_types=None, + add_to_graph=True, experimental_nested_dataset_support=False): + """Creates a new `StructuredFunctionWrapper` for the given function. + + Args: + func: A function from a nested structure to another nested structure. + transformation_name: Human-readable name of the transformation in which + this function is being instantiated, for error messages. + dataset: (Optional.) A @{tf.data.Dataset}. If given, the structure of this + dataset will be assumed as the structure for `func` arguments; otherwise + `input_classes`, `input_shapes`, and `input_types` must be defined. + input_classes: (Optional.) A nested structure of `type`. If given, this + argument defines the Python types for `func` arguments. + input_shapes: (Optional.) A nested structure of @{tf.TensorShape}. If + given, this argument defines the shapes and structure for `func` + arguments. + input_types: (Optional.) A nested structure of @{tf.DType}. If given, this + argument defines the element types and structure for `func` arguments. + add_to_graph: (Optional.) If `True`, the function will be added to the + default graph. + experimental_nested_dataset_support: (Optional.) If `True`, the function + will support @{tf.data.Dataset} objects as arguments and return values. + + Raises: + ValueError: If an invalid combination of `dataset`, `input_classes`, + `input_shapes`, and `input_types` is passed. + """ + if dataset is None: + if input_classes is None or input_shapes is None or input_types is None: + raise ValueError("Either `dataset`, or all of `input_classes`, " + "`input_shapes`, and `input_types` must be specified.") + self._input_shapes = input_shapes + self._input_types = input_types + self._input_classes = input_classes + else: + if not (input_classes is None and input_shapes is None and + input_types is None): + raise ValueError("Either `dataset`, or all of `input_classes`, " + "`input_shapes`, and `input_types` must be specified.") + self._input_shapes = dataset.output_shapes + self._input_types = dataset.output_types + self._input_classes = dataset.output_classes + + self._transformation_name = transformation_name + + # TODO(b/110122868): Enable this support for all `tf.data` functions. + self._nested_dataset_support = experimental_nested_dataset_support + + @function.Defun(*self._defun_args()) + def tf_data_structured_function_wrapper(*args): + """Wrapper for passing nested structures to and from tf.data functions.""" + flat_args = [] + for arg, arg_class, arg_shape, arg_type in zip( + args, + nest.flatten(self._input_classes), + nest.flatten(self._input_shapes), + nest.flatten(self._input_types)): + # TODO(b/110122868): Add a registration mechanism for new component + # types. + if arg_class is sparse_tensor_lib.SparseTensor: + arg = sparse.deserialize_sparse_tensors( + arg, arg_type, arg_shape, arg_class) + arg.indices.set_shape([None, arg_shape.ndims]) + arg.dense_shape.set_shape([arg_shape.ndims]) + elif isinstance(arg_class, _NestedDatasetComponent): + assert self._nested_dataset_support + arg = _VariantDataset(arg, arg_class) + else: + arg.set_shape(arg_shape) + flat_args.append(arg) + nested_args = nest.pack_sequence_as(self._input_classes, flat_args) + if not _should_unpack_args(nested_args): + nested_args = (nested_args,) + + ret = func(*nested_args) + # If `func` returns a list of tensors, `nest.flatten()` and + # `ops.convert_to_tensor()` would conspire to attempt to stack + # those tensors into a single tensor, because the customized + # version of `nest.flatten()` does not recurse into lists. Since + # it is more likely that the list arose from returning the + # result of an operation (such as `tf.py_func()`) that returns a + # list of not-necessarily-stackable tensors, we treat the + # returned value is a `tuple` instead. A user wishing to pack + # the return value into a single tensor can use an explicit + # `tf.stack()` before returning. + if isinstance(ret, list): + ret = tuple(ret) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + flat_ret = [] + flat_classes = [] + flat_shapes = [] + flat_types = [] + for t in nest.flatten(ret): + # TODO(b/110122868): Add a registration mechanism for new component + # types. + if sparse_tensor_lib.is_sparse(t): + t = sparse_tensor_lib.SparseTensor.from_value(t) + flat_ret.append(sparse.serialize_sparse_tensors(t)) + flat_classes.append(sparse_tensor_lib.SparseTensor) + flat_shapes.append(t.get_shape()) + flat_types.append(t.dtype) + elif isinstance(t, Dataset): + if not self._nested_dataset_support: + raise NotImplementedError( + "The %s transformation does not currently support nested " + "datasets as outputs." % self._transformation_name) + + flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access + component = _NestedDatasetComponent(t) + flat_classes.append(component) + flat_shapes.append(component) + flat_types.append(component) + else: + t = ops.convert_to_tensor(t) + flat_ret.append(t) + flat_classes.append(ops.Tensor) + flat_shapes.append(t.get_shape()) + flat_types.append(t.dtype) + + ret = nest.pack_sequence_as(ret, flat_ret) + self._output_classes = nest.pack_sequence_as(ret, flat_classes) + self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) + self._output_types = nest.pack_sequence_as(ret, flat_types) + + _warn_if_collections(transformation_name) + + return flat_ret + + self._function = tf_data_structured_function_wrapper + if add_to_graph: + self._function.add_to_graph(ops.get_default_graph()) + else: + # Use the private method that will execute + # `tf_data_structured_function_wrapper` but delay adding it to the graph + # in case (e.g.) we need to rerun the function. + self._function._create_definition_if_needed() # pylint: disable=protected-access + + def _defun_args(self): + """Returns a flat list of @{tf.DType} for the input element structure.""" + ret = [] + for input_type, input_class in zip(nest.flatten(self._input_types), + nest.flatten(self._input_classes)): + # TODO(b/110122868): Add a registration mechanism for new component types. + if input_class is sparse_tensor_lib.SparseTensor: + ret.append(dtypes.variant) + elif isinstance(input_class, _NestedDatasetComponent): + if not self._nested_dataset_support: + raise NotImplementedError( + "The %s transformation does not currently support nested " + "datasets as inputs." % self._transformation_name) + ret.append(dtypes.variant) + else: + assert isinstance(input_type, dtypes.DType) + ret.append(input_type) + return ret + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + @property + def function(self): + return self._function + + +def flat_structure(dataset): + """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops. + + Most Dataset op constructors expect `output_shapes` and `output_types` + arguments that represent the flattened structure of an element. This helper + function generates these attrs as a keyword argument dictionary, allowing + `Dataset._as_variant_tensor()` implementations to pass + `**flat_structure(self)` to the op constructor. + + Args: + dataset: A @{tf.data.Dataset}. + + Returns: + A dictionary of keyword arguments that can be passed to many Dataset op + constructors. + """ + output_classes = [] + output_shapes = [] + output_types = [] + for output_class, output_shape, output_type in zip( + nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_types)): + if isinstance(output_class, _NestedDatasetComponent): + output_classes.append(output_class.output_classes) + output_shapes.append(output_shape.output_shapes) + output_types.append(output_type.output_types) + else: + output_classes.append(output_class) + output_shapes.append(output_shape) + output_types.append(output_type) + + output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes) + output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes) + output_types = nest.pack_sequence_as(dataset.output_types, output_types) + + return { + "output_shapes": + nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)), + "output_types": + nest.flatten(sparse.as_dense_types(output_types, output_classes)), + } + + class _GeneratorDataset(Dataset): """A `Dataset` that generates elements by invoking a function.""" @@ -1142,137 +1560,26 @@ class _GeneratorDataset(Dataset): init_args_types = nest.pack_sequence_as( init_args, [t.dtype for t in nest.flatten(init_args)]) - @function.Defun(*nest.flatten( - sparse.as_dense_types(init_args_types, init_args_classes))) - def tf_init_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - dense_shapes = sparse.as_dense_shapes(init_args_shapes, init_args_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(init_args_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, init_args_types, init_args_shapes, init_args_classes) - if _should_unpack_args(nested_args): - ret = init_func(*nested_args) - else: - ret = init_func(nested_args) - - # If `init_func` returns a list of tensors, `nest.flatten()` and - # `ops.convert_to_tensor()` would conspire to attempt to stack - # those tensors into a single tensor, because the customized - # version of `nest.flatten()` does not recurse into lists. Since - # it is more likely that the list arose from returning the - # result of an operation (such as `tf.py_func()`) that returns a - # list of not-necessarily-stackable tensors, we treat the - # returned value is a `tuple` instead. A user wishing to pack - # the return value into a single tensor can use an explicit - # `tf.stack()` before returning. - if isinstance(ret, list): - ret = tuple(ret) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._state_classes = sparse.get_classes(ret) - self._state_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._state_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._init_func = tf_init_func - self._init_func.add_to_graph(ops.get_default_graph()) - - # These members will be initialized by `tf_next_func`. - self._output_classes = None - self._output_shapes = None - self._output_types = None - - @function.Defun(*nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes))) - def tf_next_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(self._state_shapes, - self._state_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - if _should_unpack_args(nested_args): - ret = next_func(*nested_args) - else: - ret = next_func(nested_args) - - # If `next_func` returns a list of tensors, `nest.flatten()` and - # `ops.convert_to_tensor()` would conspire to attempt to stack - # those tensors into a single tensor, because the customized - # version of `nest.flatten()` does not recurse into lists. Since - # it is more likely that the list arose from returning the - # result of an operation (such as `tf.py_func()`) that returns a - # list of not-necessarily-stackable tensors, we treat the - # returned value is a `tuple` instead. A user wishing to pack - # the return value into a single tensor can use an explicit - # `tf.stack()` before returning. - if isinstance(ret, list): - ret = tuple(ret) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._next_func = tf_next_func - self._next_func.add_to_graph(ops.get_default_graph()) - - @function.Defun(*nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes))) - def tf_finalize_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the state. - dense_shapes = sparse.as_dense_shapes(self._state_shapes, - self._state_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - if _should_unpack_args(nested_args): - return finalize_func(*nested_args) - else: - return finalize_func(nested_args) - - self._finalize_func = tf_finalize_func - self._finalize_func.add_to_graph(ops.get_default_graph()) + wrapped_init_func = StructuredFunctionWrapper( + init_func, "GeneratorDataset", input_classes=init_args_classes, + input_shapes=init_args_shapes, input_types=init_args_types) + self._state_classes = wrapped_init_func.output_classes + self._state_shapes = wrapped_init_func.output_shapes + self._state_types = wrapped_init_func.output_types + self._init_func = wrapped_init_func.function + + wrapped_next_func = StructuredFunctionWrapper( + next_func, "GeneratorDataset", input_classes=self._state_classes, + input_shapes=self._state_shapes, input_types=self._state_types) + self._output_classes = wrapped_next_func.output_classes + self._output_shapes = wrapped_next_func.output_shapes + self._output_types = wrapped_next_func.output_types + self._next_func = wrapped_next_func.function + + wrapped_finalize_func = StructuredFunctionWrapper( + finalize_func, "GeneratorDataset", input_classes=self._state_classes, + input_shapes=self._state_shapes, input_types=self._state_types) + self._finalize_func = wrapped_finalize_func.function def _as_variant_tensor(self): return gen_dataset_ops.generator_dataset( @@ -1282,10 +1589,7 @@ class _GeneratorDataset(Dataset): init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1322,16 +1626,7 @@ class ZipDataset(Dataset): # pylint: disable=protected-access return gen_dataset_ops.zip_dataset( [ds._as_variant_tensor() for ds in nest.flatten(self._datasets)], - output_shapes=[ - s - for ds in nest.flatten(self._datasets) - for s in nest.flatten(ds.output_shapes) - ], - output_types=[ - t - for ds in nest.flatten(self._datasets) - for t in nest.flatten(ds.output_types) - ]) + **flat_structure(self)) # pylint: enable=protected-access @property @@ -1376,10 +1671,7 @@ class ConcatenateDataset(Dataset): return gen_dataset_ops.concatenate_dataset( self._input_dataset._as_variant_tensor(), self._dataset_to_concatenate._as_variant_tensor(), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) # pylint: enable=protected-access @property @@ -1417,10 +1709,7 @@ class RepeatDataset(Dataset): return gen_dataset_ops.repeat_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access count=self._count, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1444,6 +1733,7 @@ class RangeDataset(Dataset): self._parse_args(*args) def _parse_args(self, *args): + """Parse arguments according to the same rules as the `range()` builtin.""" if len(args) == 1: self._start = self._build_tensor(0, "start") self._stop = self._build_tensor(args[0], "stop") @@ -1467,10 +1757,7 @@ class RangeDataset(Dataset): start=self._start, stop=self._stop, step=self._step, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1499,10 +1786,7 @@ class CacheDataset(Dataset): return gen_dataset_ops.cache_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access filename=self._filename, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1562,10 +1846,7 @@ class ShuffleDataset(Dataset): seed=self._seed, seed2=self._seed2, reshuffle_each_iteration=self._reshuffle_each_iteration, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1593,10 +1874,7 @@ class TakeDataset(Dataset): return gen_dataset_ops.take_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access count=self._count, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1624,10 +1902,7 @@ class SkipDataset(Dataset): return gen_dataset_ops.skip_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access count=self._count, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1645,21 +1920,28 @@ class SkipDataset(Dataset): class BatchDataset(Dataset): """A `Dataset` that batches contiguous elements from its input.""" - def __init__(self, input_dataset, batch_size): + def __init__(self, input_dataset, batch_size, drop_remainder): """See `Dataset.batch()` for details.""" super(BatchDataset, self).__init__() self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") + self._drop_remainder = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") def _as_variant_tensor(self): - return gen_dataset_ops.batch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - batch_size=self._batch_size, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. + if smart_cond.smart_constant_value(self._drop_remainder) is False: + return gen_dataset_ops.batch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + batch_size=self._batch_size, + **flat_structure(self)) + else: + return gen_dataset_ops.batch_dataset_v2( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + batch_size=self._batch_size, + drop_remainder=self._drop_remainder, + **flat_structure(self)) @property def output_classes(self): @@ -1669,7 +1951,9 @@ class BatchDataset(Dataset): def output_shapes(self): input_shapes = self._input_dataset.output_shapes return nest.pack_sequence_as(input_shapes, [ - tensor_shape.vector(None).concatenate(s) + tensor_shape.vector( + tensor_util.constant_value(self._batch_size) if smart_cond. + smart_constant_value(self._drop_remainder) else None).concatenate(s) for s in nest.flatten(self._input_dataset.output_shapes) ]) @@ -1678,20 +1962,77 @@ class BatchDataset(Dataset): return self._input_dataset.output_types -def _partial_shape_to_tensor(shape_like): +def _is_padded_shape_compatible_with(padded_shape, input_component_shape): + """Returns `True` if `input_component_shape` can be padded to `padded_shape`. + + Args: + padded_shape: A `tf.TensorShape`. + input_component_shape: A `tf.TensorShape`. + + Returns: + `True` if `input_component_shape` can be padded to `padded_shape`, otherwise + `False`. + """ + + if padded_shape.dims is None or input_component_shape.dims is None: + return True + if len(padded_shape.dims) != len(input_component_shape.dims): + return False + for padded_dim, input_dim in zip( + padded_shape.dims, input_component_shape.dims): + if (padded_dim.value is not None and input_dim.value is not None + and padded_dim.value < input_dim.value): + return False + return True + + +def _padded_shape_to_tensor(padded_shape, input_component_shape): + """Converts `padded_shape` to a `tf.Tensor` representing that shape. + + Args: + padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python + sequence, or a 1-D `tf.Tensor` of `tf.int64` elements. + input_component_shape: A `tf.TensorShape`, with which `padded_shape` must + be compatible. + + Returns: + A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`. + + Raises: + ValueError: If `padded_shape` is not a shape or not compatible with + `input_component_shape`. + TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor. + """ try: - # First attempt to convert the input to a shape, and return the - # "canonical" tensor representation, which uses `-1` in place of - # `None`. - shape_like = tensor_shape.as_shape(shape_like) - return ops.convert_to_tensor( - [dim if dim is not None else -1 for dim in shape_like.as_list()], - dtype=dtypes.int64) + # Try to convert the `padded_shape` to a `tf.TensorShape` + padded_shape_as_shape = tensor_shape.as_shape(padded_shape) + # We will return the "canonical" tensor representation, which uses + # `-1` in place of `None`. + ret = ops.convert_to_tensor( + [dim if dim is not None else -1 + for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64) except (TypeError, ValueError): # The argument was not trivially convertible to a # `tf.TensorShape`, so fall back on the conversion to tensor # machinery. - return ops.convert_to_tensor(shape_like, dtype=dtypes.int64) + ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) + if ret.shape.dims is not None and len(ret.shape.dims) != 1: + raise ValueError( + "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " + "shape was %s." % (padded_shape, ret.shape)) + if ret.dtype != dtypes.int64: + raise TypeError( + "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " + "element type was %s." % (padded_shape, ret.dtype.name)) + padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) + + if not _is_padded_shape_compatible_with(padded_shape_as_shape, + input_component_shape): + raise ValueError("The padded shape %s is not compatible with the " + "corresponding input component shape %s." + % (padded_shape_as_shape, input_component_shape)) + + return ret def _padding_value_to_tensor(value, output_type): @@ -1718,7 +2059,7 @@ def _padding_value_to_tensor(value, output_type): def _default_padding(input_dataset): - + """Returns default padding tensors in a structure matching `input_dataset`.""" def make_zero(t): if t.base_dtype == dtypes.string: return "" @@ -1733,7 +2074,8 @@ def _default_padding(input_dataset): class PaddedBatchDataset(Dataset): """A `Dataset` that batches and pads contiguous elements from its input.""" - def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): + def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, + drop_remainder): """See `Dataset.batch()` for details.""" super(PaddedBatchDataset, self).__init__() if sparse.any_sparse(input_dataset.output_classes): @@ -1746,23 +2088,51 @@ class PaddedBatchDataset(Dataset): padding_values = ( padding_values if padding_values is not None else _default_padding(input_dataset)) - self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes) + + flat_padded_shapes = nest.flatten_up_to(input_dataset.output_shapes, + padded_shapes) + + flat_padded_shapes_as_tensors = [] + + for input_component_shape, padded_shape in zip( + nest.flatten(input_dataset.output_shapes), flat_padded_shapes): + flat_padded_shapes_as_tensors.append( + _padded_shape_to_tensor(padded_shape, input_component_shape)) + + self._padded_shapes = nest.pack_sequence_as(input_dataset.output_shapes, + flat_padded_shapes_as_tensors) + self._padding_values = nest.map_structure_up_to( input_dataset.output_shapes, _padding_value_to_tensor, padding_values, input_dataset.output_types) + self._drop_remainder = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") def _as_variant_tensor(self): - return gen_dataset_ops.padded_batch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - batch_size=self._batch_size, - padded_shapes=[ - ops.convert_to_tensor(s, dtype=dtypes.int64) - for s in nest.flatten(self._padded_shapes) - ], - padding_values=nest.flatten(self._padding_values), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. + if smart_cond.smart_constant_value(self._drop_remainder) is False: + return gen_dataset_ops.padded_batch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + batch_size=self._batch_size, + padded_shapes=[ + ops.convert_to_tensor(s, dtype=dtypes.int64) + for s in nest.flatten(self._padded_shapes) + ], + padding_values=nest.flatten(self._padding_values), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + else: + return gen_dataset_ops.padded_batch_dataset_v2( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + batch_size=self._batch_size, + padded_shapes=[ + ops.convert_to_tensor(s, dtype=dtypes.int64) + for s in nest.flatten(self._padded_shapes) + ], + padding_values=nest.flatten(self._padding_values), + drop_remainder=self._drop_remainder, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) @property def output_classes(self): @@ -1772,8 +2142,10 @@ class PaddedBatchDataset(Dataset): def output_shapes(self): def _padded_shape_to_batch_shape(s): - return tensor_shape.vector(None).concatenate( - tensor_util.constant_value_as_shape(s)) + return tensor_shape.vector( + tensor_util.constant_value(self._batch_size) if smart_cond. + smart_constant_value(self._drop_remainder) else None).concatenate( + tensor_util.constant_value_as_shape(s)) return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes) @@ -1787,6 +2159,24 @@ def _should_unpack_args(args): return type(args) is tuple # pylint: disable=unidiomatic-typecheck +def _warn_if_collections(transformation_name): + """Prints warning message if the current graph uses common graph collections. + + NOTE(mrry): Currently a warning is only generated for lookup tables. Any + variables created will be automatically hoisted out to the outermost scope + using `init_scope()`. Some collections (such as for control-flow contexts) + are benign and should not generate a warning. + + Args: + transformation_name: A human-readable name for the transformation. + """ + if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS): + warnings.warn("Creating lookup tables inside a function passed to %s is not" + " supported. Create each table outside the function, and " + "capture it inside the function to use it." + % transformation_name) + + class MapDataset(Dataset): """A `Dataset` that maps a function over elements in its input.""" @@ -1795,64 +2185,12 @@ class MapDataset(Dataset): super(MapDataset, self).__init__() self._input_dataset = input_dataset - self._output_classes = None - self._output_shapes = None - self._output_types = None - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - ret = map_func(*nested_args) - else: - ret = map_func(nested_args) - - # If `map_func` returns a list of tensors, `nest.flatten()` and - # `ops.convert_to_tensor()` would conspire to attempt to stack - # those tensors into a single tensor, because the customized - # version of `nest.flatten()` does not recurse into lists. Since - # it is more likely that the list arose from returning the - # result of an operation (such as `tf.py_func()`) that returns a - # list of not-necessarily-stackable tensors, we treat the - # returned value is a `tuple` instead. A user wishing to pack - # the return value into a single tensor can use an explicit - # `tf.stack()` before returning. - if isinstance(ret, list): - ret = tuple(ret) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) + wrapped_func = StructuredFunctionWrapper( + map_func, "Dataset.map()", input_dataset) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access @@ -1860,10 +2198,7 @@ class MapDataset(Dataset): input_t, self._map_func.captured_inputs, f=self._map_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1896,10 +2231,7 @@ class ParallelMapDataset(MapDataset): self._map_func.captured_inputs, f=self._map_func, num_parallel_calls=self._num_parallel_calls, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **flat_structure(self)) # pylint: enable=protected-access @@ -1911,47 +2243,22 @@ class FlatMapDataset(Dataset): super(FlatMapDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - dataset = map_func(*nested_args) - else: - dataset = map_func(nested_args) - - if not isinstance(dataset, Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - - return dataset._as_variant_tensor() # pylint: disable=protected-access - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) + wrapped_func = StructuredFunctionWrapper( + map_func, self._transformation_name(), input_dataset, + experimental_nested_dataset_support=True) + if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent): + raise TypeError("`map_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes + self._map_func = wrapped_func.function def _as_variant_tensor(self): return gen_dataset_ops.flat_map_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._map_func.captured_inputs, f=self._map_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -1965,6 +2272,9 @@ class FlatMapDataset(Dataset): def output_types(self): return self._output_types + def _transformation_name(self): + return "Dataset.flat_map()" + class InterleaveDataset(FlatMapDataset): """A `Dataset` that maps a function over its input and interleaves the result. @@ -1985,10 +2295,10 @@ class InterleaveDataset(FlatMapDataset): self._cycle_length, self._block_length, f=self._map_func, # pylint: disable=protected-access - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **flat_structure(self)) + + def _transformation_name(self): + return "Dataset.interleave()" class FilterDataset(Dataset): @@ -1998,46 +2308,20 @@ class FilterDataset(Dataset): """See `Dataset.filter()` for details.""" super(FilterDataset, self).__init__() self._input_dataset = input_dataset - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_predicate(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - ret = predicate(*nested_args) - else: - ret = predicate(nested_args) - - ret = ops.convert_to_tensor(ret, dtype=dtypes.bool) - if not (ret.dtype == dtypes.bool and - ret.shape.is_compatible_with(tensor_shape.scalar())): - raise ValueError("`predicate` must return a scalar boolean tensor.") - - return ret - - self._predicate = tf_predicate - self._predicate.add_to_graph(ops.get_default_graph()) + wrapped_func = StructuredFunctionWrapper( + predicate, "Dataset.filter()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.bool and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError("`predicate` must return a scalar boolean tensor.") + self._predicate = wrapped_func.function def _as_variant_tensor(self): return gen_dataset_ops.filter_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access other_arguments=self._predicate.captured_inputs, predicate=self._predicate, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): @@ -2068,10 +2352,7 @@ class PrefetchDataset(Dataset): return gen_dataset_ops.prefetch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access buffer_size=self._buffer_size, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index fd164277b6fd7509403d84c50a62638df7968a03..b6dba4e3ca3874b8e9bc3b7ea92fb91fe41759d8 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -471,9 +471,7 @@ class EagerIterator(object): sparse.as_dense_types(self._output_types, self._output_classes)) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.iterator( - shared_name="", - container=_generate_shared_name("eageriterator"), + self._resource = gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index a73a8b5cdc494d7a14c1a2bcb6aa766dbf819403..066e09969c0ba8f054ada42a40960c7513945963 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -19,8 +19,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -150,12 +148,12 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): self._buffer_output_elements, self._prefetch_input_elements, f=self._map_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access + def _transformation_name(self): + return "tf.contrib.data.parallel_interleave()" + @tf_export("data.TFRecordDataset") class TFRecordDataset(dataset_ops.Dataset): diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index 0fc32d51b9fe581a54519139f3bf12118f8f4028..5fcc62b60b696e05d7674c0bf46f57e71d6cc007 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -70,6 +70,7 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", ], ) diff --git a/tensorflow/python/data/util/convert.py b/tensorflow/python/data/util/convert.py index eeb1d700f3c67a1a2ab627aa8a291755bc2127e4..746b3d66de082d59e8c1e316c51e2a9ab7670e6d 100644 --- a/tensorflow/python/data/util/convert.py +++ b/tensorflow/python/data/util/convert.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape def optional_param_to_tensor(argument_name, @@ -32,3 +33,40 @@ def optional_param_to_tensor(argument_name, else: return constant_op.constant( argument_default, dtype=argument_dtype, name=argument_name) + + +def partial_shape_to_tensor(shape_like): + """Returns a @{tf.Tensor} that represents the given shape. + + Args: + shape_like: A value that can be converted to a @{tf.TensorShape} or a + @{tf.Tensor}. + + Returns: + A 1-D `tf.Tensor` of `tf.int64` elements representing the given shape, where + `-1` is substituted for any unknown dimensions. + """ + try: + # First attempt to convert the input to a shape, and return the + # "canonical" tensor representation, which uses `-1` in place of + # `None`. + shape_like = tensor_shape.as_shape(shape_like) + return ops.convert_to_tensor( + [dim if dim is not None else -1 for dim in shape_like.as_list()], + dtype=dtypes.int64) + except (TypeError, ValueError): + # The argument was not trivially convertible to a + # `tf.TensorShape`, so fall back on the conversion to tensor + # machinery. + ret = ops.convert_to_tensor(shape_like, preferred_dtype=dtypes.int64) + if ret.shape.dims is not None and len(ret.shape.dims) != 1: + raise ValueError("The given shape %s must be a 1-D tensor of tf.int64 " + "values, but the shape was %s." + % (shape_like, ret.shape)) + if ret.dtype != dtypes.int64: + raise TypeError("The given shape %s must be a 1-D tensor of tf.int64 " + "values, but the element type was %s." + % (shape_like, ret.dtype.name)) + + return ret + diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py index 2cb6488070eb422f6c8d56ca5d712cbdf09fa883..6a67093e48c988b01b8137a544078d570aabf74f 100644 --- a/tensorflow/python/data/util/convert_test.py +++ b/tensorflow/python/data/util/convert_test.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.util import convert +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -48,6 +50,77 @@ class ConvertTest(test.TestCase): with self.test_session() as sess: self.assertEqual(compat.as_bytes("value"), sess.run(resp)) + def testPartialShapeToTensorKnownDimension(self): + with self.test_session() as sess: + self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor( + tensor_shape.TensorShape([1])))) + self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,)))) + self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor([1]))) + self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor( + constant_op.constant([1], dtype=dtypes.int64)))) + + def testPartialShapeToTensorUnknownDimension(self): + with self.test_session() as sess: + self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( + tensor_shape.TensorShape([None])))) + self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( + (None,)))) + self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( + [None]))) + self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( + [-1]))) + self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( + constant_op.constant([-1], dtype=dtypes.int64)))) + + with self.assertRaisesRegexp( + ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 " + r"values, but the shape was \(2, 2\)."): + convert.partial_shape_to_tensor(constant_op.constant( + [[1, 1], [1, 1]], dtype=dtypes.int64)) + + with self.assertRaisesRegexp( + TypeError, r"The given shape .* must be a 1-D tensor of tf.int64 " + r"values, but the element type was float32."): + convert.partial_shape_to_tensor(constant_op.constant([1., 1.])) + + def testPartialShapeToTensorMultipleDimensions(self): + with self.test_session() as sess: + self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( + tensor_shape.TensorShape([3, 6])))) + self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( + (3, 6)))) + self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( + [3, 6]))) + self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( + constant_op.constant([3, 6], dtype=dtypes.int64)))) + + self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor( + tensor_shape.TensorShape([3, None])))) + self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor( + (3, None)))) + self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor( + [3, None]))) + self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor( + constant_op.constant([3, -1], dtype=dtypes.int64)))) + + self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor( + tensor_shape.TensorShape([None, None])))) + self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor( + (None, None)))) + self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor( + [None, None]))) + self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor( + constant_op.constant([-1, -1], dtype=dtypes.int64)))) + + def testPartialShapeToTensorScalar(self): + with self.test_session() as sess: + self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor( + tensor_shape.TensorShape([])))) + self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(()))) + self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor([]))) + self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor( + constant_op.constant([], dtype=dtypes.int64)))) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 7ee3d92cadd5d7081f05f9e8c6cb7a70c8c661dc..32e08021dc80d11baaead68ea062b6dab7a8dfdd 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -17,19 +17,16 @@ """## Functions for working with arbitrarily nested sequences of elements. NOTE(mrry): This fork of the `tensorflow.python.util.nest` module -makes three changes: +makes two changes: -1. It adds support for dictionaries as a level of nesting in nested structures. -2. It removes support for lists as a level of nesting in nested structures. -3. It adds support for `SparseTensorValue` as an atomic element. +1. It removes support for lists as a level of nesting in nested structures. +2. It adds support for `SparseTensorValue` as an atomic element. -The motivation for this change is threefold: +The motivation for this change is twofold: -1. Many input-processing functions (e.g. `tf.parse_example()`) return - dictionaries, and we would like to support them natively in datasets. -2. It seems more natural for lists to be treated (e.g. in Dataset constructors) +1. It seems more natural for lists to be treated (e.g. in Dataset constructors) as tensors, rather than lists of (lists of...) tensors. -3. This is needed because `SparseTensorValue` is implemented as a `namedtuple` +2. This is needed because `SparseTensorValue` is implemented as a `namedtuple` that would normally be flattened and we want to be able to create sparse tensor from `SparseTensorValue's similarly to creating tensors from numpy arrays. @@ -43,6 +40,7 @@ import collections as _collections import six as _six +from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python.framework import sparse_tensor as _sparse_tensor @@ -99,15 +97,6 @@ def _yield_value(iterable): yield value -def _yield_flat_nest(nest): - for n in _yield_value(nest): - if is_sequence(n): - for ni in _yield_flat_nest(n): - yield ni - else: - yield n - - def is_sequence(seq): """Returns a true if `seq` is a Sequence or dict (except strings/lists). @@ -123,9 +112,7 @@ def is_sequence(seq): True if the sequence is a not a string or list and is a collections.Sequence. """ - return (isinstance(seq, (_collections.Sequence, dict)) and - not isinstance(seq, _sparse_tensor.SparseTensorValue) and - not isinstance(seq, (list, _six.string_types))) + return _pywrap_tensorflow.IsSequenceForData(seq) def flatten(nest): @@ -140,7 +127,7 @@ def flatten(nest): Returns: A Python list, the flattened version of the input. """ - return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest] + return _pywrap_tensorflow.FlattenForData(nest) def _recursive_assert_same_structure(nest1, nest2, check_types): @@ -536,4 +523,3 @@ def map_structure_up_to(shallow_tree, func, *inputs): results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] return pack_sequence_as(structure=shallow_tree, flat_sequence=results) - diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py index 33227e82afe6fe1c748693d107d4e9844abb8e09..a809151e6ef57de8a39806b8164f818d94b8a783 100644 --- a/tensorflow/python/data/util/random_seed_test.py +++ b/tensorflow/python/data/util/random_seed_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class RandomSeedTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandomSeed(self): zero_t = constant_op.constant(0, dtype=dtypes.int64, name='zero') one_t = constant_op.constant(1, dtype=dtypes.int64, name='one') diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 183994ddaa72b5961f62f34fb43935b0859a3b25..c025dc8aa58a500ace3e28ba4528abd4f4c38ba7 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -5,7 +5,7 @@ # # ":debug_py": Public Python methods and classes of tfdbg. # For API documentation, see https://www.tensorflow.org/api_docs/python/tfdbg -# For a user interface walkthrough, see https://www.tensorflow.org/programmers_guide/debugger +# For a user interface walkthrough, see https://www.tensorflow.org/guide/debugger # ":grpc_debug_server": Server interface for grpc:// debug URLs. package( @@ -167,6 +167,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:platform", + "//third_party/py/numpy", "@six_archive//:six", ], ) @@ -453,6 +454,17 @@ py_binary( ], ) +py_binary( + name = "debug_keras", + srcs = ["examples/debug_keras.py"], + srcs_version = "PY2AND3", + deps = [ + ":debug_py", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + py_test( name = "common_test", size = "small", @@ -572,6 +584,7 @@ py_test( ":source_utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", + "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", @@ -801,6 +814,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -1003,6 +1017,7 @@ cuda_py_test( "no_oss", # Test flaky due to port collisions. "no_windows", "noasan", # Times out due to size of test (b/73731462). + "optonly", # Test flaky (b/80130873) "oss_serial", ], ) @@ -1082,6 +1097,7 @@ py_test( "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -1092,6 +1108,7 @@ sh_test( data = [ ":debug_errors", ":debug_fibonacci", + ":debug_keras", ":debug_mnist", ":debug_tflearn_iris", ":offline_analyzer", diff --git a/tensorflow/python/debug/README.md b/tensorflow/python/debug/README.md index 269bbb19bdb898d1d81d0b9c618a284a437e68b9..9c16af4d79754cee5d77158d5c2466412c6b9e68 100644 --- a/tensorflow/python/debug/README.md +++ b/tensorflow/python/debug/README.md @@ -28,7 +28,7 @@ models: * Easy access through session wrappers * Easy integration with common high-level APIs, such as - [TensorFlow Estimators](https://www.tensorflow.org/programmers_guide/estimators) and + [TensorFlow Estimators](https://www.tensorflow.org/guide/estimators) and [Keras](https://keras.io/) * Inspection of runtime tensor values and node connections * Conditional breaking after runs that generate tensors satisfying given @@ -43,7 +43,7 @@ models: ## How to use TFDBG? -* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/programmers_guide/debugger. +* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/guide/debugger. * For information on the web GUI of TFDBG (TensorBoard Debugger Plugin), see [this README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). * For programmatic use of the API of TFDBG, see https://www.tensorflow.org/api_docs/python/tfdbg. diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index dea019fef58015fbd7982a81319dcabe4e5f4930..6a368682de5db12e128f010bfe0c9bbf9cf3b997 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -451,42 +451,48 @@ def get_error_intro(tf_error): sample commands for debugging. """ - op_name = tf_error.op.name + if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"): + op_name = tf_error.op.name + else: + op_name = None intro_lines = [ "--------------------------------------", RL("!!! An error occurred during the run !!!", "blink"), "", - "You may use the following commands to debug:", ] out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines) - out.extend( - _recommend_command("ni -a -d -t %s" % op_name, - "Inspect information about the failing op.", - create_link=True)) - out.extend( - _recommend_command("li -r %s" % op_name, - "List inputs to the failing op, recursively.", - create_link=True)) - - out.extend( - _recommend_command( - "lt", - "List all tensors dumped during the failing run() call.", - create_link=True)) + if op_name is not None: + out.extend(debugger_cli_common.RichTextLines( + ["You may use the following commands to debug:"])) + out.extend( + _recommend_command("ni -a -d -t %s" % op_name, + "Inspect information about the failing op.", + create_link=True)) + out.extend( + _recommend_command("li -r %s" % op_name, + "List inputs to the failing op, recursively.", + create_link=True)) + + out.extend( + _recommend_command( + "lt", + "List all tensors dumped during the failing run() call.", + create_link=True)) + else: + out.extend(debugger_cli_common.RichTextLines([ + "WARNING: Cannot determine the name of the op that caused the error."])) more_lines = [ "", - "Op name: " + op_name, + "Op name: %s" % op_name, "Error type: " + str(type(tf_error)), "", "Details:", str(tf_error), "", - "WARNING: Using client GraphDef due to the error, instead of " - "executor GraphDefs.", "--------------------------------------", "", ] diff --git a/tensorflow/python/debug/cli/cli_shared_test.py b/tensorflow/python/debug/cli/cli_shared_test.py index 3d7939490dfe08118ee4972541c4166b2a536608..07b364db9f2aab9c11ecb769a94f36e0809d70a0 100644 --- a/tensorflow/python/debug/cli/cli_shared_test.py +++ b/tensorflow/python/debug/cli/cli_shared_test.py @@ -372,6 +372,11 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase): self.assertEqual("Details:", error_intro.lines[14]) self.assertStartsWith(error_intro.lines[15], "foo description") + def testGetErrorIntroForNoOpName(self): + tf_error = errors.OpError(None, None, "Fake OpError", -1) + error_intro = cli_shared.get_error_intro(tf_error) + self.assertIn("Cannot determine the name of the op", error_intro.lines[3]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py index 12e79ab07a4655c7d41f41d2e71906273e154a08..02563fde845e7951046a8bcd65899ef5e1fcc35f 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common.py +++ b/tensorflow/python/debug/cli/debugger_cli_common.py @@ -23,9 +23,11 @@ import re import sre_constants import traceback +import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python import pywrap_tensorflow_internal from tensorflow.python.platform import gfile HELP_INDENT = " " @@ -131,6 +133,25 @@ def rich_text_lines_from_rich_line_list(rich_text_list, annotations=None): return RichTextLines(lines, font_attr_segs, annotations=annotations) +def get_tensorflow_version_lines(include_dependency_versions=False): + """Generate RichTextLines with TensorFlow version info. + + Args: + include_dependency_versions: Include the version of TensorFlow's key + dependencies, such as numpy. + + Returns: + A formatted, multi-line `RichTextLines` object. + """ + lines = ["TensorFlow version: %s" % pywrap_tensorflow_internal.__version__] + lines.append("") + if include_dependency_versions: + lines.append("Dependency version(s):") + lines.append(" numpy: %s" % np.__version__) + lines.append("") + return RichTextLines(lines) + + class RichTextLines(object): """Rich multi-line text. @@ -538,6 +559,8 @@ class CommandHandlerRegistry(object): HELP_COMMAND = "help" HELP_COMMAND_ALIASES = ["h"] + VERSION_COMMAND = "version" + VERSION_COMMAND_ALIASES = ["ver"] def __init__(self): # A dictionary from command prefix to handler. @@ -562,6 +585,13 @@ class CommandHandlerRegistry(object): "Print this help message.", prefix_aliases=self.HELP_COMMAND_ALIASES) + # Register a default handler for the command "version". + self.register_command_handler( + self.VERSION_COMMAND, + self._version_handler, + "Print the versions of TensorFlow and its key dependencies.", + prefix_aliases=self.VERSION_COMMAND_ALIASES) + def register_command_handler(self, prefix, handler, @@ -763,6 +793,11 @@ class CommandHandlerRegistry(object): else: return RichTextLines(["ERROR: help takes only 0 or 1 input argument."]) + def _version_handler(self, args, screen_info=None): + del args # Unused currently. + del screen_info # Unused currently. + return get_tensorflow_version_lines(include_dependency_versions=True) + def _resolve_prefix(self, token): """Resolve command prefix from the prefix itself or its alias. diff --git a/tensorflow/python/debug/cli/debugger_cli_common_test.py b/tensorflow/python/debug/cli/debugger_cli_common_test.py index 1b7a5962fe7dc4e19446c3e3b0aeab672eb30f1f..aba95e5820b1d8c6b3811fc69328317ce2c3ac64 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common_test.py +++ b/tensorflow/python/debug/cli/debugger_cli_common_test.py @@ -21,6 +21,9 @@ import os import stat import tempfile +import numpy as np + +from tensorflow.python import pywrap_tensorflow_internal from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.framework import test_util from tensorflow.python.platform import gfile @@ -547,7 +550,10 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase): " Show screen width in number of columns.", "", "", "help", " Aliases: h", "", " Print this help message.", "", "", "noop", " Aliases: n, NOOP", "", - " No operation.", " I.e., do nothing.", "", ""], + " No operation.", " I.e., do nothing.", "", "", + "version", " Aliases: ver", "", + " Print the versions of TensorFlow and its key " + "dependencies.", "", ""], output.lines) # Get help for one specific command prefix. @@ -575,7 +581,9 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase): self.assertEqual(help_intro.lines + [ "help", " Aliases: h", "", " Print this help message.", "", "", "noop", " Aliases: n, NOOP", "", " No operation.", - " I.e., do nothing.", "", "" + " I.e., do nothing.", "", "", + "version", " Aliases: ver", "", + " Print the versions of TensorFlow and its key dependencies.", "", "" ], output.lines) @@ -1147,5 +1155,22 @@ class MenuTest(test_util.TensorFlowTestCase): self.assertEqual((40, 50, ["bold"]), output.font_attr_segs[0][2]) +class GetTensorFlowVersionLinesTest(test_util.TensorFlowTestCase): + + def testGetVersionWithoutDependencies(self): + out = debugger_cli_common.get_tensorflow_version_lines() + self.assertEqual(2, len(out.lines)) + self.assertEqual( + "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__, + out.lines[0]) + + def testGetVersionWithDependencies(self): + out = debugger_cli_common.get_tensorflow_version_lines(True) + self.assertIn( + "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__, + out.lines) + self.assertIn(" numpy: %s" % np.__version__, out.lines) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/examples/README.md b/tensorflow/python/debug/examples/README.md index cb4d484092fe39698de1ff11e4d50d4879960e0c..3b431e04dc3565037dc018991bea68ab019e8af0 100644 --- a/tensorflow/python/debug/examples/README.md +++ b/tensorflow/python/debug/examples/README.md @@ -3,7 +3,7 @@ Hi, there! The documentation of **TensorFlow Debugger (tfdbg)** has moved. See the source version at -[this new location](../../../docs_src/programmers_guide/debugger.md). +[this new location](../../../docs_src/guide/debugger.md). See the public website version at -[https://www.tensorflow.org/programmers_guide/debugger](https://www.tensorflow.org/programmers_guide/debugger). +[https://www.tensorflow.org/guide/debugger](https://www.tensorflow.org/guide/debugger). diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py new file mode 100644 index 0000000000000000000000000000000000000000..3272d85ade957b254b2c1a0977156179cd71bb9d --- /dev/null +++ b/tensorflow/python/debug/examples/debug_keras.py @@ -0,0 +1,89 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfdbg example: debugging tf.keras models training on tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf + +from tensorflow.python import debug as tf_debug + + +def main(_): + # Create a dummy dataset. + num_examples = 8 + steps_per_epoch = 2 + input_dims = 3 + output_dims = 1 + xs = np.zeros([num_examples, input_dims]) + ys = np.zeros([num_examples, output_dims]) + dataset = tf.data.Dataset.from_tensor_slices( + (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch)) + + sess = tf.Session() + if FLAGS.debug: + # Use the command-line interface (CLI) of tfdbg. + sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) + elif FLAGS.tensorboard_debug_address: + # Use the TensorBoard Debugger Plugin (GUI of tfdbg). + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, FLAGS.tensorboard_debug_address) + tf.keras.backend.set_session(sess) + + # Create a dummy model. + model = tf.keras.Sequential([ + tf.keras.layers.Dense(1, input_shape=[input_dims])]) + model.compile(loss="mse", optimizer="sgd") + + # Train the model using the dummy dataset created above. + model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training. " + "Mutually exclusive with the --tensorboard_debug_address flag.") + parser.add_argument( + "--ui_type", + type=str, + default="curses", + help="Command-line user interface type (curses | readline).") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") + parser.add_argument( + "--epochs", + type=int, + default=2, + help="Number of epochs to train the model for.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh index 2df6c0b6a2701022e3fed6648208b9708197bebc..2d35b2d8bb10d17decfa404afd5004d3409c06e5 100755 --- a/tensorflow/python/debug/examples/examples_test.sh +++ b/tensorflow/python/debug/examples/examples_test.sh @@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors" DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist" DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris" + DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras" OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer" else DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci" DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors" DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist" DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris" + DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras" OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer" fi @@ -69,6 +71,12 @@ run exit EOF +cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline +run +ni -a -d -t v/read +exit +EOF + cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline run -t 1 run --node_name_filter hidden --op_type_filter MatMul @@ -90,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then exit 1 fi +# Test debugging of tf.keras. +cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline +run -f has_inf_or_nan +EOF + # Test offline_analyzer. echo echo "Testing offline_analyzer" diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index 8a65ad087b3002d8ad93f3a64f48715d26ff62d8..7c96c2878c78d5650f3d1907065cc17c4eb71f5c 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -748,7 +748,7 @@ class DebugDumpDir(object): return sum(len(self._dump_tensor_data[device_name]) for device_name in self._dump_tensor_data) - def _load_partition_graphs(self, partition_graphs, validate): + def _load_partition_graphs(self, client_partition_graphs, validate): """Load and process partition graphs. Load the graphs; parse the input and control input structure; obtain the @@ -757,8 +757,10 @@ class DebugDumpDir(object): tensor dumps. Args: - partition_graphs: A repeated field of GraphDefs representing the - partition graphs executed by the TensorFlow runtime. + client_partition_graphs: A repeated field of GraphDefs representing the + partition graphs executed by the TensorFlow runtime, from the Python + client. These partition graphs are used only if partition graphs + cannot be loaded from the dump directory on the file system. validate: (`bool`) Whether the dump files are to be validated against the partition graphs. @@ -769,24 +771,23 @@ class DebugDumpDir(object): self._debug_graphs = {} self._node_devices = {} - if partition_graphs: - partition_graphs_and_device_names = [ - (partition_graph, None) for partition_graph in partition_graphs] - else: - partition_graphs_and_device_names = [] - for device_name in self._device_names: - partition_graph = None - if device_name in self._dump_graph_file_paths: - partition_graph = _load_graph_def_from_event_file( - self._dump_graph_file_paths[device_name]) - else: - partition_graph = self._find_partition_graph(partition_graphs, - device_name) - if partition_graph: - partition_graphs_and_device_names.append((partition_graph, - device_name)) - else: - logging.warn("Failed to load partition graphs from disk.") + partition_graphs_and_device_names = [] + for device_name in self._device_names: + partition_graph = None + if device_name in self._dump_graph_file_paths: + partition_graph = _load_graph_def_from_event_file( + self._dump_graph_file_paths[device_name]) + else: + logging.warn( + "Failed to load partition graphs for device %s from disk. " + "As a fallback, the client graphs will be used. This " + "may cause mismatches in device names." % device_name) + partition_graph = self._find_partition_graph(client_partition_graphs, + device_name) + + if partition_graph: + partition_graphs_and_device_names.append((partition_graph, + device_name)) for partition_graph, maybe_device_name in partition_graphs_and_device_names: debug_graph = debug_graphs.DebugGraph(partition_graph, diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py index bd00f738610627a4b3bc7c61476164188a7b460c..676097fde95e2e5a685e8e43f8f38d3e62e7084a 100644 --- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py +++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py @@ -44,7 +44,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): def _no_rewrite_session_config(self): rewriter_config = rewriter_config_pb2.RewriterConfig( - dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py index 917004694845c752d1f6bf88cc2a203eb8f9ba73..a7be20948df0d88c0861007c926186a469ffa19e 100644 --- a/tensorflow/python/debug/lib/grpc_debug_test_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py @@ -245,7 +245,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): self._origin_id_to_strings = [] self._graph_tracebacks = [] self._graph_versions = [] - self._source_files = None + self._source_files = [] def _initialize_toggle_watch_state(self, toggle_watches): self._toggle_watches = toggle_watches @@ -274,7 +274,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): self._origin_id_to_strings = [] self._graph_tracebacks = [] self._graph_versions = [] - self._source_files = None + self._source_files = [] def SendTracebacks(self, request, context): self._call_types.append(request.call_type) @@ -286,7 +286,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): return debug_service_pb2.EventReply() def SendSourceFiles(self, request, context): - self._source_files = request + self._source_files.append(request) return debug_service_pb2.EventReply() def query_op_traceback(self, op_name): @@ -351,9 +351,10 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): if not self._source_files: raise ValueError( "This debug server has not received any source file contents yet.") - for source_file_proto in self._source_files.source_files: - if source_file_proto.file_path == file_path: - return source_file_proto.lines[lineno - 1] + for source_files in self._source_files: + for source_file_proto in source_files.source_files: + if source_file_proto.file_path == file_path: + return source_file_proto.lines[lineno - 1] raise ValueError( "Source file at path %s has not been received by the debug server", file_path) diff --git a/tensorflow/python/debug/lib/source_remote.py b/tensorflow/python/debug/lib/source_remote.py index 4b6b2b995ecd13cffddaa38bd2ec673e6b824574..4afae41bc9a672c2c991f8fd2c3e1e6eecac193f 100644 --- a/tensorflow/python/debug/lib/source_remote.py +++ b/tensorflow/python/debug/lib/source_remote.py @@ -28,6 +28,7 @@ from tensorflow.python.debug.lib import common from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.debug.lib import source_utils from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging from tensorflow.python.profiler import tfprof_logger @@ -95,6 +96,11 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string): return non_tf_files +def grpc_message_length_bytes(): + """Maximum gRPC message length in bytes.""" + return 4 * 1024 * 1024 + + def _send_call_tracebacks(destinations, origin_stack, is_eager_execution=False, @@ -155,17 +161,28 @@ def _send_call_tracebacks(destinations, source_file_paths.update(_source_file_paths_outside_tensorflow_py_library( [call_traceback.origin_stack], call_traceback.origin_id_to_string)) - debugged_source_files = debug_pb2.DebuggedSourceFiles() + debugged_source_files = [] for file_path in source_file_paths: + source_files = debug_pb2.DebuggedSourceFiles() _load_debugged_source_file( - file_path, debugged_source_files.source_files.add()) + file_path, source_files.source_files.add()) + debugged_source_files.append(source_files) for destination in destinations: channel = grpc.insecure_channel(destination) stub = debug_service_pb2_grpc.EventListenerStub(channel) stub.SendTracebacks(call_traceback) if send_source: - stub.SendSourceFiles(debugged_source_files) + for path, source_files in zip( + source_file_paths, debugged_source_files): + if source_files.ByteSize() < grpc_message_length_bytes(): + stub.SendSourceFiles(source_files) + else: + tf_logging.warn( + "The content of the source file at %s is not sent to " + "gRPC debug server %s, because the message size exceeds " + "gRPC message length limit (%d bytes)." % ( + path, destination, grpc_message_length_bytes())) def send_graph_tracebacks(destinations, diff --git a/tensorflow/python/debug/lib/source_remote_test.py b/tensorflow/python/debug/lib/source_remote_test.py index 27bafa45e1207513e46fd2ae0f92d5bfa686ffd5..29add425e946aadfe941c73e9f9cef4aef3c8a9c 100644 --- a/tensorflow/python/debug/lib/source_remote_test.py +++ b/tensorflow/python/debug/lib/source_remote_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.platform import test from tensorflow.python.util import tf_inspect @@ -155,6 +156,51 @@ class SendTracebacksTest(test_util.TensorFlowTestCase): self.assertEqual(["dummy_run_key"], server.query_call_keys()) self.assertEqual([sess.graph.version], server.query_graph_versions()) + def testSourceFileSizeExceedsGrpcMessageLengthLimit(self): + """In case source file size exceeds the grpc message length limit. + + it ought not to have been sent to the server. + """ + this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit" + + # Patch the method to simulate a very small message length limit. + with test.mock.patch.object( + source_remote, "grpc_message_length_bytes", return_value=2): + with session.Session() as sess: + a = variables.Variable(21.0, name="two/a") + a_lineno = line_number_above() + b = variables.Variable(2.0, name="two/b") + b_lineno = line_number_above() + x = math_ops.add(a, b, name="two/x") + x_lineno = line_number_above() + + send_traceback = traceback.extract_stack() + send_lineno = line_number_above() + source_remote.send_graph_tracebacks( + [self._server_address, self._server_address_2], + "dummy_run_key", send_traceback, sess.graph) + + servers = [self._server, self._server_2] + for server in servers: + # Even though the source file content is not sent, the traceback + # should have been sent. + tb = server.query_op_traceback("two/a") + self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) + tb = server.query_op_traceback("two/b") + self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) + tb = server.query_op_traceback("two/x") + self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) + + self.assertIn( + (self._curr_file_path, send_lineno, this_func_name), + server.query_origin_stack()[-1]) + + tf_trace_file_path = ( + self._findFirstTraceInsideTensorFlowPyLibrary(x.op)) + # Verify that the source content is not sent to the server. + with self.assertRaises(ValueError): + self._server.query_source_file_line(tf_trace_file_path, 0) + def testSendEagerTracebacksToSingleDebugServer(self): this_func_name = "testSendEagerTracebacksToSingleDebugServer" send_traceback = traceback.extract_stack() diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index c530204bbf6959f56a72c6e67add91f1e575f067..b9524ce649c7d6d888affacc22cfadd41dbe2e40 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface): self._default_session_context_manager = None + # A cache for callables created from CallableOptions. + self._cached_callables_from_options = dict() + @property def graph(self): return self._sess.graph @@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): """Wrapper around Session.run() that inserts tensor watch options. Args: @@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. callable_runner: A `callable` returned by `Session.make_callable()`. If not `None`, `fetches` and `feed_dict` must both be `None`. - callable_runner_args: An optional list of arguments to `callable_runner`. + Mutually exclusive with `callable_options`. + callable_runner_args: An optional list of arguments to `callable_runner` + or for `callable_options`. + callable_options: An instance of `config_pb2.CallableOptions`, to be + used with `Session._make_callable_from_options()`. Mutually exclusive + with `callable_runner`. Returns: Simply forwards the output of the wrapped `Session.run()` call. @@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface): ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` is not `None` and either or both of `fetches` and `feed_dict` is `None`. """ - if not callable_runner: + if callable_runner and callable_options: + raise ValueError( + "callable_runner and callable_options are mutually exclusive, but " + "are both specified in this call to BaseDebugWrapperSession.run().") + + if not (callable_runner or callable_options): self.increment_run_call_count() - else: - if fetches or feed_dict: - raise ValueError( - "callable_runner and fetches/feed_dict are mutually exclusive, but " - "are used simultaneously.") + elif callable_runner and (fetches or feed_dict): + raise ValueError( + "callable_runner and fetches/feed_dict are mutually exclusive, " + "but are used simultaneously.") empty_fetches = not nest.flatten(fetches) if empty_fetches: @@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface): if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) + elif callable_options: + # pylint:disable=protected-access + return self._sess._make_callable_from_options( + callable_options)(*callable_runner_args) + # pylint:enable=protected-access else: return self._sess.run(fetches, feed_dict=feed_dict, @@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface): if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. - decorated_run_options = options or config_pb2.RunOptions() + decorated_run_options = None + if callable_options: + callable_options_id = id(callable_options) + if callable_options_id not in self._cached_callables_from_options: + # Make a copy of callable_options to avoid mutating it. + new_callable_options = config_pb2.CallableOptions() + new_callable_options.CopyFrom(callable_options) + decorated_run_options = new_callable_options.run_options + else: + decorated_run_options = options or config_pb2.RunOptions() + run_metadata = run_metadata or config_pb2.RunMetadata() - self._decorate_run_options_for_debug( - decorated_run_options, - run_start_resp.debug_urls, - debug_ops=run_start_resp.debug_ops, - node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, - op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=( - run_start_resp.tensor_dtype_regex_whitelist), - tolerate_debug_op_creation_failures=( - run_start_resp.tolerate_debug_op_creation_failures)) + if decorated_run_options: + self._decorate_run_options_for_debug( + decorated_run_options, + run_start_resp.debug_urls, + debug_ops=run_start_resp.debug_ops, + node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, + op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, + tensor_dtype_regex_whitelist=( + run_start_resp.tensor_dtype_regex_whitelist), + tolerate_debug_op_creation_failures=( + run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. @@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface): retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) + elif callable_options: + # pylint:disable=protected-access + if callable_options_id in self._cached_callables_from_options: + callable_object = self._cached_callables_from_options[ + callable_options_id] + else: + callable_object = self._sess._make_callable_from_options( + new_callable_options) + self._cached_callables_from_options[ + callable_options_id] = callable_object + # pylint:enable=protected-access + retvals = callable_object( + *callable_runner_args, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, @@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata=kwargs.get("run_metadata", None), callable_runner=runner, callable_runner_args=runner_args) + return wrapped_runner + def _make_callable_from_options(self, callable_options): + def wrapped_runner(*feed_values, **kwargs): + return self.run(None, + run_metadata=kwargs.get("run_metadata", None), + callable_options=callable_options, + callable_runner_args=feed_values) return wrapped_runner @property diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 1f9c8fa5a96b4d6826fae0870608e0e737c7cd88..85944fa61118114cc73f9288f3f974f0a5a8a839 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): if self._send_traceback_and_source_code: self._sent_graph_version = publish_traceback( self._grpc_debug_server_urls, self.graph, feed_dict, fetches, @@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=options, run_metadata=run_metadata, callable_runner=callable_runner, - callable_runner_args=callable_runner_args) + callable_runner_args=callable_runner_args, + callable_options=callable_options) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index c8625655e51a43a222addedd4beecdd3515d7fb6..668ffb57f10a69ce7e11e889fe613afbd618e823 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -290,6 +290,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_call_count == 1: # Show logo at the onset of the first run. help_intro.extend(cli_shared.get_tfdbg_logo()) + help_intro.extend(debugger_cli_common.get_tensorflow_version_lines()) help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:")) help_intro.extend(self._run_info) @@ -466,6 +467,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_call_count == 1: output.extend(cli_shared.get_tfdbg_logo()) + output.extend(debugger_cli_common.get_tensorflow_version_lines()) output.extend(self._run_info) if (not self._is_run_start and @@ -594,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): # Register tab completion for the filter names. curses_cli.register_tab_comp_context(["run", "r"], list(self._tensor_filters.keys())) - if self._feed_dict: + if self._feed_dict and hasattr(self._feed_dict, "keys"): # Register tab completion for feed_dict keys. feed_keys = [common.get_graph_element_name(key) for key in self._feed_dict.keys()] diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index b06fa26a935b42709575f8e400e0bda951ffbbc7..05c9eaa4d27319ecf5e12fdeb0a973246c61704a 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -21,7 +21,10 @@ import os import shutil import tempfile +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import debugger_cli_common @@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): dtypes.float32, shape=([5, 5]), name="sparse_placeholder") self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph) - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config_proto = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config_proto) # Initialize variable. self.sess.run(variables.global_variables_initializer()) @@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertAllClose(42.0, tensor_runner(41.0, 1.0)) self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"])) + def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + for _ in range(2): + callable_output = sess_callable() + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + + def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + a = math_ops.add(ph1, ph1, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array([10.5, -10.5], dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value) + self.assertAllClose( + np.array([42.0, -42.0], dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2") + a = math_ops.add(ph1, ph2, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.feed.append("callable_ph2") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array(5.0, dtype=np.float32) + ph2_value = np.array(16.0, dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value, ph2_value) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + run_metadata = config_pb2.RunMetadata() + # Call the callable with a custom run_metadata. + callable_output = sess_callable(run_metadata=run_metadata) + # Verify that step_stats is populated in the custom run_metadata. + self.assertTrue(run_metadata.step_stats) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(1, len(debug_dumps)) + debug_dump = debug_dumps[0] + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + def testRuntimeErrorShouldBeCaught(self): wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 5530193d4e1dd8f351b60c06502b2406c2c75d33..6ede8e4f4d9c549faae3223d400d25b7712bbc74 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -26,12 +26,14 @@ cc_library( "//tensorflow/c/eager:tape", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/python:cpp_python_util", "//tensorflow/python:ndarray_tensor", "//tensorflow/python:ndarray_tensor_bridge", "//tensorflow/python:numpy_lib", "//tensorflow/python:py_seq_tensor", "//tensorflow/python:safe_ptr", - "//util/python:python_headers", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", ], ) @@ -390,3 +392,20 @@ py_library( srcs = ["imperative_grad.py"], srcs_version = "PY2AND3", ) + +cuda_py_test( + name = "memory_test", + size = "medium", + srcs = ["memory_test.py"], + additional_deps = [ + "//tensorflow/python/eager:backprop", + "//tensorflow/python/keras", + "//tensorflow/python/eager:test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], + tags = [ + "optonly", # The test is too slow in non-opt mode + ], +) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 4cdf0a41adf12c6f642222bc5bd98b0e10aefcc5..3e3c82e56a8c957839e420550bfb073d400b4a77 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -20,7 +20,6 @@ from __future__ import print_function import functools import operator -import threading import six @@ -39,6 +38,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -93,8 +93,8 @@ class _MockOp(object): ) -def _magic_gradient_function(op_name, attr_tuple, num_inputs, - inputs, outputs, out_grads): +def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, + out_grads): """Calls the gradient function of the op. Args: @@ -116,8 +116,7 @@ def _magic_gradient_function(op_name, attr_tuple, num_inputs, return grad_fn(mock_op, *out_grads) -_gradient_functions = {} -_gradient_functions_lock = threading.Lock() +pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function) _tracing = False @@ -141,22 +140,6 @@ _grad_fn_accepts_none_for_indices = { } -def _get_backward_fn(op_name, attrs, num_inputs, op_inputs, op_outputs): - - def grad_fn(*orig_outputs): - result = _magic_gradient_function(op_name, attrs, num_inputs, - op_inputs, op_outputs, orig_outputs) - if _tracing: - print("Gradient for", op_name, "inputs", op_inputs, "output_grads", - orig_outputs, "gradients", result) - return nest.flatten(result) - - return grad_fn - - -pywrap_tensorflow.TFE_Py_RegisterBackwardFunctionGetter(_get_backward_fn) - - def _record_gradient(op_name, inputs, attrs, results, name): return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs, results, name) @@ -213,27 +196,25 @@ def implicit_val_and_grad(f): # TODO(cais): Remove calls to tf.constant() once the gradients functions # accept lists and np.ndarrays. - def grad_fn(*args): + def grad_fn(*args, **kwds): """Computes the gradient of the wrapped function.""" this_tape = tape.push_new_tape() try: - end_node = f(*args) + end_node = f(*args, **kwds) if end_node is None: raise ValueError("Cannot differentiate a function that returns None; " "did you forget to return a value from {}?".format( f.__name__)) finally: tape.pop_tape(this_tape) - # Sorting variables by id, which is monotonically increasing in construction - # order. This ensures unique order across executions. - # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc. - variables = list(sorted(this_tape.watched_variables(), - key=lambda v: v.handle._id)) # pylint: disable=protected-access - sources = [x.handle for x in variables] - - if not sources: + # Note: variables are returned in construction order. This ensures unique + # order across executions. + variables = this_tape.watched_variables() + if not variables: raise ValueError("No trainable variables were accessed while the " "function was being computed.") + + sources = [v.handle for v in variables] grad = imperative_grad.imperative_grad(_default_vspace, this_tape, nest.flatten(end_node), @@ -624,7 +605,9 @@ def _zeros(shape, dtype): # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None - cache_key = shape, dtype, device + # pylint: disable=protected-access + cache_key = shape, dtype, device, context.context()._eager_context.mode + # pylint: enable=protected-access cached = _zeros_cache.get(cache_key) if cached is None: cached = _fast_fill(0, shape, dtype) @@ -680,8 +663,8 @@ class GradientTape(object): ```python x = tf.constant(3.0) - with tfe.GradientTape() as g: - with tfe.GradientTape() as gg: + with tf.GradientTape() as g: + with tf.GradientTape() as gg: gg.watch(x) y = x * x dy_dx = gg.gradient(y, x) # Will compute to 6.0 @@ -722,21 +705,21 @@ class GradientTape(object): def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" - self._start_recording() + self._push_tape() return self def __exit__(self, typ, value, traceback): """Exits the recording context, no further operations are traced.""" if self._recording: - self._stop_recording() + self._pop_tape() - def _start_recording(self): + def _push_tape(self): if self._recording: raise ValueError("Tape is already recording.") self._tape = tape.push_new_tape(persistent=self._persistent) self._recording = True - def _stop_recording(self): + def _pop_tape(self): if not self._recording: raise ValueError("Tape is not recording.") tape.pop_tape(self._tape) @@ -751,12 +734,75 @@ class GradientTape(object): for t in nest.flatten(tensor): tape.watch(_handle_or_self(t)) + @tf_contextlib.contextmanager + def stop_recording(self): + """Temporarily stops recording operations on this tape. + + Operations executed while this context manager is active will not be + recorded on the tape. This is useful for reducing the memory used by tracing + all computations. + + For example: + + ``` + with tf.GradientTape(persistent=True) as t: + loss = compute_loss(model) + with t.stop_recording(): + # The gradient computation below is not traced, saving memory. + grads = t.gradient(loss, model.variables) + ``` + + Yields: + None + Raises: + RuntimeError: if the tape is not currently recording. + """ + if self._tape is None: + raise RuntimeError( + "Trying to stop recording a tape which is not recording.") + self._pop_tape() + try: + yield + finally: + self._push_tape() + + def reset(self): + """Clears all information stored in this tape. + + Equivalent to exiting and reentering the tape context manager with a new + tape. For example, the two following code blocks are equivalent: + ``` + with tf.GradientTape() as t: + loss = loss_fn() + with tf.GradientTape() as t: + loss += other_loss_fn() + t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn + + + # The following is equivalent to the above + with tf.GradientTape() as t: + loss = loss_fn() + t.reset() + loss += other_loss_fn() + t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn + ``` + + This is useful if you don't want to exit the context manager for the tape, + or can't because the desired reset point is inside a control flow construct: + + ``` + with tf.GradientTape() as t: + loss = ... + if loss > k: + t.reset() + ``` + """ + self._pop_tape() + self._push_tape() + def watched_variables(self): - # Sorting variables by id, which is monotonically increasing in construction - # order. This ensures unique order across executions. - # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc. - return list(sorted(self._tape.watched_variables(), - key=lambda v: v.handle._id)) # pylint: disable=protected-access + """Returns variables watched by this tape in order of construction.""" + return self._tape.watched_variables() def gradient(self, target, sources, output_gradients=None): """Computes the gradient using operations recorded in context of this tape. @@ -782,7 +828,7 @@ class GradientTape(object): "non-persistent tapes.") if self._recording: if not self._persistent: - self._stop_recording() + self._pop_tape() else: logging.log_first_n(logging.WARN, "Calling GradientTape.gradient on a persistent " diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index d4b3c8bb5fe95e96fbb2bcb7c57c8f8ada62cad5..ebbd3cd98e892fddb556fc95a4292e05d16fc167 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -46,7 +46,7 @@ from tensorflow.python.training import training class BackpropTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateGradients(self): def fn(x): @@ -221,6 +221,21 @@ class BackpropTest(test.TestCase): self.assertTrue(ordered_variables[0] is v0) self.assertTrue(ordered_variables[1] is v1) + def testTapeStopRecording(self): + with backprop.GradientTape() as t: + x = constant_op.constant(1.0) + with t.stop_recording(): + y = x * x + self.assertEqual(t.gradient(y, x), None) + + def testTapeReset(self): + with backprop.GradientTape() as t: + v = resource_variable_ops.ResourceVariable(1.0) + loss = v * v + t.reset() + loss += v * v + self.assertAllEqual(t.gradient(loss, v), 2.0) + @test_util.assert_no_new_tensors def testGradientNone(self): @@ -236,7 +251,7 @@ class BackpropTest(test.TestCase): g, = backprop.gradients_function(loss, [0])(logits, labels) self.assertAllEqual(g.numpy(), [[-0.5, 0.5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientWithinTapeBlock(self): v1 = resource_variable_ops.ResourceVariable(1.) self.evaluate(v1.initializer) @@ -250,7 +265,7 @@ class BackpropTest(test.TestCase): grad = t.gradient(loss, v1) self.assertAllEqual(self.evaluate(grad), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestedSelfContexts(self): v1 = resource_variable_ops.ResourceVariable(1.) self.evaluate(v1.initializer) @@ -420,7 +435,7 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=False) as g: x = constant_op.constant(3.0) @@ -430,7 +445,7 @@ class BackpropTest(test.TestCase): self.assertEqual(self.evaluate(grad), [2.0, 2.0]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -444,7 +459,7 @@ class BackpropTest(test.TestCase): self.assertEqual(self.evaluate(grad), [3.0, 11.0]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeStructure(self): with backprop.GradientTape(persistent=True) as g: # Using different constant values because constant tensors are @@ -467,7 +482,7 @@ class BackpropTest(test.TestCase): [1.0, {'x2': 2.0, 'x3': 3.0}]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTape(self): with backprop.GradientTape() as g: x = constant_op.constant(3.0) @@ -482,7 +497,7 @@ class BackpropTest(test.TestCase): grad = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(grad), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithCond(self): x = constant_op.constant(3.0) @@ -503,7 +518,7 @@ class BackpropTest(test.TestCase): dy = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(dy), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithWhileLoop(self): i = constant_op.constant(1) x = constant_op.constant(2.) @@ -538,7 +553,7 @@ class BackpropTest(test.TestCase): g.gradient(y, [x]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -552,7 +567,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testHigherOrderGradient(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -569,7 +584,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentNestedTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -590,7 +605,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') self.evaluate(v.initializer) @@ -599,6 +614,18 @@ class BackpropTest(test.TestCase): grad = g.gradient(y, [v])[0] self.assertAllEqual(self.evaluate(grad), 2.0) + @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes + def testNestedGradients(self): + x = constant_op.constant(3.0) + with backprop.GradientTape() as g: + g.watch(x) + y = x * x + z = y * y + dz_dx, dz_dy = g.gradient(z, [x, y]) + self.assertEqual(self.evaluate(dz_dx), 108.0) + self.assertEqual(self.evaluate(dz_dy), 18.0) + @test_util.assert_no_new_tensors def testEmptyParamsForValueAndGradFunction(self): def fn(a, b): @@ -873,6 +900,33 @@ class BackpropTest(test.TestCase): 'did you forget to return a value from fn?'): val_and_grads_fn(x, y) + def testZerosCacheDoesntLeakAcrossModes(self): + with ops.Graph().as_default(): + t = random_ops.random_normal(shape=[100, 2]) + x = random_ops.random_normal(shape=[100, 4]) + dy = random_ops.random_normal(shape=[100, 4]) + with backprop.GradientTape() as gradient_tape: + gradient_tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1 ** 2. + y = array_ops.concat([y1, t], axis=1) + + dx = gradient_tape.gradient(y, x, output_gradients=dy) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(dx) + + t = random_ops.random_normal(shape=[100, 2]) + x = random_ops.random_normal(shape=[100, 4]) + dy = random_ops.random_normal(shape=[100, 4]) + with backprop.GradientTape() as gradient_tape: + gradient_tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1 ** 2. + y = array_ops.concat([y1, t], axis=1) + + dx = gradient_tape.gradient(y, x, output_gradients=dy) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 9e146f021e813886b42ca72b07122b485901a24b..85b9491903de2ea6ffe1c5ac7ef76efdfda2818b 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -143,7 +143,11 @@ class Context(object): # TODO(agarwal): create and link in some documentation for `execution_mode`. # pylint: disable=redefined-outer-name - def __init__(self, config=None, device_policy=None, execution_mode=None): + def __init__(self, + config=None, + device_policy=None, + execution_mode=None, + server_def=None): """Creates a new Context. Args: @@ -192,6 +196,7 @@ class Context(object): if execution_mode is None: execution_mode = SYNC self._execution_mode = execution_mode + self._server_def = server_def # pylint: enable=redefined-outer-name @@ -231,6 +236,9 @@ class Context(object): opts, self._device_policy) if self._execution_mode == ASYNC: pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) + if self._server_def is not None: + server_def_str = self._server_def.SerializeToString() + pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str) self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 120b298171b6272fdd63f76a538fa2c92a58616d..7a7e8cd219858e74cb30f22c194fe86d1a4b5e83 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import functools import numpy as np @@ -35,6 +36,7 @@ from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import compat @@ -46,8 +48,11 @@ def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: - captured_value = graph_placeholder( - dtype=dtype or value.dtype, shape=value.shape, name=name) + # Note: setting ops.control_dependencies(None) ensures we always put + # capturing placeholders outside of any control flow context. + with ops.control_dependencies(None): + captured_value = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): @@ -222,11 +227,25 @@ def _inference_name(n): return "__inference_%s_%s" % (n, ops.uid()) +def _register(fn): + """Registers the function `fn`.""" + context.context().add_function(fn) + + +_xla_compile_attr = "_XlaCompile" + + # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction # so it doesn't have the definition-generating logic and is just a container for # an already-defined function. class _EagerDefinedFunction(object): - """Function object with the interface of tf _DefinedFunction.""" + """Callable with the interface of `framework.function._DefinedFunction.` + + `_EagerDefinedFunction` encapsulates a function definition and its properties, + and it provides a method for calling the encapsulated function. Some Ops + take functions as attributes, which have type `func`; an instance of this + class may be provided as the value of these `func` attributes. + """ def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. @@ -257,6 +276,7 @@ class _EagerDefinedFunction(object): # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) + self._xla_compile = _xla_compile_attr in attrs # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. @@ -268,12 +288,92 @@ class _EagerDefinedFunction(object): if context.executing_eagerly(): _register(fn) self.definition = function_def - self.name = function_def.signature.name + self.name = compat.as_bytes(function_def.signature.name) self.signature = function_def.signature + self._num_outputs = len(self.signature.output_arg) + self._output_types = [o.type for o in self.signature.output_arg] self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None + self._graph = graph + self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful) + + def add_to_graph(self, g): + # pylint: disable=protected-access + if self.name not in g._functions: + g._add_function(self) + for f in self._graph._functions.values(): + if f.name not in g._functions: + g._add_function(f) + # pylint: enable=protected-access + + @property + def stateful_ops(self): + return self._stateful_ops + + def call(self, ctx, args, output_shapes): + """Calls this function with `args` as inputs. + + Function execution respects device annotations only if the function won't + be compiled with xla. + + Args: + ctx: a Context object + args: a list of arguments to supply this function with. + output_shapes: shapes to which outputs should be set; ignored when + executing eagerly. + + Returns: + The outputs of the function call. + """ + + executing_eagerly = ctx.executing_eagerly() + + xla_compile = self._xla_compile or (executing_eagerly and + ctx.device_spec.device_type == "TPU") + + if xla_compile: + # XLA compilation relies upon a custom kernel creator to run functions. + signature = self.signature + if executing_eagerly: + outputs = execute.execute( + str(signature.name), + num_outputs=self._num_outputs, + inputs=args, + attrs=None, + ctx=ctx) + else: + g = ops.get_default_graph() + self.add_to_graph(g) + op = g.create_op( + signature.name, + [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), + op_def=signature, + name="FunctionCall", + compute_shapes=False) + outputs = op.outputs + if not outputs: + return op + outputs = [outputs] if isinstance( + outputs, (ops.Tensor, type(None))) else list(outputs) + else: + # TODO(akshayka): Either remove this if the FunctionLibraryRuntime + # creates `PartitionedCallOp` kernels by default, or remove the previous + # branch if a TPU kernel is registered for `PartitionedCall`. + outputs = functional_ops.partitioned_call( + args=args, + f=self, + tout=self._output_types, + executing_eagerly=executing_eagerly) + + if executing_eagerly: + return outputs + else: + for i, shape in enumerate(output_shapes): + outputs[i].set_shape(shape) + return outputs def _map_sequence_obj_to_idx(sequence): @@ -297,8 +397,12 @@ def _flatten(sequence): return outputs +# TODO(akshayka): Perhaps rename to something more appropriate. class GraphModeFunction(object): - """Callable object representing a graph-mode function. + """Callable object encapsulating a function definition and its gradient. + + `GraphModeFunction` is a callable that encapsulates a function definition and + is differentiable under `tf.GradientTape` objects. """ def __init__(self, @@ -308,7 +412,7 @@ class GraphModeFunction(object): graph, operations, outputs, - func_outputs, + python_func_outputs, output_shapes, variables=None, attrs=None): @@ -327,9 +431,10 @@ class GraphModeFunction(object): definition. outputs: a flat list of the Tensors in the graph used as outputs to the function - func_outputs: a possibly nested python object which will be returned by - this function. The Tensors in this structure will be replaced by their - corresponding values in outputs. + python_func_outputs: a possibly nested python object which will be + returned by this function. The Tensors in this structure will be + replaced by their corresponding values in outputs. Note that this + structure might contain Python `None`s. output_shapes: List of shapes of all tensors in outputs variables: (optional) List of variables to watch during function execution. @@ -351,9 +456,10 @@ class GraphModeFunction(object): self._function_def = defined_function self._num_outputs = len(defined_function.signature.output_arg) self._ops = operations - self._func_outputs = func_outputs - self._returns = [func_outputs] if isinstance( - func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs) + self._python_func_outputs = python_func_outputs + self._python_returns = [python_func_outputs] if isinstance( + python_func_outputs, + (ops.Tensor, type(None))) else _flatten(python_func_outputs) self._output_shapes = output_shapes self._variables = variables if variables is not None else [] @@ -368,7 +474,7 @@ class GraphModeFunction(object): c_captured_tensors = set() existing_op_len = len(self._graph.get_operations()) - filtered_outputs = [x for x in self._returns if x is not None] + filtered_outputs = [x for x in self._python_returns if x is not None] self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] in_gradients = gradients_impl.gradients( @@ -377,7 +483,7 @@ class GraphModeFunction(object): grad_ys=self._out_grad_placeholders) for op in self._graph.get_operations()[existing_op_len:]: if op.type in ["Variable", "VariableV2", "VarHandleOp"]: - raise ValueError("tfe.defun cannot capture variables created without " + raise ValueError("defun cannot capture variables created without " "using tf.get_variable. Op: %s" % op) c_known_ops.add(op) for i in op.inputs: @@ -409,40 +515,32 @@ class GraphModeFunction(object): backward_outputs, in_gradients, output_shapes, attrs=self._attrs) def _backprop_call(self, args): - """Calls the wrapped function and records the result on a tape.""" + """Calls the wrapped function and records the result on a tape. + + (Only records results on a tape if the function has outputs) + + Args: + args: The tensor inputs to the function. + Returns: + The call output. + """ all_args = args + self._extra_inputs - signature = self._forward_fdef.signature ctx = context.context() - if ctx.executing_eagerly(): - outputs = execute.execute( - str(signature.name), - num_outputs=len(signature.output_arg), - inputs=all_args, - attrs=None, - ctx=ctx) - else: - g = ops.get_default_graph() - g._add_function(self._forward_fdef) # pylint: disable=protected-access - op = g.create_op( - signature.name, - [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), - op_def=signature, - name="FunctionCall", - compute_shapes=False) - outputs = op.outputs - outputs = [outputs] if isinstance( - outputs, (ops.Tensor, type(None))) else list(outputs) - for i, s in enumerate(self._output_shapes): - outputs[i].set_shape(s) - real_outputs = outputs[:len(self._returns)] - side_outputs = outputs[len(self._returns):] + outputs = self._forward_fdef.call(ctx, all_args, self._output_shapes) + if isinstance(outputs, ops.Operation) or outputs is None: + return outputs + + # `real_outputs` are the actual outputs of the inference graph function; + # `side_outputs` are the intermediate Tensors that were added as outputs to + # the forward graph function so that we can compute its gradient. + real_outputs = outputs[:self._num_outputs] + side_outputs = outputs[self._num_outputs:] def backward_function(*args): return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable tape.record_operation( - signature.name, + self._forward_fdef.signature.name, real_outputs, (args + self._extra_inputs), backward_function) @@ -453,8 +551,8 @@ class GraphModeFunction(object): def output_shapes(self): """The function's output shapes.""" # TODO(ebrevdo): Should we only keep the output shapes associated - # with len(self._returns) outputs? - outputs_list = nest.flatten(self._func_outputs) + # with len(self._python_returns) outputs? + outputs_list = nest.flatten(self._python_func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -468,12 +566,12 @@ class GraphModeFunction(object): else: outputs_list[i] = self._output_shapes[j] j += 1 - return nest.pack_sequence_as(self._func_outputs, outputs_list) + return nest.pack_sequence_as(self._python_func_outputs, outputs_list) @property def output_dtypes(self): return nest.map_structure( - lambda x: x.dtype if x is not None else None, self._func_outputs) + lambda x: x.dtype if x is not None else None, self._python_func_outputs) @property def captured_inputs(self): @@ -484,17 +582,10 @@ class GraphModeFunction(object): """Returns the name of the function in Eager-compatible format.""" return self._function_def.name.encode("utf-8") - def add_to_graph(self, g): - if self._function_def.name not in g._functions: # pylint: disable=protected-access - g._add_function(self._function_def) # pylint: disable=protected-access - for f in self._graph._functions.values(): # pylint: disable=protected-access - if f.name not in g._functions: # pylint: disable=protected-access - g._add_function(f) # pylint: disable=protected-access - def __call__(self, *args): """Executes the passed function in eager mode.""" for v in self._variables: - if v._trainable: # pylint: disable=protected-access + if v.trainable: tape.watch_variable(v) tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] @@ -505,32 +596,9 @@ class GraphModeFunction(object): return self._backprop_call(tensor_inputs) ctx = context.context() - if ctx.executing_eagerly(): - result = execute.execute( - str(self._func_name), - num_outputs=self._num_outputs, - inputs=tensor_inputs + self._extra_inputs, - attrs=None, - ctx=ctx) - else: - g = ops.get_default_graph() - self.add_to_graph(g) - signature = self._function_def.definition.signature - args = list(tensor_inputs) + self._extra_inputs - op = g.create_op( - signature.name, - [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), - op_def=signature, - name="FunctionCall", - compute_shapes=False) - result = op.outputs - if not result: - return op - for i, s in enumerate(self._output_shapes): - result[i].set_shape(s) - - return self._build_call_outputs(result) + args = tensor_inputs + self._extra_inputs + outputs = self._function_def.call(ctx, args, self._output_shapes) + return self._build_call_outputs(outputs) def _build_call_outputs(self, result): """Maps the fdef output list to actual output structure. @@ -540,11 +608,12 @@ class GraphModeFunction(object): Returns: The actual call output. """ - if self._func_outputs is None: - return None + if self._python_func_outputs is None: + return result + # Use `nest.flatten` instead of `_flatten` in order to preserve any - # IndexedSlices in `self._func_outputs`. - outputs_list = nest.flatten(self._func_outputs) + # IndexedSlices in `self._python_func_outputs`. + outputs_list = nest.flatten(self._python_func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -564,7 +633,7 @@ class GraphModeFunction(object): else: outputs_list[i] = result[j] j += 1 - ret = nest.pack_sequence_as(self._func_outputs, outputs_list) + ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list) return ret @@ -580,7 +649,11 @@ def _get_defun_inputs(args): return nest.pack_sequence_as(args, ret) -def _defun_internal(name, func, compiled, args, kwds): +def _deterministic_dict_values(kwds): + return tuple(kwds[key] for key in sorted(kwds)) + + +def _trace_and_define_function(name, func, compiled, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): @@ -597,7 +670,8 @@ def _defun_internal(name, func, compiled, args, kwds): tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) with tmp_graph.as_default(), AutomaticControlDependencies() as a: - func_inputs = _get_defun_inputs(args) + func_args = _get_defun_inputs(args) + func_kwds = _get_defun_inputs(kwds) def convert(x): if x is None: @@ -608,7 +682,7 @@ def _defun_internal(name, func, compiled, args, kwds): this_tape = tape.push_new_tape() try: - func_outputs = func(*func_inputs, **kwds) + func_outputs = func(*func_args, **func_kwds) func_outputs = nest.map_structure(convert, func_outputs) finally: tape.pop_tape(this_tape) @@ -630,10 +704,13 @@ def _defun_internal(name, func, compiled, args, kwds): extra_placeholders = [] output_shapes = tuple( x.shape if isinstance(x, ops.Tensor) else None - for x in outputs_list) + for x in func_def_outputs) - flat_inputs = [x for x in nest.flatten(func_inputs) - if isinstance(x, ops.Tensor)] + func_kwds_values = _deterministic_dict_values(func_kwds) + flat_inputs = [ + x for x in nest.flatten(func_args) + nest.flatten(func_kwds_values) + if isinstance(x, ops.Tensor) + ] all_inputs = flat_inputs + list(extra_placeholders) all_ignored_ops = frozenset(x.op for x in all_inputs) fname = _inference_name(name) @@ -648,7 +725,7 @@ def _defun_internal(name, func, compiled, args, kwds): attrs = {} if compiled: - attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True) + attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True) return GraphModeFunction( fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, @@ -688,42 +765,89 @@ def _cache_key(x): return x -def _register(fn): - """Registers the function `fn`.""" - context.context().add_function(fn) +class _PolymorphicFunction(object): + """Wrapper class for the graph functions defined for a Python function. + See the documentation for `defun` for more information on the semantics of + defined functions. + """ -# TODO(apassos): better error messages for non-hashable arguments. -def named_defun(func, name, compiled=False): - """Defines a function with a given name. + def __init__(self, python_function, name, compiled=False): + """Initializes a polymorphic function. - See the documentation for `defun` for more information on the semantics of - this function. + Args: + python_function: the function to be wrapped. + name: the name given to it. + compiled: if True, the framework will attempt to compile func with XLA. + """ - Args: - func: the function to be wrapped. - name: the name given to it. - compiled: if true, the framework will attempt to compile func with XLA. + self._python_function = python_function + self._name = name + self._compiled = compiled + self._arguments_to_functions = {} + self._variables = [] + + def __get__(self, instance, owner): + """Makes it possible to defun instance methods.""" + del owner + # `instance` here is the instance that this `_PolymorphicFunction` was + # accessed through; e.g., for + # + # class Foo(object): + # + # @function.defun + # def bar(self): + # ... + # + # foo = Foo() + # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance + # + # then `instance` will be `foo` (and `owner` will be `Foo`). + return functools.partial(self.__call__, instance) - Returns: - the wrapped function. - """ - arguments_to_functions = {} + def _maybe_define_function(self, *args, **kwds): + """Gets a function for these inputs, defining it if necessary. + + Args: + *args: args for the Python function; used to compute the signature + **kwds: kwds for the Python function; used to compute the signature - def decorated(*args, **kwds): - """Decorated version of func.""" - # Macroexpand on non-Tensor arguments - cache_key = tuple(_cache_key(x) for x in args) - if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): - raise ValueError("Tensor keyword arguments are not supported.") - cache_key = (cache_key, tuple(kwds.items())) + Returns: + A graph function corresponding to the input signature implied by args and + kwds, as well as the inputs that the object should be called with. + """ - if cache_key not in arguments_to_functions: - arguments_to_functions[cache_key] = _defun_internal( - name, func, compiled, args, kwds) - return arguments_to_functions[cache_key](*args) + # TODO(apassos): Better error messages for non-hashable arguments. + kwd_values = _deterministic_dict_values(kwds) + inputs = args + kwd_values + signature = tuple(_cache_key(x) for x in inputs) + # The graph, or whether we're executing eagerly, should be a part of the + # signature so we don't improperly capture tensors such as variables. + signature += tuple([context.executing_eagerly() or ops.get_default_graph()]) + + if signature not in self._arguments_to_functions: + graph_function = _trace_and_define_function( + self._name, self._python_function, self._compiled, args, kwds) + self._arguments_to_functions[signature] = graph_function + self._variables.extend( + [v for v in graph_function.variables if v not in self._variables]) + return graph_function, inputs + else: + return self._arguments_to_functions[signature], inputs - return decorated + def __call__(self, *args, **kwds): + """Calls a graph function specialized for this input signature.""" + graph_function, inputs = self._maybe_define_function(*args, **kwds) + return graph_function(*inputs) + + def call_python_function(self, *args, **kwargs): + """Directly calls the wrapped python function.""" + return self._python_function(*args, **kwargs) + + @property + def variables(self): + """Returns a list of variables used in any of the defined functions.""" + return self._variables # TODO(akshayka): Remove the `compiled` flag and create a separate @@ -734,22 +858,33 @@ def defun(func=None, compiled=False): `defun` (short for "define function") trace-compiles a Python function composed of TensorFlow operations into a callable that executes a @{tf.Graph} - containing those operations. When eager execution is enabled, the ability to - create graphs from Python functions makes it possible to incrementally trade - off debugability and interactivity for performance. Functions compiled with - `defun` cannot be inspected with `pdb` and `print` statements; however, - executing a graph generated by `defun` sometimes takes less time and memory - than eagerly executing the corresponding Python function, since specifying - computations as graphs allows for optimizations like automatic buffer reuse - and parallelization among ops. Note that executing a `defun`-compiled function + containing those operations. The callable produced by `defun` contains only + the subgraph of TensorFlow operations that were executed when the Python + function was called with a particular input signature, defined as a list + of the shapes and dtypes of the Python function's Tensor-valued arguments and + the values of its non-Tensor Python objects. In particular, `defun` is _not_ a + compiler for arbitrary Python code. + + When eager execution is enabled, the ability to create graphs from Python + functions makes it possible to incrementally trade off debugability and + interactivity for performance. Functions compiled with `defun` cannot be + inspected with `pdb` and `print` statements; however, executing a graph + generated by `defun` sometimes takes less time and memory than eagerly + executing the corresponding Python function, since specifying computations as + graphs allows for optimizations like automatic buffer reuse and + parallelization among ops. Note that executing a `defun`-compiled function incurs a small constant overhead, so eagerly executing sufficiently small Python functions might take less time than executing their corresponding `defun`-generated graphs. - For a Python function to be compatible with `defun`, the values of its keyword - arguments cannot be Tensors and all of its arguments, including its keyword - arguments, must be hashable Python objects or lists thereof. Additionally, it - must return zero or more @{tf.Tensor} objects. + For a Python function to be compatible with `defun`, all of its arguments must + be hashable Python objects or lists thereof. Additionally, it must return zero + or more @{tf.Tensor} objects. + + Executing a graph generated by `defun` respects device annotations (i.e., + all `with tf.device` directives present in a Python function will also be + present in its corresponding graph), but it is not yet possible to execute the + generated graphs across multiple machines. _Example Usage_ @@ -777,7 +912,7 @@ def defun(func=None, compiled=False): def h(): return f(x, y) - assert h().numpy() == f(x, y) + assert (h().numpy() == f(x, y).numpy()).all() # `defun` automatically lifts variables out of the graphs it creates, # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and @@ -785,6 +920,7 @@ def defun(func=None, compiled=False): class MyModel(tf.keras.Model): def __init__(self, keep_probability=0.2): + super(MyModel, self).__init__() self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) self.keep_probability = keep_probability @@ -804,7 +940,7 @@ def defun(func=None, compiled=False): # `defun`-compiled functions are differentiable. optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) with tf.GradientTape() as tape: - outputs = model(inputs) + outputs = model(x) gradient = tape.gradient(outputs, model.trainable_variables) optimizer.apply_gradients((grad, var) for grad, var in zip(gradient, model.trainable_variables)) @@ -821,36 +957,47 @@ def defun(func=None, compiled=False): _Tracing and Input Signatures_. The signature of inputs supplied to `F` is defined to be a tuple of the shapes - and dtypes of Tensor-typed arguments and the values of non-Tensor arguments - and keyword arguments. Every time `F` is invoked, the signature of its inputs - are inferred. The first time `F(*args, **kwargs)` is invoked with a particular - signature, `f(*args, **kwargs)` is executed and all the TensorFlow operations - that `f` executes, along with the Tensors that flow between them, are recorded - in a TensorFlow graph. `F` caches this graph and binds it to the inputs' - signature; every subsequent invocation of `F` with inputs conforming to this - signature will immediately retrieve the cached graph and pass it to the - TensorFlow runtime for execution. - - Be aware that because `F` only logs TensorFlow operations, all non-TensorFlow - operations that `f` executes will only shape the _construction_ of the graphs - that `F` executes: They won't be executed when the graphs themselves are - executed. For example, whereas the Python function + and dtypes of Tensor-typed arguments and the values of non-Tensor arguments, + where "arguments" includes both args and kwargs. Every time `F` is invoked, + the signature of its inputs are inferred. The first time `F(*args, **kwargs)` + is invoked with a particular signature, `f(*args, **kwargs)` is executed and + all the TensorFlow operations that `f` executes, along with the Tensors that + flow between them, are recorded in a TensorFlow graph. `F` caches this graph + and binds it to the inputs' signature; every subsequent invocation of `F` with + inputs conforming to this signature will immediately retrieve the cached graph + and pass it to the TensorFlow runtime for execution. + + Be aware that because `F` only logs TensorFlow operations, all the other + Python code that `f` executes will only shape the _construction_ of the graphs + that `F` executes: the Python code won't be executed when the graphs + themselves are executed, though it will be executed every time the Python + function is traced (and a given Python function might be traced multiple + times, once for each input signature it is invoked with). For example, whereas + the Python function ```python import tensorflow as tf import numpy as np - matrix = tf.eye(5) - # `matrix` is assumed to be a Tensor + tf.enable_eager_execution() + def add_noise(): - return matrix + np.random.randn(matrix.shape[0], matrix.shape[1]) + return tf.eye(5) + np.random.randn(5, 5) ``` will return a different output everytime it is invoked, the compiled function `compiled = tf.contrib.eager.defun(add_noise)` will return the same value every time it is called, since a particular random offset generated by NumPy will be inserted into the graph as a TensorFlow constant. The solution is to - replace the call to `np.random.randn` with `tf.random_normal(matrix.shape)`. + replace the call to `np.random.randn` with `tf.random_normal((5, 5))`. + + _Python Side-Effects_ + A corollary of the previous discussion on tracing is the following: If a + Python function `f` has Python side-effects, then executing `f` multiple times + will not necessarily be semantically equivalent to executing `F = + tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact + that `defun` only captures the subgraph of TensorFlow operations that is + constructed when `f` is called in a graph-building context. _Python Control Flow_. The structure of many machine learning computations depend upon whether one is @@ -862,6 +1009,8 @@ def defun(func=None, compiled=False): ```python import tensorflow as tf + tf.enable_eager_execution() + @tf.contrib.eager.defun def lossy_matmul(W, x, training=True): outputs = tf.matmul(W, x) @@ -869,6 +1018,9 @@ def defun(func=None, compiled=False): outputs = tf.nn.dropout(outputs, keep_probability=0.2) return outputs + W = tf.random_normal((3, 5)) + x = tf.random_normal((5, 1)) + # Executes a graph that applies dropout. lossy_outputs = lossy_matmul(W, x, training=True) @@ -919,14 +1071,14 @@ def defun(func=None, compiled=False): # `fn` is a Python function, so x is created, initialized, and destroyed upon # every invocation - assert(fn().numpy() == fn().numpy() == 1.0) + assert fn().numpy() == fn().numpy() == 1.0 compiled = tf.contrib.eager.defun(fn) # Compiling `fn` with `defun` hoists all variables outside of the generated # graph, so initialization happens exactly once. - assert(compiled().numpy() == 1.0) - assert(compiled().numpy() == 2.0) + assert compiled().numpy() == 1.0 + assert compiled().numpy() == 2.0 ``` Finally, because each input signature is bound to a unique graph, if your @@ -972,7 +1124,7 @@ def defun(func=None, compiled=False): except AttributeError: name = "function" return tf_decorator.make_decorator( - function, named_defun(function, name, compiled=compiled)) + function, _PolymorphicFunction(function, name, compiled=compiled)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1029,15 +1181,8 @@ def make_defun_op(func, *args, **kwds): A wrapper object which can be queried for its output properties, and which can be called directly the way a `@defun` wrapped function can. - - Raises: - ValueError: if any of the keyword arguments to `func` are `EagerTensor` - objects (not yet supported). """ - name = func.__name__ - if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): - raise ValueError("Tensor keyword arguments are not supported.") - return _defun_internal(name, func, False, args, kwds) + return _trace_and_define_function(func.__name__, func, False, args, kwds) class AutomaticControlDependencies(object): @@ -1207,6 +1352,9 @@ class AutomaticControlDependencies(object): # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: + # TODO(apassos) make this code safely support while loops. + if isinstance(op._control_flow_context, control_flow_ops.WhileContext): # pylint: disable=protected-access + continue control_inputs = set() # Ensure stateful ops run if (op.type not in self._graph._registered_ops # pylint: disable=protected-access diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index f53d6c26083cad8efd291a064393561c4bebfcfb..1de25811b4ee2cbee03229e9351baf41517c6bf9 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -19,11 +19,12 @@ from __future__ import print_function import collections +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import tape -from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function as tf_function @@ -34,12 +35,15 @@ from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent +from tensorflow.python.util import compat @test_util.with_c_shapes @@ -90,6 +94,32 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) + def testGraphGradientVariable(self): + with ops.Graph().as_default(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + return 2.0 * v + + node = f() + grads, = gradients_impl.gradients(node, v) + v.initializer.run() + self.assertAllEqual(grads.eval(), 2.0) + self.assertEqual(grads.shape, v.shape) + + def testGraphEagerIsolation(self): + + @function.defun + def f(): + v = resource_variable_ops.ResourceVariable(1.0) + return v.read_value() + + self.assertAllEqual(f(), 1.0) + + with ops.Graph().as_default(): + self.assertEqual(f().shape, ()) + def testBasicDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) @@ -166,6 +196,15 @@ class FunctionTest(test.TestCase): self.assertEqual(fn_op.output_shapes, None) self.assertAllEqual(fn_op(x, x), None) + def testDefunCapturedInt32(self): + x = constant_op.constant(1, dtype=dtypes.int32) + + @function.defun + def add_int32s(): + return x + x + + self.assertEqual(2, int(add_int32s())) + def testDefunReadVariable(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -177,13 +216,14 @@ class FunctionTest(test.TestCase): def testDefunAssignAddVariable(self): v = resource_variable_ops.ResourceVariable(1.0) + x = constant_op.constant(2.0) @function.defun - def f(): - v.assign_add(2.0) + def test_assign_add(): + v.assign_add(x) return v.read_value() - self.assertEqual(3.0, float(f())) + self.assertEqual(3.0, float(test_assign_add())) def testDefunShapeInferenceWithCapturedResourceVariable(self): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) @@ -196,6 +236,21 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + def testVariableInLoopInFunction(self): + + @function.defun + def test_function(): + + def loop_test(_): + return False + + def loop_body(_): + return variable_scope.get_variable('a', shape=()) + + return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) + + self.assertEqual(test_function().shape, []) + def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): with context.graph_mode(): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) @@ -349,6 +404,23 @@ class FunctionTest(test.TestCase): g(constant_op.constant(1.0)) + def testNestedDefunWithNoOutputAndTapedInput(self): + three = resource_variable_ops.ResourceVariable(3.0, name='v') + + @function.defun + def f(x): + # This function intentionally takes a taped variable as input, + # but does not return any values + math_ops.add(x, three) + + @function.defun + def g(x): + tape.watch_variable(x) + y = math_ops.add(x, three) + f(y) + + g(three) + def testGradientTensorConversionWithDefun(self): three = resource_variable_ops.ResourceVariable(3.0, name='v') @@ -381,24 +453,33 @@ class FunctionTest(test.TestCase): self.assertAllEqual(f(constant_op.constant(1.0)), 2.0) - def testGradientOfGatherWithDefun(self): + def testGatherResourceWithDefun(self): with ops.device('cpu:0'): v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) - def sum_gather(): - return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) + def sum_gather(): + return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) + + defined = function.defun(sum_gather) + self.assertAllEqual(sum_gather(), defined()) + + def testGradientOfGatherWithDefun(self): + v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + def sum_gather(): + return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) - grad_fn = backprop.implicit_grad(sum_gather) - gradient = grad_fn() - defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) - defun_gradient = defun_grad_fn() - self.assertEqual(len(gradient), len(defun_gradient)) + grad_fn = backprop.implicit_grad(sum_gather) + gradient = grad_fn() + defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) + defun_gradient = defun_grad_fn() + self.assertEqual(len(gradient), len(defun_gradient)) - gradient = gradient[0][0] - defun_gradient = defun_gradient[0][0] - self.assertAllEqual(gradient.values, defun_gradient.values) - self.assertAllEqual(gradient.indices, defun_gradient.indices) - self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) + gradient = gradient[0][0] + defun_gradient = defun_gradient[0][0] + self.assertAllEqual(gradient.values, defun_gradient.values) + self.assertAllEqual(gradient.indices, defun_gradient.indices) + self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) def testReturningIndexedSlicesWithDefun(self): @@ -462,6 +543,30 @@ class FunctionTest(test.TestCase): y = f(x, x).cpu() self.assertAllEqual(y, [2.]) + @test_util.run_in_graph_and_eager_modes + def testFunctionWithResourcesOnDifferentDevices(self): + # TODO(akshayka): Remove the `skipTest` once we can whitelist ops as + # safe to be invoked with resources on different devices. + self.skipTest('The Placer disallows ops with resource inputs ' + 'on different devices.') + + with ops.device('/cpu:0'): + v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + with ops.device('/gpu:0'): + v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + def sum_gather(): + cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) + gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) + return cpu_result, gpu_result + + defined = function.defun(sum_gather) + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + expected = self.evaluate(sum_gather()) + self.assertAllEqual(expected, self.evaluate(defined())) + def testFunctionHandlesInputsOnDifferentDevices(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -495,6 +600,60 @@ class FunctionTest(test.TestCase): g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0)) self.assertAllEqual(g[0], 1.) + @function.defun + def foo(a): + return None, a * a + + x = constant_op.constant(5.0) + with backprop.GradientTape() as tp: + tp.watch(x) + none, r = foo(x) + g = tp.gradient(r, x) + + self.assertIs(none, None) + self.assertAllEqual(r, 25.0) + self.assertAllEqual(g, 2 * 5.0) + + def testNestedDifferentiableFunction(self): + @function.defun + def inner_fn(a, b): + return a * math_ops.add(a, b) + + @function.defun + def outer_fn(x): + return inner_fn(x, 1.0) + + x = constant_op.constant(5.0) + with backprop.GradientTape() as tp: + tp.watch(x) + result = outer_fn(x) + grad = tp.gradient(result, x) + + self.assertAllEqual(grad, 2 * 5.0 + 1.0) + + def testNestedDifferentiableFunctionNoneOutputs(self): + @function.defun + def foo(a, b): + return None, a * math_ops.add(a, b), None, 2*a + + @function.defun + def bar(x): + return foo(x, 1.0) + + x = constant_op.constant(5.0) + with backprop.GradientTape(persistent=True) as tp: + tp.watch(x) + none1, r1, none2, r2 = bar(x) + g1 = tp.gradient(r1, x) + g2 = tp.gradient(r2, x) + + self.assertAllEqual(r1, 30.0) + self.assertAllEqual(r2, 10.0) + self.assertIs(none1, None) + self.assertIs(none2, None) + self.assertAllEqual(g1, 2 * 5.0 + 1.0) + self.assertAllEqual(g2, 2.0) + def testNoneOutput(self): @function.defun @@ -517,15 +676,15 @@ class FunctionTest(test.TestCase): self.assertAllEqual(3, add_one(constant_op.constant(2))) def testVariableCaptureInNestedFunctions(self): - v = resource_variable_ops.ResourceVariable(1) + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) @function.defun - def read(): + def inner_read(): return v.read_value() @function.defun def outer(): - return read() + return inner_read() self.assertEqual(1, int(outer())) @@ -616,6 +775,146 @@ class FunctionTest(test.TestCase): y = model(x) self.assertAllEqual([[[[4.0]]]], y.numpy()) + @test_util.run_in_graph_and_eager_modes( + config=config_pb2.ConfigProto(device_count={'CPU': 3})) + def testDeviceAnnotationsRespected(self): + @function.defun + def multi_device_fn(): + with ops.device('/cpu:0'): + s1 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + with ops.device('/cpu:1'): + s2 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + with ops.device('/cpu:2'): + s3 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + return s1, s2, s3 + + outputs = multi_device_fn() + self.assertTrue(compat.as_bytes('CPU:0') in self.evaluate(outputs[0])) + self.assertTrue(compat.as_bytes('CPU:1') in self.evaluate(outputs[1])) + self.assertTrue(compat.as_bytes('CPU:2') in self.evaluate(outputs[2])) + + def testVariablesAreTracked(self): + v = resource_variable_ops.ResourceVariable(1.0) + + def foo(x): + return v * x + + defined = function.defun(foo) + + x = constant_op.constant([1.0]) + self.assertAllEqual(defined.variables, []) + _ = defined(x) + self.assertAllEqual(defined.variables, [v]) + + x = constant_op.constant([1.0, 2.0]) + _ = defined(x) # ensure the variables list remains the same + self.assertAllEqual(defined.variables, [v]) + + def testTensorKeywordArguments(self): + + def foo(a, b): + del a + return b + + defined = function.defun(foo) + a = constant_op.constant(2.0) + b = constant_op.constant([1.0, 2.0]) + one = defined(a, b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + two = defined(a=a, b=b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + three = defined(b=b, a=a) + self.assertEqual(len(defined._arguments_to_functions), 1) + + four = defined(a, b=b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + # The next call corresponds to a new input signature, hence + # we expect another function to be defined. + five = defined(b, a) + self.assertEqual(len(defined._arguments_to_functions), 2) + + six = defined(a=b, b=a) + self.assertEqual(len(defined._arguments_to_functions), 2) + + seven = defined(b=a, a=b) + self.assertEqual(len(defined._arguments_to_functions), 2) + + self.assertAllEqual(one, [1.0, 2.0]) + self.assertAllEqual(two, [1.0, 2.0]) + self.assertAllEqual(three, [1.0, 2.0]) + self.assertAllEqual(four, [1.0, 2.0]) + self.assertAllEqual(five, 2.0) + self.assertAllEqual(six, 2.0) + self.assertAllEqual(seven, 2.0) + + def testGradientWithKeywordArguments(self): + matmul = function.defun(math_ops.matmul) + + def sq(x): + return matmul(a=x, b=x, transpose_a=True) + + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + grad_t, = backprop.gradients_function(sq, [0])(t) + self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) + + with backprop.GradientTape(persistent=True) as gtape: + gtape.watch(t) + one = matmul(t, b=t, transpose_a=True) + two = matmul(b=t, a=t, transpose_a=True) + three = matmul(a=t, b=t, transpose_a=True) + + for output in [one, two, three]: + self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]]) + + def testGradientInFunctionWithKeywordArguments(self): + + @function.defun + def f(x): + return backprop.gradients_function(lambda y: y * y, [0])(x)[0] + + self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0) + + def testDecoratingInstanceMethod(self): + + class Foo(object): + + def one(self, tensor): + return tensor + + @function.defun + def two(self, tensor): + return self.one(tensor) + + foo = Foo() + t = constant_op.constant(1.0) + out = foo.two(t) + self.assertEqual(float(out), 1.0) + + def testPythonCallWithSideEffects(self): + state = [] + + @function.defun + def side_effecting_function(): + state.append(0) + + side_effecting_function() + self.assertAllEqual(state, [0]) + + # The second invocation should call the graph function, which shouldn't + # trigger the list append. + side_effecting_function() + self.assertAllEqual(state, [0]) + + # Whereas calling the python function directly should create a side-effect. + side_effecting_function.call_python_function() + self.assertAllEqual(state, [0, 0]) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): @@ -832,4 +1131,6 @@ class AutomaticControlDependenciesTest(test.TestCase): if __name__ == '__main__': + ops.enable_eager_execution( + config=config_pb2.ConfigProto(device_count={'CPU': 3})) test.main() diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index d9ffcbd2036b9e312967012597ceea22e607d2a7..848adf4fd3b2c93e7b5afb3ec2911857663c29bb 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -110,13 +110,25 @@ class _VariableCapturingScope(object): """ # TODO(apassos) ignoring the regularizer and partitioner here; figure out # how to deal with these. - def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name - partitioner=None, validate_shape=True, - use_resource=None): + def _custom_getter( # pylint: disable=missing-docstring + getter=None, + name=None, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=True, + collections=None, + caching_device=None, # pylint: disable=redefined-outer-name + partitioner=None, + validate_shape=True, + use_resource=None, + aggregation=variable_scope.VariableAggregation.NONE, + synchronization=variable_scope.VariableSynchronization.AUTO): del getter, regularizer, partitioner, validate_shape, use_resource, dtype - del collections, initializer, trainable, reuse, caching_device, shape, + del collections, initializer, trainable, reuse, caching_device, shape + del aggregation, synchronization assert name in self.variables v = self.variables[name] return v.variable @@ -136,13 +148,24 @@ class _VariableCapturingScope(object): """ # TODO(apassos) ignoring the regularizer and partitioner here; figure out # how to deal with these. - def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name - partitioner=None, validate_shape=True, - use_resource=None): + def _custom_getter( # pylint: disable=missing-docstring + getter=None, + name=None, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=True, + collections=None, + caching_device=None, # pylint: disable=redefined-outer-name + partitioner=None, + validate_shape=True, + use_resource=None, + aggregation=variable_scope.VariableAggregation.NONE, + synchronization=variable_scope.VariableSynchronization.AUTO): del getter, regularizer, collections, caching_device, partitioner - del use_resource, validate_shape + del use_resource, validate_shape, aggregation, synchronization if name in self.tf_variables: if reuse: return self.tf_variables[name].initialized_value() @@ -202,7 +225,7 @@ class _InitializingFunctionObject(object): v.handle).numpy() for v in self._call_fn.variables] if all(x for x in initialized): for v in self._call_fn.variables: - if v._trainable: # pylint: disable=protected-access + if v.trainable: tape.watch_variable(v) return self._call_fn(*args) elif all(not x for x in initialized): diff --git a/tensorflow/python/eager/memory_test.py b/tensorflow/python/eager/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..74c6cbdd319a3a0476adbff08fc6e70fee65df5c --- /dev/null +++ b/tensorflow/python/eager/memory_test.py @@ -0,0 +1,108 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for memory leaks in eager execution. + +It is possible that this test suite will eventually become flaky due to taking +too long to run (since the tests iterate many times), but for now they are +helpful for finding memory leaks since not all PyObject leaks are found by +introspection (test_util decorators). Please be careful adding new tests here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import keras +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + +# memory_profiler might not be available in the OSS version of TensorFlow. +try: + import memory_profiler # pylint:disable=g-import-not-at-top +except ImportError: + memory_profiler = None + + +class SingleLayerNet(keras.Model): + """Simple keras model used to ensure that there are no leaks.""" + + def __init__(self): + super(SingleLayerNet, self).__init__() + self.fc1 = keras.layers.Dense(5) + + def call(self, x): + return self.fc1(x) + + +class MemoryTest(test.TestCase): + + def assertNotIncreasingMemory(self, + f, + num_iters=100000, + increase_threshold_absolute_mb=10): + """Assert memory usage doesn't increase beyond given threshold for f.""" + + with context.eager_mode(): + # Warm up. + f() + + initial = memory_profiler.memory_usage(-1)[0] + + for _ in xrange(num_iters): + f() + + increase = memory_profiler.memory_usage(-1)[0] - initial + + assert increase < increase_threshold_absolute_mb, ( + "Increase is too high. Initial memory usage: %f MB. Increase: %f MB. " + "Maximum allowed increase: %f") % (initial, increase, + increase_threshold_absolute_mb) + + def testMemoryLeakInSimpleModelForwardOnly(self): + if memory_profiler is None: + self.skipTest("memory_profiler required to run this test") + + inputs = array_ops.zeros([32, 100], dtypes.float32) + net = SingleLayerNet() + + def f(): + with backprop.GradientTape(): + net(inputs) + + self.assertNotIncreasingMemory(f) + + def testMemoryLeakInSimpleModelForwardAndBackward(self): + if memory_profiler is None: + self.skipTest("memory_profiler required to run this test") + + inputs = array_ops.zeros([32, 100], dtypes.float32) + net = SingleLayerNet() + + def f(): + with backprop.GradientTape() as tape: + result = net(inputs) + + tape.gradient(result, net.variables) + + del tape + + self.assertNotIncreasingMemory(f) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index b3aadd55ce7805f008b22d2b0f88cddc82e7da7a..ea604647faede0e5b86a17938d0a7c8a7621dec1 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -27,8 +27,15 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" +// forward declare +struct EagerTensor; + namespace { +// An instance of _EagerTensorProfiler that will receive callbacks about +// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all. +PyObject* eager_tensor_profiler = nullptr; + TFE_Context* GetContext(PyObject* ctx) { TFE_Context* context = reinterpret_cast(PyCapsule_GetPointer(ctx, nullptr)); @@ -253,8 +260,45 @@ typedef struct EagerTensor { // to use a TF_Status object. However note that accesses to `status` are not // thread-safe. TF_Status* status; + + PyObject* weakreflist; /* List of weak references */ } EagerTensor; +namespace { + +// Returns true on success - successfully invoked or no profiler registered. +// Returns false if some error occurred. +bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) { + if (eager_tensor_profiler != nullptr) { +#if PY_MAJOR_VERSION < 3 + PyObject* created_method_name = PyString_InternFromString("created"); +#else + PyObject* created_method_name = PyUnicode_InternFromString("created"); +#endif + if (created_method_name == nullptr) { + return false; + } + PyObject* result = PyObject_CallMethodObjArgs( + eager_tensor_profiler, created_method_name, created_tensor, NULL); + if (result == nullptr) { + LOG(ERROR) << "Invoking created() on EagerTensor profiler failed"; + // While we can potentially continue because the error is related to + // profiling, we choose to return an error because: + // - If profiling is used, the user likely wants to stop execution on + // profiling errors. + // - Error in profiling code might have left some state in an invalid + // form that can lead to an error later on. Better to fail fast. + Py_DECREF(created_method_name); + return false; + } + Py_DECREF(created_method_name); + Py_DECREF(result); + } + return true; +} + +} // namespace + // tp_init for EagerTensor. int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { self->id = get_uid(); @@ -266,6 +310,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { Py_INCREF(Py_None); self->tensor_shape = Py_None; self->status = TF_NewStatus(); + self->weakreflist = nullptr; PyObject* value; PyObject* context = nullptr; PyObject* device = nullptr; @@ -299,7 +344,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { GetContext(context), handle.get(), handle_dtype, static_cast(desired_dtype), self->status)); if (TF_GetCode(self->status) != TF_OK) { - PyErr_SetString(PyExc_ValueError, + PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( "Error while casting from DataType ", handle_dtype, " to ", desired_dtype, ". ", TF_Message(self->status)) @@ -344,11 +389,22 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { if (handle == nullptr) return -1; } self->handle = handle.release(); + + if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) { + return -1; + } + return 0; } // tp_dealloc for EagerTensor. void EagerTensor_dealloc(EagerTensor* self) { + // Clear weak references to self. + // Needs to happen before any actual destruction. + if (self->weakreflist != nullptr) { + PyObject_ClearWeakRefs((PyObject*)self); + } + TF_DeleteStatus(self->status); Py_DECREF(self->handle_data); Py_DECREF(self->keras_mask); @@ -574,43 +630,43 @@ static PyTypeObject _EagerTensorType = { // clang-format off PyVarObject_HEAD_INIT(nullptr, 0) // clang-format on - "EagerTensor", /* tp_name */ - sizeof(EagerTensor), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)EagerTensor_dealloc, /* tp_dealloc */ - nullptr, /* tp_print */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_compare */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - EagerTensor_methods, /* tp_methods */ - nullptr, /* tp_members */ - EagerTensor_getseters, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)EagerTensor_init, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ + "EagerTensor", /* tp_name */ + sizeof(EagerTensor), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)EagerTensor_dealloc, /* tp_dealloc */ + nullptr, /* tp_print */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_compare */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + EagerTensor_methods, /* tp_methods */ + nullptr, /* tp_members */ + EagerTensor_getseters, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)EagerTensor_init, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ }; #endif @@ -641,6 +697,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { t->tensor_shape = Py_None; t->handle = handle; t->status = TF_NewStatus(); + t->weakreflist = nullptr; + + if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) { + return nullptr; + } } return reinterpret_cast(t); } @@ -720,6 +781,18 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { return reinterpret_cast(EagerTensorType); } +PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) { + Py_XDECREF(eager_tensor_profiler); + + if (profiler == Py_None) { + eager_tensor_profiler = nullptr; + } else { + eager_tensor_profiler = profiler; + Py_INCREF(eager_tensor_profiler); + } + Py_RETURN_NONE; +} + PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) { if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) { PyErr_SetString(PyExc_TypeError, @@ -792,3 +865,37 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) { return EagerTensorFromHandle(handle); } + +PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) { + if (!EagerTensor_CheckExact(tensor)) { + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat("Expected an EagerTensors but got type \"", + Py_TYPE(tensor)->tp_name, "\"") + .c_str()); + return nullptr; + } + TFE_TensorHandle* handle = EagerTensor_Handle(tensor); + + auto status = tensorflow::make_safe(TF_NewStatus()); + TFE_TensorDebugInfo* debug_info = + TFE_TensorHandleTensorDebugInfo(handle, status.get()); + if (TF_GetCode(status.get()) != TF_OK) { + PyErr_SetString( + PyExc_RuntimeError, + tensorflow::strings::StrCat("Error retrieving tensor's device shape: ", + TF_Message(status.get())) + .c_str()); + return nullptr; + } + + int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info); + PyObject* shape = PyTuple_New(rank); + for (int i = 0; i < rank; ++i) { + tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i); + PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size)); + } + TFE_DeleteTensorDebugInfo(debug_info); + + return shape; +} diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 691b613e48b217c595fe0f3249c493facf756d47..a916a75f00cafc077c422cc6aee6828d07e6188d 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -16,10 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ #define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ +#include + #include "tensorflow/c/eager/c_api.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include typedef tensorflow::gtl::InlinedVector TFE_InputTensorHandles; @@ -66,14 +67,15 @@ PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e); // This function is not thread-safe. PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e); -// Registers e as the backward_function_getter. -// The registered function creates a backward function (a function that can -// return the gradient of the inputs an op given the gradient of it's outputs). -// The registered function will be passed the following arguments: -// op_name, attrs, num_inputs, op_inputs, op_outputs +// Registers e as the gradient_function. +// The registered function takes +// (op_name, attrs, num_inputs, inputs, outputs, output_gradients) and returns +// the input gradients. This function will not correctly be able to generate +// gradients for functional ops - the gradients for those ops are calculated +// through a different codepath (see function.py for additional information). // // This function is not thread-safe. -PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e); +PyObject* TFE_Py_RegisterGradientFunction(PyObject* e); // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using // `exception` if not nullptr, else using the class registered via @@ -113,6 +115,15 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); // newly created type, or nullptr on error. PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); +// Sets `profiler` as the current profiler to receive callbacks about events +// on eager tensors. Currently, the only reported event is creation. +// `profiler` is expected to have a `created(self, eager_tensor)` method that +// takes the created tensor as its single argument. +// Previous profiler, if any, is unset and will not receive any more +// callbacks. +// To unset the profiler, pass Py_None as the value of `profiler`. +PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler); + // Creates a new tape and adds it to the active set. `persistent` must be a // PyBool_Type, i.e either Py_True or Py_False PyObject* TFE_Py_TapeSetNew(PyObject* persistent); @@ -120,6 +131,9 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent); // Removes the passed tape from the set of active tapes. void TFE_Py_TapeSetRemove(PyObject* tape); +// Adds the passed tape to the set of active tapes. +void TFE_Py_TapeSetAdd(PyObject* tape); + // Returns true if the tape stack is empty. PyObject* TFE_Py_TapeSetIsEmpty(); @@ -183,7 +197,8 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, PyObject* results, PyObject* name); -// Returns the set of variables watched by the given tape. +// Returns all variables watched by the given tape in the order those variables +// were created. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); // Returns an EagerTensor of dimension [len(`tensors`)] containing @@ -198,4 +213,8 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); // tensors in `tensors`. PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim); +// Returns the shape of this tensor's on-device representation. +// The shape is represented as a Python tuple of integers. +PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor); + #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 48a5b21dc7fba9f73775cf87d0ba1ce0c7c03def..57b4dab51cc766042dfa895b197b3e3de037269d 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/python/eager/pywrap_tensor.h" #include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/util/util.h" using tensorflow::string; using tensorflow::strings::Printf; @@ -45,12 +46,14 @@ struct InputInfo { bool is_list = false; }; +// Takes in output gradients, returns input gradients. +typedef std::function PyBackwardFunction; + using AttrToInputsMap = tensorflow::gtl::FlatMap>; -tensorflow::mutex all_attr_to_input_maps_lock( - tensorflow::LINKER_INITIALIZED); +tensorflow::mutex all_attr_to_input_maps_lock(tensorflow::LINKER_INITIALIZED); tensorflow::gtl::FlatMap* GetAllAttrToInputsMaps() { static auto* all_attr_to_input_maps = new tensorflow::gtl::FlatMap; @@ -174,6 +177,8 @@ bool IsInteger(PyObject* py_value) { #endif } +// This function considers a Dimension._value of None to be valid, and sets the +// value to be -1 in that case. bool ParseDimensionValue(const string& key, PyObject* py_value, TF_Status* status, int64_t* value) { if (IsInteger(py_value)) { @@ -191,18 +196,29 @@ bool ParseDimensionValue(const string& key, PyObject* py_value, return false; } + if (dimension_value.get() == Py_None) { + *value = -1; + return true; + } + return ParseInt64Value(key, dimension_value.get(), status, value); } bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, - const char** value) { + tensorflow::StringPiece* value) { if (PyBytes_Check(py_value)) { - *value = PyBytes_AsString(py_value); + Py_ssize_t size = 0; + char* buf = nullptr; + if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; + *value = tensorflow::StringPiece(buf, size); return true; } #if PY_MAJOR_VERSION >= 3 if (PyUnicode_Check(py_value)) { - *value = PyUnicode_AsUTF8(py_value); + Py_ssize_t size = 0; + char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); + if (buf == nullptr) return false; + *value = tensorflow::StringPiece(buf, size); return true; } #endif @@ -265,8 +281,16 @@ bool SetOpAttrList( } if (type == TF_ATTR_STRING) { - PARSE_LIST(const char*, ParseStringValue); - TFE_OpSetAttrStringList(op, key, values.get(), num_values); + std::unique_ptr values(new const void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); + for (int i = 0; i < num_values; ++i) { + tensorflow::StringPiece value; + tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); + if (!ParseStringValue(key, py_value.get(), status, &value)) return false; + values[i] = value.data(); + lengths[i] = value.size(); + } + TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); } else if (type == TF_ATTR_INT) { PARSE_LIST(int64_t, ParseInt64Value); TFE_OpSetAttrIntList(op, key, values.get(), num_values); @@ -369,12 +393,15 @@ void SetOpAttrListDefault( TF_Status* status) { if (type == TF_ATTR_STRING) { int num_values = attr.default_value().list().s_size(); - std::unique_ptr values(new const char*[num_values]); + std::unique_ptr values(new const void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); (*attr_list_sizes)[key] = num_values; for (int i = 0; i < num_values; i++) { - values[i] = attr.default_value().list().s(i).data(); + const string& v = attr.default_value().list().s(i); + values[i] = v.data(); + lengths[i] = v.size(); } - TFE_OpSetAttrStringList(op, key, values.get(), num_values); + TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); } else if (type == TF_ATTR_INT) { int num_values = attr.default_value().list().i_size(); std::unique_ptr values(new int64_t[num_values]); @@ -460,9 +487,9 @@ bool SetOpAttrScalar( tensorflow::gtl::FlatMap* attr_list_sizes, TF_Status* status) { if (type == TF_ATTR_STRING) { - const char* value; + tensorflow::StringPiece value; if (!ParseStringValue(key, py_value, status, &value)) return false; - TFE_OpSetAttrString(op, key, value); + TFE_OpSetAttrString(op, key, value.data(), value.size()); } else if (type == TF_ATTR_INT) { int64_t value; if (!ParseInt64Value(key, py_value, status, &value)) return false; @@ -523,7 +550,7 @@ bool SetOpAttrScalar( // (which is what the various "defun" or "Defun" decorators do). // And in the future also allow an object that can encapsulate // the function name and its attribute values. - const char* func_name = nullptr; + tensorflow::StringPiece func_name; if (!ParseStringValue(key, py_value, status, &func_name)) { PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); if (name_attr == nullptr || @@ -539,7 +566,8 @@ bool SetOpAttrScalar( return false; } } - TFE_Op* func = TFE_NewOp(ctx, func_name, status); + TFE_Op* func = TFE_NewOp( + ctx, string(func_name.data(), func_name.size()).c_str(), status); if (TF_GetCode(status) != TF_OK) return false; TFE_OpSetAttrFunction(op, key, func); TFE_DeleteOp(func); @@ -634,8 +662,8 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; // Python subclass of Exception that is created to signal fallback. PyObject* fallback_exception_class = nullptr; -// Python function that returns a backward_function. -PyObject* backward_function_getter = nullptr; +// Python function that returns input gradients given output gradients. +PyObject* gradient_function = nullptr; PyTypeObject* resource_variable_type = nullptr; @@ -728,26 +756,26 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { } } -PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) { - if (backward_function_getter != nullptr) { - Py_DECREF(backward_function_getter); +PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) { + if (gradient_function != nullptr) { + Py_DECREF(gradient_function); } if (!PyCallable_Check(e)) { - backward_function_getter = nullptr; + gradient_function = nullptr; PyErr_SetString(PyExc_TypeError, "TFE_Py_RegisterBackwardFunctionGetter: " "Registered object should be function."); return nullptr; } else { Py_INCREF(e); - backward_function_getter = e; + gradient_function = e; Py_RETURN_NONE; } } void RaiseFallbackException(const char* message) { if (fallback_exception_class != nullptr) { - PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message)); + PyErr_SetString(fallback_exception_class, message); return; } @@ -765,8 +793,9 @@ int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { if (exception == nullptr) { tensorflow::mutex_lock l(exception_class_mutex); if (exception_class != nullptr) { - PyErr_SetObject(exception_class, - Py_BuildValue("si", msg, TF_GetCode(status))); + tensorflow::Safe_PyObjectPtr val( + Py_BuildValue("si", msg, TF_GetCode(status))); + PyErr_SetObject(exception_class, val.get()); return -1; } else { exception = PyExc_RuntimeError; @@ -784,7 +813,8 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, if (exception == nullptr) { tensorflow::mutex_lock l(exception_class_mutex); if (exception_class != nullptr) { - PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code())); + tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code())); + PyErr_SetObject(exception_class, val.get()); return -1; } else { exception = PyExc_RuntimeError; @@ -862,41 +892,70 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { } class GradientTape - : public tensorflow::eager::GradientTape { + : public tensorflow::eager::GradientTape { public: explicit GradientTape(bool persistent) - : tensorflow::eager::GradientTape(persistent) {} + : tensorflow::eager::GradientTape( + persistent) {} virtual ~GradientTape() { - for (PyObject* v : watched_variables_) { - Py_DECREF(v); + for (const IdAndVariable& v : watched_variables_) { + Py_DECREF(v.variable); } } void WatchVariable(PyObject* v) { - auto insert_result = watched_variables_.insert(v); - if (insert_result.second) { - // Only increment the reference count if we aren't already watching this - // variable. - Py_INCREF(v); - } - PyObject* handle = PyObject_GetAttrString(v, "handle"); + tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); if (handle == nullptr) { return; } - tensorflow::int64 id = FastTensorId(handle); - Py_DECREF(handle); + tensorflow::int64 id = FastTensorId(handle.get()); + if (!PyErr_Occurred()) { this->Watch(id); } + + tensorflow::mutex_lock l(watched_variables_mu_); + auto insert_result = watched_variables_.emplace(id, v); + + if (insert_result.second) { + // Only increment the reference count if we aren't already watching this + // variable. + Py_INCREF(v); + } } - const std::unordered_set WatchedVariables() { - return watched_variables_; + PyObject* GetVariablesAsPyTuple() { + tensorflow::mutex_lock l(watched_variables_mu_); + PyObject* result = PyTuple_New(watched_variables_.size()); + Py_ssize_t pos = 0; + for (const IdAndVariable& id_and_variable : watched_variables_) { + PyTuple_SET_ITEM(result, pos++, id_and_variable.variable); + Py_INCREF(id_and_variable.variable); + } + return result; } private: - std::unordered_set watched_variables_; + // We store an IdAndVariable in the map since the map needs to be locked + // during insert, but should not call back into python during insert to avoid + // deadlocking with the GIL. + struct IdAndVariable { + tensorflow::int64 id; + PyObject* variable; + + IdAndVariable(tensorflow::int64 id, PyObject* variable) + : id(id), variable(variable) {} + }; + struct CompareById { + bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { + return lhs.id < rhs.id; + } + }; + + tensorflow::mutex watched_variables_mu_; + std::set watched_variables_ + GUARDED_BY(watched_variables_mu_); }; typedef struct { @@ -1009,6 +1068,14 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { return reinterpret_cast(tape); } +void TFE_Py_TapeSetAdd(PyObject* tape) { + Py_INCREF(tape); + if (!GetTapeSet()->insert(reinterpret_cast(tape)).second) { + // Already exists in the tape set. + Py_DECREF(tape); + } +} + PyObject* TFE_Py_TapeSetIsEmpty() { if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { Py_RETURN_TRUE; @@ -1180,13 +1247,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) { } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { - const std::unordered_set& watched_variables = - reinterpret_cast(tape)->tape->WatchedVariables(); - PyObject* result = PySet_New(nullptr); - for (PyObject* variable : watched_variables) { - PySet_Add(result, variable); - } - return result; + return reinterpret_cast(tape)->tape->GetVariablesAsPyTuple(); } namespace { @@ -1210,11 +1271,13 @@ void TapeSetRecordOperation( PyObject* op_type, PyObject* output_tensors, const std::vector& input_ids, const std::vector& input_dtypes, - PyObject* backward_function) { + const std::function& backward_function_getter, + const std::function& backward_function_killer) { std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); int len = PySequence_Size(output_tensors); + if (PyErr_Occurred()) return; output_info.reserve(len); for (int i = 0; i < len; ++i) { output_info.push_back( @@ -1243,10 +1306,10 @@ void TapeSetRecordOperation( } for (TFE_Py_Tape* tape : SafeTapeSet()) { - Py_INCREF(backward_function); - tape->tape->RecordOperation( - op_type_str, output_info, input_ids, input_dtypes, backward_function, - [backward_function]() { Py_DECREF(backward_function); }); + auto* function = backward_function_getter(); + tape->tape->RecordOperation(op_type_str, output_info, input_ids, + input_dtypes, function, + backward_function_killer); } } } // namespace @@ -1263,8 +1326,21 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, std::vector input_dtypes = MakeTensorDtypeList(input_tensors); if (PyErr_Occurred()) return; - TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes, - backward_function); + + TapeSetRecordOperation( + op_type, output_tensors, input_ids, input_dtypes, + [backward_function]() { + Py_INCREF(backward_function); + PyBackwardFunction* function = + new PyBackwardFunction([backward_function](PyObject* out_grads) { + return PyObject_CallObject(backward_function, out_grads); + }); + return function; + }, + [backward_function](PyBackwardFunction* py_backward_function) { + Py_DECREF(backward_function); + delete py_backward_function; + }); } void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { @@ -1273,7 +1349,8 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { } } -class PyVSpace : public tensorflow::eager::VSpace { +class PyVSpace + : public tensorflow::eager::VSpace { public: explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {} @@ -1333,6 +1410,8 @@ class PyVSpace : public tensorflow::eager::VSpace { return result; } + void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } + PyObject* Zeros(tensorflow::TensorShape shape, tensorflow::DataType dtype) const final { PyObject* py_shape = PyTuple_New(shape.dims()); @@ -1364,7 +1443,7 @@ class PyVSpace : public tensorflow::eager::VSpace { } tensorflow::Status CallBackwardFunction( - PyObject* backward_function, + PyBackwardFunction* backward_function, tensorflow::gtl::ArraySlice output_gradients, std::vector* result) const final { PyObject* grads = PyTuple_New(output_gradients.size()); @@ -1377,8 +1456,7 @@ class PyVSpace : public tensorflow::eager::VSpace { reinterpret_cast(output_gradients[i])); } } - PyObject* py_result = PyEval_CallObject( - reinterpret_cast(backward_function), grads); + PyObject* py_result = (*backward_function)(grads); Py_DECREF(grads); if (py_result == nullptr) { return tensorflow::errors::Internal("gradient function threw exceptions"); @@ -1407,10 +1485,6 @@ class PyVSpace : public tensorflow::eager::VSpace { return tensorflow::Status::OK(); } - void ReleaseBackwardFunction(PyObject* backward_function) const final { - Py_DECREF(backward_function); - } - void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } private: @@ -1569,12 +1643,12 @@ bool CheckInputsOk(PyObject* seq, int start_index, for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) { PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j); if (!CheckOneInput(inner_item)) { - VLOG(1) - << "Falling back to slow path for Op \"" << op_def.name() - << "\", Input \"" << op_def.input_arg(i).name() << "\", Index " - << j - << " since we expected an EagerTensor/ResourceVariable, but got " - << inner_item->ob_type->tp_name; + VLOG(1) << "Falling back to slow path for Op \"" << op_def.name() + << "\", Input \"" << op_def.input_arg(i).name() + << "\", Index " << j + << " since we expected an EagerTensor/ResourceVariable, " + "but got " + << inner_item->ob_type->tp_name; return false; } } @@ -1781,28 +1855,53 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, } PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); - PyObject* callback_args = - Py_BuildValue("OOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs); - PyObject* backward_function = - PyObject_CallObject(backward_function_getter, callback_args); - Py_DECREF(callback_args); - if (backward_function == nullptr) return nullptr; - - TapeSetRecordOperation(op_name, results, input_ids, input_dtypes, - backward_function); + TapeSetRecordOperation( + op_name, results, input_ids, input_dtypes, + [op_name, attrs, num_inputs, op_inputs, op_outputs]() { + Py_INCREF(op_name); + Py_INCREF(attrs); + Py_INCREF(num_inputs); + Py_INCREF(op_inputs); + Py_INCREF(op_outputs); + PyBackwardFunction* function = + new PyBackwardFunction([op_name, attrs, num_inputs, op_inputs, + op_outputs](PyObject* output_grads) { + tensorflow::Safe_PyObjectPtr callback_args( + Py_BuildValue("OOOOOO", op_name, attrs, num_inputs, op_inputs, + op_outputs, output_grads)); + + tensorflow::Safe_PyObjectPtr result( + PyObject_CallObject(gradient_function, callback_args.get())); + + if (PyErr_Occurred()) return static_cast(nullptr); + + return tensorflow::swig::Flatten(result.get()); + }); + return function; + }, + [op_name, attrs, num_inputs, op_inputs, + op_outputs](PyBackwardFunction* backward_function) { + Py_DECREF(op_name); + Py_DECREF(attrs); + Py_DECREF(num_inputs); + Py_DECREF(op_inputs); + Py_DECREF(op_outputs); + + delete backward_function; + }); - Py_DECREF(backward_function); + Py_DECREF(num_inputs); Py_RETURN_NONE; } void MaybeWatchVariable(PyObject* input) { DCHECK(CheckResourceVariable(input)); - DCHECK(PyObject_HasAttrString(input, "_trainable")); + DCHECK(PyObject_HasAttrString(input, "trainable")); tensorflow::Safe_PyObjectPtr trainable( - PyObject_GetAttrString(input, "_trainable")); + PyObject_GetAttrString(input, "trainable")); if (trainable.get() == Py_False) return; TFE_Py_TapeSetWatchVariable(input); } @@ -1852,8 +1951,10 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, Py_INCREF(output->get()); // stay alive after since tuple steals. PyTuple_SET_ITEM(outputs.get(), 0, output->get()); - if (!RecordGradient(GetPythonObjectFromString("ReadVariableOp"), - inputs.get(), Py_None, outputs.get(), Py_None)) { + tensorflow::Safe_PyObjectPtr op_string( + GetPythonObjectFromString("ReadVariableOp")); + if (!RecordGradient(op_string.get(), inputs.get(), Py_None, outputs.get(), + Py_None)) { return false; } } @@ -1863,8 +1964,8 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, // Supports only 2 cases at the moment: // i) input is an EagerTensor -// ii) input is a ResourceVariable - in this case, the is_variable param is set -// to true. +// ii) input is a ResourceVariable - in this case, the is_variable param is +// set to true. // // NOTE: dtype_hint_getter must *always* return a PyObject that can be // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index ad82266beca05d9f508a702124390fd934161ffd..caa217b70cabfdc3fdec3528ea1e7ca553072fbe 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -39,6 +39,11 @@ def push_new_tape(persistent=False): return Tape(tape) +def push_tape(tape): + """Pushes an existing tape onto the tape stack.""" + pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access + + def watch(tensor): """Marks this tensor to be watched by all tapes in the stack. diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index b044b30231603b0265aa1ef0320e9f1cfb303724..626a4eb1eee9bda6c910c9dfa9cfff27b04444c1 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -292,6 +292,11 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): def testUnicode(self): self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf") + def testFloatTensor(self): + self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype) + self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype) + self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype) + def testSliceDimOutOfRange(self): t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32) t2 = _create_tensor([1, 2], dtype=dtypes.int32) diff --git a/tensorflow/python/eager/test.py b/tensorflow/python/eager/test.py index f6a46e7eb3d03982f07bf4162d94c6038217bf61..33ee797678ed73c52ebb17723f688cec4feca402 100644 --- a/tensorflow/python/eager/test.py +++ b/tensorflow/python/eager/test.py @@ -23,6 +23,7 @@ from tensorflow.python.platform import test as _test from tensorflow.python.platform.test import * # pylint: disable=wildcard-import +# TODO(akshayka): Do away with this file. def main(argv=None): _ops.enable_eager_execution() _test.main(argv) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 0754041f9eb50b429d02a06f9f0357c3431d3df5..8ee38d35cc152e6c281e83d7fd49540ddaee2a7e 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -1,8 +1,4 @@ -package( - default_visibility = [ - "//tensorflow:internal", - ], -) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 @@ -10,8 +6,15 @@ load("//tensorflow:tensorflow.bzl", "py_test") py_library( name = "estimator_py", - srcs = ["estimator_lib.py"], + srcs = [ + "__init__.py", + "estimator_lib.py", + ], srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow:internal", + ], deps = [ ":baseline", ":boosted_trees", @@ -27,7 +30,7 @@ py_library( ":parsing_utils", ":run_config", ":training", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -37,10 +40,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gc", - "//tensorflow/python:errors", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:util", ], @@ -54,10 +54,7 @@ py_test( deps = [ ":estimator", ":exporter", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -66,8 +63,7 @@ py_library( srcs = ["gc.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -78,10 +74,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":gc", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -91,12 +84,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":export_output", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -109,12 +97,7 @@ py_test( deps = [ ":export_output", ":model_fn", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -126,11 +109,7 @@ py_library( ":estimator", ":exporter", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -149,13 +128,7 @@ py_test( ":inputs", ":run_config", ":training", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -164,7 +137,7 @@ py_library( srcs = ["run_config.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -176,8 +149,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -190,14 +162,7 @@ py_library( ":head", ":model_fn", ":optimizers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -221,26 +186,7 @@ py_test( ":numpy_io", ":pandas_io", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -253,20 +199,7 @@ py_library( ":estimator", ":head", ":model_fn", - "//tensorflow/python:array_ops", - "//tensorflow/python:boosted_trees_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -275,21 +208,13 @@ py_test( size = "medium", srcs = ["canned/boosted_trees_test.py"], srcs_version = "PY2AND3", + tags = [ + "optonly", + ], deps = [ ":boosted_trees", - "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:resources", - "//tensorflow/python:training", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", + ":inputs", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -302,14 +227,7 @@ py_library( ":head", ":model_fn", ":optimizers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -326,22 +244,7 @@ py_library( ":model_fn", ":numpy_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -364,16 +267,7 @@ py_test( ":numpy_io", ":pandas_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -389,19 +283,7 @@ py_library( ":linear", ":model_fn", ":optimizers", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -424,17 +306,7 @@ py_test( ":numpy_io", ":pandas_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:nn", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -446,7 +318,20 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/67510291 + deps = [ + ":util", + "//tensorflow:tensorflow_py_no_contrib", + "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -461,21 +346,7 @@ py_library( ":model_fn", ":run_config", ":util", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:constants", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -494,29 +365,7 @@ py_test( ":model_fn", ":numpy_io", ":run_config", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:lib", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:saver_test_utils", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -529,9 +378,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -542,10 +389,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":parsing_utils", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -554,9 +398,7 @@ py_library( srcs = ["export/export_output.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -568,13 +410,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":export_output", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -587,7 +423,7 @@ py_library( deps = [ ":export_export", ":export_output", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -598,13 +434,8 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", + ":util", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -617,17 +448,8 @@ py_test( deps = [ ":export_export", ":export_output", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", + ":util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -640,24 +462,7 @@ py_library( ":metric_keys", ":model_fn", ":prediction_keys", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:weights_broadcast_ops", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -676,22 +481,7 @@ py_test( ":model_fn", ":numpy_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -704,7 +494,7 @@ py_library( deps = [ ":numpy_io", ":pandas_io", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -716,11 +506,7 @@ py_library( ":estimator", ":head", ":optimizers", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -738,25 +524,7 @@ py_library( ":numpy_io", ":pandas_io", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -774,7 +542,7 @@ py_test( deps = [ ":linear", ":linear_testing_utils", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -803,9 +571,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":numpy_io", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -814,7 +580,7 @@ py_library( srcs = ["canned/optimizers.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -826,8 +592,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":optimizers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -845,9 +610,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":pandas_io", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -867,15 +630,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -889,7 +644,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":inputs_queues", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -900,10 +655,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":inputs_queues", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -916,32 +668,7 @@ py_library( ":export_export", ":model_fn", ":run_config", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_util", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", - "//tensorflow/python/keras:layers", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -950,21 +677,47 @@ py_test( size = "large", srcs = ["keras_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = [ + "no_windows", + "notsan", + ], deps = [ ":keras", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:run_config", - "//tensorflow/python/keras", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", "//third_party/py/numpy", ], ) + +py_library( + name = "expect_numpy_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect numpy to already be installed on the system, e.g. via + # `pip install numpy` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_pandas_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect pandas to already be installed on the system, e.g. via + # `pip install pandas` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_six_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect six to already be installed on the system, e.g. via + # `pip install six` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_tensorflow_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect tensorflow to already be installed on the system, e.g. via + # `pip install tensorflow` or `pip install tensorflow_gpu` + visibility = ["//visibility:public"], +) diff --git a/tensorflow/python/estimator/__init__.py b/tensorflow/python/estimator/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..8cf8df567f0e36604b5c3f6fe992b572d6632954 100644 --- a/tensorflow/python/estimator/__init__.py +++ b/tensorflow/python/estimator/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Import Estimator APIs. + +Note: This file is imported by the create_estimator_api genrule. It must +transitively import all Estimator modules/packages for their @estimator_export +annotations to generate the public Estimator python API. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.estimator.estimator_lib diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aa5a29e6dd148c39ebb098cb99cb1907d9c5a9d9 --- /dev/null +++ b/tensorflow/python/estimator/api/BUILD @@ -0,0 +1,18 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files") +load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") + +gen_api_init_files( + name = "estimator_python_api_gen", + api_name = "estimator", + output_files = ESTIMATOR_API_INIT_FILES, + package = "tensorflow.python.estimator", + package_dep = "//tensorflow/python/estimator:estimator_py", +) diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py index 980c0573726945bcc80863319da98a220c86bd91..20c7a69b7cb071365e5442b512c1a858a7e0b246 100644 --- a/tensorflow/python/estimator/canned/baseline.py +++ b/tensorflow/python/estimator/canned/baseline.py @@ -24,10 +24,10 @@ Example: classifier = BaselineClassifier(n_classes=3) # Input builders -def input_fn_train: # returns x, y (where y represents label's class index). +def input_fn_train(): # returns x, y (where y represents label's class index). pass -def input_fn_eval: # returns x, y (where y represents label's class index). +def input_fn_eval(): # returns x, y (where y represents label's class index). pass # Fit model. @@ -59,7 +59,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rate of 0.3 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -174,7 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer, train_op_fn=train_op_fn) -@tf_export('estimator.BaselineClassifier') +@estimator_export('estimator.BaselineClassifier') class BaselineClassifier(estimator.Estimator): """A classifier that can establish a simple baseline. @@ -215,6 +215,13 @@ class BaselineClassifier(estimator.Estimator): * if `weight_column` is not `None`, a feature with `key=weight_column` whose value is a `Tensor`. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility """ def __init__(self, @@ -277,7 +284,7 @@ class BaselineClassifier(estimator.Estimator): config=config) -@tf_export('estimator.BaselineRegressor') +@estimator_export('estimator.BaselineRegressor') class BaselineRegressor(estimator.Estimator): """A regressor that can establish a simple baseline. @@ -313,6 +320,13 @@ class BaselineRegressor(estimator.Estimator): * if `weight_column` is not `None`, a feature with `key=weight_column` whose value is a `Tensor`. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility """ def __init__(self, diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 6e4a19f0befff187f44f98bec4fba10b48eb367c..a22e9745c1929a29394add8ade835b2aa5fbd13b 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import functools from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn @@ -39,17 +40,18 @@ from tensorflow.python.summary import summary from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # TODO(nponomareva): Reveal pruning params here. _TreeHParams = collections.namedtuple('TreeHParams', [ 'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity', - 'min_node_weight' + 'min_node_weight', 'center_bias' ]) _HOLD_FOR_MULTI_CLASS_SUPPORT = object() _HOLD_FOR_MULTI_DIM_SUPPORT = object() _DUMMY_NUM_BUCKETS = -1 +_DUMMY_NODE_ID = -1 def _get_transformed_features(features, sorted_feature_columns): @@ -96,14 +98,18 @@ def _get_transformed_features(features, sorted_feature_columns): return result_features -def _local_variable(tensor, name=None): +def _local_variable(initial_value, name=None): """Stores a tensor as a local Variable for faster read.""" - return variable_scope.variable( - initial_value=tensor, + result = variable_scope.variable( + initial_value=initial_value, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], validate_shape=False, name=name) + if isinstance(initial_value, ops.Tensor): + # Match the resulting variable's shape if the initial_value is a Tensor. + result.set_shape(initial_value.shape) + return result def _group_features_by_num_buckets(sorted_feature_columns): @@ -164,9 +170,10 @@ def _group_features_by_num_buckets(sorted_feature_columns): # pylint:enable=protected-access # Replace the dummy key with the real max num of buckets for all bucketized # columns. - bucket_size_to_feature_ids_dict[ - max_buckets_for_bucketized] = bucket_size_to_feature_ids_dict[ - _DUMMY_NUM_BUCKETS] + if max_buckets_for_bucketized not in bucket_size_to_feature_ids_dict: + bucket_size_to_feature_ids_dict[max_buckets_for_bucketized] = [] + bucket_size_to_feature_ids_dict[max_buckets_for_bucketized].extend( + bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS]) del bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS] feature_ids_list = list(bucket_size_to_feature_ids_dict.values()) @@ -264,20 +271,28 @@ class _CacheTrainingStatesUsingHashTable(object): # bitcast the ids to int32. self._table_ref = lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3]) - self._example_ids = example_ids + self._example_ids = ops.convert_to_tensor(example_ids) + if self._example_ids.shape.ndims not in (None, 1): + raise ValueError('example_id should have rank 1, but got %s' % + self._example_ids) self._logits_dimension = logits_dimension def lookup(self): """Returns cached_tree_ids, cached_node_ids, cached_logits.""" cached_tree_ids, cached_node_ids, cached_logits = array_ops.split( lookup_ops.lookup_table_find_v2( - self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]), + self._table_ref, + self._example_ids, + default_value=[0.0, _DUMMY_NODE_ID, 0.0]), [1, 1, self._logits_dimension], axis=1) cached_tree_ids = array_ops.squeeze( array_ops.bitcast(cached_tree_ids, dtypes.int32)) cached_node_ids = array_ops.squeeze( array_ops.bitcast(cached_node_ids, dtypes.int32)) + if self._example_ids.shape.ndims is not None: + cached_logits.set_shape( + [self._example_ids.shape[0], self._logits_dimension]) return (cached_tree_ids, cached_node_ids, cached_logits) def insert(self, tree_ids, node_ids, logits): @@ -319,7 +334,7 @@ class _CacheTrainingStatesUsingVariables(object): array_ops.zeros([batch_size], dtype=dtypes.int32), name='tree_ids_cache') self._node_ids = _local_variable( - array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32), name='node_ids_cache') self._logits = _local_variable( array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32), @@ -414,8 +429,8 @@ def _bt_model_fn( ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas <= 1) - sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) + center_bias = tree_hparams.center_bias if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' @@ -458,6 +473,9 @@ def _bt_model_fn( # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) + # Variable that determines whether bias centering is needed. + center_bias_var = variable_scope.variable( + initial_value=center_bias, name='center_bias_needed', trainable=False) # Create logits. if mode != model_fn.ModeKeys.TRAIN: logits = boosted_trees_ops.predict( @@ -478,6 +496,7 @@ def _bt_model_fn( # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) + if training_state_cache: cached_tree_ids, cached_node_ids, cached_logits = ( training_state_cache.lookup()) @@ -486,9 +505,10 @@ def _bt_model_fn( batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), - array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), array_ops.zeros( [batch_size, head.logits_dimension], dtype=dtypes.float32)) + with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range) = local_tree_ensemble.get_states() @@ -502,13 +522,20 @@ def _bt_model_fn( cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) + logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" if training_state_cache: - train_op.append(training_state_cache.insert(tree_ids, node_ids, logits)) + # Cache logits only after center_bias is complete, if it's in progress. + train_op.append( + control_flow_ops.cond( + center_bias_var, control_flow_ops.no_op, + lambda: training_state_cache.insert(tree_ids, node_ids, logits)) + ) + if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn(logits, labels) else: @@ -532,8 +559,7 @@ def _bt_model_fn( ] stats_summaries_list.append(summaries) - accumulators = [] - + # ========= Helper methods for both in and not in memory. ============== def grow_tree_from_stats_summaries(stats_summaries_list, feature_ids_list): """Updates ensemble based on the best gains from stats summaries.""" @@ -580,55 +606,122 @@ def _bt_model_fn( pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) return grow_op + def _center_bias_fn(mean_gradients, mean_hessians): + """Updates the ensembles and cache (if needed) with logits prior.""" + continue_centering = boosted_trees_ops.center_bias( + tree_ensemble.resource_handle, + mean_gradients=mean_gradients, + mean_hessians=mean_hessians, + l1=tree_hparams.l1, + l2=tree_hparams.l2 + ) + return center_bias_var.assign(continue_centering) + + # ========= End of helper methods. ============== + if train_in_memory and is_single_machine: train_op.append(distribute_lib.increment_var(global_step)) + + mean_gradients = array_ops.expand_dims( + math_ops.reduce_mean(gradients, 0), 0) + mean_heassians = array_ops.expand_dims( + math_ops.reduce_mean(hessians, 0), 0) + train_op.append( - grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list)) + control_flow_ops.cond( + center_bias_var, + lambda: _center_bias_fn(mean_gradients, mean_heassians), + functools.partial(grow_tree_from_stats_summaries, + stats_summaries_list, feature_ids_list))) else: - dependencies = [] - for i, feature_ids in enumerate(feature_ids_list): - stats_summaries = stats_summaries_list[i] - accumulator = data_flow_ops.ConditionalAccumulator( + def center_bias_not_in_mem(): + """Accumulates the data and updates the logits bias, when ready.""" + bias_dependencies = [] + + bias_accumulator = data_flow_ops.ConditionalAccumulator( dtype=dtypes.float32, - # The stats consist of grads and hessians (the last dimension). - shape=[len(feature_ids), max_splits, bucket_size_list[i], 2], - shared_name='numeric_stats_summary_accumulator_' + str(i)) - accumulators.append(accumulator) - - apply_grad = accumulator.apply_grad( - array_ops.stack(stats_summaries, axis=0), stamp_token) - dependencies.append(apply_grad) - - def grow_tree_from_accumulated_summaries_fn(): - """Updates the tree with the best layer from accumulated summaries.""" - # Take out the accumulated summaries from the accumulator and grow. - stats_summaries_list = [] - - stats_summaries_list = [ - array_ops.unstack(accumulator.take_grad(1), axis=0) - for accumulator in accumulators - ] - - grow_op = grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list) - return grow_op - - with ops.control_dependencies(dependencies): - train_op.append(distribute_lib.increment_var(global_step)) - if config.is_chief: - min_accumulated = math_ops.reduce_min( - array_ops.stack( - [acc.num_accumulated() for acc in accumulators])) - - train_op.append( - control_flow_ops.cond( - math_ops.greater_equal(min_accumulated, - n_batches_per_layer), - grow_tree_from_accumulated_summaries_fn, - control_flow_ops.no_op, - name='wait_until_n_batches_accumulated')) + # The stats consist of grads and hessians means only. + # TODO(nponomareva): this will change for a multiclass + shape=[2, 1], + shared_name='bias_accumulator') + + grads_and_hess = array_ops.stack([gradients, hessians], axis=0) + grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1) + + apply_grad = bias_accumulator.apply_grad(grads_and_hess, stamp_token) + bias_dependencies.append(apply_grad) + + def center_bias_from_accumulator(): + accumulated = array_ops.unstack( + bias_accumulator.take_grad(1), axis=0) + return _center_bias_fn( + array_ops.expand_dims(accumulated[0], 0), + array_ops.expand_dims(accumulated[1], 0)) + + with ops.control_dependencies(bias_dependencies): + if config.is_chief: + center_bias_op = control_flow_ops.cond( + math_ops.greater_equal(bias_accumulator.num_accumulated(), + n_batches_per_layer), + center_bias_from_accumulator, + control_flow_ops.no_op, + name='wait_until_n_batches_for_bias_accumulated') + + return center_bias_op + + def grow_not_in_mem(): + """Accumulates the data and grows a layer when ready.""" + + accumulators = [] + dependencies = [] + for i, feature_ids in enumerate(feature_ids_list): + stats_summaries = stats_summaries_list[i] + accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians (the last dimension). + shape=[len(feature_ids), max_splits, bucket_size_list[i], 2], + shared_name='numeric_stats_summary_accumulator_' + str(i)) + accumulators.append(accumulator) + + apply_grad = accumulator.apply_grad( + array_ops.stack(stats_summaries, axis=0), stamp_token) + dependencies.append(apply_grad) + + def grow_tree_from_accumulated_summaries_fn(): + """Updates tree with the best layer from accumulated summaries.""" + # Take out the accumulated summaries from the accumulator and grow. + stats_summaries_list = [] + + stats_summaries_list = [ + array_ops.unstack(accumulator.take_grad(1), axis=0) + for accumulator in accumulators + ] + + grow_op = grow_tree_from_stats_summaries(stats_summaries_list, + feature_ids_list) + return grow_op + + with ops.control_dependencies(dependencies): + if config.is_chief: + min_accumulated = math_ops.reduce_min( + array_ops.stack( + [acc.num_accumulated() for acc in accumulators])) + + grow_model = control_flow_ops.cond( + math_ops.greater_equal(min_accumulated, n_batches_per_layer), + grow_tree_from_accumulated_summaries_fn, + control_flow_ops.no_op, + name='wait_until_n_batches_accumulated') + + return grow_model + + update_model = control_flow_ops.cond( + center_bias_var, center_bias_not_in_mem, grow_not_in_mem) + train_op.append(update_model) + with ops.control_dependencies([update_model]): + increment_global = distribute_lib.increment_var(global_step) + train_op.append(increment_global) return control_flow_ops.group(train_op, name='train_op') @@ -668,14 +761,18 @@ def _create_classification_head_and_closed_form(n_classes, weight_column, label_vocabulary): """Creates a head for classifier and the closed form gradients/hessians.""" head = _create_classification_head(n_classes, weight_column, label_vocabulary) - if n_classes == 2 and weight_column is None and label_vocabulary is None: + if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None + and label_vocabulary is None): # Use the closed-form gradients/hessians for 2 class. def _grad_and_hess_for_logloss(logits, labels): + """A closed form gradient and hessian for logistic loss.""" # TODO(youngheek): add weights handling. predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0) normalizer = math_ops.reciprocal( math_ops.cast(array_ops.size(predictions), dtypes.float32)) labels = math_ops.cast(labels, dtypes.float32) + labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint: disable=protected-access + labels, logits, head.logits_dimension) gradients = (predictions - labels) * normalizer hessians = predictions * (1.0 - predictions) * normalizer return gradients, hessians @@ -698,9 +795,17 @@ def _create_regression_head(label_dimension, weight_column=None): # pylint: enable=protected-access -@tf_export('estimator.BoostedTreesClassifier') +@estimator_export('estimator.BoostedTreesClassifier') class BoostedTreesClassifier(estimator.Estimator): - """A Classifier for Tensorflow Boosted Trees models.""" + """A Classifier for Tensorflow Boosted Trees models. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility + """ def __init__(self, feature_columns, @@ -716,7 +821,8 @@ class BoostedTreesClassifier(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False): """Initializes a `BoostedTreesClassifier` instance. Example: @@ -784,6 +890,13 @@ class BoostedTreesClassifier(estimator.Estimator): split to be considered. The value will be compared with sum(leaf_hessian)/(batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + Raises: ValueError: when wrong arguments are given or unsupported functionalities @@ -798,7 +911,7 @@ class BoostedTreesClassifier(estimator.Estimator): # HParams for the model. tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return _bt_model_fn( # pylint: disable=protected-access @@ -816,9 +929,17 @@ class BoostedTreesClassifier(estimator.Estimator): model_fn=_model_fn, model_dir=model_dir, config=config) -@tf_export('estimator.BoostedTreesRegressor') +@estimator_export('estimator.BoostedTreesRegressor') class BoostedTreesRegressor(estimator.Estimator): - """A Regressor for Tensorflow Boosted Trees models.""" + """A Regressor for Tensorflow Boosted Trees models. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility + """ def __init__(self, feature_columns, @@ -833,7 +954,8 @@ class BoostedTreesRegressor(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False): """Initializes a `BoostedTreesRegressor` instance. Example: @@ -894,6 +1016,12 @@ class BoostedTreesRegressor(estimator.Estimator): split to be considered. The value will be compared with sum(leaf_hessian)/(batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. Raises: ValueError: when wrong arguments are given or unsupported functionalities @@ -907,7 +1035,7 @@ class BoostedTreesRegressor(estimator.Estimator): # HParams for the model. tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return _bt_model_fn( # pylint: disable=protected-access diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 13595d4c83566950a94a42889570494b7d53c784..f807641057990971407f69ff0ba4d3513302e452 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import session_run_hook NUM_FEATURES = 3 @@ -60,7 +61,7 @@ def _make_train_input_fn(is_classification): """Makes train input_fn for classification/regression.""" def _input_fn(): - features_dict = dict(FEATURES_DICT) + features_dict = dict(FEATURES_DICT) # copies the dict to add an entry. features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS) labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS return features_dict, labels @@ -72,7 +73,7 @@ def _make_train_input_fn_dataset(is_classification, batch=None, repeat=None): """Makes input_fn using Dataset.""" def _input_fn(): - features_dict = dict(FEATURES_DICT) + features_dict = dict(FEATURES_DICT) # copies the dict to add an entry. features_dict[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS) labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS if batch: @@ -121,6 +122,39 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): return ensemble_proto + def testFirstCheckpointWorksFine(self): + """Tests that eval/pred doesn't crash with the very first checkpoint. + + The step-0 checkpoint will have only an empty ensemble, and a separate eval + job might read from it and crash. + This test ensures that prediction/evaluation works fine with it. + """ + input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + + class BailOutWithoutTraining(session_run_hook.SessionRunHook): + + def before_run(self, run_context): + raise StopIteration('to bail out.') + + est.train(input_fn, steps=100, # must stop at 0 anyway. + hooks=[BailOutWithoutTraining()]) + self._assert_checkpoint( + est.model_dir, global_step=0, finalized_trees=0, attempted_layers=0) + # Empty ensemble returns 0 logits, so that all output labels are 0. + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 0.6) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [0], [0], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testTrainAndEvaluateBinaryClassifier(self): input_fn = _make_train_input_fn(is_classification=True) @@ -160,6 +194,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testTrainClassifierWithRankOneLabel(self): + """Tests that label with rank-1 tensor is also accepted by classifier.""" + def _input_fn_with_rank_one_label(): + return FEATURES_DICT, [0., 1., 1., 0., 0.] + + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + + # It will stop after 5 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(_input_fn_with_rank_one_label, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + eval_res = est.evaluate(input_fn=_input_fn_with_rank_one_label, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + def testTrainClassifierWithLabelVocabulary(self): apple, banana = 'apple', 'banana' def _input_fn_with_label_vocab(): @@ -262,6 +316,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], [pred['predictions'] for pred in predictions]) + def testTrainRegressorWithRankOneLabel(self): + """Tests that label with rank-1 tensor is also accepted by regressor.""" + def _input_fn_with_rank_one_label(): + return FEATURES_DICT, [1.5, 0.3, 0.2, 2., 5.] + + est = boosted_trees.BoostedTreesRegressor( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + + # It will stop after 5 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(_input_fn_with_rank_one_label, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + eval_res = est.evaluate(input_fn=_input_fn_with_rank_one_label, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + def testTrainRegressorWithDataset(self): train_input_fn = _make_train_input_fn_dataset(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( @@ -426,6 +500,50 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertEqual(2, ensemble.trees[0].nodes[0].bucketized_split.feature_id) self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + def testTrainEvaluateAndPredictWithOnlyIndicatorColumn(self): + categorical = feature_column.categorical_column_with_vocabulary_list( + key='categorical', vocabulary_list=('bad', 'good', 'ok')) + feature_indicator = feature_column.indicator_column(categorical) + + labels = np.array([[0.], [5.7], [5.7], [0.], [0.]], dtype=np.float32) + # Our categorical feature defines the labels perfectly + input_fn = numpy_io.numpy_input_fn( + x={ + 'categorical': np.array(['bad', 'good', 'good', 'ok', 'bad']), + }, + y=labels, + batch_size=5, + shuffle=False) + + # Train depth 1 tree. + est = boosted_trees.BoostedTreesRegressor( + feature_columns=[feature_indicator], + n_batches_per_layer=1, + n_trees=1, + learning_rate=1.0, + max_depth=1) + + num_steps = 1 + est.train(input_fn, steps=num_steps) + ensemble = self._assert_checkpoint_and_return_model( + est.model_dir, global_step=1, finalized_trees=1, attempted_layers=1) + + # We learnt perfectly. + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['loss'], 0) + + predictions = list(est.predict(input_fn)) + self.assertAllClose( + labels, + [pred['predictions'] for pred in predictions]) + + self.assertEqual(3, len(ensemble.trees[0].nodes)) + + # Check that the split happened on 'good' value, which will be encoded as + # feature with index 1 (0 - 'bad', 2 - 'ok') + self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id) + self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + class ModelFnTests(test_util.TensorFlowTestCase): """Tests bt_model_fn including unexposed internal functionalities.""" @@ -436,14 +554,6 @@ class ModelFnTests(test_util.TensorFlowTestCase): feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), BUCKET_BOUNDARIES) for i in range(NUM_FEATURES) } - self._tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access - n_trees=2, - max_depth=2, - learning_rate=0.1, - l1=0., - l2=0.01, - tree_complexity=0., - min_node_weight=0.) def _get_expected_ensembles_for_classification(self): first_round = """ @@ -672,28 +782,43 @@ class ModelFnTests(test_util.TensorFlowTestCase): """ return (first_round, second_round, third_round) - def _get_expected_ensembles_for_regression(self): + def _get_expected_ensembles_for_classification_with_bias(self): first_round = """ + trees { + nodes { + leaf { + scalar: -0.405086 + } + } + } + tree_weights: 1.0 + tree_metadata { + } + """ + second_round = """ trees { nodes { bucketized_split { - feature_id: 1 - threshold: 1 + feature_id: 2 + threshold: 2 left_id: 1 right_id: 2 } metadata { - gain: 1.169714 + gain: 0.407711 + original_leaf { + scalar: -0.405086 + } } } nodes { leaf { - scalar: 0.241322 + scalar: -0.556054 } } nodes { leaf { - scalar: 0.083951 + scalar: -0.301233 } } } @@ -709,30 +834,32 @@ class ModelFnTests(test_util.TensorFlowTestCase): last_layer_node_end: 3 } """ - second_round = """ + third_round = """ trees { nodes { bucketized_split { - feature_id: 1 - threshold: 1 + feature_id: 2 + threshold: 2 left_id: 1 right_id: 2 } metadata { - gain: 1.169714 + gain: 0.407711 + original_leaf { + scalar: -0.405086 + } } } nodes { bucketized_split { feature_id: 0 - threshold: 1 + threshold: 3 left_id: 3 right_id: 4 } metadata { - gain: 2.673407 original_leaf { - scalar: 0.241322 + scalar: -0.556054 } } } @@ -744,37 +871,36 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 6 } metadata { - gain: 0.324102 + gain: 0.09876 original_leaf { - scalar: 0.083951 + scalar: -0.301233 } } } nodes { leaf { - scalar: 0.563167 + scalar: -0.698072 } } nodes { leaf { - scalar: 0.247047 + scalar: -0.556054 } } nodes { leaf { - scalar: 0.095273 + scalar: -0.106016 } } nodes { leaf { - scalar: 0.222102 + scalar: -0.27349 } } } trees { nodes { leaf { - scalar: 0.0 } } } @@ -785,98 +911,95 @@ class ModelFnTests(test_util.TensorFlowTestCase): is_finalized: true } tree_metadata { - num_layers_grown: 0 - is_finalized: false } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 - last_layer_node_start: 0 last_layer_node_end: 1 } """ - third_round = """ + forth_round = """ trees { nodes { bucketized_split { - feature_id: 1 - threshold: 1 + feature_id: 2 + threshold: 2 left_id: 1 right_id: 2 } metadata { - gain: 1.169714 + gain: 0.4077113 + original_leaf { + scalar: -0.405086 + } } } nodes { bucketized_split { - feature_id: 0 - threshold: 1 + threshold: 3 left_id: 3 right_id: 4 } metadata { - gain: 2.673407 original_leaf { - scalar: 0.241322 + scalar: -0.556054 } } } nodes { bucketized_split { - feature_id: 0 threshold: 0 left_id: 5 right_id: 6 } metadata { - gain: 0.324102 + gain: 0.09876 original_leaf { - scalar: 0.083951 + scalar: -0.301233 } } } nodes { leaf { - scalar: 0.563167 + scalar: -0.698072 } } nodes { leaf { - scalar: 0.247047 + scalar: -0.556054 } } nodes { leaf { - scalar: 0.095273 + scalar: -0.106016 } } nodes { leaf { - scalar: 0.222102 + scalar: -0.27349 } } } trees { nodes { bucketized_split { - feature_id: 1 - threshold: 0 + feature_id: 2 + threshold: 2 left_id: 1 right_id: 2 } metadata { - gain: 0.981026 + gain: 0.289927 } } nodes { leaf { - scalar: 0.005166 + scalar: -0.134588 } } nodes { leaf { - scalar: 0.180281 + scalar: 0.083838 } } } @@ -888,7 +1011,6 @@ class ModelFnTests(test_util.TensorFlowTestCase): } tree_metadata { num_layers_grown: 1 - is_finalized: false } growing_metadata { num_trees_attempted: 2 @@ -897,102 +1019,671 @@ class ModelFnTests(test_util.TensorFlowTestCase): last_layer_node_end: 3 } """ - return (first_round, second_round, third_round) - - def _get_train_op_and_ensemble(self, head, config, is_classification, - train_in_memory): - """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" - features, labels = _make_train_input_fn(is_classification)() - estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access - features=features, - labels=labels, - mode=model_fn.ModeKeys.TRAIN, - head=head, - feature_columns=self._feature_columns, - tree_hparams=self._tree_hparams, - example_id_column_name=EXAMPLE_ID_COLUMN, - n_batches_per_layer=1, - config=config, - train_in_memory=train_in_memory) - resources.initialize_resources(resources.shared_resources()).run() - variables.global_variables_initializer().run() - variables.local_variables_initializer().run() - - # Gets the train_op and serialized proto of the ensemble. - shared_resources = resources.shared_resources() - self.assertEqual(1, len(shared_resources)) - train_op = estimator_spec.train_op - with ops.control_dependencies([train_op]): - _, ensemble_serialized = ( - gen_boosted_trees_ops.boosted_trees_serialize_ensemble( - shared_resources[0].handle)) - return train_op, ensemble_serialized - - def testTrainClassifierInMemory(self): - ops.reset_default_graph() - expected_first, expected_second, expected_third = ( - self._get_expected_ensembles_for_classification()) - with self.test_session() as sess: - # Train with train_in_memory mode. - with sess.graph.as_default(): - train_op, ensemble_serialized = self._get_train_op_and_ensemble( - boosted_trees._create_classification_head(n_classes=2), - run_config.RunConfig(), - is_classification=True, - train_in_memory=True) - _, serialized = sess.run([train_op, ensemble_serialized]) - # Validate the trained ensemble. - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertProtoEquals(expected_first, ensemble_proto) - - # Run one more time and validate the trained ensemble. - _, serialized = sess.run([train_op, ensemble_serialized]) - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertProtoEquals(expected_second, ensemble_proto) - - # Third round training and validation. - _, serialized = sess.run([train_op, ensemble_serialized]) - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertProtoEquals(expected_third, ensemble_proto) - - def testTrainClassifierNonInMemory(self): - ops.reset_default_graph() - expected_first, expected_second, expected_third = ( - self._get_expected_ensembles_for_classification()) - with self.test_session() as sess: - # Train without train_in_memory mode. - with sess.graph.as_default(): - train_op, ensemble_serialized = self._get_train_op_and_ensemble( - boosted_trees._create_classification_head(n_classes=2), - run_config.RunConfig(), - is_classification=True, - train_in_memory=False) - _, serialized = sess.run([train_op, ensemble_serialized]) - # Validate the trained ensemble. - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertProtoEquals(expected_first, ensemble_proto) - - # Run one more time and validate the trained ensemble. - _, serialized = sess.run([train_op, ensemble_serialized]) - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertProtoEquals(expected_second, ensemble_proto) - - # Third round training and validation. - _, serialized = sess.run([train_op, ensemble_serialized]) - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertProtoEquals(expected_third, ensemble_proto) + return (first_round, second_round, third_round, forth_round) - def testTrainRegressorInMemory(self): - ops.reset_default_graph() - expected_first, expected_second, expected_third = ( - self._get_expected_ensembles_for_regression()) - with self.test_session() as sess: + def _get_expected_ensembles_for_regression(self): + first_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.169714 + } + } + nodes { + leaf { + scalar: 0.241322 + } + } + nodes { + leaf { + scalar: 0.083951 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + second_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.169714 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.673407 + original_leaf { + scalar: 0.241322 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.324102 + original_leaf { + scalar: 0.083951 + } + } + } + nodes { + leaf { + scalar: 0.563167 + } + } + nodes { + leaf { + scalar: 0.247047 + } + } + nodes { + leaf { + scalar: 0.095273 + } + } + nodes { + leaf { + scalar: 0.222102 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_start: 0 + last_layer_node_end: 1 + } + """ + third_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.169714 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.673407 + original_leaf { + scalar: 0.241322 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.324102 + original_leaf { + scalar: 0.083951 + } + } + } + nodes { + leaf { + scalar: 0.563167 + } + } + nodes { + leaf { + scalar: 0.247047 + } + } + nodes { + leaf { + scalar: 0.095273 + } + } + nodes { + leaf { + scalar: 0.222102 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.981026 + } + } + nodes { + leaf { + scalar: 0.005166 + } + } + nodes { + leaf { + scalar: 0.180281 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 3 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + return (first_round, second_round, third_round) + + def _get_expected_ensembles_for_regression_with_bias(self): + first_round = """ + trees { + nodes { + leaf { + scalar: 1.799974 + } + } + } + tree_weights: 1.0 + tree_metadata { + } + """ + second_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } + } + } + nodes { + leaf { + scalar: 1.862786 + } + } + nodes { + leaf { + scalar: 1.706149 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + third_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.683594 + original_leaf { + scalar: 1.862786 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.322693 + original_leaf { + scalar: 1.706149 + } + } + } + nodes { + leaf { + scalar: 2.024487 + } + } + nodes { + leaf { + scalar: 1.710319 + } + } + nodes { + leaf { + scalar: 1.559208 + } + } + nodes { + leaf { + scalar: 1.686037 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_start: 0 + last_layer_node_end: 1 + } + """ + forth_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } + } + } + nodes { + bucketized_split { + threshold: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.683594 + original_leaf { + scalar: 1.8627863 + } + } + } + nodes { + bucketized_split { + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.322693 + original_leaf { + scalar: 1.706149 + } + } + } + nodes { + leaf { + scalar: 2.024487 + } + } + nodes { + leaf { + scalar: 1.710319 + } + } + nodes { + leaf { + scalar: 1.5592078 + } + } + nodes { + leaf { + scalar: 1.686037 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.972589 + } + } + nodes { + leaf { + scalar: -0.137592 + } + } + nodes { + leaf { + scalar: 0.034926 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 3 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + return (first_round, second_round, third_round, forth_round) + + def _get_train_op_and_ensemble(self, + head, + config, + is_classification, + train_in_memory, + center_bias=False): + """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" + features, labels = _make_train_input_fn(is_classification)() + + tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access + n_trees=2, + max_depth=2, + learning_rate=0.1, + l1=0., + l2=0.01, + tree_complexity=0., + min_node_weight=0., + center_bias=center_bias) + + estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access + features=features, + labels=labels, + mode=model_fn.ModeKeys.TRAIN, + head=head, + feature_columns=self._feature_columns, + tree_hparams=tree_hparams, + example_id_column_name=EXAMPLE_ID_COLUMN, + n_batches_per_layer=1, + config=config, + train_in_memory=train_in_memory) + resources.initialize_resources(resources.shared_resources()).run() + variables.global_variables_initializer().run() + variables.local_variables_initializer().run() + + # Gets the train_op and serialized proto of the ensemble. + shared_resources = resources.shared_resources() + self.assertEqual(1, len(shared_resources)) + train_op = estimator_spec.train_op + with ops.control_dependencies([train_op]): + _, ensemble_serialized = ( + gen_boosted_trees_ops.boosted_trees_serialize_ensemble( + shared_resources[0].handle)) + return train_op, ensemble_serialized + + def testTrainClassifierInMemory(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third = ( + self._get_expected_ensembles_for_classification()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=True) + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + def testTrainClassifierWithCenterBiasInMemory(self): + ops.reset_default_graph() + + # When bias centering is on, we expect the very first node to have the + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_classification_with_bias()) + + with self.test_session() as sess: + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=True, + center_bias=True) + + # 4 iterations to center bias. + for _ in range(4): + _, serialized = sess.run([train_op, ensemble_serialized]) + + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + self.assertProtoEquals(expected_forth, ensemble_proto) + + def testTrainClassifierNonInMemory(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third = ( + self._get_expected_ensembles_for_classification()) + with self.test_session() as sess: + # Train without train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=False) + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + def testTrainClassifierWithCenterBiasNonInMemory(self): + ops.reset_default_graph() + + # When bias centering is on, we expect the very first node to have the + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_classification_with_bias()) + + with self.test_session() as sess: + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=False, + center_bias=True) + # 4 iterations to center bias. + for _ in range(4): + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + + def testTrainRegressorInMemory(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third = ( + self._get_expected_ensembles_for_regression()) + with self.test_session() as sess: # Train with train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( @@ -1018,6 +1709,46 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainRegressorInMemoryWithCenterBias(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_regression_with_bias()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_regression_head(label_dimension=1), + run_config.RunConfig(), + is_classification=False, + train_in_memory=True, + center_bias=True) + # 3 iterations to center bias. + for _ in range(3): + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainRegressorNonInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1048,6 +1779,46 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainRegressorNotInMemoryWithCenterBias(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_regression_with_bias()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_regression_head(label_dimension=1), + run_config.RunConfig(), + is_classification=False, + train_in_memory=False, + center_bias=True) + # 3 iterations to center the bias (because we are using regularization). + for _ in range(3): + _, serialized = sess.run([train_op, ensemble_serialized]) + + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 1feac36f356cc5b2615217b7ca69a79d2a781ca6..c08cf61220716730fa495c6e327b91e8f3c69cd5 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -26,13 +26,14 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rate of 0.05 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -45,7 +46,7 @@ def _add_hidden_layer_summary(value, tag): def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, - dropout, input_layer_partitioner): + dropout, input_layer_partitioner, batch_norm): """Function builder for a dnn logit_fn. Args: @@ -58,6 +59,7 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, dropout: When not `None`, the probability we will drop out a given coordinate. input_layer_partitioner: Partitioner for input layer. + batch_norm: Whether to use batch normalization after each hidden layer. Returns: A logit_fn (see below). @@ -83,6 +85,7 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, A `Tensor` representing the logits, or a list of `Tensor`'s representing multiple logits in the MultiHead case. """ + is_training = mode == model_fn.ModeKeys.TRAIN with variable_scope.variable_scope( 'input_from_feature_columns', values=tuple(six.itervalues(features)), @@ -98,8 +101,20 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, activation=activation_fn, kernel_initializer=init_ops.glorot_uniform_initializer(), name=hidden_layer_scope) - if dropout is not None and mode == model_fn.ModeKeys.TRAIN: + if dropout is not None and is_training: net = core_layers.dropout(net, rate=dropout, training=True) + if batch_norm: + # TODO(hjm): In future, if this becomes popular, we can enable + # customization of the batch normalization params by accepting a + # list of `BatchNormalization` instances as `batch_norm`. + net = normalization.batch_normalization( + net, + # The default momentum 0.99 actually crashes on certain + # problem, so here we use 0.999, which is the default of + # tf.contrib.layers.batch_norm. + momentum=0.999, + training=is_training, + name='batchnorm_%d' % layer_id) _add_hidden_layer_summary(net, hidden_layer_scope.name) with variable_scope.variable_scope('logits', values=(net,)) as logits_scope: @@ -127,7 +142,8 @@ def _dnn_model_fn(features, dropout=None, input_layer_partitioner=None, config=None, - tpu_estimator_spec=False): + tpu_estimator_spec=False, + batch_norm=False): """Deep Neural Net model_fn. Args: @@ -150,6 +166,7 @@ def _dnn_model_fn(features, config: `RunConfig` object to configure the runtime settings. tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or or `model_fn.EstimatorSpec` instance. + batch_norm: Whether to use batch normalization after each hidden layer. Returns: An `EstimatorSpec` instance. @@ -182,7 +199,8 @@ def _dnn_model_fn(features, feature_columns=feature_columns, activation_fn=activation_fn, dropout=dropout, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=batch_norm) logits = logit_fn(features=features, mode=mode) if tpu_estimator_spec: @@ -201,7 +219,7 @@ def _dnn_model_fn(features, logits=logits) -@tf_export('estimator.DNNClassifier') +@estimator_export('estimator.DNNClassifier') class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. @@ -230,6 +248,17 @@ class DNNClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -266,7 +295,10 @@ class DNNClassifier(estimator.Estimator): Loss is calculated by using softmax cross entropy. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -285,6 +317,7 @@ class DNNClassifier(estimator.Estimator): config=None, warm_start_from=None, loss_reduction=losses.Reduction.SUM, + batch_norm=False, ): """Initializes a `DNNClassifier` instance. @@ -314,8 +347,9 @@ class DNNClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -330,6 +364,7 @@ class DNNClassifier(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. """ head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access n_classes, weight_column, label_vocabulary, loss_reduction) @@ -346,14 +381,15 @@ class DNNClassifier(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, warm_start_from=warm_start_from) -@tf_export('estimator.DNNRegressor') +@estimator_export('estimator.DNNRegressor') class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. @@ -382,6 +418,17 @@ class DNNRegressor(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNRegressor( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNRegressor( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -418,7 +465,10 @@ class DNNRegressor(estimator.Estimator): Loss is calculated by using mean squared error. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -436,6 +486,7 @@ class DNNRegressor(estimator.Estimator): config=None, warm_start_from=None, loss_reduction=losses.Reduction.SUM, + batch_norm=False, ): """Initializes a `DNNRegressor` instance. @@ -459,8 +510,9 @@ class DNNRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -475,6 +527,7 @@ class DNNRegressor(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. """ def _model_fn(features, labels, mode, config): @@ -492,7 +545,8 @@ class DNNRegressor(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 95efc0a028bc90911106a8947dcfc199ddd29444..efa7812452427a6cdd7854b50b7d95a9a003abbb 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -37,7 +37,7 @@ from tensorflow.python.summary import summary from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rates are a historical artifact of the initial # implementation. @@ -88,7 +88,9 @@ def _dnn_linear_combined_model_fn(features, dnn_activation_fn=nn.relu, dnn_dropout=None, input_layer_partitioner=None, - config=None): + config=None, + batch_norm=False, + linear_sparse_combiner='sum'): """Deep Neural Net and Linear combined model_fn. Args: @@ -115,7 +117,10 @@ def _dnn_linear_combined_model_fn(features, coordinate. input_layer_partitioner: Partitioner for input layer. config: `RunConfig` object to configure the runtime settings. - + batch_norm: Whether to use batch normalization after each hidden layer. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum". Returns: An `EstimatorSpec` instance. @@ -164,7 +169,8 @@ def _dnn_linear_combined_model_fn(features, feature_columns=dnn_feature_columns, activation_fn=dnn_activation_fn, dropout=dnn_dropout, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=batch_norm) dnn_logits = dnn_logit_fn(features=features, mode=mode) linear_parent_scope = 'linear' @@ -182,7 +188,8 @@ def _dnn_linear_combined_model_fn(features, partitioner=input_layer_partitioner) as scope: logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access units=head.logits_dimension, - feature_columns=linear_feature_columns) + feature_columns=linear_feature_columns, + sparse_combiner=linear_sparse_combiner) linear_logits = logit_fn(features=features) _add_layer_summary(linear_logits, scope.name) @@ -225,7 +232,7 @@ def _dnn_linear_combined_model_fn(features, logits=logits) -@tf_export('estimator.DNNLinearCombinedClassifier') +@estimator_export('estimator.DNNLinearCombinedClassifier') class DNNLinearCombinedClassifier(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined classification models. @@ -257,12 +264,19 @@ class DNNLinearCombinedClassifier(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -292,7 +306,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): Loss is calculated by using softmax cross entropy. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -311,7 +328,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + batch_norm=False, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedClassifier instance. Args: @@ -322,12 +341,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -360,6 +383,12 @@ class DNNLinearCombinedClassifier(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see @{tf.feature_column.linear_model$linear_model}. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -399,14 +428,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, warm_start_from=warm_start_from) -@tf_export('estimator.DNNLinearCombinedRegressor') +@estimator_export('estimator.DNNLinearCombinedRegressor') class DNNLinearCombinedRegressor(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined models for regression. @@ -438,12 +469,19 @@ class DNNLinearCombinedRegressor(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -473,7 +511,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator): Loss is calculated by using mean squared error. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -491,7 +532,9 @@ class DNNLinearCombinedRegressor(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + batch_norm=False, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedRegressor instance. Args: @@ -502,12 +545,16 @@ class DNNLinearCombinedRegressor(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -534,6 +581,12 @@ class DNNLinearCombinedRegressor(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see @{tf.feature_column.linear_model$linear_model}. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -564,7 +617,9 @@ class DNNLinearCombinedRegressor(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py index d275695eb319117cf94aefd7038ab5ee685e05a9..d16318659ba8fac70486e88fff07d71e060eac9b 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py @@ -100,7 +100,8 @@ def _linear_regressor_fn(feature_columns, weight_column=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedRegressor( model_dir=model_dir, linear_feature_columns=feature_columns, @@ -108,7 +109,8 @@ def _linear_regressor_fn(feature_columns, label_dimension=label_dimension, weight_column=weight_column, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyRegressorPartitionerTest( @@ -163,7 +165,8 @@ def _linear_classifier_fn(feature_columns, label_vocabulary=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedClassifier( model_dir=model_dir, linear_feature_columns=feature_columns, @@ -172,7 +175,8 @@ def _linear_classifier_fn(feature_columns, weight_column=weight_column, label_vocabulary=label_vocabulary, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyClassifierTrainingTest( diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 06a648777f8f730b4c739a69528090c5821f2681..ba1782125905fd14ec9b89a29c891062824028f3 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -65,6 +65,11 @@ from tensorflow.python.training import training_util LEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate' HIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel' HIDDEN_BIASES_NAME_PATTERN = 'dnn/hiddenlayer_%d/bias' +BATCH_NORM_BETA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/beta' +BATCH_NORM_GAMMA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/gamma' +BATCH_NORM_MEAN_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/moving_mean' +BATCH_NORM_VARIANCE_NAME_PATTERN = ( + 'dnn/hiddenlayer_%d/batchnorm_%d/moving_variance') LOGITS_WEIGHTS_NAME = 'dnn/logits/kernel' LOGITS_BIASES_NAME = 'dnn/logits/bias' OCCUPATION_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/' @@ -89,7 +94,10 @@ def assert_close(expected, actual, rtol=1e-04, message='', name='assert_close'): name=scope) -def create_checkpoint(weights_and_biases, global_step, model_dir): +def create_checkpoint(weights_and_biases, + global_step, + model_dir, + batch_norm_vars=None): """Create checkpoint file with provided model weights. Args: @@ -98,12 +106,20 @@ def create_checkpoint(weights_and_biases, global_step, model_dir): model_dir: Directory into which checkpoint is saved. """ weights, biases = zip(*weights_and_biases) + if batch_norm_vars: + assert len(batch_norm_vars) == len(weights_and_biases) - 1 + (bn_betas, bn_gammas, bn_means, bn_variances) = zip(*batch_norm_vars) model_weights = {} # Hidden layer weights. for i in range(0, len(weights) - 1): model_weights[HIDDEN_WEIGHTS_NAME_PATTERN % i] = weights[i] model_weights[HIDDEN_BIASES_NAME_PATTERN % i] = biases[i] + if batch_norm_vars: + model_weights[BATCH_NORM_BETA_NAME_PATTERN % (i, i)] = bn_betas[i] + model_weights[BATCH_NORM_GAMMA_NAME_PATTERN % (i, i)] = bn_gammas[i] + model_weights[BATCH_NORM_MEAN_NAME_PATTERN % (i, i)] = bn_means[i] + model_weights[BATCH_NORM_VARIANCE_NAME_PATTERN % (i, i)] = bn_variances[i] # Output layer weights. model_weights[LOGITS_WEIGHTS_NAME] = weights[-1] @@ -503,8 +519,13 @@ class BaseDNNLogitFnTest(object): writer_cache.FileWriterCache.clear() shutil.rmtree(self._model_dir) - def _test_logits(self, mode, hidden_units, logits_dimension, inputs, - expected_logits): + def _test_logits(self, + mode, + hidden_units, + logits_dimension, + inputs, + expected_logits, + batch_norm=False): """Tests that the expected logits are calculated.""" with ops.Graph().as_default(): # Global step needed for MonitoredSession, which is in turn used to @@ -525,7 +546,8 @@ class BaseDNNLogitFnTest(object): ], activation_fn=nn.relu, dropout=None, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=batch_norm) logits = logit_fn( features={'age': constant_op.constant(inputs)}, mode=mode) with monitored_session.MonitoredTrainingSession( @@ -556,6 +578,69 @@ class BaseDNNLogitFnTest(object): inputs=[[10.]], expected_logits=[[-2.08]]) + def test_one_dim_logits_with_batch_norm(self): + """Tests one-dimensional logits. + + input_layer = [[10]] + hidden_layer_0 = [[relu(0.6*10 +1), relu(0.5*10 -1)]] = [[7, 4]] + hidden_layer_0 = [[relu(0.6*20 +1), relu(0.5*20 -1)]] = [[13, 9]] + + batch_norm_0, training (epsilon = 0.001): + mean1 = 1/2*(7+13) = 10, + variance1 = 1/2*(3^2+3^2) = 9 + x11 = (7-10)/sqrt(9+0.001) = -0.999944449, + x21 = (13-10)/sqrt(9+0.001) = 0.999944449, + + mean2 = 1/2*(4+9) = 6.5, + variance2 = 1/2*(2.5^2+.2.5^2) = 6.25 + x12 = (4-6.5)/sqrt(6.25+0.001) = -0.99992001, + x22 = (9-6.5)/sqrt(6.25+0.001) = 0.99992001, + + logits = [[-1*(-0.999944449) + 2*(-0.99992001) + 0.3], + [-1*0.999944449 + 2*0.99992001 + 0.3]] + = [[-0.699895571],[1.299895571]] + + batch_norm_0, not training (epsilon = 0.001): + moving_mean1 = 0, moving_variance1 = 1 + x11 = (7-0)/sqrt(1+0.001) = 6.996502623, + x21 = (13-0)/sqrt(1+0.001) = 12.993504871, + moving_mean2 = 0, moving_variance2 = 1 + x12 = (4-0)/sqrt(1+0.001) = 3.998001499, + x22 = (9-0)/sqrt(1+0.001) = 8.995503372, + + logits = [[-1*6.996502623 + 2*3.998001499 + 0.3], + [-1*12.993504871 + 2*8.995503372 + 0.3]] + = [[1.299500375],[5.297501873]] + """ + base_global_step = 100 + create_checkpoint( + ( + ([[.6, .5]], [1., -1.]), + ([[-1.], [2.]], [.3]), + ), + base_global_step, + self._model_dir, + batch_norm_vars=([[0, 0], # beta. + [1, 1], # gamma. + [0, 0], # moving mean. + [1, 1], # moving variance. + ],)) + self._test_logits( + model_fn.ModeKeys.TRAIN, + hidden_units=[2], + logits_dimension=1, + inputs=[[10.], [20.]], + expected_logits=[[-0.699895571], [1.299895571]], + batch_norm=True) + for mode in [model_fn.ModeKeys.EVAL, model_fn.ModeKeys.PREDICT]: + self._test_logits( + mode, + hidden_units=[2], + logits_dimension=1, + inputs=[[10.], [20.]], + expected_logits=[[1.299500375], [5.297501873]], + batch_norm=True) + def test_multi_dim_logits(self): """Tests multi-dimensional logits. @@ -706,7 +791,8 @@ class BaseDNNLogitFnTest(object): ], activation_fn=nn.relu, dropout=None, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=False) logits = logit_fn( features={ 'age': constant_op.constant(inputs[0]), diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 04fe4d97e40d60f7e5a5c9c2e9b40a08678f35d1..b74ef1015cc564c20370e17e94e3a09d460c4f85 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -873,6 +873,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1244,6 +1245,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1506,6 +1508,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1533,6 +1536,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): train_op=train_op) +def _append_update_ops(train_op): + """Returns `train_op` appending `UPDATE_OPS` collection if present.""" + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops: + return control_flow_ops.group(train_op, *update_ops) + return train_op + + def _assert_range(labels, n_classes, message=None): with ops.name_scope(None, 'assert_range', (labels,)): assert_less = check_ops.assert_less_equal( diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index ecca3e8b0d82864c5fda6b94cc75db0521d5e8d3..08ce5ca8e833fdd88f9c45b668f0914fcc70acd0 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -969,6 +970,35 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + n_classes = 3 + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32), + labels=np.array(((1,), (1,)), dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_summaries_with_head_name(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( @@ -2102,6 +2132,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertAllClose(expected_loss, loss) self.assertEqual(expected_train_result, train_result) + def test_train_with_update_ops(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array(((45,), (-41,),), dtype=np.float32), + labels=np.array(((1,), (1,),), dtype=np.float64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_summaries_with_head_name(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( name='some_binary_head') @@ -3278,6 +3336,34 @@ class RegressionHead(test.TestCase): self.assertAllClose(expected_loss, loss) self.assertEqual(expected_train_result, train_result) + def test_train_with_update_ops(self): + head = head_lib._regression_head() + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array(((45,), (41,),), dtype=np.float32), + labels=np.array(((43.,), (44.,),), dtype=np.float64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_summaries_with_head_name(self): head = head_lib._regression_head(name='some_regression_head') self.assertEqual(1, head.logits_dimension) diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 81657f0c01644524f1f706a0d42dd67e1345273e..58a71603488198373bc4d1fd716538c2cee4d86f 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import ftrl -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rate of 0.2 is a historical artifact of the initial @@ -66,13 +66,15 @@ def _compute_fraction_of_zero(cols_to_vars): return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0)) -def _linear_logit_fn_builder(units, feature_columns): +def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'): """Function builder for a linear logit_fn. Args: units: An int indicating the dimension of the logit layer. feature_columns: An iterable containing all the feature columns used by the model. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum". Returns: A logit_fn (see below). @@ -95,6 +97,7 @@ def _linear_logit_fn_builder(units, feature_columns): features=features, feature_columns=feature_columns, units=units, + sparse_combiner=sparse_combiner, cols_to_vars=cols_to_vars) bias = cols_to_vars.pop('bias') if units > 1: @@ -111,7 +114,7 @@ def _linear_logit_fn_builder(units, feature_columns): def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, - partitioner, config): + partitioner, config, sparse_combiner='sum'): """A model_fn for linear models that use a gradient-based optimizer. Args: @@ -126,6 +129,8 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, optimizer to use for training. If `None`, will use a FTRL optimizer. partitioner: Partitioner for variables. config: `RunConfig` object to configure the runtime settings. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum". Returns: An `EstimatorSpec` instance. @@ -153,7 +158,8 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, partitioner=partitioner): logit_fn = _linear_logit_fn_builder( - units=head.logits_dimension, feature_columns=feature_columns) + units=head.logits_dimension, feature_columns=feature_columns, + sparse_combiner=sparse_combiner) logits = logit_fn(features=features) return head.create_estimator_spec( @@ -164,7 +170,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, logits=logits) -@tf_export('estimator.LinearClassifier') +@estimator_export('estimator.LinearClassifier') class LinearClassifier(estimator.Estimator): """Linear classifier model. @@ -193,6 +199,17 @@ class LinearClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearClassifier( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearClassifier( feature_columns=[categorical_column_a, @@ -227,7 +244,10 @@ class LinearClassifier(estimator.Estimator): Loss is calculated by using softmax cross entropy. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -241,7 +261,8 @@ class LinearClassifier(estimator.Estimator): config=None, partitioner=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + sparse_combiner='sum'): """Construct a `LinearClassifier` estimator object. Args: @@ -269,8 +290,9 @@ class LinearClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or @@ -280,6 +302,11 @@ class LinearClassifier(estimator.Estimator): and Tensor names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + @{tf.feature_column.linear_model$linear_model}. Returns: A `LinearClassifier` estimator. @@ -308,7 +335,8 @@ class LinearClassifier(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearClassifier, self).__init__( model_fn=_model_fn, @@ -317,7 +345,7 @@ class LinearClassifier(estimator.Estimator): warm_start_from=warm_start_from) -@tf_export('estimator.LinearRegressor') +@estimator_export('estimator.LinearRegressor') class LinearRegressor(estimator.Estimator): """An estimator for TensorFlow Linear regression problems. @@ -332,10 +360,31 @@ class LinearRegressor(estimator.Estimator): categorical_feature_a_x_categorical_feature_b = crossed_column(...) + # Estimator using the default optimizer. estimator = LinearRegressor( feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearRegressor( feature_columns=[categorical_column_a, @@ -370,7 +419,10 @@ class LinearRegressor(estimator.Estimator): Loss is calculated by using mean squared error. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -383,7 +435,8 @@ class LinearRegressor(estimator.Estimator): config=None, partitioner=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + sparse_combiner='sum'): """Initializes a `LinearRegressor` instance. Args: @@ -403,8 +456,9 @@ class LinearRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or @@ -414,6 +468,11 @@ class LinearRegressor(estimator.Estimator): and Tensor names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + @{tf.feature_column.linear_model$linear_model}. """ head = head_lib._regression_head( # pylint: disable=protected-access label_dimension=label_dimension, weight_column=weight_column, @@ -429,7 +488,8 @@ class LinearRegressor(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearRegressor, self).__init__( model_fn=_model_fn, diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index 0e6436b42143f4b136165d47c41e143dacb4d476..9e9c2f7c4b0a79718da43769d983f49adbe537ca 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -29,6 +29,7 @@ import six from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import run_config from tensorflow.python.estimator.canned import linear @@ -484,6 +485,69 @@ class BaseLinearRegressorPredictTest(object): # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2 self.assertAllClose([[80.2]], predicted_scores) + def testSparseCombiner(self): + w_a = 2.0 + w_b = 3.0 + w_c = 5.0 + bias = 5.0 + with ops.Graph().as_default(): + variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME) + variables_lib.Variable([bias], name=BIAS_NAME) + variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + return dataset_ops.Dataset.from_tensors({ + 'language': sparse_tensor.SparseTensor( + values=['a', 'c', 'b', 'c'], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }) + + feature_columns = ( + feature_column_lib.categorical_column_with_vocabulary_list( + 'language', vocabulary_list=['a', 'b', 'c']),) + + # Check prediction for each sparse_combiner. + # With sparse_combiner = 'sum', we have + # logits_1 = w_a + w_c + bias + # = 2.0 + 5.0 + 5.0 = 12.0 + # logits_2 = w_b + w_c + bias + # = 3.0 + 5.0 + 5.0 = 13.0 + linear_regressor = self._linear_regressor_fn( + feature_columns=feature_columns, + model_dir=self._model_dir) + predictions = linear_regressor.predict(input_fn=_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + self.assertAllClose([[12.0], [13.0]], predicted_scores) + + # With sparse_combiner = 'mean', we have + # logits_1 = 1/2 * (w_a + w_c) + bias + # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5 + # logits_2 = 1/2 * (w_b + w_c) + bias + # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0 + linear_regressor = self._linear_regressor_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='mean') + predictions = linear_regressor.predict(input_fn=_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + self.assertAllClose([[8.5], [9.0]], predicted_scores) + + # With sparse_combiner = 'sqrtn', we have + # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias + # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974 + # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias + # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685 + linear_regressor = self._linear_regressor_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='sqrtn') + predictions = linear_regressor.predict(input_fn=_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + self.assertAllClose([[9.94974], [10.65685]], predicted_scores) + class BaseLinearRegressorIntegrationTest(object): @@ -1636,6 +1700,69 @@ class BaseLinearClassifierPredictTest(object): for i in range(n_classes)], label_output_fn=lambda x: ('class_vocab_%s' % x).encode()) + def testSparseCombiner(self): + w_a = 2.0 + w_b = 3.0 + w_c = 5.0 + bias = 5.0 + with ops.Graph().as_default(): + variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME) + variables_lib.Variable([bias], name=BIAS_NAME) + variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + return dataset_ops.Dataset.from_tensors({ + 'language': sparse_tensor.SparseTensor( + values=['a', 'c', 'b', 'c'], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }) + + feature_columns = ( + feature_column_lib.categorical_column_with_vocabulary_list( + 'language', vocabulary_list=['a', 'b', 'c']),) + + # Check prediction for each sparse_combiner. + # With sparse_combiner = 'sum', we have + # logits_1 = w_a + w_c + bias + # = 2.0 + 5.0 + 5.0 = 12.0 + # logits_2 = w_b + w_c + bias + # = 3.0 + 5.0 + 5.0 = 13.0 + linear_classifier = self._linear_classifier_fn( + feature_columns=feature_columns, + model_dir=self._model_dir) + predictions = linear_classifier.predict(input_fn=_input_fn) + predicted_scores = list([x['logits'] for x in predictions]) + self.assertAllClose([[12.0], [13.0]], predicted_scores) + + # With sparse_combiner = 'mean', we have + # logits_1 = 1/2 * (w_a + w_c) + bias + # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5 + # logits_2 = 1/2 * (w_b + w_c) + bias + # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0 + linear_classifier = self._linear_classifier_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='mean') + predictions = linear_classifier.predict(input_fn=_input_fn) + predicted_scores = list([x['logits'] for x in predictions]) + self.assertAllClose([[8.5], [9.0]], predicted_scores) + + # With sparse_combiner = 'sqrtn', we have + # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias + # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974 + # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias + # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685 + linear_classifier = self._linear_classifier_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='sqrtn') + predictions = linear_classifier.predict(input_fn=_input_fn) + predicted_scores = list([x['logits'] for x in predictions]) + self.assertAllClose([[9.94974], [10.65685]], predicted_scores) + class BaseLinearClassifierIntegrationTest(object): diff --git a/tensorflow/python/estimator/canned/optimizers.py b/tensorflow/python/estimator/canned/optimizers.py index f72c5ca5cbb2721d967ad9ef9dfa896f7ccce240..8f51cc3a80dd9b91eb24a83577b7d0614615e008 100644 --- a/tensorflow/python/estimator/canned/optimizers.py +++ b/tensorflow/python/estimator/canned/optimizers.py @@ -72,6 +72,8 @@ def get_optimizer_instance(opt, learning_rate=None): raise ValueError( 'Unsupported optimizer name: {}. Supported names are: {}'.format( opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES))))) + if callable(opt): + opt = opt() if not isinstance(opt, optimizer_lib.Optimizer): raise ValueError( 'The given object is not an Optimizer instance. Given: {}'.format(opt)) diff --git a/tensorflow/python/estimator/canned/optimizers_test.py b/tensorflow/python/estimator/canned/optimizers_test.py index ee28756155afd5ae3421475c3d41542db9411345..eadabdbc496334270cd792f5b8d5ff39a446bcf7 100644 --- a/tensorflow/python/estimator/canned/optimizers_test.py +++ b/tensorflow/python/estimator/canned/optimizers_test.py @@ -28,6 +28,13 @@ from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import rmsprop +class _TestOptimizer(optimizer_lib.Optimizer): + + def __init__(self): + super(_TestOptimizer, self).__init__( + use_locking=False, name='TestOptimizer') + + class GetOptimizerInstance(test.TestCase): def test_unsupported_name(self): @@ -66,12 +73,6 @@ class GetOptimizerInstance(test.TestCase): self.assertAlmostEqual(0.1, opt._learning_rate) def test_object(self): - class _TestOptimizer(optimizer_lib.Optimizer): - - def __init__(self): - super(_TestOptimizer, self).__init__( - use_locking=False, name='TestOptimizer') - opt = optimizers.get_optimizer_instance(_TestOptimizer()) self.assertIsInstance(opt, _TestOptimizer) @@ -80,6 +81,23 @@ class GetOptimizerInstance(test.TestCase): ValueError, 'The given object is not an Optimizer instance'): optimizers.get_optimizer_instance((1, 2, 3)) + def test_callable(self): + def _optimizer_fn(): + return _TestOptimizer() + opt = optimizers.get_optimizer_instance(_optimizer_fn) + self.assertIsInstance(opt, _TestOptimizer) + + def test_lambda(self): + opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda + self.assertIsInstance(opt, _TestOptimizer) + + def test_callable_returns_invalid(self): + def _optimizer_fn(): + return (1, 2, 3) + with self.assertRaisesRegexp( + ValueError, 'The given object is not an Optimizer instance'): + optimizers.get_optimizer_instance(_optimizer_fn) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py index 74e5e5a1bed80229c68daa3ff33ee7af4004bf47..1ae0f1e9f7781be84e71790146a90cf99a5e9831 100644 --- a/tensorflow/python/estimator/canned/parsing_utils.py +++ b/tensorflow/python/estimator/canned/parsing_utils.py @@ -23,10 +23,10 @@ import six from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.ops import parsing_ops -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.classifier_parse_example_spec') +@estimator_export('estimator.classifier_parse_example_spec') def classifier_parse_example_spec(feature_columns, label_key, label_dtype=dtypes.int64, @@ -166,7 +166,7 @@ def classifier_parse_example_spec(feature_columns, return parsing_spec -@tf_export('estimator.regressor_parse_example_spec') +@estimator_export('estimator.regressor_parse_example_spec') def regressor_parse_example_spec(feature_columns, label_key, label_dtype=dtypes.float32, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index a98600b2610f8eaee6eac94d692c743801f21b4c..350a95eea1f1112ea270156855409d7a1b264bfb 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -32,15 +32,18 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as tf_session -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import metrics as metrics_lib @@ -64,14 +67,14 @@ from tensorflow.python.util import compat from tensorflow.python.util import compat_internal from tensorflow.python.util import function_utils from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _VALID_MODEL_FN_ARGS = set( ['features', 'labels', 'mode', 'params', 'self', 'config']) -@tf_export('estimator.Estimator') +@estimator_export('estimator.Estimator') class Estimator(object): """Estimator class to train and evaluate TensorFlow models. @@ -101,6 +104,15 @@ class Estimator(object): None of `Estimator`'s methods can be overridden in subclasses (its constructor enforces this). Subclasses should use `model_fn` to configure the base class, and may add methods implementing specialized functionality. + + @compatbility(eager) + Calling methods of `Estimator` will work while eager execution is enabled. + However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator` + will switch to graph model before calling all user-provided functions (incl. + hooks), so their code has to be compatible with graph mode execution. Note + that `input_fn` code using `tf.data` generally works in both graph and eager + modes. + @end_compatibility """ def __init__(self, model_fn, model_dir=None, config=None, params=None, @@ -212,8 +224,8 @@ class Estimator(object): else: self._session_config = self._config.session_config - self._device_fn = self._config.device_fn or \ - _get_replica_device_setter(self._config) + self._device_fn = ( + self._config.device_fn or _get_replica_device_setter(self._config)) if model_fn is None: raise ValueError('model_fn must be provided to Estimator.') @@ -302,7 +314,7 @@ class Estimator(object): Args: input_fn: A function that provides input data for training as minibatches. - See @{$get_started/premade_estimators#create_input_functions} for more + See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: @@ -398,7 +410,7 @@ class Estimator(object): Args: input_fn: A function that constructs the input data for evaluation. - See @{$get_started/premade_estimators#create_input_functions} for more + See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: @@ -437,11 +449,25 @@ class Estimator(object): hooks = _check_hooks_type(hooks) hooks.extend(self._convert_eval_steps_to_hooks(steps)) - return self._evaluate_model( - input_fn=input_fn, - hooks=hooks, - checkpoint_path=checkpoint_path, - name=name) + # Check that model has been trained (if nothing has been set explicitly). + if not checkpoint_path: + latest_path = saver.latest_checkpoint(self._model_dir) + if not latest_path: + logging.info('Could not find trained model in model_dir: {}, running ' + 'initialization to evaluate.'.format(self._model_dir)) + checkpoint_path = latest_path + + with ops.Graph().as_default(): + (scaffold, update_op, + eval_dict, all_hooks) = self._evaluate_build_graph( + input_fn, hooks, checkpoint_path) + return self._evaluate_run( + checkpoint_path=checkpoint_path, + scaffold=scaffold, + update_op=update_op, + eval_dict=eval_dict, + all_hooks=all_hooks, + output_dir=self.eval_dir(name)) def _convert_eval_steps_to_hooks(self, steps): if steps is None: @@ -463,7 +489,7 @@ class Estimator(object): input_fn: A function that constructs the features. Prediction continues until `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`). - See @{$get_started/premade_estimators#create_input_functions} for more + See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: @@ -550,7 +576,9 @@ class Estimator(object): allowed_overrides = set([ '_call_input_fn', '_create_global_step', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names', '_validate_features_in_predict_input' + '_tf_api_names', '_estimator_api_names', '_estimator_api_constants', + '_validate_features_in_predict_input', + '_call_model_fn', '_add_meta_graph_for_mode' ]) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) @@ -814,10 +842,15 @@ class Estimator(object): gfile.Rename(temp_export_dir, export_dir) return export_dir - def _add_meta_graph_for_mode( - self, builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables=True, - mode=model_fn_lib.ModeKeys.PREDICT): + def _add_meta_graph_for_mode(self, + builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=True, + mode=model_fn_lib.ModeKeys.PREDICT, + export_tags=None, + check_variables=True): # pylint: disable=line-too-long """Loads variables and adds them along with a MetaGraphDef for saving. @@ -836,9 +869,18 @@ class Estimator(object): True for the first call to this function, and the SavedModelBuilder will raise an error if that is not the case. mode: tf.estimator.ModeKeys value indicating which mode will be exported. + export_tags: The set of tags with which to save `MetaGraphDef`. If None, + a default set will be selected to matched the passed mode. + check_variables: bool, whether to check the checkpoint has all variables. + + Raises: + ValueError: if `save_variables` is `True` and `check_variable` is `False`. """ # pylint: enable=line-too-long + if export_tags is None: + export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] input_receiver_fn = input_receiver_fn_map[mode] + with ops.Graph().as_default() as g: self._create_and_assert_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) @@ -863,25 +905,30 @@ class Estimator(object): with tf_session.Session(config=self._session_config) as session: - export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] - local_init_op = ( estimator_spec.scaffold.local_init_op or monitored_session.Scaffold.default_local_init_op()) - saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( - sharded=True) - - try: - saver_for_restore.restore(session, checkpoint_path) - except errors.NotFoundError as e: - msg = ('Could not load all requested variables from the checkpoint. ' - 'Please make sure your model_fn does not expect variables ' - 'that were not saved in the checkpoint.\n\n' - 'Encountered error with mode `{}` while restoring checkpoint ' - 'from: `{}`. Full Traceback:\n\n{}').format( - mode, checkpoint_path, e) - raise ValueError(msg) + # This saver will be used both for restoring variables now, + # and in saving out the metagraph below. This ensures that any + # Custom Savers stored with the Scaffold are passed through to the + # SavedModel for restore later. + graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True) + + if save_variables and not check_variables: + raise ValueError('If `save_variables` is `True, `check_variables`' + 'must not be `False`.') + if check_variables: + try: + graph_saver.restore(session, checkpoint_path) + except errors.NotFoundError as e: + msg = ('Could not load all requested variables from checkpoint. ' + 'Please make sure your model_fn does not expect variables ' + 'that were not saved in the checkpoint.\n\n' + 'Encountered error with mode `{}` while restoring ' + 'checkpoint from: `{}`. Full Traceback:\n\n{}').format( + mode, checkpoint_path, e) + raise ValueError(msg) # We add the train op explicitly for now, so that we don't have to # change the Builder public interface. Note that this is a no-op @@ -894,7 +941,8 @@ class Estimator(object): assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), strip_default_attrs=strip_default_attrs, - legacy_init_op=local_init_op) + legacy_init_op=local_init_op, + saver=graph_saver) if save_variables: builder.add_meta_graph_and_variables( @@ -942,17 +990,9 @@ class Estimator(object): def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) - input_hooks = [] - if isinstance(result, dataset_ops.Dataset): - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() - if isinstance(result, (list, tuple)): - # Unconditionally drop the label (the second element of result). - result = result[0] - + result, _, hooks = estimator_util.parse_input_fn_result(result) self._validate_features_in_predict_input(result) - return result, input_hooks + return result, hooks def _validate_features_in_predict_input(self, result): if not _has_dataset_or_queue_runner(result): @@ -962,25 +1002,13 @@ class Estimator(object): def _get_features_and_labels_from_input_fn(self, input_fn, mode): """Extracts the `features` and labels from return values of `input_fn`.""" - input_hooks = [] if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN: result = self._distribution.distribute_dataset( lambda: self._call_input_fn(input_fn, mode)) - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() else: result = self._call_input_fn(input_fn, mode) - if isinstance(result, dataset_ops.Dataset): - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() - if isinstance(result, (list, tuple)): - if len(result) != 2: - raise ValueError( - 'input_fn should return (features, labels) as a len 2 tuple.') - return result[0], result[1], input_hooks - return result, None, input_hooks + + return estimator_util.parse_input_fn_result(result) def _extract_batch_length(self, preds_evaluated): """Extracts batch length of predictions.""" @@ -1045,9 +1073,15 @@ class Estimator(object): mode: ModeKeys Returns: - Either features or (features, labels) where features and labels are: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. + The return value of the passed input_fn, which should be one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. Raises: ValueError: if input_fn takes invalid arguments. @@ -1109,6 +1143,18 @@ class Estimator(object): return self._train_model_default(input_fn, hooks, saving_listeners) def _train_model_default(self, input_fn, hooks, saving_listeners): + """Initiate training with input_fn, without DistributionStrategies. + + Args: + input_fn: A function that provides input data for training as minibatches. + hooks: List of `SessionRunHook` subclass instances. Used for callbacks + inside the training loop. + saving_listeners: list of `CheckpointSaverListener` objects. Used for + callbacks that run immediately before or after checkpoint savings. + + Returns: + Loss from training + """ worker_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) @@ -1125,29 +1171,86 @@ class Estimator(object): saving_listeners) def _train_model_distributed(self, input_fn, hooks, saving_listeners): + """Initiate training with input_fn, using DistributionStrategies. + + Args: + input_fn: A function that provides input data for training as minibatches. + hooks: List of `SessionRunHook` subclass instances. Used for callbacks + inside the training loop. + saving_listeners: list of `CheckpointSaverListener` objects. Used for + callbacks that run immediately before or after checkpoint savings. + + Returns: + Loss from training + """ self._distribution.configure(self._session_config) + + # TODO(sourabhbajaj): Remove this hack once we migrate the other strategies + # to use the new API + is_tpu_strategy = self._distribution.__class__.__name__ == 'TPUStrategy' + worker_hooks = [] with ops.Graph().as_default() as g: with self._distribution.scope(): random_seed.set_random_seed(self._config.tf_random_seed) - features, labels, input_hooks = ( - self._get_features_and_labels_from_input_fn( - input_fn, model_fn_lib.ModeKeys.TRAIN)) - worker_hooks.extend(input_hooks) - global_step_tensor = self._create_and_assert_global_step(g) - # The default destination for the global_step_tensor fetch call is the - # CPU. - global_step_read_tensor = self._distribution.fetch(global_step_tensor) - # we want to add to the global collection in the main thread not the - # tower threads. - ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, - global_step_read_tensor) - grouped_estimator_spec = self._distribution.call_for_each_tower( - self._call_model_fn, - features, - labels, # although this will be None it seems - model_fn_lib.ModeKeys.TRAIN, - self.config) + + if is_tpu_strategy: + # Create the iterator for run_on_dataset function + # TODO(sourabhbajaj): refactor this out to call a function on the + # strategy + dataset = self._distribution.distribute_dataset( + lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda + model_fn_lib.ModeKeys.TRAIN)) + iterator = dataset.make_initializable_iterator() + worker_hooks.append( + estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access + + global_step_tensor = self._create_and_assert_global_step(g) + # we want to add to the global collection in the main thread not the + # tower threads. + ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, + self._distribution.read_var(global_step_tensor)) + + # Create a step_fn from the train_op of grouped_estimator_spec + def step_fn(ctx, inputs): + """A single step that is passed to run_on_dataset.""" + features, labels = inputs + estimator_spec = self._distribution.call_for_each_tower( + self._call_model_fn, + features, + labels, + model_fn_lib.ModeKeys.TRAIN, + self.config) + ctx.last_step_outputs = estimator_spec.loss + ctx.non_tensor_outputs = {'estimator_spec': estimator_spec} + with ops.control_dependencies([estimator_spec.train_op]): + return array_ops.identity(estimator_spec.loss) + + # Create new train_op post graph rewrites + # TODO(sourabhbajaj): Make sure train_steps and tpu_iterations + # work correctly. Currently hardcoded at 2 + initial_training_loss = constant_op.constant(1e7) + distributed_train_op, tpu_result, ctx = \ + self._distribution._run_steps_on_dataset( # pylint: disable=protected-access + step_fn, iterator, iterations=2, + initial_loop_values=initial_training_loss) + grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec'] + else: + features, labels, input_hooks = ( + self._get_features_and_labels_from_input_fn( + input_fn, model_fn_lib.ModeKeys.TRAIN)) + worker_hooks.extend(input_hooks) + global_step_tensor = self._create_and_assert_global_step(g) + # we want to add to the global collection in the main thread not the + # tower threads. + ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, + self._distribution.read_var(global_step_tensor)) + grouped_estimator_spec = self._distribution.call_for_each_tower( + self._call_model_fn, + features, + labels, # although this will be None it seems + model_fn_lib.ModeKeys.TRAIN, + self.config) # TODO(anjalisridhar): Figure out how to resolve the following scaffold # parameters: init_feed_dict, init_fn. @@ -1175,10 +1278,16 @@ class Estimator(object): else: init_op = None + def _unwrap_and_concat(value): + value = nest.flatten(self._distribution.unwrap(value)) + if len(value) != 1: + return array_ops.concat(value) + return value[0] + ready_op = self._distribution.call_for_each_tower( create_per_tower_ready_op, grouped_estimator_spec.scaffold) if ready_op is not None: - ready_op = self._distribution.group(ready_op) + ready_op = _unwrap_and_concat(ready_op) else: ready_op = None @@ -1186,8 +1295,7 @@ class Estimator(object): create_per_tower_ready_for_local_init_op, grouped_estimator_spec.scaffold) if ready_for_local_init_op is not None: - ready_for_local_init_op = self._distribution.group( - ready_for_local_init_op) + ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op) else: ready_for_local_init_op = None @@ -1228,18 +1336,33 @@ class Estimator(object): training_chief_hooks = get_hooks_from_the_first_device( grouped_estimator_spec.training_chief_hooks) + # TODO(sourabhbajaj): Merge the two code paths once we can + # handle per device variables correctly in reduce and can output + # the loss scaler. + if is_tpu_strategy: + loss = self._distribution.unwrap( + self._distribution.reduce(distribute_lib.get_loss_reduction(), + tpu_result)[0])[0] + worker_hooks.append( + estimator_util.StrategyInitFinalizeHook( + self._distribution.get_initialization_ops, + self._distribution.get_finalize_ops)) + else: + loss = self._distribution.unwrap( + self._distribution.reduce(distribute_lib.get_loss_reduction(), + grouped_estimator_spec.loss, + destinations='/device:CPU:0'))[0] + distributed_train_op = grouped_estimator_spec.train_op + estimator_spec = model_fn_lib.EstimatorSpec( mode=grouped_estimator_spec.mode, - loss=self._distribution.unwrap( - self._distribution.reduce(distribute_lib.get_loss_reduction(), - grouped_estimator_spec.loss, - destinations='/device:CPU:0'))[0], - train_op=self._distribution.group(grouped_estimator_spec.train_op), + loss=loss, + train_op=self._distribution.group(distributed_train_op), training_hooks=training_hooks, training_chief_hooks=training_chief_hooks, scaffold=scaffold) return self._train_with_estimator_spec(estimator_spec, worker_hooks, - hooks, global_step_read_tensor, + hooks, global_step_tensor, saving_listeners) def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, @@ -1326,66 +1449,67 @@ class Estimator(object): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss - def _evaluate_model(self, - input_fn, - hooks=None, - checkpoint_path=None, - name=''): - """Evaluates the model using the training.evaluation library.""" - # Check that model has been trained (if nothing has been set explicitly). - if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) - if not latest_path: - logging.info('Could not find trained model in model_dir: {}, running ' - 'initialization to evaluate.'.format(self._model_dir)) - checkpoint_path = latest_path - - with ops.Graph().as_default() as g: - random_seed.set_random_seed(self._config.tf_random_seed) - global_step_tensor = self._create_and_assert_global_step(g) - features, labels, input_hooks = ( - self._get_features_and_labels_from_input_fn( - input_fn, model_fn_lib.ModeKeys.EVAL)) - estimator_spec = self._call_model_fn( - features, labels, model_fn_lib.ModeKeys.EVAL, self.config) - - # Call to warm_start has to be after model_fn is called. - self._maybe_warm_start(checkpoint_path) + def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None): + """Builds the graph and related hooks to run evaluation.""" + random_seed.set_random_seed(self._config.tf_random_seed) + global_step_tensor = self._create_and_assert_global_step( + ops.get_default_graph()) + features, labels, input_hooks = ( + self._get_features_and_labels_from_input_fn(input_fn, + model_fn_lib.ModeKeys.EVAL)) + estimator_spec = self._call_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.config) + + # Call to warm_start has to be after model_fn is called. + self._maybe_warm_start(checkpoint_path) + + if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops: + raise ValueError( + 'Metric with name "%s" is not allowed, because Estimator ' % + (model_fn_lib.LOSS_METRIC_KEY) + + 'already defines a default metric with the same name.') + estimator_spec.eval_metric_ops[ + model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss) - if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops: - raise ValueError( - 'Metric with name "%s" is not allowed, because Estimator ' % ( - model_fn_lib.LOSS_METRIC_KEY) + - 'already defines a default metric with the same name.') - estimator_spec.eval_metric_ops[ - model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss) + update_op, eval_dict = _extract_metric_update_ops( + estimator_spec.eval_metric_ops) - update_op, eval_dict = _extract_metric_update_ops( - estimator_spec.eval_metric_ops) + if ops.GraphKeys.GLOBAL_STEP in eval_dict: + raise ValueError( + 'Metric with name `global_step` is not allowed, because Estimator ' + 'already defines a default metric with the same name.') + eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor - if ops.GraphKeys.GLOBAL_STEP in eval_dict: - raise ValueError( - 'Metric with name `global_step` is not allowed, because Estimator ' - 'already defines a default metric with the same name.') - eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor + all_hooks = list(input_hooks) + all_hooks.extend(hooks) + all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) - all_hooks = list(input_hooks) - all_hooks.extend(hooks) - all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) + return estimator_spec.scaffold, update_op, eval_dict, all_hooks - eval_results = evaluation._evaluate_once( # pylint: disable=protected-access + def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict, + all_hooks, output_dir): + """Run evaluation.""" + eval_results = evaluation._evaluate_once( # pylint: disable=protected-access + checkpoint_path=checkpoint_path, + master=self._config.evaluation_master, + scaffold=scaffold, + eval_ops=update_op, + final_ops=eval_dict, + hooks=all_hooks, + config=self._session_config) + + current_global_step = eval_results[ops.GraphKeys.GLOBAL_STEP] + + _write_dict_to_summary( + output_dir=output_dir, + dictionary=eval_results, + current_global_step=current_global_step) + + if checkpoint_path: + _write_checkpoint_path_to_summary( + output_dir=output_dir, checkpoint_path=checkpoint_path, - master=self._config.evaluation_master, - scaffold=estimator_spec.scaffold, - eval_ops=update_op, - final_ops=eval_dict, - hooks=all_hooks, - config=self._session_config) - - _write_dict_to_summary( - output_dir=self.eval_dir(name), - dictionary=eval_results, - current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) + current_global_step=current_global_step) return eval_results @@ -1584,6 +1708,30 @@ def _write_dict_to_summary(output_dir, summary_writer.flush() +def _write_checkpoint_path_to_summary(output_dir, checkpoint_path, + current_global_step): + """Writes `checkpoint_path` into summary file in the given output directory. + + Args: + output_dir: `str`, directory to write the summary file in. + checkpoint_path: `str`, checkpoint file path to be written to summary file. + current_global_step: `int`, the current global step. + """ + + checkpoint_path_tag = 'checkpoint_path' + + logging.info('Saving \'%s\' summary for global step %d: %s', + checkpoint_path_tag, current_global_step, checkpoint_path) + summary_proto = summary_pb2.Summary() + summary_proto.value.add( + tag=checkpoint_path_tag, + tensor=tensor_util.make_tensor_proto( + checkpoint_path, dtype=dtypes.string)) + summary_writer = writer_cache.FileWriterCache.get(output_dir) + summary_writer.add_summary(summary_proto, current_global_step) + summary_writer.flush() + + def _has_dataset_or_queue_runner(maybe_tensor): """Returns True if TF dataset or QueueRunner has been used.""" # Check TF dataset first. Here, we use a simple algorithm to check the top @@ -1596,22 +1744,11 @@ def _has_dataset_or_queue_runner(maybe_tensor): return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) -class _DatasetInitializerHook(training.SessionRunHook): - - def __init__(self, iterator): - self._iterator = iterator - - def begin(self): - self._initializer = self._iterator.initializer - - def after_create_session(self, session, coord): - del coord - session.run(self._initializer) - VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name +estimator_export('estimator.VocabInfo')(VocabInfo) -@tf_export('estimator.WarmStartSettings') +@estimator_export('estimator.WarmStartSettings') class WarmStartSettings( collections.namedtuple('WarmStartSettings', [ 'ckpt_to_initialize_from', @@ -1738,10 +1875,19 @@ class WarmStartSettings( ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. - vars_to_warm_start: [Optional] A regular expression that captures which - variables to warm-start (see tf.get_collection). Defaults to `'.*'`, - which warm-starts all variables. If `None` is explicitly given, only - variables specified in `var_name_to_vocab_info` will be warm-started. + vars_to_warm_start: [Optional] One of the following: + + - A regular expression (string) that captures which variables to + warm-start (see tf.get_collection). This expression will only consider + variables in the TRAINABLE_VARIABLES collection. + - A list of Variables to warm-start. + - A list of strings, each representing a full variable name to warm-start. + - `None`, in which case only variables specified in + `var_name_to_vocab_info` will be warm-started. + + Defaults to `'.*'`, which warm-starts all variables in the + TRAINABLE_VARIABLES collection. Note that this excludes variables such as + accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to VocabInfo. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to @@ -1802,5 +1948,3 @@ def _get_default_warm_start_settings(warm_start_from): else: raise ValueError('warm_start_from must be a string or a WarmStartSettings, ' 'instead got {}'.format(type(warm_start_from))) - - diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 1b701899487ede96a90a6be299323fe278e5027f..2a0e4e761755e272a316ce2d326b0c0a51ecbaba 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -38,7 +38,9 @@ from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.layers import layers from tensorflow.python.lib.io import file_io @@ -60,6 +62,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator @@ -81,21 +84,27 @@ def dummy_model_fn(features, labels, params): _, _, _ = features, labels, params -def check_eventfile_for_keyword(keyword, dir_): - """Checks event files for the keyword.""" +def summaries_with_matching_keyword(keyword, dir_): + """Yields summary protos matching given keyword from event file.""" writer_cache.FileWriterCache.clear() - # Get last Event written. event_paths = glob.glob(os.path.join(dir_, 'events*')) - last_event = None - for last_event in summary_iterator.summary_iterator(event_paths[-1]): - if last_event.summary is not None: - for value in last_event.summary.value: + for event in summary_iterator.summary_iterator(event_paths[-1]): + if event.summary is not None: + for value in event.summary.value: if keyword in value.tag: - return True + yield event.summary + + +def check_eventfile_for_keyword(keyword, dir_): + """Checks event files for the keyword.""" + return any(summaries_with_matching_keyword(keyword, dir_)) + - return False +def get_mock_saver(): + real_saver = saver.Saver() + return test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def) class EstimatorInheritanceConstraintTest(test.TestCase): @@ -814,6 +823,7 @@ class EstimatorTrainTest(test.TestCase): def test_saving_listeners_are_used(self): listener = test.mock.Mock(spec=training.CheckpointSaverListener) + listener.after_save.return_value = None est = estimator.Estimator( model_fn=model_fn_global_step_incrementer, config=run_config.RunConfig(save_checkpoints_steps=10)) @@ -1287,14 +1297,37 @@ class EstimatorEvaluateTest(test.TestCase): dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint()) self.assertEqual(5, scores['global_step']) + def test_wrong_shape_throws_reasonable_error(self): + """Make sure we are helpful when model_fns change. See b/110263146.""" + def _get_model_fn(val=1): + def _model_fn(features, labels, mode): + del features, labels # unused + variables.Variable(val, name='weight') + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + return _model_fn + + model_fn_1 = _get_model_fn() + model_fn_2 = _get_model_fn(val=[1]) + + est1 = estimator.Estimator(model_fn=model_fn_1) + est1.train(dummy_input_fn, steps=5) + est2 = estimator.Estimator( + model_fn=model_fn_2, model_dir=est1.model_dir) + + expected_msg = 'Restoring from checkpoint failed.*a mismatch between' + with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg): + est2.train(dummy_input_fn, steps=1,) + def test_scaffold_is_used(self): def _model_fn_scaffold(features, labels, mode): _, _ = features, labels variables.Variable(1., name='weight') - real_saver = saver.Saver() - self.mock_saver = test.mock.Mock( - wraps=real_saver, saver_def=real_saver.saver_def) + self.mock_saver = get_mock_saver() return model_fn_lib.EstimatorSpec( mode=mode, predictions=constant_op.constant([[1.]]), @@ -1397,6 +1430,19 @@ class EstimatorEvaluateTest(test.TestCase): check_eventfile_for_keyword(key, est.eval_dir()), '{} should be part of reported summaries.'.format(key)) + # Verify that evaluated checkpoint path is written to event file. + checkpoint_path_tag = 'checkpoint_path' + self.assertTrue( + check_eventfile_for_keyword(checkpoint_path_tag, est.eval_dir()), + '{} should be part of reported summaries.'.format(checkpoint_path_tag)) + + expected_tensor_proto = tensor_util.make_tensor_proto( + est.latest_checkpoint(), dtype=dtypes.string) + summaries = summaries_with_matching_keyword(checkpoint_path_tag, + est.eval_dir()) + self.assertProtoEquals(expected_tensor_proto, + next(summaries).value[0].tensor) + class EstimatorPredictTest(test.TestCase): @@ -1803,9 +1849,7 @@ class EstimatorPredictTest(test.TestCase): def _model_fn_scaffold(features, labels, mode): _, _ = features, labels variables.Variable(1., name='weight') - real_saver = saver.Saver() - self.mock_saver = test.mock.Mock( - wraps=real_saver, saver_def=real_saver.saver_def) + self.mock_saver = get_mock_saver() return model_fn_lib.EstimatorSpec( mode=mode, predictions=constant_op.constant([[1.]]), @@ -2299,8 +2343,8 @@ class EstimatorExportTest(test.TestCase): graph_ops = [x.name for x in graph.get_operations()] self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops) - # Note that the SavedModel builder replaced the Saver with a new one - self.assertTrue('save_1/LookupTableImportV2' in graph_ops) + # The original saver is used to restore variables + self.assertTrue('save/LookupTableImportV2' in graph_ops) # Clean up. gfile.DeleteRecursively(tmpdir) @@ -2465,9 +2509,7 @@ class EstimatorExportTest(test.TestCase): def _model_fn_scaffold(features, labels, mode): _, _ = features, labels variables.Variable(1., name='weight') - real_saver = saver.Saver() - self.mock_saver = test.mock.Mock( - wraps=real_saver, saver_def=real_saver.saver_def) + self.mock_saver = get_mock_saver() scores = constant_op.constant([3.]) return model_fn_lib.EstimatorSpec( mode=mode, @@ -2490,19 +2532,24 @@ class EstimatorExportTest(test.TestCase): est.export_savedmodel(export_dir_base, serving_input_receiver_fn) self.assertTrue(self.mock_saver.restore.called) + self.assertTrue(self.mock_saver.export_meta_graph.called) + self.assertTrue(self.mock_saver.save.called) def test_scaffold_is_used_for_saver_multiple_modes(self): tmpdir = tempfile.mkdtemp() + savers = {'predict_saver': None, 'train_saver': None} def _model_fn_scaffold(features, labels, mode): _, _ = features, labels variables.Variable(1., name='weight') - real_saver = saver.Saver() - self.mock_saver = test.mock.Mock( - wraps=real_saver, saver_def=real_saver.saver_def) + scores = constant_op.constant([3.]) if mode == model_fn_lib.ModeKeys.PREDICT: - scaffold = training.Scaffold(saver=self.mock_saver) + savers['predict_saver'] = get_mock_saver() + scaffold = training.Scaffold(saver=savers['predict_saver']) + elif mode == model_fn_lib.ModeKeys.TRAIN: + savers['train_saver'] = get_mock_saver() + scaffold = training.Scaffold(saver=savers['train_saver']) else: scaffold = training.Scaffold() return model_fn_lib.EstimatorSpec( @@ -2526,7 +2573,13 @@ class EstimatorExportTest(test.TestCase): compat.as_bytes(tmpdir), compat.as_bytes('export')) est._export_all_saved_models(export_dir_base, input_receiver_fn_map) - self.assertTrue(self.mock_saver.restore.called) + self.assertTrue(savers['train_saver'].restore.called) + self.assertEqual(savers['train_saver'].export_meta_graph.call_count, 1) + self.assertEqual(savers['train_saver'].save.call_count, 1) + + self.assertTrue(savers['predict_saver'].restore.called) + self.assertEqual(savers['predict_saver'].export_meta_graph.call_count, 1) + self.assertEqual(savers['predict_saver'].save.call_count, 0) def test_scaffold_is_used_for_local_init(self): tmpdir = tempfile.mkdtemp() @@ -2803,6 +2856,45 @@ class EstimatorExportTest(test.TestCase): # Clean up. gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_no_export_outputs(self): + """Ensure that an EstimatorSpec without outputs defined can be exported.""" + + def _model_fn(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn) + est.train(input_fn=dummy_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('no_export_outputs')) + export_dir = est.export_savedmodel( + export_dir_base, _get_serving_input_receiver_fn()) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + meta_graph = loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('weight' in graph_ops) + + sig_def = meta_graph.signature_def + self.assertEqual(len(sig_def), 1) + sig_outputs = sig_def[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs + self.assertEqual(sig_outputs['output'].name, 'Const:0') + class EstimatorHookOrderingTest(test.TestCase): @@ -2847,7 +2939,7 @@ class EstimatorHookOrderingTest(test.TestCase): class EstimatorIntegrationTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_complete_flow_with_a_simple_linear_model(self): def _model_fn(features, labels, mode): diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 48ae8cd49791c27a1e9674ed1be19d543d690b35..ca26341445e86ad554ac2e7cbf643c7775dd9825 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' @@ -93,7 +93,7 @@ def _check_tensor_key(name, error_label='feature'): raise ValueError('{} keys must be strings: {}.'.format(error_label, name)) -@tf_export('estimator.export.ServingInputReceiver') +@estimator_export('estimator.export.ServingInputReceiver') class ServingInputReceiver( collections.namedtuple( 'ServingInputReceiver', @@ -161,7 +161,7 @@ class ServingInputReceiver( receiver_tensors_alternatives=receiver_tensors_alternatives) -@tf_export('estimator.export.TensorServingInputReceiver') +@estimator_export('estimator.export.TensorServingInputReceiver') class TensorServingInputReceiver( collections.namedtuple( 'TensorServingInputReceiver', @@ -263,7 +263,7 @@ class SupervisedInputReceiver( receiver_tensors=receiver_tensors) -@tf_export('estimator.export.build_parsing_serving_input_receiver_fn') +@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): """Build a serving_input_receiver_fn expecting fed tf.Examples. @@ -313,7 +313,7 @@ def _placeholders_from_receiver_tensors_dict(input_vals, } -@tf_export('estimator.export.build_raw_serving_input_receiver_fn') +@estimator_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. @@ -333,11 +333,7 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """A serving_input_receiver_fn that expects features to be fed directly.""" receiver_tensors = _placeholders_from_receiver_tensors_dict( features, default_batch_size) - - # TODO(b/34885899): remove the unnecessary copy - # The features provided are simply the placeholders, but we defensively copy - # the dict because it may be mutated. - return ServingInputReceiver(receiver_tensors, receiver_tensors.copy()) + return ServingInputReceiver(receiver_tensors, receiver_tensors) return serving_input_receiver_fn @@ -404,6 +400,42 @@ def build_raw_supervised_input_receiver_fn(features, return supervised_input_receiver_fn +def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args): + """Get a function that returns a SupervisedInputReceiver matching an input_fn. + + Note that this function calls the input_fn in a local graph in order to + extract features and labels. Placeholders are then created from those + features and labels in the default graph. + + Args: + input_fn: An Estimator input_fn, which is a function that returns one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + **input_fn_args: set of kwargs to be passed to the input_fn. Note that + these will not be checked or validated here, and any errors raised by + the input_fn will be thrown to the top. + + Returns: + A function taking no arguments that, when called, returns a + SupervisedInputReceiver. This function can be passed in as part of the + input_receiver_map when exporting SavedModels from Estimator with multiple + modes. + """ + # Wrap the input_fn call in a graph to prevent sullying the default namespace + with ops.Graph().as_default(): + result = input_fn(**input_fn_args) + features, labels, _ = util.parse_input_fn_result(result) + # Placeholders are created back in the default graph. + return build_raw_supervised_input_receiver_fn(features, labels) + + ### Below utilities are specific to SavedModel exports. diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index d387ea2940e7a450afe28b884c52113355c70fe6..6c26d299851eaea74f1e564d0fac217f238d76a2 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -26,10 +26,10 @@ import six from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.export.ExportOutput') +@estimator_export('estimator.export.ExportOutput') class ExportOutput(object): """Represents an output of a model that can be served. @@ -100,7 +100,7 @@ class ExportOutput(object): return output_dict -@tf_export('estimator.export.ClassificationOutput') +@estimator_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): """Represents the output of a classification head. @@ -169,7 +169,7 @@ class ClassificationOutput(ExportOutput): examples, self.classes, self.scores) -@tf_export('estimator.export.RegressionOutput') +@estimator_export('estimator.export.RegressionOutput') class RegressionOutput(ExportOutput): """Represents the output of a regression head.""" @@ -202,7 +202,7 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) -@tf_export('estimator.export.PredictOutput') +@estimator_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index 0af587f2a850dff3ca2dc744e157ed5fbb329735..a7074712c25532a1d2156a11d2314150d9efabc1 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -459,6 +459,41 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_supervised_input_receiver_fn_from_input_fn(self): + def dummy_input_fn(): + return ({"x": constant_op.constant([[1], [1]]), + "y": constant_op.constant(["hello", "goodbye"])}, + constant_op.constant([[1], [1]])) + + input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn( + dummy_input_fn) + + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["x", "y"]), + set(input_receiver.features.keys())) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["x", "y", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_supervised_input_receiver_fn_from_input_fn_args(self): + def dummy_input_fn(feature_key="x"): + return ({feature_key: constant_op.constant([[1], [1]]), + "y": constant_op.constant(["hello", "goodbye"])}, + {"my_label": constant_op.constant([[1], [1]])}) + + input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn( + dummy_input_fn, feature_key="z") + + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["z", "y"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["my_label"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["z", "y", "my_label"]), + set(input_receiver.receiver_tensors.keys())) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py index ced793067194b026474f3fbddf28de9c4108f67a..b18212cfcda8f817f909672007c5b000db718232 100644 --- a/tensorflow/python/estimator/exporter.py +++ b/tensorflow/python/estimator/exporter.py @@ -28,10 +28,10 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging from tensorflow.python.summary import summary_iterator -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.Exporter') +@estimator_export('estimator.Exporter') class Exporter(object): """A class representing a type of model export.""" @@ -156,7 +156,7 @@ def _loss_smaller(best_eval_result, current_eval_result): return best_eval_result[default_key] > current_eval_result[default_key] -def _verify_compre_fn_args(compare_fn): +def _verify_compare_fn_args(compare_fn): """Verifies compare_fn arguments.""" args = set(util.fn_args(compare_fn)) if 'best_eval_result' not in args: @@ -172,7 +172,7 @@ def _verify_compre_fn_args(compare_fn): (compare_fn, non_valid_args)) -@tf_export('estimator.BestExporter') +@estimator_export('estimator.BestExporter') class BestExporter(Exporter): """This class exports the serving graph and checkpoints of the best models. @@ -265,7 +265,7 @@ class BestExporter(Exporter): self._compare_fn = compare_fn if self._compare_fn is None: raise ValueError('`compare_fn` must not be None.') - _verify_compre_fn_args(self._compare_fn) + _verify_compare_fn_args(self._compare_fn) self._saved_model_exporter = _SavedModelExporter( name, serving_input_receiver_fn, assets_extra, as_text) @@ -287,11 +287,11 @@ class BestExporter(Exporter): is_the_final_export): export_result = None - if self._model_dir != estimator.model_dir() and self._event_file_pattern: + if self._model_dir != estimator.model_dir and self._event_file_pattern: # Loads best metric from event files. tf_logging.info('Loading best metric from event files.') - self._model_dir = estimator.model_dir() + self._model_dir = estimator.model_dir full_event_file_pattern = os.path.join(self._model_dir, self._event_file_pattern) self._best_eval_result = self._get_best_eval_result( @@ -360,13 +360,14 @@ class BestExporter(Exporter): for value in event.summary.value: if value.HasField('simple_value'): event_eval_result[value.tag] = value.simple_value - if best_eval_result is None or self._compare_fn( - best_eval_result, event_eval_result): - best_eval_result = event_eval_result + if event_eval_result: + if best_eval_result is None or self._compare_fn( + best_eval_result, event_eval_result): + best_eval_result = event_eval_result return best_eval_result -@tf_export('estimator.FinalExporter') +@estimator_export('estimator.FinalExporter') class FinalExporter(Exporter): """This class exports the serving graph and checkpoints in the end. @@ -417,7 +418,7 @@ class FinalExporter(Exporter): is_the_final_export) -@tf_export('estimator.LatestExporter') +@estimator_export('estimator.LatestExporter') class LatestExporter(Exporter): """This class regularly exports the serving graph and checkpoints. diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py index 053c5490711cf538b4b83599b7938a43d8eaec34..c4b006955c4128d6f40dcb30215beda761abbd49 100644 --- a/tensorflow/python/estimator/exporter_test.py +++ b/tensorflow/python/estimator/exporter_test.py @@ -62,7 +62,7 @@ class BestExporterTest(test.TestCase): exports_to_keep=5) estimator = test.mock.Mock(spec=estimator_lib.Estimator) estimator.export_savedmodel.return_value = "export_result_path" - estimator.model_dir.return_value = export_dir_base + estimator.model_dir = export_dir_base export_result = exporter.export(estimator, export_dir_base, "checkpoint_path", {}, False) @@ -94,7 +94,7 @@ class BestExporterTest(test.TestCase): exports_to_keep=1) estimator = test.mock.Mock(spec=estimator_lib.Estimator) estimator.export_savedmodel.return_value = "export_result_path" - estimator.model_dir.return_value = export_dir_base + estimator.model_dir = export_dir_base export_result = exporter.export(estimator, export_dir_base, "checkpoint_path", {"loss": 0.5}, False) @@ -133,7 +133,7 @@ class BestExporterTest(test.TestCase): exports_to_keep=1) estimator = test.mock.Mock(spec=estimator_lib.Estimator) - estimator.model_dir.return_value = export_dir_base + estimator.model_dir = export_dir_base estimator.export_savedmodel.return_value = "export_result_path" export_result = exporter.export(estimator, export_dir_base, @@ -148,6 +148,40 @@ class BestExporterTest(test.TestCase): "checkpoint_path", {"loss": 20}, False) self.assertEqual(None, export_result) + def test_best_exporter_with_empty_event(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + eval_dir_base = os.path.join(export_dir_base, "eval_continuous") + estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1) + estimator_lib._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2) + + exporter = exporter_lib.BestExporter( + name="best_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + event_file_pattern="eval_continuous/*.tfevents.*", + assets_extra={"from/path": "to/path"}, + as_text=False, + exports_to_keep=1) + + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.model_dir = export_dir_base + estimator.export_savedmodel.return_value = "export_result_path" + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"loss": 100}, False) + self.assertEqual(None, export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"loss": 10}, False) + self.assertEqual("export_result_path", export_result) + def test_garbage_collect_exports(self): export_dir_base = tempfile.mkdtemp() gfile.MkDir(export_dir_base) @@ -172,7 +206,7 @@ class BestExporterTest(test.TestCase): serving_input_receiver_fn=_serving_input_receiver_fn, exports_to_keep=2) estimator = test.mock.Mock(spec=estimator_lib.Estimator) - estimator.model_dir.return_value = export_dir_base + estimator.model_dir = export_dir_base # Garbage collect all but the most recent 2 exports, # where recency is determined based on the timestamp directory names. exporter.export(estimator, export_dir_base, None, None, False) diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index a6f471291008e3c27dea1aeea5865e334f76e5c8..a6cefdece21fa8ce944095cb5d3395f2b67142bd 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -24,7 +24,7 @@ import numpy as np from six import string_types from tensorflow.python.estimator.inputs.queues import feeding_functions -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # Key name to pack the target into dict of `features`. See # `_get_unique_target_key` for details. @@ -87,7 +87,7 @@ def _validate_and_convert_features(x): return ordered_dict_data -@tf_export('estimator.inputs.numpy_input_fn') +@estimator_export('estimator.inputs.numpy_input_fn') def numpy_input_fn(x, y=None, batch_size=128, @@ -136,11 +136,13 @@ def numpy_input_fn(x, values in `x` have same shape). ValueError: if duplicate keys are in both `x` and `y` when `y` is a dict. ValueError: if x or y is an empty dict. - TypeError: `x` is not a dict or array, or if `shuffle` is not bool. + TypeError: `x` is not a dict or array. + ValueError: if 'shuffle' is not provided or a bool. """ if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) def input_fn(): """Numpy input function.""" diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py index 92d057e25da785cf5ee310ca1c80f67a5fbdb43a..81b201cc5c5f3d6b8211030d17006f89a545793e 100644 --- a/tensorflow/python/estimator/inputs/numpy_io_test.py +++ b/tensorflow/python/estimator/inputs/numpy_io_test.py @@ -286,8 +286,9 @@ class NumpyIoTest(test.TestCase): x = np.arange(32, 36) y = np.arange(4) with self.test_session(): - with self.assertRaisesRegexp(TypeError, - 'shuffle must be explicitly set as boolean'): + with self.assertRaisesRegexp(ValueError, + 'shuffle must be provided and explicitly ' + 'set as boolean'): # Default shuffle is None. numpy_io.numpy_input_fn(x, y) diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index bd06843021f47f81fc0c22d0fcee43530dc10098..616bcb410f8119e170e991f8320c5b6448ee85c9 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six +import uuid import numpy as np from tensorflow.python.estimator.inputs.queues import feeding_functions -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export try: # pylint: disable=g-import-not-at-top @@ -35,7 +37,23 @@ except ImportError: HAS_PANDAS = False -@tf_export('estimator.inputs.pandas_input_fn') +def _get_unique_target_key(features, target_column_name): + """Returns a key that does not exist in the input DataFrame `features`. + + Args: + features: DataFrame + target_column_name: Name of the target column as a `str` + + Returns: + A unique key that can be used to insert the target into + features. + """ + if target_column_name in features: + target_column_name += '_' + str(uuid.uuid4()) + return target_column_name + + +@estimator_export('estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, @@ -50,7 +68,7 @@ def pandas_input_fn(x, Args: x: pandas `DataFrame` object. - y: pandas `Series` object. `None` if absent. + y: pandas `Series` object or `DataFrame`. `None` if absent. batch_size: int, size of batches to return. num_epochs: int, number of epochs to iterate over data. If not `None`, read attempts that would exceed this value will raise `OutOfRangeError`. @@ -60,7 +78,8 @@ def pandas_input_fn(x, num_threads: Integer, number of threads used for reading and enqueueing. In order to have predicted and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `num_threads` should be 1. - target_column: str, name to give the target column `y`. + target_column: str, name to give the target column `y`. This parameter + is not used when `y` is a `DataFrame`. Returns: Function, that has signature of ()->(dict of `features`, `target`) @@ -68,15 +87,19 @@ def pandas_input_fn(x, Raises: ValueError: if `x` already contains a column with the same name as `y`, or if the indexes of `x` and `y` don't match. - TypeError: `shuffle` is not bool. + ValueError: if 'shuffle' is not provided or a bool. """ if not HAS_PANDAS: raise TypeError( 'pandas_input_fn should not be called without pandas installed') if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) + + if not isinstance(target_column, six.string_types): + raise TypeError('target_column must be a string type') x = x.copy() if y is not None: @@ -87,7 +110,13 @@ def pandas_input_fn(x, if not np.array_equal(x.index, y.index): raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n' 'Index for y: %s\n' % (x.index, y.index)) - x[target_column] = y + if isinstance(y, pd.DataFrame): + y_columns = [(column, _get_unique_target_key(x, column)) + for column in list(y)] + target_column = [v for _, v in y_columns] + x[target_column] = y + else: + x[target_column] = y # TODO(mdan): These are memory copies. We probably don't need 4x slack space. # The sizes below are consistent with what I've seen elsewhere. @@ -117,7 +146,12 @@ def pandas_input_fn(x, features = features[1:] features = dict(zip(list(x.columns), features)) if y is not None: - target = features.pop(target_column) + if isinstance(target_column, list): + keys = [k for k, _ in y_columns] + values = [features.pop(column) for column in target_column] + target = {k: v for k, v in zip(keys, values)} + else: + target = features.pop(target_column) return features, target return features return input_fn diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py index e5912a3b28e78c6fc9d8b259a81b2575e6868c6f..6f13bc95d2d315ad1aabfd89d5d479d65fe08502 100644 --- a/tensorflow/python/estimator/inputs/pandas_io_test.py +++ b/tensorflow/python/estimator/inputs/pandas_io_test.py @@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase): y = pd.Series(np.arange(-32, -28), index=index) return x, y + def makeTestDataFrameWithYAsDataFrame(self): + index = np.arange(100, 104) + a = np.arange(4) + b = np.arange(32, 36) + a_label = np.arange(10, 14) + b_label = np.arange(50, 54) + x = pd.DataFrame({'a': a, 'b': b}, index=index) + y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index) + return x, y + def callInputFnOnce(self, input_fn, session): results = input_fn() coord = coordinator.Coordinator() @@ -65,13 +75,27 @@ class PandasIoTest(test.TestCase): pandas_io.pandas_input_fn( x, y_noindex, batch_size=2, shuffle=False, num_epochs=1) + def testPandasInputFn_RaisesWhenTargetColumnIsAList(self): + if not HAS_PANDAS: + return + + x, y = self.makeTestDataFrame() + + with self.assertRaisesRegexp(TypeError, + 'target_column must be a string type'): + pandas_io.pandas_input_fn(x, y, batch_size=2, + shuffle=False, + num_epochs=1, + target_column=['one', 'two']) + def testPandasInputFn_NonBoolShuffle(self): if not HAS_PANDAS: return x, _ = self.makeTestDataFrame() y_noindex = pd.Series(np.arange(-32, -28)) - with self.assertRaisesRegexp(TypeError, - 'shuffle must be explicitly set as boolean'): + with self.assertRaisesRegexp(ValueError, + 'shuffle must be provided and explicitly ' + 'set as boolean'): # Default shuffle is None pandas_io.pandas_input_fn(x, y_noindex) @@ -89,6 +113,53 @@ class PandasIoTest(test.TestCase): self.assertAllEqual(features['b'], [32, 33]) self.assertAllEqual(target, [-32, -31]) + def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrameWithYAsDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, targets = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(targets['a_target'], [10, 11]) + self.assertAllEqual(targets['b_target'], [50, 51]) + + def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrameWithYAsDataFrame() + y = y.rename(columns={'a_target': 'a', 'b_target': 'b'}) + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, targets = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(targets['a'], [10, 11]) + self.assertAllEqual(targets['b'], [50, 51]) + + def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrameWithYAsDataFrame() + y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'}) + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, targets = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(targets['a'], [10, 11]) + self.assertAllEqual(targets['a_n'], [50, 51]) + def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self): if not HAS_PANDAS: return diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py index 8e2ec83020abc5193309303d0cdd56bd07ef3b5e..51a61adb216c9b019aa01bb7e55c71a8464c01b3 100644 --- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py +++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py @@ -250,7 +250,7 @@ class _PandasFeedFn(object): num_epochs=None): if len(placeholders) != len(dataframe.columns) + 1: raise ValueError("Expected {} placeholders; got {}.".format( - len(dataframe.columns), len(placeholders))) + len(dataframe.columns) + 1, len(placeholders))) self._index_placeholder = placeholders[0] self._col_placeholders = placeholders[1:] self._dataframe = dataframe diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 5c79c964c8171ceecd00e9d03245773837394ae0..cb37f99704a8d01af6149bd3c8030b653981d0e2 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function import os - +import re from tensorflow.python.client import session from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import export as export_lib @@ -30,21 +30,24 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_util -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import models -from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras.engine.base_layer import Layer -from tensorflow.python.keras._impl.keras.engine.network import Network -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import models +from tensorflow.python.keras import optimizers +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.engine.network import Network +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_module from tensorflow.python.ops import variables as variables_module from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures + _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -68,7 +71,7 @@ def _convert_tensor(x): return x -def _any_variable_initalized(): +def _any_variable_initialized(): """Check if any variable has been initialized in the Keras model. Returns: @@ -121,8 +124,8 @@ def _create_ordered_io(keras_model, estimator_io, is_input=True): 'It needs to match one ' 'of the following: %s' % ('input' if is_input else 'output', key, ', '.join(keras_io_names))) - tensors = [_convert_tensor(estimator_io[io_name]) - for io_name in keras_io_names] + tensors = [_convert_tensor(estimator_io[io_name]) + for io_name in keras_io_names] return tensors else: # Plain array. @@ -136,8 +139,9 @@ def _in_place_subclassed_model_reset(model): To "instantiate" an identical model in a new TF graph, we reuse the original model object, but we clear its state. - After calling this function on a model intance, you can use the model instance - as if it were a model clone (in particular you can use it in a new graph). + After calling this function on a model instance, you can use the model + instance as if it were a model clone (in particular you can use it in a new + graph). This method clears the state of the input model. It is thus destructive. However the original state can be restored fully by calling @@ -220,7 +224,6 @@ def _in_place_subclassed_model_reset(model): for name in attributes_to_cache: attributes_cache[name] = getattr(model, name) model._original_attributes_cache = attributes_cache - # Reset built state model.built = False model.inputs = None @@ -240,8 +243,17 @@ def _in_place_subclassed_model_state_restoration(model): # Restore layers and build attributes if (hasattr(model, '_original_attributes_cache') and model._original_attributes_cache is not None): - model._layers = [] + # Models have sticky attribute assignment, so we want to be careful to add + # back the previous attributes and track Layers by their original names + # without adding dependencies on "utility" attributes which Models exempt + # when they're constructed. + model._layers = data_structures.NoDependency([]) for name, value in model._original_attributes_cache.items(): + if not isinstance(value, checkpointable.CheckpointableBase): + # If this value is not already checkpointable, it's probably that way + # for a reason; we don't want to start tracking data structures that the + # original Model didn't. + value = data_structures.NoDependency(value) setattr(model, name, value) model._original_attributes_cache = None else: @@ -340,8 +352,19 @@ def _create_keras_model_fn(keras_model, custom_objects=None): """model_fn for keras Estimator.""" model = _clone_and_build_model(mode, keras_model, custom_objects, features, labels) + model_output_names = [] + # We need to make sure that the output names of the last layer in the model + # is the same for each of the cloned models. This is required for mirrored + # strategy when we call regroup. + if distribute_lib.has_distribution_strategy(): + for name in model.output_names: + name = re.compile(r'_\d$').sub('', name) + model_output_names.append(name) + else: + model_output_names = model.output_names + # Get inputs to EstimatorSpec - predictions = dict(zip(model.output_names, model.outputs)) + predictions = dict(zip(model_output_names, model.outputs)) loss = None train_op = None @@ -433,7 +456,6 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt')) -@tf_export('keras.estimator.model_to_estimator') def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, @@ -442,13 +464,17 @@ def model_to_estimator(keras_model=None, """Constructs an `Estimator` instance from given keras model. For usage example, please see - @{$programmers_guide/estimators$creating_estimators_from_keras_models}. + @{$guide/estimators$creating_estimators_from_keras_models}. Args: - keras_model: Keras model in memory. - keras_model_path: Directory to a keras model on disk. + keras_model: A compiled Keras model object. This argument is mutually + exclusive with `keras_model_path`. + keras_model_path: Path to a compiled Keras model saved on disk, in HDF5 + format, which can be generated with the `save()` method of a Keras model. + This argument is mutually exclusive with `keras_model`. custom_objects: Dictionary for custom objects. - model_dir: Directory to save Estimator model parameters, graph and etc. + model_dir: Directory to save Estimator model parameters, graph, summary + files for TensorBoard, etc. config: Configuration object. Returns: @@ -460,7 +486,7 @@ def model_to_estimator(keras_model=None, ValueError: if the keras_model_path is a GCS URI. ValueError: if keras_model has not been compiled. """ - if (not keras_model) and (not keras_model_path): + if not (keras_model or keras_model_path): raise ValueError( 'Either `keras_model` or `keras_model_path` needs to be provided.') if keras_model and keras_model_path: @@ -482,8 +508,9 @@ def model_to_estimator(keras_model=None, if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer: raise ValueError( - 'The given keras model has not been compiled yet. Please compile first ' - 'before calling `model_to_estimator`.') + 'The given keras model has not been compiled yet. ' + 'Please compile the model with `model.compile()` ' + 'before calling `model_to_estimator()`.') if isinstance(config, dict): config = run_config_lib.RunConfig(**config) @@ -493,7 +520,7 @@ def model_to_estimator(keras_model=None, keras_model_fn, model_dir=model_dir, config=config) # Check if we need to call get_weights: - if _any_variable_initalized(): + if _any_variable_initialized(): keras_weights = keras_model.get_weights() # Warn if config passed to estimator tries to update GPUOptions. If a # session has already been created, the GPUOptions passed to the first diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index a89f7f7db3b0ba9da59bffabcbe32e24448e3d7a..5e094ae92bcf88a48d7afe3fb88bbced4971b587 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -25,16 +25,16 @@ import tempfile import numpy as np from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.keras._impl.keras.applications import mobilenet -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.applications import mobilenet +from tensorflow.python.keras.optimizers import SGD +from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -146,13 +146,13 @@ def randomize_io_type(array, name): def multi_inputs_multi_outputs_model(): a = keras.layers.Input(shape=(16,), name='input_a') b = keras.layers.Input(shape=(16,), name='input_b') - m = keras.layers.Input(shape=(8,), dtype='bool', name='input_m') + m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') dense = keras.layers.Dense(8, name='dense_1') a_2 = dense(a) - # Apply a mask - s_2 = keras.layers.Lambda(lambda k: - K.switch(k[0], k[1], K.zeros_like(k[1])))([m, a_2]) + # Read m + m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m) + s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2]) b_2 = dense(b) merged = keras.layers.concatenate([s_2, b_2], name='merge') c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) @@ -372,13 +372,13 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): def train_input_fn(): input_dict = {'input_a': a_train, 'input_b': b_train, - 'input_m': input_m_train > 0} + 'input_m': input_m_train.astype(np.str)} output_dict = {'dense_2': c_train, 'dense_3': d_train} return input_dict, output_dict def eval_input_fn(): input_dict = {'input_a': a_test, 'input_b': b_test, - 'input_m': input_m_test > 0} + 'input_m': input_m_test.astype(np.str)} output_dict = {'dense_2': c_test, 'dense_3': d_test} return input_dict, output_dict diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 3edf9fe940b19c7a0b1a7c21a9674189faba5acb..a9fd8f8e1a4259fece1a5996343970900c853ce0 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -23,7 +23,7 @@ import collections import six -from tensorflow.python.estimator.export.export_output import ExportOutput +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -32,10 +32,10 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.ModeKeys') +@estimator_export('estimator.ModeKeys') class ModeKeys(object): """Standard names for model modes. @@ -62,7 +62,7 @@ EXPORT_TAG_MAP = { } -@tf_export('estimator.EstimatorSpec') +@estimator_export('estimator.EstimatorSpec') class EstimatorSpec( collections.namedtuple('EstimatorSpec', [ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', @@ -99,7 +99,7 @@ class EstimatorSpec( ignored in eval and infer modes. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): predictions = ... loss = ... train_op = ... @@ -114,7 +114,7 @@ class EstimatorSpec( given mode. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... @@ -158,6 +158,8 @@ class EstimatorSpec( Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. + If no entry is provided, a default `PredictOutput` mapping to + `predictions` will be created. training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training. training_hooks: Iterable of `tf.train.SessionRunHook` objects to run @@ -232,29 +234,9 @@ class EstimatorSpec( _check_is_tensor_or_operation(metric_update, 'eval_metric_ops[{}]'.format(key)) - # Validate export_outputs. - if export_outputs is not None: - if not isinstance(export_outputs, dict): - raise TypeError('export_outputs must be dict, given: {}'.format( - export_outputs)) - for v in six.itervalues(export_outputs): - if not isinstance(v, ExportOutput): - raise TypeError( - 'Values in export_outputs must be ExportOutput objects. ' - 'Given: {}'.format(export_outputs)) - # Note export_outputs is allowed to be empty. - if len(export_outputs) == 1: - (key, value), = export_outputs.items() - if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - export_outputs[ - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value - if len(export_outputs) > 1: - if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - not in export_outputs): - raise ValueError( - 'Multiple export_outputs were provided, but none of them is ' - 'specified as the default. Do this by naming one of them with ' - 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') + # Validate the passed export outputs, or generate defaults. + if mode == ModeKeys.PREDICT: + export_outputs = _get_export_outputs(export_outputs, predictions) # Validate that all tensors and ops are from the default graph. default_graph = ops.get_default_graph() @@ -286,11 +268,11 @@ class EstimatorSpec( raise ValueError(error_message_template.format('train_op', train_op.name)) for key, value in list(six.iteritems(eval_metric_ops)): values = nest.flatten(value) - for value in values: - if value.graph is not default_graph: + for val in values: + if val.graph is not default_graph: raise ValueError(error_message_template.format( 'eval_metric_ops', - '{0}: {1}'.format(key, value.name))) + '{0}: {1}'.format(key, val.name))) # Validate hooks. training_chief_hooks = tuple(training_chief_hooks or []) @@ -334,6 +316,70 @@ class EstimatorSpec( return EstimatorSpec(*new_fields) +def _get_export_outputs(export_outputs, predictions): + """Validate export_outputs or create default export_outputs. + + Args: + export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict or None. + predictions: Predictions `Tensor` or dict of `Tensor`. + + Returns: + Valid export_outputs dict + + Raises: + TypeError: if export_outputs is not a dict or its values are not + ExportOutput instances. + """ + if export_outputs is None: + default_output = export_output_lib.PredictOutput(predictions) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} + + if not isinstance(export_outputs, dict): + raise TypeError('export_outputs must be dict, given: {}'.format( + export_outputs)) + for v in six.itervalues(export_outputs): + if not isinstance(v, export_output_lib.ExportOutput): + raise TypeError( + 'Values in export_outputs must be ExportOutput objects. ' + 'Given: {}'.format(export_outputs)) + + _maybe_add_default_serving_output(export_outputs) + + return export_outputs + + +def _maybe_add_default_serving_output(export_outputs): + """Add a default serving output to the export_outputs if not present. + + Args: + export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict. + + Returns: + export_outputs dict with default serving signature added if necessary + + Raises: + ValueError: if multiple export_outputs were provided without a default + serving key. + """ + if len(export_outputs) == 1: + (key, value), = export_outputs.items() + if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value + if len(export_outputs) > 1: + if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + not in export_outputs): + raise ValueError( + 'Multiple export_outputs were provided, but none of them is ' + 'specified as the default. Do this by naming one of them with ' + 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') + + return export_outputs + + class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ 'mode', 'predictions', diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index b7eeeb437cb4a624cdee552be3032364b18a8290..08e41fd4146e9254fc8cc7da6bc809e80d053a5b 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -592,6 +592,27 @@ class EstimatorSpecInferTest(test.TestCase): predictions=predictions, export_outputs=export_outputs) + def testDefaultExportOutputCreated(self): + """Ensure that a default PredictOutput is created for export.""" + with ops.Graph().as_default(), self.test_session(): + predictions = constant_op.constant(1.) + self._assertDefaultExportOutputForPredictions(predictions) + + def testDefaultExportOutputCreatedDict(self): + """Ensure that a default PredictOutput is created for export for dicts.""" + with ops.Graph().as_default(), self.test_session(): + predictions = {'loss': constant_op.constant(1.), + 'score': constant_op.constant(10.)} + self._assertDefaultExportOutputForPredictions(predictions) + + def _assertDefaultExportOutputForPredictions(self, predictions): + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.PREDICT, predictions=predictions) + + expected = export_output.PredictOutput(predictions).outputs + serving_output = spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + self.assertEqual(serving_output.outputs, expected) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index c7707be8397d950f4e5993b678c215128d3d8b9f..3d60c63b68968c98a00364948bd3de0581daadd4 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -25,11 +25,12 @@ import os import six from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat_internal from tensorflow.python.util import function_utils -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _USE_DEFAULT = object() @@ -296,7 +297,7 @@ class TaskType(object): EVALUATOR = 'evaluator' -@tf_export('estimator.RunConfig') +@estimator_export('estimator.RunConfig') class RunConfig(object): """This class specifies the configurations for an `Estimator` run.""" @@ -484,6 +485,43 @@ class RunConfig(object): self._init_distributed_setting_from_environment_var(tf_config) + # Get session_config only for distributed mode (cluster_spec is present). + if not self._session_config and self._cluster_spec: + RunConfig._replace( + self, + allowed_properties_list=_DEFAULT_REPLACEABLE_LIST, + session_config=self._get_default_session_config()) + + def _get_default_session_config(self): + """Returns None or tf.ConfigProto instance with default device_filters set. + + Device filters are set such that chief/master and worker communicates with + only ps. session_config=None for evaluators or any other TaskType. + """ + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + + device_filters = None + if self._task_type == TaskType.MASTER: + device_filters = ['/job:ps', '/job:master'] + elif self._task_type == TaskType.CHIEF: + device_filters = ['/job:ps', '/job:chief'] + elif self._task_type == TaskType.WORKER: + device_filters = ['/job:ps', '/job:worker/task:%d' % self._task_id] + elif self._task_type == TaskType.PS: + device_filters = ['/job:ps', '/job:worker', '/job:master'] + else: + # If the task_type is `EVALUATOR` or something other than the ones in + # TaskType then don't set any device filters. + return None + + return config_pb2.ConfigProto( + allow_soft_placement=True, + graph_options=graph_opts, + device_filters=device_filters) + def _init_distributed_setting_from_environment_var(self, tf_config): """Initialize distributed properties based on `tf_config`.""" diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py index c8b12605e1aaad11e114e4ace63697b93f3b2b92..06df7cb9dd4ae3d167d622601e551079b64e80a2 100644 --- a/tensorflow/python/estimator/run_config_test.py +++ b/tensorflow/python/estimator/run_config_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import json from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import test @@ -290,6 +291,7 @@ class RunConfigDistributedSettingTest(test.TestCase): expected_num_worker_replicas=1, expected_num_ps_replicas=0) self.assertEqual(0, run_config.global_id_in_cluster) + self.assertIsNone(run_config.session_config, None) def test_session_master_for_local(self): tf_config = {'session_master': '_my_master'} @@ -1119,5 +1121,115 @@ class RunConfigModelDirTest(test.TestCase): _create_run_config_with_cluster_spec(tf_config) +class RunConfigSessionConfigTest(test.TestCase): + + def _assert_equal_session_config(self, session_config, + expected_device_filters): + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + expected_session_config = config_pb2.ConfigProto( + allow_soft_placement=True, + graph_options=graph_opts, + device_filters=expected_device_filters) + self.assertEqual(session_config, expected_session_config) + + def test_master_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.MASTER, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:master']) + + def test_chief_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:chief']) + + def test_worker_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.WORKER, + 'index': 1 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:worker/task:1']) + + def test_ps_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.PS, + 'index': 1 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:worker', '/job:master']) + + def test_evaluator_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.EVALUATOR, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self.assertIsNone(run_config.session_config) + + def test_other_type_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + 'other_type': ['host3:1', 'host4:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': 'other_type', + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self.assertIsNone(run_config.session_config) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 994115c9eaa4b69a1f4336a6a52922dacee64361..57301010920be90c63e00594d686df3a09466c91 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -35,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook from tensorflow.python.util import compat -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _MAX_DELAY_SECS = 60 _DELAY_SECS_PER_WORKER = 5 @@ -115,7 +115,7 @@ def _is_google_env(): return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE -@tf_export('estimator.TrainSpec') +@estimator_export('estimator.TrainSpec') class TrainSpec( collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])): """Configuration for the "train" part for the `train_and_evaluate` call. @@ -129,7 +129,7 @@ class TrainSpec( Args: input_fn: A function that provides input data for training as minibatches. - See @{$get_started/premade_estimators#create_input_functions} for more + See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a @@ -167,7 +167,7 @@ class TrainSpec( cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks) -@tf_export('estimator.EvalSpec') +@estimator_export('estimator.EvalSpec') class EvalSpec( collections.namedtuple('EvalSpec', [ 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs', @@ -193,7 +193,7 @@ class EvalSpec( Args: input_fn: A function that constructs the input data for evaluation. - See @{$get_started/premade_estimators#create_input_functions} for more + See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a @@ -263,7 +263,7 @@ class EvalSpec( throttle_secs=throttle_secs) -@tf_export('estimator.train_and_evaluate') +@estimator_export('estimator.train_and_evaluate') def train_and_evaluate(estimator, train_spec, eval_spec): """Train and evaluate the `estimator`. @@ -278,10 +278,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): supported distributed training configuration is between-graph replication. Overfitting: In order to avoid overfitting, it is recommended to set up the - training `input_fn` to shuffle the training data properly. It is also - recommended to train the model a little longer, say multiple epochs, before - performing evaluation, as the input pipeline starts from scratch for each - training. It is particularly important for local training and evaluation. + training `input_fn` to shuffle the training data properly. Stop condition: In order to support both distributed and non-distributed configuration reliably, the only supported stop condition for model @@ -295,6 +292,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): model will be trained with three epochs of training data instead of one epoch. Example of local (non-distributed) training: + ```python # Set up feature columns. categorial_feature_a = categorial_column_with_hash_bucket(...) @@ -339,12 +337,14 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Setting environment variable depends on the platform. For example, on Linux, it can be done as follows (`$` is the shell prompt): + ``` $ TF_CONFIG='' python train_model.py ``` For the content in `TF_CONFIG`, assume that the training cluster spec looks like: + ``` cluster = {"chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], @@ -352,6 +352,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): ``` Example of `TF_CONFIG` for chief training worker (must have one and only one): + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -371,6 +372,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Example of `TF_CONFIG` for non-chief training worker (optional, could be multiple): + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -387,6 +389,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): for non-chief training workers. Example of `TF_CONFIG` for parameter server, aka ps (could be multiple): + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -405,6 +408,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is not part of the training cluster. There could be only one. It is used for model evaluation. + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -424,6 +428,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec): eval_spec: A `EvalSpec` instance to specify the evaluation and export specification. + Returns: + A tuple of the result of the `evaluate` call to the `Estimator` and the + export results using the specified `ExportStrategy`. + Currently, the return value is undefined for distributed training mode. + Raises: ValueError: if environment variable `TF_CONFIG` is incorrectly set. """ @@ -439,7 +448,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): 'For distributed training, there can only be one `evaluator` task ' '(with task id 0). Given task id {}'.format(config.task_id)) - executor.run() + return executor.run() class _StopAtSecsHook(session_run_hook.SessionRunHook): @@ -458,6 +467,61 @@ class _StopAtSecsHook(session_run_hook.SessionRunHook): run_context.request_stop() +class _NewCheckpointListenerForEvaluate( + basic_session_run_hooks.CheckpointSaverListener): + """A saver listener to run evaluate with every checkpoint.""" + + def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener): + self._evaluator = evaluator + self._eval_throttle_secs = eval_throttle_secs + self._continuous_eval_listener = continuous_eval_listener + self.eval_result, self.export_results = None, None + + def begin(self): + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=self._eval_throttle_secs) + self._is_first_run = True + + def after_save(self, session, global_step_value): + del session # unused; required by signature. + # skip first run model is not trained yet. + if self._is_first_run: + self._is_first_run = False + return + + if not self._continuous_eval_listener.before_eval(): + logging.info('Exiting training and evaluation loop, as requested by ' + '_ContinuousEvalListener.before_eval.') + return True + if self._timer.should_trigger_for_step(global_step_value): + self._evaluate(global_step_value) # updates self.eval_result + if not self._continuous_eval_listener.after_eval(self.eval_result): + logging.info('Exiting evaluation, as requested by ' + '_ContinuousEvalListener.after_eval.') + return True + else: + # TODO(ispir): add remaining time in the log. + logging.info('Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) + + def end(self, session, global_step_value): + # Evaluate if the last step has not been evaluated, yet. + if global_step_value != self._timer.last_triggered_step(): + if self._continuous_eval_listener.before_eval(): + self._evaluate(global_step_value) + self._continuous_eval_listener.after_eval(self.eval_result) + + def _evaluate(self, global_step_value): + self._timer.update_last_triggered_step(global_step_value) + self.eval_result, self.export_results = ( + self._evaluator.evaluate_and_export()) + if self.eval_result.status != _EvalStatus.EVALUATED: + # This is unexpected; should never happen. + # Training should always end with a new checkpoint. + raise RuntimeError('There was no new checkpoint after the training. ' + 'Eval status: {}'.format(self.eval_result.status)) + + class _TrainingExecutor(object): """The executor to run `Estimator` training and evaluation. @@ -510,6 +574,11 @@ class _TrainingExecutor(object): procedure is `run_foo'. This `run` method invoke the procedure base on the `RunConfig.task_type`. + Returns: + A tuple of the result of the `evaluate` call to the `Estimator` and the + export results using the specified `ExportStrategy`. + Currently undefined for distributed training mode. + Raises: ValueError: if the estimator.config is mis-configured. """ @@ -518,8 +587,7 @@ class _TrainingExecutor(object): if (not config.cluster_spec and config.task_type != run_config_lib.TaskType.EVALUATOR): logging.info('Running training and evaluation locally (non-distributed).') - self.run_local() - return + return self.run_local() # Distributed case. if not config.task_type: @@ -560,28 +628,6 @@ class _TrainingExecutor(object): def run_master(self): """Runs task master.""" - - class NewCheckpointListener( - basic_session_run_hooks.CheckpointSaverListener): - - def __init__(self, evaluator, eval_throttle_secs): - self._evaluator = evaluator - self._eval_throttle_secs = eval_throttle_secs - - def begin(self): - self._timer = basic_session_run_hooks.SecondOrStepTimer( - every_secs=self._eval_throttle_secs) - - def after_save(self, session, global_step_value): - del session # unused; required by signature. - - if self._timer.should_trigger_for_step(global_step_value): - self._timer.update_last_triggered_step(global_step_value) - self._evaluator.evaluate_and_export() - else: - logging.info('Skip the current checkpoint eval due to throttle secs ' - '({} secs).'.format(self._eval_throttle_secs)) - _assert_eval_spec(self._eval_spec) # Final export signal: For any eval result with global_step >= train @@ -601,16 +647,12 @@ class _TrainingExecutor(object): # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. saving_listeners = [ - NewCheckpointListener(evaluator, self._eval_spec.throttle_secs) + _NewCheckpointListenerForEvaluate(evaluator, + self._eval_spec.throttle_secs, + _ContinuousEvalListener()) ] self._start_distributed_training(saving_listeners=saving_listeners) - if not evaluator.is_final_export_triggered: - logging.info('Training has already ended. But the last eval is skipped ' - 'due to eval throttle_secs. Now evaluating the final ' - 'checkpoint.') - evaluator.evaluate_and_export() - def run_evaluator(self): """Runs task evaluator.""" # TODO(xiejw): To allow execution framework to add continuous eval listener. @@ -624,64 +666,33 @@ class _TrainingExecutor(object): def run_local(self): """Runs training and evaluation locally (non-distributed).""" - - def _should_stop_local_train(global_step): - if self._train_spec.max_steps is None: - return False - if global_step >= self._train_spec.max_steps: - return True - return False - _assert_eval_spec(self._eval_spec) - if self._eval_spec.throttle_secs <= 0: - raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' - 'It is used do determine how long each training ' - 'iteration should go when train and evaluate ' - 'locally.'.format(self._eval_spec.throttle_secs)) - - stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) - train_hooks = ( - list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks)) + train_hooks = list(self._train_spec.hooks) + list(self._train_hooks) logging.info('Start train and evaluate loop. The evaluate will happen ' - 'after {} secs (eval_spec.throttle_secs) or training is ' - 'finished.'.format(self._eval_spec.throttle_secs)) + 'after every checkpoint. Checkpoint frequency is determined ' + 'based on RunConfig arguments: save_checkpoints_steps {} or ' + 'save_checkpoints_secs {}.'.format( + self._estimator.config.save_checkpoints_steps, + self._estimator.config.save_checkpoints_secs)) evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, self._train_spec.max_steps) - while True: - self._estimator.train( - input_fn=self._train_spec.input_fn, - max_steps=self._train_spec.max_steps, - hooks=train_hooks) - - if not self._continuous_eval_listener.before_eval(): - logging.info('Exiting training and evaluation loop, as requested by ' - '_ContinuousEvalListener.before_eval.') - break - - # Final export signal: For any eval result with global_step >= train - # max_steps, the evaluator will send the final export signal. The - # _should_stop_local_train will then end the while True as the stopping - # condition is satisfied (both checks use the same global_step value, - # i.e., no race condition) - eval_result = evaluator.evaluate_and_export() - - if eval_result.status != _EvalStatus.EVALUATED: - # This is unexpected; should never happen. - # Training should always end with a new checkpoint. - raise RuntimeError('There was no new checkpoint after the training. ' - 'Eval status: {}'.format(eval_result.status)) - - if not self._continuous_eval_listener.after_eval(eval_result): - logging.info('Exiting evaluation, as requested by ' - '_ContinuousEvalListener.after_eval.') - break + listener_for_eval = _NewCheckpointListenerForEvaluate( + evaluator, self._eval_spec.throttle_secs, + self._continuous_eval_listener) + saving_listeners = [listener_for_eval] + + self._estimator.train( + input_fn=self._train_spec.input_fn, + max_steps=self._train_spec.max_steps, + hooks=train_hooks, + saving_listeners=saving_listeners) - if _should_stop_local_train( - eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]): - break + eval_result = listener_for_eval.eval_result or _EvalResult( + status=_EvalStatus.MISSING_CHECKPOINT) + return eval_result.metrics, listener_for_eval.export_results def _start_std_server(self, config): """Creates, starts, and returns a server_lib.Server.""" @@ -807,7 +818,7 @@ class _TrainingExecutor(object): # iteration of while loop will end the continuous eval as the stopping # condition is satisfied (both checks use the same global_step value, # i.e., no race condition) - eval_result = evaluator.evaluate_and_export() + eval_result, _ = evaluator.evaluate_and_export() if not self._continuous_eval_listener.after_eval(eval_result): logging.info('Exiting evaluation, as requested by ' @@ -846,7 +857,7 @@ class _TrainingExecutor(object): """Evaluate and (maybe) export the current model. Returns: - An `EvalResult` instance. + A tuple of `EvalResult` instance and the export results. Raises: RuntimeError: for any unexpected internal error. @@ -856,14 +867,14 @@ class _TrainingExecutor(object): if not latest_ckpt_path: self._log_err_msg('Estimator is not trained yet. Will start an ' 'evaluation when a checkpoint is ready.') - return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT) + return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), [] if latest_ckpt_path == self._previous_ckpt_path: self._log_err_msg( 'No new checkpoint ready for evaluation. Skip the current ' 'evaluation pass as evaluation results are expected to be same ' 'for the same checkpoint.') - return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT) + return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), [] metrics = self._estimator.evaluate( input_fn=self._eval_spec.input_fn, @@ -881,7 +892,8 @@ class _TrainingExecutor(object): is_the_final_export = ( eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >= self._max_training_steps if self._max_training_steps else False) - self._export_eval_result(eval_result, is_the_final_export) + export_results = self._export_eval_result(eval_result, + is_the_final_export) if is_the_final_export: logging.debug('Calling exporter with the `is_the_final_export=True`.') @@ -889,7 +901,7 @@ class _TrainingExecutor(object): self._last_warning_time = 0 self._previous_ckpt_path = latest_ckpt_path - return eval_result + return eval_result, export_results def _log_err_msg(self, message): """Prints warning `message` every 10 mins.""" @@ -904,15 +916,18 @@ class _TrainingExecutor(object): compat.as_str_any(self._estimator.model_dir), compat.as_str_any('export')) + export_results = [] for exporter in self._eval_spec.exporters: - exporter.export( - estimator=self._estimator, - export_path=os.path.join( - compat.as_str_any(export_dir_base), - compat.as_str_any(exporter.name)), - checkpoint_path=eval_result.checkpoint_path, - eval_result=eval_result.metrics, - is_the_final_export=is_the_final_export) + export_results.append( + exporter.export( + estimator=self._estimator, + export_path=os.path.join( + compat.as_str_any(export_dir_base), + compat.as_str_any(exporter.name)), + checkpoint_path=eval_result.checkpoint_path, + eval_result=eval_result.metrics, + is_the_final_export=is_the_final_export)) + return export_results class _EvalStatus(object): diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 3b6f5e18cb50d84002dae842d15284cc57c0f972..6bee7cbe83a5e9b623ea16ebe48cce93e27534e2 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -29,17 +29,21 @@ import time import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -49,6 +53,7 @@ from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util from tensorflow.python.util import compat _DEFAULT_EVAL_STEPS = 100 @@ -885,7 +890,8 @@ class TrainingExecutorRunMasterTest(test.TestCase): # `after_save`. del args, kwargs saving_listeners[0].begin() - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=0) + saving_listeners[0].after_save(session=None, global_step_value=10) mock_est = test.mock.Mock( spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train) @@ -930,7 +936,10 @@ class TrainingExecutorRunMasterTest(test.TestCase): del args, kwargs saving_listeners[0].begin() - # Call three times. + # Call four times. + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=None) + mock_timer.should_trigger_for_step.return_value = True saving_listeners[0].after_save(session=None, global_step_value=None) @@ -979,14 +988,19 @@ class TrainingExecutorRunMasterTest(test.TestCase): del args, kwargs saving_listeners[0].begin() - # Call two times. + # Call tree times (one for first saving). mock_timer.should_trigger_for_step.return_value = True - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=0) + + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=125) - # The final ckpt is skipped by the timer. It will be picked up the final - # export check in the code. mock_timer.should_trigger_for_step.return_value = False - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=250) + + # At the end evaluate should be called even if throttle secs prevents it. + mock_timer.should_trigger_for_step.return_value = False + saving_listeners[0].end(session=None, global_step_value=300) mock_est.train = estimator_train mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] @@ -1566,28 +1580,31 @@ class StopAtSecsHookTest(test.TestCase): class TrainingExecutorRunLocalTest(test.TestCase): """Tests run_local of _TrainingExecutor.""" + def _model_fn(self, features, labels, mode): + del labels + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training_util.get_global_step(), 1) + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(0.), + train_op=train_op, + predictions=constant_op.constant([[10.]]), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + def _input_fn(self, repeat=True): + ds = dataset_ops.Dataset.from_tensors([1]) + if repeat: + return ds.repeat() + return ds + def unique_checkpoint_every_time_fn(self): return 'checkpoint_path_%s/' % random.random() - def test_send_stop_at_secs_to_train(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) - eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - executor.run_local() - - stop_hook = mock_est.train.call_args[1]['hooks'][-1] - self.assertIsInstance(stop_hook, training._StopAtSecsHook) - self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs) - - def test_runs_in_a_loop_until_max_steps(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + def test_runs_evaluate_with_every_new_checkpoint(self): + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.times_export_was_called = 0 mock_est.times_final_export_was_called = 0 @@ -1604,42 +1621,30 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.name = 'see_how_many_times_export_is_called' exporter.export = export - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=22) eval_spec = training.EvalSpec( - input_fn=lambda: 1, - hooks=[_FakeHook()], - throttle_secs=100, + input_fn=lambda: self._input_fn(repeat=False), + throttle_secs=0, exporters=exporter) - # should be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() - self.assertEqual(3, mock_est.train.call_count) + self.assertEqual(1, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_was_called) self.assertEqual(1, mock_est.times_final_export_was_called) def test_runs_with_eval_listener_before_eval(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn - train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) - # should be called 2 times without the evallistener - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12) + eval_spec = training.EvalSpec(input_fn=lambda: self._input_fn(repeat=False)) + mock_est.evaluate.side_effect = [{_GLOBAL_STEP_KEY: train_spec.max_steps}] class _Listener(training._ContinuousEvalListener): @@ -1658,67 +1663,61 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(1, mock_est.train.call_count) self.assertEqual(0, mock_est.evaluate.call_count) - self.assertEqual(1, listener.call_count) def test_runs_with_eval_listener_after_eval(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) - train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) - # should be called 2 times without the evallistener - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=3000) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) class _Listener(training._ContinuousEvalListener): - def __init__(self, test_case): + def __init__(self): self.call_count = 0 - self._test_case = test_case def after_eval(self, eval_result): self.call_count += 1 - self._test_case.assertEqual( - train_spec.max_steps - 50, eval_result.metrics[_GLOBAL_STEP_KEY]) return False # Will stop the run_local after first eval. - listener = _Listener(test_case=self) + listener = _Listener() executor = training._TrainingExecutor( mock_est, train_spec, eval_spec, continuous_eval_listener=listener) - executor.run_local() + metrics, _ = executor.run_local() # pylint: disable=assignment-from-no-return self.assertEqual(1, mock_est.train.call_count) self.assertEqual(1, mock_est.evaluate.call_count) self.assertEqual(1, listener.call_count) + # Should be less than max_steps since listener did early stopping. + self.assertLess(metrics[_GLOBAL_STEP_KEY], train_spec.max_steps) def test_handles_no_new_checkpoint_found(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = ( - 'no_new_checkpoints_after_the_first_train_step') + est = estimator_lib.Estimator( + model_fn=self._model_fn, + # disable saving checkpoint + config=run_config_lib.RunConfig( + save_checkpoints_steps=None, save_checkpoints_secs=None)) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - # It was going to be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + input_fn=lambda: self._input_fn(repeat=False), + hooks=[_FakeHook()], + throttle_secs=100) - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): + executor = training._TrainingExecutor(est, train_spec, eval_spec) + with self.assertRaisesRegexp(ValueError, + 'There should be a CheckpointSaverHook'): executor.run_local() def test_final_export_is_true_in_the_end(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.times_export_fn_was_called = 0 mock_est.times_the_final_export_was_true = 0 @@ -1734,37 +1733,29 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.export = export train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, - hooks=[_FakeHook()], - throttle_secs=100, + input_fn=lambda: self._input_fn(repeat=False), + throttle_secs=0, exporters=exporter) - # should be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() - self.assertEqual(3, mock_est.train.call_count) - self.assertEqual(3, mock_est.evaluate.call_count) - self.assertEqual(3, mock_est.times_export_fn_was_called) + self.assertEqual(1, mock_est.train.call_count) + self.assertEqual(2, mock_est.evaluate.call_count) + self.assertEqual(2, mock_est.times_export_fn_was_called) self.assertEqual(1, mock_est.times_the_final_export_was_true) def test_train_and_evaluate_args(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = 'checkpoint_path/' + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + input_fn=lambda: self._input_fn(repeat=False), + steps=2, + hooks=[_FakeHook()], + name='local_eval') executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() @@ -1773,11 +1764,11 @@ class TrainingExecutorRunLocalTest(test.TestCase): name=eval_spec.name, input_fn=eval_spec.input_fn, steps=eval_spec.steps, - checkpoint_path='checkpoint_path/', + checkpoint_path=est.latest_checkpoint(), hooks=eval_spec.hooks) train_args = mock_est.train.call_args[1] - self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) + self.assertEqual(list(train_spec.hooks), list(train_args['hooks'])) self.assertEqual(train_spec.input_fn, train_args['input_fn']) self.assertEqual(train_spec.max_steps, train_args['max_steps']) @@ -1812,50 +1803,44 @@ class TrainingExecutorRunLocalTest(test.TestCase): if not isinstance(h, training._StopAtSecsHook) ]) - def test_errors_out_if_throttle_secs_is_zero(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0) - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(ValueError, 'throttle_secs'): - executor.run_local() - def test_that_export_is_called_with_run_local(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_train_spec.max_steps = 200 - mock_est.evaluate.return_value = { - _GLOBAL_STEP_KEY: mock_train_spec.max_steps - } - # _validate_hooks would have made sure that train_spec.hooks is [], when - # None were passed. - mock_train_spec.hooks = [] + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12) + mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} def export(estimator, *args, **kwargs): del args, kwargs estimator.export_was_called = True + return 'path_to_export' exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_whether_export_is_called' exporter.export = export eval_spec = training.EvalSpec( - input_fn=lambda: 1, + input_fn=lambda: self._input_fn(repeat=False), steps=2, start_delay_secs=0, throttle_secs=213, exporters=exporter) - executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) - executor.run_local() + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) + # pylint: disable=assignment-from-no-return + _, export_results = executor.run_local() + # pylint: enable=assignment-from-no-return self.assertTrue(mock_est.export_was_called) + self.assertEqual(export_results, ['path_to_export']) def test_errors_out_if_evaluate_returns_empty_dict(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = {} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) @@ -1863,19 +1848,26 @@ class TrainingExecutorRunLocalTest(test.TestCase): executor.run_local() def test_errors_out_if_evaluate_returns_non_dict(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = 123 - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): executor.run_local() def test_errors_out_if_evaluate_returns_dict_without_global_step(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = {'loss': 123} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) @@ -1883,6 +1875,23 @@ class TrainingExecutorRunLocalTest(test.TestCase): _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR): executor.run_local() + def test_train_and_evaluate_return_metrics(self): + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec( + input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()]) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), + steps=2, + hooks=[_FakeHook()], + name='local_eval') + + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) + # pylint: disable=assignment-from-no-return + metrics, _ = executor.run_local() + # pylint: enable=assignment-from-no-return + self.assertEqual(metrics['global_step'], 12) + class TrainAndEvaluateRunTest(test.TestCase): @@ -2078,7 +2087,7 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): # max_steps should be larger than save_summary_steps max_steps = 10 - save_summary_steps = 2 + save_summary_steps = 9 data = np.linspace( 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) @@ -2086,24 +2095,20 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) # learn y = x - train_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - y=y_data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - y=y_data, - batch_size=batch_size, - num_epochs=1, - shuffle=False) - - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - batch_size=batch_size, - shuffle=False) + def train_input_fn(): + return dataset_ops.Dataset.from_tensor_slices(({ + 'x': x_data + }, y_data)).batch(batch_size).repeat().shuffle(1000) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensor_slices(({ + 'x': x_data + }, y_data)).batch(batch_size) + + def predict_input_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + 'x': x_data + }).batch(batch_size) feature_columns = [ feature_column.numeric_column('x', shape=(input_dimension,))] @@ -2119,9 +2124,11 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): max_steps=max_steps) eval_spec = training.EvalSpec( - name=eval_name, input_fn=eval_input_fn, steps=None, + name=eval_name, + input_fn=eval_input_fn, + steps=None, exporters=self._get_exporter(exporter_name, feature_columns), - throttle_secs=2) + throttle_secs=0) training.train_and_evaluate(est, train_spec, eval_spec) @@ -2130,15 +2137,12 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): # Examine the training events. Use a range to check global step to avoid # flakyness due to global step race condition. - training_loss, training_global_step = self._extract_loss_and_global_step( - est.model_dir) + training_loss, _ = self._extract_loss_and_global_step(est.model_dir) self.assertIsNotNone(training_loss) - self.assertTrue( - max_steps - save_summary_steps < training_global_step <= max_steps) # Examine the eval events. The global step should be accurate. eval_loss, eval_global_step = self._extract_loss_and_global_step( - event_folder=os.path.join(est.model_dir, 'eval_' + eval_name)) + event_folder=est.eval_dir(eval_name)) self.assertIsNotNone(eval_loss) self.assertEqual(max_steps, eval_global_step) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index e4e1d37f74330c9bfd48adff95e6409793714729..d4a75478d53f5b3dc8e66df98a78b51a6d25aab8 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -22,8 +22,10 @@ from __future__ import print_function import os import time +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import training from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -72,3 +74,80 @@ def get_timestamped_dir(dir_base): result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) raise RuntimeError('Failed to obtain a unique export directory name after ' '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) + + +def parse_input_fn_result(result): + """Gets features, labels, and hooks from the result of an Estimator input_fn. + + Args: + result: output of an input_fn to an estimator, which should be one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + Returns: + Tuple of features, labels, and input_hooks, where features are as described + above, labels are as described above or None, and input_hooks are a list + of SessionRunHooks to be included when running. + + Raises: + ValueError: if the result is a list or tuple of length != 2. + """ + input_hooks = [] + try: + # We can't just check whether this is a tf.data.Dataset instance here, + # as this is plausibly a PerDeviceDataset. Try treating as a dataset first. + iterator = result.make_initializable_iterator() + except AttributeError: + # Not a dataset or dataset-like-object. Move along. + pass + else: + input_hooks.append(_DatasetInitializerHook(iterator)) + result = iterator.get_next() + + if isinstance(result, (list, tuple)): + if len(result) != 2: + raise ValueError( + 'input_fn should return (features, labels) as a len 2 tuple.') + return result[0], result[1], input_hooks + return result, None, input_hooks + + +class _DatasetInitializerHook(training.SessionRunHook): + """Creates a SessionRunHook that initializes the passed iterator.""" + + def __init__(self, iterator): + self._iterator = iterator + + def begin(self): + self._initializer = self._iterator.initializer + + def after_create_session(self, session, coord): + del coord + session.run(self._initializer) + + +class StrategyInitFinalizeHook(training.SessionRunHook): + """Creates a SessionRunHook that initializes and shutsdown devices.""" + + def __init__(self, initialization_fn, finalize_fn): + self._initialization_fn = initialization_fn + self._finalize_fn = finalize_fn + + def begin(self): + self._init_ops = self._initialization_fn() + self._finalize_ops = self._finalize_fn() + + def after_create_session(self, session, coord): + logging.info('Initialize system') + session.run(self._init_ops, + options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + + def end(self, session): + logging.info('Finalize system.') + session.run(self._finalize_ops) diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e0610779023199aa659b119d366ba69ac2b15f --- /dev/null +++ b/tensorflow/python/estimator/util_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for util.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class UtilTest(test.TestCase): + """Tests for miscellaneous Estimator utils.""" + + def test_parse_input_fn_result_tuple(self): + def _input_fn(): + features = constant_op.constant(np.arange(100)) + labels = constant_op.constant(np.arange(100, 200)) + return features, labels + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with self.test_session() as sess: + vals = sess.run([features, labels]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertAllEqual(vals[1], np.arange(100, 200)) + self.assertEqual(hooks, []) + + def test_parse_input_fn_result_dataset(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + labels = np.expand_dims(np.arange(100, 200), 0) + return dataset_ops.Dataset.from_tensor_slices((features, labels)) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with training.MonitoredSession(hooks=hooks) as sess: + vals = sess.run([features, labels]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertAllEqual(vals[1], np.arange(100, 200)) + self.assertIsInstance(hooks[0], util._DatasetInitializerHook) + + def test_parse_input_fn_result_features_only(self): + def _input_fn(): + return constant_op.constant(np.arange(100)) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with self.test_session() as sess: + vals = sess.run([features]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertEqual(labels, None) + self.assertEqual(hooks, []) + + def test_parse_input_fn_result_features_only_dataset(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + return dataset_ops.Dataset.from_tensor_slices(features) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with training.MonitoredSession(hooks=hooks) as sess: + vals = sess.run([features]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertEqual(labels, None) + self.assertIsInstance(hooks[0], util._DatasetInitializerHook) + + def test_parse_input_fn_result_invalid(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + labels = np.expand_dims(np.arange(100, 200), 0) + return dataset_ops.Dataset.from_tensor_slices((features, labels, labels)) + + with self.assertRaisesRegexp(ValueError, 'input_fn should return'): + util.parse_input_fn_result(_input_fn()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 3e154d73b0b6ff80215658661bed04c119e397f2..d091d2fe0ac688773b27d80f37fbf3083b8ffa1f 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -140,7 +140,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras.engine import training from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -172,7 +172,7 @@ def _internal_input_layer(features, scope=None): """See input_layer. `scope` is a name or variable scope to use.""" - feature_columns = _clean_feature_columns(feature_columns) + feature_columns = _normalize_feature_columns(feature_columns) for column in feature_columns: if not isinstance(column, _DenseColumn): raise ValueError( @@ -350,10 +350,23 @@ def linear_model(features, prediction itself for linear regression problems. Note on supported columns: `linear_model` treats categorical columns as - `indicator_column`s while `input_layer` explicitly requires wrapping each - of them with an `embedding_column` or an `indicator_column`. + `indicator_column`s. To be specific, assume the input as `SparseTensor` looks + like: - Example: + ```python + shape = [2, 2] + { + [0, 0]: "a" + [1, 0]: "b" + [1, 1]: "c" + } + ``` + `linear_model` assigns weights for the presence of "a", "b", "c' implicitly, + just like `indicator_column`, while `input_layer` explicitly requires wrapping + each of categorical columns with an `embedding_column` or an + `indicator_column`. + + Example of usage: ```python price = numeric_column('price') @@ -374,13 +387,44 @@ def linear_model(features, to your model. All items should be instances of classes derived from `_FeatureColumn`s. units: An integer, dimensionality of the output space. Default value is 1. - sparse_combiner: A string specifying how to reduce if a sparse column is - multivalent. Currently "mean", "sqrtn" and "sum" are supported, with "sum" - the default. "sqrtn" often achieves good accuracy, in particular with - bag-of-words columns. It combines each sparse columns independently. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. Except `numeric_column`, almost all columns passed to + `linear_model` are considered as categorical columns. It combines each + categorical column independently. Currently "mean", "sqrtn" and "sum" are + supported, with "sum" the default for linear model. "sqrtn" often achieves + good accuracy, in particular with bag-of-words columns. * "sum": do not normalize features in the column * "mean": do l1 normalization on features in the column * "sqrtn": do l2 normalization on features in the column + For example, for two features represented as the categorical columns: + + ```python + # Feature 1 + + shape = [2, 2] + { + [0, 0]: "a" + [0, 1]: "b" + [1, 0]: "c" + } + + # Feature 2 + + shape = [2, 3] + { + [0, 0]: "d" + [1, 0]: "e" + [1, 1]: "f" + [1, 2]: "g" + } + ``` + with `sparse_combiner` as "mean", the linear model outputs conceptly are: + ``` + y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0 + y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1 + ``` + where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight + assigned to the presence of `x` in the input features. weight_collections: A list of collection names to which the Variable will be added. Note that, variables will also be added to collections `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. @@ -408,13 +452,15 @@ def linear_model(features, ValueError: if an item in `feature_columns` is neither a `_DenseColumn` nor `_CategoricalColumn`. """ + with variable_scope.variable_scope(None, 'linear_model') as vs: + model_name = _strip_leading_slashes(vs.name) linear_model_layer = _LinearModel( feature_columns=feature_columns, units=units, sparse_combiner=sparse_combiner, weight_collections=weight_collections, trainable=trainable, - name='linear_model') + name=model_name) retval = linear_model_layer(features) # pylint: disable=not-callable if cols_to_vars is not None: cols_to_vars.update(linear_model_layer.cols_to_vars()) @@ -422,13 +468,25 @@ def linear_model(features, def _add_to_collections(var, weight_collections): - # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` - # so that we don't have to do this check. - if isinstance(var, variables.PartitionedVariable): - for constituent_var in list(var): - ops.add_to_collections(weight_collections, constituent_var) - else: - ops.add_to_collections(weight_collections, var) + """Adds a var to the list of weight_collections provided. + + Handles the case for partitioned and non-partitioned variables. + + Args: + var: A variable or Partitioned Variable. + weight_collections: List of collections to add variable to. + """ + for weight_collection in weight_collections: + # The layer self.add_variable call already adds it to GLOBAL_VARIABLES. + if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES: + continue + # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` + # so that we don't have to do this check. + if isinstance(var, variables.PartitionedVariable): + for constituent_var in list(var): + ops.add_to_collection(weight_collection, constituent_var) + else: + ops.add_to_collection(weight_collection, var) class _FCLinearWrapper(base.Layer): @@ -536,8 +594,11 @@ class _LinearModel(training.Model): name=None, **kwargs): super(_LinearModel, self).__init__(name=name, **kwargs) - self._feature_columns = _clean_feature_columns(feature_columns) + self._feature_columns = _normalize_feature_columns( + feature_columns) self._weight_collections = list(weight_collections or []) + if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections: + self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections: self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) @@ -643,7 +704,7 @@ def _transform_features(features, feature_columns): Returns: A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values. """ - feature_columns = _clean_feature_columns(feature_columns) + feature_columns = _normalize_feature_columns(feature_columns) outputs = {} with ops.name_scope( None, default_name='transform_features', values=features.values()): @@ -911,7 +972,8 @@ def shared_embedding_columns( tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which to restore the column weights. Required if `ckpt_to_load_from` is not `None`. - max_norm: If not `None`, embedding values are l2-normalized to this value. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value, before combining. trainable: Whether or not the embedding is trainable. Default is True. Returns: @@ -925,7 +987,12 @@ def shared_embedding_columns( ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified. ValueError: if `initializer` is specified and is not callable. + RuntimeError: if eager execution is enabled. """ + if context.executing_eagerly(): + raise RuntimeError('shared_embedding_columns are not supported when eager ' + 'execution is enabled.') + if (dimension is None) or (dimension < 1): raise ValueError('Invalid dimension {}.'.format(dimension)) if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): @@ -970,16 +1037,6 @@ def shared_embedding_columns( shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) shared_embedding_collection_name += '_shared_embedding' - # Create the state (_SharedEmbeddingColumnLayer) here. - embedding_shape = num_buckets, dimension - - shared_embedding_column_layer = _EmbeddingColumnLayer( - embedding_shape=embedding_shape, - initializer=initializer, - weight_collections=[], - trainable=trainable, - name=shared_embedding_collection_name) - result = [] for column in categorical_columns: result.append( @@ -988,16 +1045,12 @@ def shared_embedding_columns( initializer=initializer, dimension=dimension, combiner=combiner, - var_scope_name=shared_embedding_collection_name, + shared_embedding_collection_name=shared_embedding_collection_name, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, trainable=trainable)) - for single_result in result: - single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access - single_result._set_all_columns(result) # pylint: disable=protected-access - return result @@ -1182,12 +1235,13 @@ def categorical_column_with_hash_bucket(key, Use this when your sparse features are in string or integer format, and you want to distribute your inputs into a finite number of buckets by hashing. - output_id = Hash(input_feature_string) % bucket_size + output_id = Hash(input_feature_string) % bucket_size for string type input. + For int type input, the value is converted to its string representation first + and then hashed by the same formula. For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. Example: @@ -1249,8 +1303,7 @@ def categorical_column_with_vocabulary_file(key, For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. Example with `num_oov_buckets`: File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state @@ -1366,8 +1419,7 @@ def categorical_column_with_vocabulary_list( For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. Example with `num_oov_buckets`: In the following example, each input in `vocabulary_list` is assigned an ID @@ -1480,8 +1532,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. In the following examples, each input in the range `[0, 1000000)` is assigned the same value. All other inputs are assigned `default_value` 0. Note that a @@ -1538,8 +1589,14 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): def indicator_column(categorical_column): """Represents multi-hot representation of given categorical column. - Used to wrap any `categorical_column_*` (e.g., to feed to DNN). Use - `embedding_column` if the inputs are sparse. + - For DNN model, `indicator_column` can be used to wrap any + `categorical_column_*` (e.g., to feed to DNN). Consider to Use + `embedding_column` if the number of buckets/unique(values) are large. + + - For Wide (aka linear) model, `indicator_column` is the internal + representation for categorical column when passing categorical column + directly (as any element in feature_columns) to `linear_model`. See + `linear_model` for details. ```python name = indicator_column(categorical_column_with_vocabulary_list( @@ -1782,9 +1839,7 @@ class _EmbeddingColumnLayer(base.Layer): Args: embedding_shape: Shape of the embedding variable used for lookup. initializer: A variable initializer function to be used in embedding - variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean `0.0` and standard deviation - `1/sqrt(dimension)`. + variable initialization. weight_collections: A list of collection names to which the Variable will be added. Note that, variables will also be added to collections `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. @@ -1799,6 +1854,15 @@ class _EmbeddingColumnLayer(base.Layer): self._initializer = initializer self._weight_collections = weight_collections + def set_weight_collections(self, weight_collections): + """Sets the weight collections for the layer. + + Args: + weight_collections: A list of collection names to which the Variable will + be added. + """ + self._weight_collections = weight_collections + def build(self, _): self._embedding_weight_var = self.add_variable( name='embedding_weights', @@ -1806,11 +1870,8 @@ class _EmbeddingColumnLayer(base.Layer): dtype=dtypes.float32, initializer=self._initializer, trainable=self.trainable) - # self.add_variable already appends to GLOBAL_VARIABLES collection. if self._weight_collections and not context.executing_eagerly(): - for weight_collection in self._weight_collections: - if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES: - _add_to_collections(self._embedding_weight_var, [weight_collection]) + _add_to_collections(self._embedding_weight_var, self._weight_collections) self.built = True def call(self, _): @@ -1949,7 +2010,7 @@ def _create_weighted_sum(column, weight_collections, trainable, weight_var=None): - """Creates a weighted sum for a dense or sparse column for linear_model.""" + """Creates a weighted sum for a dense/categorical column for linear_model.""" if isinstance(column, _CategoricalColumn): return _create_categorical_column_weighted_sum( column=column, @@ -2048,7 +2109,34 @@ def _create_categorical_column_weighted_sum(column, weight_collections, trainable, weight_var=None): - """Create a weighted sum of a categorical column for linear_model.""" + # pylint: disable=g-doc-return-or-yield,g-doc-args + """Create a weighted sum of a categorical column for linear_model. + + Note to maintainer: As implementation details, the weighted sum is + implemented via embedding_lookup_sparse toward efficiency. Mathematically, + they are the same. + + To be specific, conceptually, categorical column can be treated as multi-hot + vector. Say: + + ```python + x = [0 0 1] # categorical column input + w = [a b c] # weights + ``` + The weighted sum is `c` in this case, which is same as `w[2]`. + + Another example is + + ```python + x = [0 1 1] # categorical column input + w = [a b c] # weights + ``` + The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. + + For both cases, we can implement weighted sum via embedding_lookup with + sparse_combiner = "sum". + """ + sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access builder, weight_collections=weight_collections, @@ -2070,7 +2158,7 @@ def _create_categorical_column_weighted_sum(column, initializer=init_ops.zeros_initializer(), trainable=trainable, collections=weight_collections) - return _safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( weight, id_tensor, sparse_weights=weight_tensor, @@ -2163,7 +2251,7 @@ class _LazyBuilder(object): self._feature_tensors[key] = feature_tensor return feature_tensor - if isinstance(key, str): + if isinstance(key, six.string_types): raise ValueError('Feature {} is not in features dictionary.'.format(key)) if not isinstance(key, _FeatureColumn): @@ -2242,7 +2330,7 @@ def _shape_offsets(shape): # TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py -def _to_sparse_input(input_tensor, ignore_value=None): +def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. If `input_tensor` is already a `SparseTensor`, just return it. @@ -2286,8 +2374,22 @@ def _to_sparse_input(input_tensor, ignore_value=None): input_tensor, out_type=dtypes.int64, name='dense_shape')) -def _clean_feature_columns(feature_columns): - """Verifies and normalizes `feature_columns` input.""" +def _normalize_feature_columns(feature_columns): + """Normalizes the `feature_columns` input. + + This method converts the `feature_columns` to list type as best as it can. In + addition, verifies the type and other parts of feature_columns, required by + downstream library. + + Args: + feature_columns: The raw feature columns, usually passed by users. + + Returns: + The normalized feature column list. + + Raises: + ValueError: for any invalid inputs, such as empty, duplicated names, etc. + """ if isinstance(feature_columns, _FeatureColumn): feature_columns = [feature_columns] @@ -2413,6 +2515,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn, def _get_sparse_tensors(self, inputs, weight_collections=None, trainable=None): + """Converts dense inputs to SparseTensor so downstream code can use it.""" input_tensor = inputs.get(self) batch_size = array_ops.shape(input_tensor)[0] # By construction, source_column is always one-dimensional. @@ -2491,7 +2594,7 @@ class _EmbeddingColumn( }) # Return embedding lookup result. - return _safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( embedding_weights=embedding_weights, sparse_ids=sparse_ids, sparse_weights=sparse_weights, @@ -2546,12 +2649,12 @@ def _get_graph_for_variable(var): class _SharedEmbeddingColumn( - _DenseColumn, + _DenseColumn, _SequenceDenseColumn, collections.namedtuple( '_SharedEmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', - 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt', - 'max_norm', 'trainable'))): + 'shared_embedding_collection_name', 'ckpt_to_load_from', + 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): """See `embedding_column`.""" @property @@ -2562,7 +2665,7 @@ class _SharedEmbeddingColumn( @property def _var_scope_name(self): - return self.var_scope_name + return self.shared_embedding_collection_name @property def _parse_example_spec(self): @@ -2571,29 +2674,17 @@ class _SharedEmbeddingColumn( def _transform_feature(self, inputs): return inputs.get(self.categorical_column) - def _set_layer(self, layer): - self._layer = layer - - def _set_all_columns(self, all_columns): - self._all_columns = all_columns - - def _reset_config(self): - config = self._layer.get_config() - config['embedding_shape'] = ( - self.categorical_column._num_buckets, # pylint: disable=protected-access - self.dimension) - config['initializer'] = self.initializer - self._layer = self._layer.__class__.from_config(config) - for column in self._all_columns: - column._set_layer(self._layer) # pylint: disable=protected-access - @property def _variable_shape(self): if not hasattr(self, '_shape'): self._shape = tensor_shape.vector(self.dimension) return self._shape - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _get_dense_tensor_internal(self, + inputs, + weight_collections=None, + trainable=None): + """Private method that follows the signature of _get_dense_tensor.""" # This method is called from a variable_scope with name _var_scope_name, # which is shared among all shared embeddings. Open a name_scope here, so # that the ops for different columns have distinct names. @@ -2604,17 +2695,38 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_weights = self._layer( - None, scope=variable_scope.get_variable_scope()) - # If we're in graph mode and this is called with a different graph, - # then we should reset. - if not context.executing_eagerly() and ( - ops.get_default_graph() != - _get_graph_for_variable(embedding_weights)): - self._reset_config() - embedding_weights = self._layer( - None, scope=variable_scope.get_variable_scope()) - + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + shared_embedding_collection = ops.get_collection( + self.shared_embedding_collection_name) + if shared_embedding_collection: + if len(shared_embedding_collection) > 1: + raise ValueError( + 'Collection {} can only contain one variable. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(shared_embedding_collection)) + embedding_weights = shared_embedding_collection[0] + if embedding_weights.get_shape() != embedding_shape: + raise ValueError( + 'Shared embedding collection {} contains variable {} of ' + 'unexpected shape {}. Expected shape is {}. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(self.shared_embedding_collection_name, + embedding_weights.name, + embedding_weights.get_shape(), embedding_shape)) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + ops.add_to_collection(self.shared_embedding_collection_name, + embedding_weights) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): @@ -2624,7 +2736,7 @@ class _SharedEmbeddingColumn( }) # Return embedding lookup result. - return _safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( embedding_weights=embedding_weights, sparse_ids=sparse_ids, sparse_weights=sparse_weights, @@ -2632,6 +2744,44 @@ class _SharedEmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + return self._get_dense_tensor_internal( + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + + def _get_sequence_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + if not isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + def _create_tuple(shape, value): """Returns a tuple with given shape and filled with value.""" @@ -2753,7 +2903,7 @@ class _HashedCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') @@ -2804,7 +2954,7 @@ class _VocabularyFileCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -2856,7 +3006,7 @@ class _VocabularyListCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -2908,7 +3058,7 @@ class _IdentityCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if not input_tensor.dtype.is_integer: raise ValueError( @@ -2990,7 +3140,8 @@ class _WeightedCategoricalColumn( self.dtype, weight_tensor.dtype)) if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): # The weight tensor can be a regular Tensor. In this case, sparsify it. - weight_tensor = _to_sparse_input(weight_tensor, ignore_value=0.0) + weight_tensor = _to_sparse_input_and_drop_ignore_values( + weight_tensor, ignore_value=0.0) if not weight_tensor.dtype.is_floating: weight_tensor = math_ops.to_float(weight_tensor) return (inputs.get(self.categorical_column), weight_tensor) @@ -3077,161 +3228,6 @@ def _collect_leaf_level_keys(cross): return leaf_level_keys -# TODO(zakaria): Move this to embedding_ops and make it public. -def _safe_embedding_lookup_sparse(embedding_weights, - sparse_ids, - sparse_weights=None, - combiner='mean', - default_id=None, - name=None, - partition_strategy='div', - max_norm=None): - """Lookup embedding results, accounting for invalid IDs and empty features. - - The partitioned embedding in `embedding_weights` must all be the same shape - except for the first dimension. The first dimension is allowed to vary as the - vocabulary size is not necessarily a multiple of `P`. `embedding_weights` - may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a - partitioner. - - Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs - with non-positive weight. For an entry with no features, the embedding vector - for `default_id` is returned, or the 0-vector if `default_id` is not supplied. - - The ids and weights may be multi-dimensional. Embeddings are always aggregated - along the last dimension. - - Args: - embedding_weights: A list of `P` float `Tensor`s or values representing - partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` - created by partitioning along dimension 0. The total unpartitioned - shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the - vocab size and `e_1, ..., e_m` are the embedding dimensions. - sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the - ids. `d_0` is typically batch size. - sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing - float weights corresponding to `sparse_ids`, or `None` if all weights - are be assumed to be 1.0. - combiner: A string specifying how to combine embedding results for each - entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" - the default. - default_id: The id to use for an entry with no features. - name: A name for this operation (optional). - partition_strategy: A string specifying the partitioning strategy. - Currently `"div"` and `"mod"` are supported. Default is `"div"`. - max_norm: If not `None`, all embeddings are l2-normalized to max_norm before - combining. - - - Returns: - Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. - - Raises: - ValueError: if `embedding_weights` is empty. - """ - if embedding_weights is None: - raise ValueError('Missing embedding_weights %s.' % embedding_weights) - if isinstance(embedding_weights, variables.PartitionedVariable): - embedding_weights = list(embedding_weights) # get underlying Variables. - if not isinstance(embedding_weights, list): - embedding_weights = [embedding_weights] - if len(embedding_weights) < 1: - raise ValueError('Missing embedding_weights %s.' % embedding_weights) - - dtype = sparse_weights.dtype if sparse_weights is not None else None - embedding_weights = [ - ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights - ] - - with ops.name_scope(name, 'embedding_lookup', - embedding_weights + [sparse_ids, - sparse_weights]) as scope: - # Reshape higher-rank sparse ids and weights to linear segment ids. - original_shape = sparse_ids.dense_shape - original_rank_dim = sparse_ids.dense_shape.get_shape()[0] - original_rank = ( - array_ops.size(original_shape) - if original_rank_dim.value is None - else original_rank_dim.value) - sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ - math_ops.reduce_prod( - array_ops.slice(original_shape, [0], [original_rank - 1])), - array_ops.gather(original_shape, original_rank - 1)]) - if sparse_weights is not None: - sparse_weights = sparse_tensor_lib.SparseTensor( - sparse_ids.indices, - sparse_weights.values, sparse_ids.dense_shape) - - # Prune invalid ids and weights. - sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) - if combiner != 'sum': - sparse_ids, sparse_weights = _prune_invalid_weights( - sparse_ids, sparse_weights) - - # Fill in dummy values for empty features, if necessary. - sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, - default_id or - 0) - if sparse_weights is not None: - sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) - - result = embedding_ops.embedding_lookup_sparse( - embedding_weights, - sparse_ids, - sparse_weights, - combiner=combiner, - partition_strategy=partition_strategy, - name=None if default_id is None else scope, - max_norm=max_norm) - - if default_id is None: - # Broadcast is_row_empty to the same shape as embedding_lookup_result, - # for use in Select. - is_row_empty = array_ops.tile( - array_ops.reshape(is_row_empty, [-1, 1]), - array_ops.stack([1, array_ops.shape(result)[1]])) - - result = array_ops.where(is_row_empty, - array_ops.zeros_like(result), - result, - name=scope) - - # Reshape back from linear ids back into higher-dimensional dense result. - final_result = array_ops.reshape( - result, - array_ops.concat([ - array_ops.slice( - math_ops.cast(original_shape, dtypes.int32), [0], - [original_rank - 1]), - array_ops.slice(array_ops.shape(result), [1], [-1]) - ], 0)) - final_result.set_shape(tensor_shape.unknown_shape( - (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) - return final_result - - -def _prune_invalid_ids(sparse_ids, sparse_weights): - """Prune invalid IDs (< 0) from the input ids and weights.""" - is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) - if sparse_weights is not None: - is_id_valid = math_ops.logical_and( - is_id_valid, - array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) - sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) - if sparse_weights is not None: - sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) - return sparse_ids, sparse_weights - - -def _prune_invalid_weights(sparse_ids, sparse_weights): - """Prune invalid weights (< 0) from the input ids and weights.""" - if sparse_weights is not None: - is_weights_valid = math_ops.greater(sparse_weights.values, 0) - sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) - sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) - return sparse_ids, sparse_weights - - class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, collections.namedtuple('_IndicatorColumn', ['categorical_column'])): @@ -3268,10 +3264,14 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(self._variable_shape[-1])) - # Remove (?, -1) index + # Remove (?, -1) index. weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], weighted_column.dense_shape) - return sparse_ops.sparse_tensor_to_dense(weighted_column) + # Use scatter_nd to merge duplicated indices if existed, + # instead of sparse_tensor_to_dense. + return array_ops.scatter_nd(weighted_column.indices, + weighted_column.values, + weighted_column.dense_shape) dense_id_tensor = sparse_ops.sparse_tensor_to_dense( id_tensor, default_value=-1) diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index f9206f4f38d7631ccdb57d41d8c52f9f0edede6d..5bb47bfa47cf8fe0311d63f325198bcb7ecd5f9c 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -137,6 +137,9 @@ class LazyColumnTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'bbb is not in features dictionary'): builder.get('bbb') + with self.assertRaisesRegexp(ValueError, + 'bbb is not in features dictionary'): + builder.get(u'bbb') def test_not_supported_feature_column(self): @@ -1254,14 +1257,14 @@ class CrossedColumnTest(test.TestCase): }, (crossed,)) -def get_linear_model_bias(): - with variable_scope.variable_scope('linear_model', reuse=True): +def get_linear_model_bias(name='linear_model'): + with variable_scope.variable_scope(name, reuse=True): return variable_scope.get_variable('bias_weights') -def get_linear_model_column_var(column): +def get_linear_model_column_var(column, name='linear_model'): return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, - 'linear_model/' + column.name)[0] + name + '/' + column.name)[0] def get_keras_linear_model_predictions(features, @@ -1925,6 +1928,27 @@ class LinearModelTest(test.TestCase): with self.assertRaisesOpError('Feature .* cannot have rank 0'): sess.run(net, feed_dict={features['price']: np.array(1)}) + def test_multiple_linear_models(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features1 = {'price': [[1.], [5.]]} + features2 = {'price': [[2.], [10.]]} + predictions1 = fc.linear_model(features1, [price]) + predictions2 = fc.linear_model(features2, [price]) + bias1 = get_linear_model_bias(name='linear_model') + bias2 = get_linear_model_bias(name='linear_model_1') + price_var1 = get_linear_model_column_var(price, name='linear_model') + price_var2 = get_linear_model_column_var(price, name='linear_model_1') + with _initialized_session() as sess: + self.assertAllClose([0.], bias1.eval()) + sess.run(price_var1.assign([[10.]])) + sess.run(bias1.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions1.eval()) + self.assertAllClose([0.], bias2.eval()) + sess.run(price_var2.assign([[10.]])) + sess.run(bias2.assign([5.])) + self.assertAllClose([[25.], [105.]], predictions2.eval()) + class _LinearModelTest(test.TestCase): @@ -2583,7 +2607,7 @@ class _LinearModelTest(test.TestCase): class InputLayerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_retrieving_input(self): features = {'a': [0.]} input_layer = InputLayer(fc.numeric_column('a')) @@ -4556,12 +4580,12 @@ class IndicatorColumnTest(test.TestCase): weights = fc.weighted_categorical_column(ids, 'weights') indicator = fc.indicator_column(weights) features = { - 'ids': constant_op.constant([['c', 'b', 'a']]), - 'weights': constant_op.constant([[2., 4., 6.]]) + 'ids': constant_op.constant([['c', 'b', 'a', 'c']]), + 'weights': constant_op.constant([[2., 4., 6., 1.]]) } indicator_tensor = _transform_features(features, [indicator])[indicator] with _initialized_session(): - self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval()) + self.assertAllEqual([[6., 4., 3.]], indicator_tensor.eval()) def test_transform_with_missing_value_in_weighted_column(self): # Github issue 12583 @@ -5326,9 +5350,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertIsNone(embedding_column_a.ckpt_to_load_from) self.assertIsNone(embedding_column_b.ckpt_to_load_from) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_b.var_scope_name) + embedding_column_b.shared_embedding_collection_name) self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) self.assertIsNone(embedding_column_a.max_norm) @@ -5375,9 +5399,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('my_combiner', embedding_column_b.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('shared_embedding_collection_name', - embedding_column_b.var_scope_name) + embedding_column_b.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) @@ -5428,7 +5452,7 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column_a.dimension) self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) self.assertEqual(42., embedding_column_a.max_norm) @@ -5612,6 +5636,72 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + def test_get_dense_tensor_weight_collections(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + fc.input_layer( + input_features, [embedding_column_a, embedding_column_b], + weight_collections=('my_vars',)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in global_vars)) + my_vars = ops.get_collection('my_vars') + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in my_vars)) + def test_get_dense_tensor_placeholder_inputs(self): # Inputs. vocabulary_size = 3 diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py index aff289f7be08e2dccde02c67202b844b2ebf15ab..f68f30a0a966b9e6031f73bd59634b25fbcb6689 100644 --- a/tensorflow/python/framework/c_api_util.py +++ b/tensorflow/python/framework/c_api_util.py @@ -134,6 +134,9 @@ class ApiDefMap(object): return self._op_per_name[op_name] raise ValueError("No entry found for " + op_name + ".") + def op_names(self): + return self._op_per_name.keys() + @tf_contextlib.contextmanager def tf_buffer(data=None): diff --git a/tensorflow/python/framework/c_api_util_test.py b/tensorflow/python/framework/c_api_util_test.py index e0bc9ee531669e0824319b13ea340f3966b63838..169abf1bb46d2a6d37d79346a6bb5503b347373f 100644 --- a/tensorflow/python/framework/c_api_util_test.py +++ b/tensorflow/python/framework/c_api_util_test.py @@ -25,6 +25,10 @@ from tensorflow.python.platform import googletest class ApiDefMapTest(test_util.TensorFlowTestCase): + def testApiDefMapOpNames(self): + api_def_map = c_api_util.ApiDefMap() + self.assertIn("Add", api_def_map.op_names()) + def testApiDefMapGet(self): api_def_map = c_api_util.ApiDefMap() op_def = api_def_map.get_op_def("Add") diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 7f9ef53457ae060600067b946e686487f55adda1..c3f70df7d8056ee33d2e0d875b19e07f18220545 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -120,11 +120,7 @@ class DType(object): @property def is_numpy_compatible(self): - numpy_incompatible = [ - types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, - types_pb2.DT_RESOURCE_REF - ] - return self._type_enum not in numpy_incompatible + return self._type_enum not in _NUMPY_INCOMPATIBLE @property def as_numpy_dtype(self): @@ -162,7 +158,7 @@ class DType(object): @property def is_quantized(self): """Returns whether this is a quantized data type.""" - return self.base_dtype in [qint8, quint8, qint16, quint16, qint32] + return self.base_dtype in _QUANTIZED_DTYPES_NO_REF @property def is_unsigned(self): @@ -401,6 +397,11 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF) qint32_ref = DType(types_pb2.DT_QINT32_REF) bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) +_NUMPY_INCOMPATIBLE = frozenset([ + types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, + types_pb2.DT_RESOURCE_REF +]) + # Maintain an intern table so that we don't have to create a large # number of small objects. _INTERN_TABLE = { @@ -645,10 +646,10 @@ _TF_TO_NP = { _np_bfloat16, } -QUANTIZED_DTYPES = frozenset([ - qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref, - quint16_ref, qint32_ref -]) +_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32]) +_QUANTIZED_DTYPES_REF = frozenset( + [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) +QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF) tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES") _PYTHON_TO_TF = { @@ -662,10 +663,9 @@ def as_dtype(type_value): """Converts the given `type_value` to a `DType`. Args: - type_value: A value that can be converted to a `tf.DType` - object. This may currently be a `tf.DType` object, a - [`DataType` - enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), + type_value: A value that can be converted to a `tf.DType` object. This may + currently be a `tf.DType` object, a [`DataType` + enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), a string type name, or a `numpy.dtype`. Returns: diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccae761471e24ddb1d4d6acd89ebcc9650d1320 --- /dev/null +++ b/tensorflow/python/framework/error_interpolation.py @@ -0,0 +1,92 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Function for interpolating formatted errors from the TensorFlow runtime. + +Exposes the function `interpolate` to interpolate messages with tags of the form +^^type:name:format^^. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools +import re +import string + +import six + +_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" +_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+" +_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format( + name=_NAME_REGEX, fmt=_FORMAT_REGEX) +_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) +_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) + +_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"]) + + +def _parse_message(message): + """Parses the message. + + Splits the message into separators and tags. Tags are named tuples + representing the string ^^type:name:format^^ and they are separated by + separators. For example, in + "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and + three separators. The separators are the numeric characters. + + Args: + message: String to parse + + Returns: + (list of separator strings, list of _ParseTags). + + For example, if message is "123^^node:Foo:${file}^^456" then this function + returns (["123", "456"], [_ParseTag("node", "Foo", "${file}")]) + """ + seps = [] + tags = [] + pos = 0 + while pos < len(message): + match = re.match(_INTERPOLATION_PATTERN, message[pos:]) + if match: + seps.append(match.group(1)) + tags.append(_ParseTag(match.group(3), match.group(4), match.group(5))) + pos += match.end() + else: + break + seps.append(message[pos:]) + return seps, tags + + +# TODO(jtkeeling): Modify to actually interpolate format strings rather than +# echoing them. +def interpolate(error_message): + """Interpolates an error message. + + The error message can contain tags of the form ^^type:name:format^^ which will + be replaced. + + Args: + error_message: A string to interpolate. + + Returns: + The string with tags of the form ^^type:name:format^^ interpolated. + """ + seps, tags = _parse_message(error_message) + subs = [string.Template(tag.format).safe_substitute({}) for tag in tags] + return "".join( + itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue=""))) diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ad448deb622cb6a3d24e502d7238d3f614d5af4d --- /dev/null +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -0,0 +1,49 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.errors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import error_interpolation +from tensorflow.python.platform import test + + +class InterpolateTest(test.TestCase): + + def testNothingToDo(self): + normal_string = "This is just a normal string" + interpolated_string = error_interpolation.interpolate(normal_string) + self.assertEqual(interpolated_string, normal_string) + + def testOneTag(self): + one_tag_string = "^^node:Foo:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string) + self.assertEqual(interpolated_string, "${file}") + + def testTwoTagsNoSeps(self): + two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^" + interpolated_string = error_interpolation.interpolate(two_tags_no_seps) + self.assertEqual(interpolated_string, "${file}${line}") + + def testTwoTagsWithSeps(self): + two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789" + interpolated_string = error_interpolation.interpolate(two_tags_with_seps) + self.assertEqual(interpolated_string, "123${file}456${line}789") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 94c37d65c3f3d8e562f4ffb4c098b283eb640d78..6525607faea62a461ee38fa0393ac29b809bb9b6 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -23,6 +23,7 @@ from __future__ import print_function import collections import hashlib +import sys from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 @@ -33,12 +34,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access + class Defun(object): """Decorator used to define TensorFlow functions. @@ -68,9 +74,10 @@ class Defun(object): during the first call to the function. Subsequent function calls will refer to the same set of variables. - Definitions of functions are frozen in a graph as soon as the graph is used to - create a session. Therefore, nodes using the function must be created in the - graph before the corresponding session is created. + Definitions of functions in a graph are frozen as soon as the graph is used to + create a session. However, new functions and new calls to existing functions + may be added to the graph, with the new functions themselves becoming + immediately frozen. Example, but also see the [How To on functions](link_needed). @@ -258,12 +265,10 @@ class _DefinedFunction(object): # another reference to _definition.signature self._op_def = None - self._args = [] assert isinstance(input_types, (list, tuple)) - for i in range(len(input_types)): - argname = argnames[i] if i < len(argnames) else ("arg%d" % i) - argtype = input_types[i] - self._args.append((argname, argtype)) + self._arg_types = input_types + self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) + for i in range(len(input_types))] @property def name(self): @@ -336,42 +341,11 @@ class _DefinedFunction(object): if self._definition is not None or self._c_func is not None: return - # Create the func_def object. - temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) - with temp_graph.as_default(), ops.device(self._caller_device): - # List of placeholders for the function_def. - inputs = [] - for (argname, argtype) in self._args: - argholder = array_ops.placeholder(argtype, name=argname) - inputs.append(argholder) - # Call func and gather the output tensors. - with vs.variable_scope("", custom_getter=temp_graph.getvar): - outputs = self._func(*inputs) - - # There is no way of distinguishing between a function not returning - # anything and a function returning None in Python. - # We need to allow the former and ideally want to forbid the latter as - # it is most likely user error. - # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to - # allow users to explicitly mark the function as not returning anything. - # For now, we allow a single None return and interpret it as a function - # with no output. - if outputs is None: - outputs = [] - else: - # If func only returned one value, make it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - if any([_ is None for _ in outputs]): - raise ValueError("Function can not return None.") - # Ensures each output is a Tensor in the function graph. - outputs = [ops.convert_to_tensor(t) for t in outputs] - outputs = [ - temp_graph.capture(t) if t.graph is not temp_graph else t - for t in outputs - ] + temp_graph = func_graph_from_py_func( + self._func, self._arg_names, self._arg_types, self._func_name, + self._capture_by_value, self._caller_device) + self._extra_inputs = temp_graph.extra_inputs - inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access @@ -390,8 +364,8 @@ class _DefinedFunction(object): self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), - inputs, - outputs, + temp_graph.inputs, + temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: @@ -421,8 +395,8 @@ class _DefinedFunction(object): base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers - [t._as_tf_output() for t in inputs], - [t._as_tf_output() for t in outputs], + [t._as_tf_output() for t in temp_graph.inputs], + [t._as_tf_output() for t in temp_graph.outputs], output_names, None, # opts description) @@ -653,18 +627,70 @@ class _FuncGraph(ops.Graph): function argument and the caller passes in the captured tensor. """ - def __init__(self, capture_by_value, *args, **kwargs): + def __init__(self, name, capture_by_value, *args, **kwargs): super(_FuncGraph, self).__init__(*args, **kwargs) self._capture_by_value = capture_by_value self._building_function = True self._outer_graph = ops.get_default_graph() self._vscope = vs.get_variable_scope() self._old_custom_getter = self._vscope.custom_getter + + # The name of the function. + self.name = name + # Placeholder tensors representing the inputs to this function. The tensors + # are in this _FuncGraph. + self.inputs = [] + # Tensors that will be returned this function. The tensors are in this + # _FuncGraph. + self.outputs = [] + # Maps external tensor -> internal tensor (e.g. input placeholder). self._captured = {} + # The external tensors that have been captured as inputs and must be passed + # to this function (empty if capturing by value, otherwise these are the + # keys of _captured). self.extra_inputs = [] + # Input placeholders that been added for captured values (empty if capturing + # by value). self.extra_args = [] + # Captured variables. + # TODO(skyewm): is this needed? self.extra_vars = [] + # pylint: disable=g-doc-return-or-yield + + @tf_contextlib.contextmanager + def container(self, container_name): + """Returns a context manager that specifies the resource container to use. + + Overridden from @{tf.Graph} to update both the init_scope container + and the present inner container. This is necessary to make sure setting + containers applies correctly both to created variables and to stateful + ops. + + Args: + container_name: container name string. + + Returns: + A context manager for defining resource containers for stateful ops, + yields the container name. + """ + original_container = self._container + # pylint: disable=protected-access + with ops.init_scope(): + original_init_container = ops.get_default_graph()._container + try: + self._container = container_name + with ops.init_scope(): + ops.get_default_graph()._container = container_name + yield self._container + finally: + self._container = original_container + with ops.init_scope(): + ops.get_default_graph()._container = original_init_container + # pylint: enable=protected-access + + # pylint: enable=g-doc-return-or-yield + def getvar( self, getter, @@ -733,8 +759,14 @@ class _FuncGraph(ops.Graph): tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if ops._USE_C_SHAPES: - handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph, - tensor._as_tf_output()) + if isinstance(tensor, ops.EagerTensor): + handle_data = tensor._handle_data + if handle_data: + handle_data = handle_data.SerializeToString() + else: + handle_data = c_api.GetResourceHandleShapeAndType( + tensor.graph._c_graph, tensor._as_tf_output()) + if handle_data: c_api.SetResourceHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), @@ -742,6 +774,7 @@ class _FuncGraph(ops.Graph): else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access + self.inputs.append(ph) self._captured[tensor] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): @@ -780,6 +813,79 @@ class _FuncGraph(ops.Graph): return captured_op +def func_graph_from_py_func(func, arg_names, arg_types, name=None, + capture_by_value=False, device=None, + colocation_stack=None, container=None, + collections_ref=None): + """Returns a _FuncGraph generated from `func`. + + Args: + func: A Python callable which constructs a TF function body. The arguments + must correspond to `arg_types`. Returns a value or list/tuple of values. + No returned value can be None. + arg_names: A sequence of strings for the function argument names. + arg_types: A sequence of the function's argument types. + name: The function name. If None, the name is derived from `func`. + capture_by_value: boolean. If True, captured values will be copied into the + function body. + device: device name or function. + colocation_stack: A colocation stack (list) the _FuncGraph should use. + container: A container name the _FuncGraph should start with. + collections_ref: A reference to a collections dict the _FuncGraph should + use internally. + + Returns: + A _FuncGraph. + + Raises: + ValueError: if func returns None. + """ + if not name: + name = _get_func_name(func) + func_graph = _FuncGraph(name, capture_by_value) + + with func_graph.as_default(), ops.device(device): + # pylint: disable=protected-access + if collections_ref is not None: + func_graph._collections = collections_ref + if container is not None: + func_graph._container = container + if colocation_stack is not None: + func_graph._colocation_stack = colocation_stack + # pylint: enable=protected-access + + # Create placeholders for the function arguments. + for (argname, argtype) in zip(arg_names, arg_types): + argholder = array_ops.placeholder(argtype, name=argname) + func_graph.inputs.append(argholder) + # Call func and gather the output tensors. + with vs.variable_scope("", custom_getter=func_graph.getvar): + outputs = func(*func_graph.inputs) + + # There is no way of distinguishing between a function not returning + # anything and a function returning None in Python. + # We need to allow the former and ideally want to forbid the latter as + # it is most likely user error. + # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to + # allow users to explicitly mark the function as not returning anything. + # For now, we allow a single None return and interpret it as a function + # with no output. + if outputs is None: + outputs = [] + else: + # If func only returned one value, make it a tuple. + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + if any([_ is None for _ in outputs]): + raise ValueError("Function can not return None.") + # Ensures each output is a Tensor in the function graph. + outputs = [ops.convert_to_tensor(t) for t in outputs] + outputs = [func_graph.capture(t) if t.graph is not func_graph else t + for t in outputs] + func_graph.outputs = outputs + return func_graph + + def _is_guaranteed_const(tensor): """Determines whether `tensor` is guaranteed to be a constant. @@ -1123,3 +1229,13 @@ _DTYPE_TO_STR = { dtypes.qint32: "qi32", dtypes.bfloat16: "b16" } + + +def function_def_from_tf_function(c_func): + """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" + with c_api_util.tf_buffer() as buf: + c_api.TF_FunctionToFunctionDef(c_func, buf) + data = c_api.TF_GetBuffer(buf) + fdef = function_pb2.FunctionDef() + fdef.ParseFromString(compat.as_bytes(data)) + return fdef diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..46c9c4c14adc7d4adeb11b45210cb296acb55086 --- /dev/null +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -0,0 +1,195 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Utlity to convert FunctionDef to GraphDef and Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.framework import versions_pb2 +from tensorflow.python.framework import function +from tensorflow.python.framework import importer +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import versions +from tensorflow.python.ops import cond_v2_impl + +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=protected-access + + +def function_def_to_graph(fdef, input_shapes=None): + """Converts a FunctionDef to a function._FuncGraph (sub-class Graph). + + The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set. + The input tensors are represented as placeholders. + + Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be + set by the caller. + + Args: + fdef: FunctionDef. + input_shapes: Optional. A list of TensorShape objects of the shapes of + function inputs. If specified, its length must match length of + `fdef.signature.input_arg`. If a shape is None, the corresponding input + placeholder will have unknown shape. + + Returns: + A _FuncGraph. + """ + func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access + graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( + fdef, input_shapes) + + with func_graph.as_default(): + # Add all function nodes to the graph. + importer.import_graph_def(graph_def, name="") + + # Initialize fields specific to _FuncGraph. + + # inputs + input_tensor_names = [ + nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg + ] + func_graph.inputs = [ + func_graph.get_tensor_by_name(name) for name in input_tensor_names + ] + + # outputs + output_tensor_names = [ + nested_to_flat_tensor_name[fdef.ret[arg.name]] + for arg in fdef.signature.output_arg + ] + func_graph.outputs = [ + func_graph.get_tensor_by_name(name) for name in output_tensor_names + ] + + return func_graph + + +def function_def_to_graph_def(fdef, input_shapes=None): + """Convert a FunctionDef to a GraphDef. + + Steps: + 1. Creates placeholder nodes corresponding to inputs in + `FunctionDef.signature.input_arg`. + 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. + 3. Renames inputs of all nodes to use the convention of GraphDef instead of + FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming + in FunctionDefs is different from GraphDefs. + + Args: + fdef: FunctionDef. + input_shapes: Optional. A list of TensorShape objects of the shapes of + function inputs. If specified, its length must match length of + `fdef.signature.input_arg`. If a shape is None, the corresponding input + placeholder will have unknown shape. + + Returns: + A tuple of (GraphDef, dict). The dict contains a mapping + from nested tensor names (in FunctionDef) to flattened names (in GraphDef). + + Raises: + ValueError: If the length of input_shapes does not match the number of + input_args or if the FunctionDef is invalid. + """ + graph_def = graph_pb2.GraphDef() + graph_def.versions.CopyFrom( + versions_pb2.VersionDef( + producer=versions.GRAPH_DEF_VERSION, + min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) + + if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): + raise ValueError("Length of input_shapes must match the number of " + + "input_args. len(input_shapes): {} len(input_arg): {}". + format(len(input_shapes), len(fdef.signature.input_arg))) + + # 1. Create placeholders for input nodes. + for i, arg_def in enumerate(fdef.signature.input_arg): + node_def = graph_def.node.add() + node_def.name = arg_def.name + node_def.op = "Placeholder" + node_def.attr["dtype"].type = arg_def.type + if input_shapes and input_shapes[i] is not None: + node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto()) + + # 2. Copy all body NodeDefs to the GraphDef. + graph_def.node.extend(fdef.node_def) + + # 3. Perform the renaming. + + # Build the tensor name mapping then flatten the tensor names. + # See comment on `FunctionDef.node_def` on how the tensor naming in + # FunctionDefs is different from GraphDefs. + nested_to_flat_tensor_name = {} + + for arg_def in fdef.signature.input_arg: + nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) + + for node_def in fdef.node_def: + op_def = op_def_registry.get_registered_ops().get(node_def.op) + if not op_def: + # TODO(b/80470245): Support functions which refer other functions. + raise NotImplementedError( + "No op registered for {},".format(node_def.op) + + " it may be a function. function_def_to_graph_def " + + "currently does not support converting functions with " + + "references to other graph functions.") + + for attr in op_def.attr: + if attr.type in ("func", "list(func)"): + # TODO(b/80470245): Support functions which refer other functions. + raise NotImplementedError("Unsupported attr {} ".format(attr.name) + + " with type {}".format(attr.type) + + " in op {}. ".format(op_def.name) + + "function_def_to_graph_def currently does " + + "not support converting functions with " + + "references to other graph functions.") + + # Iterate over output_args in op_def to build the map. + # Index of the output tensor in the flattened list of *all* output + # tensors of the op. + flattened_index = 0 + for arg_def in op_def.output_arg: + num_args = _get_num_args(arg_def, node_def) + for i in range(num_args): + # Map tensor names from "node_name:output_arg_name:index" to + # "node_name:flattened_index". + nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) + flat_name = "{}:{}".format(node_def.name, flattened_index) + nested_to_flat_tensor_name[nested_name] = flat_name + flattened_index += 1 + + # Update inputs of all nodes in graph. + for node_def in graph_def.node: + for i in range(len(node_def.input)): + node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] + + return graph_def, nested_to_flat_tensor_name + + +# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. +def _get_num_args(arg_def, node_def): + if arg_def.number_attr: + return node_def.attr[arg_def.number_attr].i + elif arg_def.type_list_attr: + return len(node_def.attr[arg_def.type_list_attr].list.type) + elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: + return 1 + else: + raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4e6ef54fb02cc6ba52c9de2ccabea982fd2323 --- /dev/null +++ b/tensorflow/python/framework/function_def_to_graph_test.py @@ -0,0 +1,184 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.function_def_to_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import graph_to_function_def +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FunctionDefToGraphTest(test.TestCase): + + def _build_function_def(self): + with ops.Graph().as_default() as g: + # Inputs + x = array_ops.placeholder(dtypes.float32, name="x") + y = array_ops.placeholder(dtypes.float32, name="y") + + # Outputs + sum_squares = math_ops.add_n( + [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares") + sum_cubes = math_ops.add_n( + [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes") + fdef = graph_to_function_def.graph_to_function_def( + g, + g.get_operations(), + [x, y], # Inputs + [sum_squares, sum_cubes]) # Outputs. + fdef.signature.name = "_whats_in_a_name" + return fdef + + def testInputsAndOutputs(self): + fdef = self._build_function_def() + g = function_def_to_graph.function_def_to_graph(fdef) + self.assertEqual(g.name, "_whats_in_a_name") + with self.test_session(graph=g) as sess: + inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3}) + self.assertSequenceEqual(inputs, [2.0, 3.0]) + outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3}) + self.assertSequenceEqual(outputs, [13.0, 35.0]) + + def testShapes(self): + fdef = self._build_function_def() + + g = function_def_to_graph.function_def_to_graph(fdef) + self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims. + self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims. + self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims. + self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. + + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[tensor_shape.vector(5), + tensor_shape.vector(5)]) + self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) + self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) + self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) + self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) + + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) + print(g.as_graph_def()) + self.assertIsNone(g.inputs[0].shape.dims) + self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) + self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) + self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7]) + + # Should raise a ValueError if the length of input_shapes does not match + # the number of input args in FunctionDef.signature.input_arg. + with self.assertRaises(ValueError): + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[tensor_shape.matrix(5, 7)]) + + +class FunctionDefToGraphDefTest(test.TestCase): + + def _build_function_def(self): + with ops.Graph().as_default() as g: + # Inputs: x y z + # |\ | / + # | \ | / + # | foo_1 list_output + # | / \ / \ + # | d_1 e_1 a:1 a:0 + # | \ | / | + # | \ | / | + # | foo_2 | + # | / \ | + # Outputs: x d_2 e_2 a:0 + + x = array_ops.placeholder(dtypes.float32, name="x") + y = array_ops.placeholder(dtypes.int32, name="y") + z = array_ops.placeholder(dtypes.int32, name="z") + + d_1, e_1 = test_ops._op_def_lib.apply_op( + "Foo1", name="foo_1", a=x, b=y, c=z) + + list_output0, list_output1 = test_ops.list_output( + T=[dtypes.int32, dtypes.int32], name="list_output") + + d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2") + + fdef = graph_to_function_def.graph_to_function_def( + g, + g.get_operations(), + [x, y, z], # Inputs + [x, d_2, e_2, list_output0]) # Outputs. + + # Assert that the FunctionDef was correctly built. + assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node. + assert fdef.node_def[0].op == "Foo1" + assert fdef.node_def[0].input == ["x", "y", "z"] + assert fdef.node_def[1].op == "ListOutput" + assert not fdef.node_def[1].input + assert fdef.node_def[2].op == "Foo1" + assert fdef.node_def[2].input == [ + "foo_1:d:0", "foo_1:e:0", "list_output:a:1" + ] + return fdef + + def testTensorNames(self): + fdef = self._build_function_def() + g, tensor_name_map = function_def_to_graph.function_def_to_graph_def(fdef) + + # Verify that inputs of body nodes are correctly renamed. + # foo_1 + self.assertSequenceEqual(g.node[3].input, ["x:0", "y:0", "z:0"]) + # foo_2 + self.assertSequenceEqual(g.node[5].input, + ["foo_1:0", "foo_1:1", "list_output:1"]) + + # Verify that the `tensor_name_map` has the correct mapping. + self.assertDictEqual( + tensor_name_map, { + "x": "x:0", + "y": "y:0", + "z": "z:0", + "foo_1:d:0": "foo_1:0", + "foo_1:e:0": "foo_1:1", + "list_output:a:0": "list_output:0", + "list_output:a:1": "list_output:1", + "foo_2:d:0": "foo_2:0", + "foo_2:e:0": "foo_2:1", + }) + + def testShapes(self): + fdef = self._build_function_def() + g, _ = function_def_to_graph.function_def_to_graph_def( + fdef, + input_shapes=[tensor_shape.scalar(), + tensor_shape.vector(5), None]) + self.assertEqual("shape" in g.node[0].attr, True) + self.assertSequenceEqual( + tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), []) + self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) + self.assertEqual("shape" in g.node[1].attr, True) + self.assertSequenceEqual( + tensor_shape.TensorShape(g.node[1].attr["shape"].shape).as_list(), [5]) + self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) + self.assertFalse("shape" in g.node[2].attr) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 88f6a3667617515113b29eed4ed7731b9f19e4f4..15e41ba91f9ae121d3d4ea48e3e71eace7cd9a3e 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops @@ -1764,6 +1765,44 @@ class DevicePlacementTest(test.TestCase): for node in divide_fdef[0].node_def: self.assertAllEqual(node.device, "/device:CPU:1") + def _testNestedDeviceWithSameFunction(self, func_name): + + def MatmulWrap(a, b): + + @function.Defun( + func_name=func_name, *[dtypes.int32] * 2) + def Matmul(a, b): + return math_ops.matmul(a, b) + + return Matmul(a, b) + + with ops.Graph().as_default(), ops.device("CPU:0"): + c = MatmulWrap(1, 2) + + with ops.device("CPU:1"): + MatmulWrap(c, 3) + + gdef = ops.get_default_graph().as_graph_def() + + devices = [] + for node in gdef.library.function[0].node_def: + devices.append(node.device) + for node in gdef.library.function[1].node_def: + devices.append(node.device) + + self.assertAllEqual(sorted(devices), ["/device:CPU:0", "/device:CPU:1"]) + + def testFunctionWithName(self): + with self.assertRaises(InvalidArgumentError) as cm: + self._testNestedDeviceWithSameFunction("MatmulTest") + self.assertEqual( + cm.exception.message, + "Cannot add function \'MatmulTest\' because a different " + "function with the same name already exists.") + + def testFunctionWithoutName(self): + self._testNestedDeviceWithSameFunction(None) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 5112bea48b5033e2cd16a555d65993b575f475eb..699d2b70d176db7718a6e480f9f7b08a65ae6a8e 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -17,78 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import contextlib -import copy -from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 -from tensorflow.core.framework import types_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export -# TODO(josh11b): SWIG the code from node_def_util instead of duplicating -# the logic here. -def _GetNodeAttr(node_def, attr_name): - if attr_name not in node_def.attr: - raise ValueError('Expected one attr with name %r in %s.' % (attr_name, - str(node_def))) - return node_def.attr[attr_name] - - -def _ArgToTypesNoRef(node_def, arg_def): - if arg_def.number_attr: - repeats = _GetNodeAttr(node_def, arg_def.number_attr).i - if arg_def.type_attr: - dtype = _GetNodeAttr(node_def, arg_def.type_attr).type - else: - assert arg_def.type != types_pb2.DT_INVALID - dtype = arg_def.type - return [dtype] * repeats - elif arg_def.type_attr: - return [_GetNodeAttr(node_def, arg_def.type_attr).type] - elif arg_def.type_list_attr: - return _GetNodeAttr(node_def, arg_def.type_list_attr).list.type - else: - assert arg_def.type != types_pb2.DT_INVALID - return [arg_def.type] - - -def _SingleArgToTypes(node_def, arg_def): - types = _ArgToTypesNoRef(node_def, arg_def) - if arg_def.is_ref: - return [dtypes.as_dtype(dt)._as_ref.as_datatype_enum for dt in types] # pylint: disable=protected-access - return types - - -def _ArgsToTypes(node_def, arg_list): - types = [] - for arg_def in arg_list: - types.extend(_SingleArgToTypes(node_def, arg_def)) - return types - - -def _InputTypes(node_def, op_dict): - op_def = op_dict[node_def.op] - return _ArgsToTypes(node_def, op_def.input_arg) - - -def _OutputTypes(node_def, op_dict): - op_def = op_dict[node_def.op] - return _ArgsToTypes(node_def, op_def.output_arg) - - def _IsControlInput(input_name): # Expected format: '^operation_name' (control input). return input_name.startswith('^') @@ -128,18 +71,6 @@ def _ParseTensorName(tensor_name): raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,)) -def _CanonicalInputName(input_name): - input_name = compat.as_str(input_name) - if _IsControlInput(input_name): - return input_name - input_op_name, output_index = _ParseTensorName(input_name) - return '%s:%d' % (input_op_name, output_index) - - -def _InvalidNodeMessage(node, message): - return 'graph_def is invalid at node %r: %s.' % (node.name, message) - - @contextlib.contextmanager def _MaybeDevice(device): """Applies the given device only if device is not None or empty.""" @@ -460,351 +391,70 @@ def import_graph_def(graph_def, _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() - - if graph._c_graph: # pylint: disable=protected-access - with ops.name_scope(name, 'import', input_map.values()) as scope: - # Save unique prefix generated by name_scope - if scope: - assert scope.endswith('/') - prefix = scope[:-1] - else: - prefix = '' - - # Generate any input map tensors inside name scope - input_map = _ConvertInputMapValues(name, input_map) - - scoped_options = c_api_util.ScopedTFImportGraphDefOptions() - options = scoped_options.options - _PopulateTFImportGraphDefOptions(options, prefix, input_map, - return_elements) - - # _ProcessNewOps mutates the new operations. _lock ensures a Session.run - # call cannot occur between creating the TF_Operations in the - # TF_GraphImportGraphDefWithResults call and mutating the them in - # _ProcessNewOps. - with graph._lock: # pylint: disable=protected-access - with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: - try: - results = c_api.TF_GraphImportGraphDefWithResults( - graph._c_graph, serialized, options) # pylint: disable=protected-access - results = c_api_util.ScopedTFImportGraphDefResults(results) - except errors.InvalidArgumentError as e: - # Convert to ValueError for backwards compatibility. - raise ValueError(str(e)) - - # Create _DefinedFunctions for any imported functions. - # - # We do this by creating _DefinedFunctions directly from `graph_def`, and - # adding them to `graph`. Adding an existing function to a TF_Graph is a - # no-op, so this only has the effect of updating the Python state (usually - # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). - # - # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph - # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph - # TODO(b/74620627): move this after _ProcessNewOps outside the lock once - # _USE_C_SHAPES is removed. - if graph_def.library and graph_def.library.function: - # pylint: disable=protected-access - functions = function._from_library(graph_def.library) - for f in functions: - f.add_to_graph(graph) - # pylint: enable=protected-access - - _ProcessNewOps(graph) - - # Treat input mappings that don't appear in the graph as an error, because - # they are likely to be due to a typo. - missing_unused_input_keys = ( - c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( - results.results)) - if missing_unused_input_keys: - missing_unused_input_keys = [ - compat.as_str(s) for s in missing_unused_input_keys - ] - raise ValueError( - 'Attempted to map inputs that were not found in graph_def: [%s]' % - ', '.join(missing_unused_input_keys)) - - if return_elements is None: - return None + with ops.name_scope(name, 'import', input_map.values()) as scope: + # Save unique prefix generated by name_scope + if scope: + assert scope.endswith('/') + prefix = scope[:-1] else: - return _GatherReturnElements(return_elements, graph, results.results) - - else: - g = graph - - # Use a canonical representation for all tensor names. - input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} - used_input_keys = set() - name_to_op = {} - - # Add any functions defined in `graph_def` to `g` + prefix = '' + + # Generate any input map tensors inside name scope + input_map = _ConvertInputMapValues(name, input_map) + + scoped_options = c_api_util.ScopedTFImportGraphDefOptions() + options = scoped_options.options + _PopulateTFImportGraphDefOptions(options, prefix, input_map, + return_elements) + + # _ProcessNewOps mutates the new operations. _mutation_lock ensures a + # Session.run call cannot occur between creating the TF_Operations in the + # TF_GraphImportGraphDefWithResults call and mutating the them in + # _ProcessNewOps. + with graph._mutation_lock(): # pylint: disable=protected-access + with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: + try: + results = c_api.TF_GraphImportGraphDefWithResults( + graph._c_graph, serialized, options) # pylint: disable=protected-access + results = c_api_util.ScopedTFImportGraphDefResults(results) + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) + + # Create _DefinedFunctions for any imported functions. + # + # We do this by creating _DefinedFunctions directly from `graph_def`, and + # adding them to `graph`. Adding an existing function to a TF_Graph is a + # no-op, so this only has the effect of updating the Python state (usually + # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). + # + # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph + # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph + # TODO(b/74620627): move this after _ProcessNewOps outside the lock once + # _USE_C_SHAPES is removed. if graph_def.library and graph_def.library.function: - # Copy op_dict so we don't clobber the original - op_dict = copy.copy(op_dict) # pylint: disable=protected-access - # Note that we do not prepend `name` to the function name. The reasoning - # is that function names are similar to op definition names, which - # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: - f.add_to_graph(g) - op_dict[f.name] = f.definition.signature + f.add_to_graph(graph) # pylint: enable=protected-access - # LINT.IfChange - with ops.name_scope(name, 'import', input_map.values()) as scope: - # TODO(ashankar): Should this just copy over or should it do some - # more nuanced merging? For example, the graph may already have some - # marked "bad versions" and we don't want to lose those because of - # what's in graph_def.versions? The C++ ImporGraphDef does something - # more nuanced. - g.graph_def_versions.CopyFrom(graph_def.versions) - - input_map = _ConvertInputMapValues(name, input_map) - - # NOTE(mrry): We do this in two passes, because there may be a cycle in - # `graph_def`. - - # 1. Add operations without their inputs. - for node in graph_def.node: - # Check to see if this op's name matches a previously seen op - if node.name in name_to_op: - raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) - if node.op not in op_dict: - raise ValueError( - 'No op named %s in defined operations. If the Graph you are ' - 'importing uses custom ops or any parts of tf.contrib, you ' - 'should explicitly import the libraries defining those ops ' - 'before loading the Graph. Note that tf.contrib is lazily loaded ' - 'when accessed, so simply referencing (e.g.) ' - '`tf.contrib.resampler` will cause those ops to be made ' - 'available.' % node.op) - op_def = op_dict[node.op] - - output_types = _OutputTypes(node, op_dict) - name_to_op[node.name] = g.create_op( - node.op, [], output_types, name=node.name, attrs=node.attr, - compute_shapes=False, compute_device=False, - op_def=op_def) - - # Maps from a node to the ops it is colocated with, if colocation - # is specified in the attributes. - colocation_pairs = collections.defaultdict(list) - - # 2. Add inputs to the operations. - for node in graph_def.node: - op = name_to_op[node.name] - input_types = _InputTypes(node, op_dict) - apply_device_function = True - - # Rewrite the colocation attributes in the graph, since the - # names of new ops may have changed. - for key, value in op.node_def.attr.items(): - if key == '_class': - class_values = value.list - new_class_values = [] - for class_value in class_values.s: - if class_value.startswith(b'loc:@'): - op_to_bind_to = class_value[5:].decode() - # Find the op by its original name. - if op_to_bind_to not in name_to_op: - raise ValueError('Specified colocation to an op that ' - 'does not exist during import: %s in %s' % ( - op_to_bind_to, node.name)) - original_op = name_to_op[op_to_bind_to] - new_class_values.append(compat.as_bytes( - 'loc:@' + original_op.name)) - if op_to_bind_to != node.name: - # Keep track of this mapping for a later phase. - colocation_pairs[op].append(original_op) - # Don't apply this op's device function, - # the colocation constraint will ensure - # the proper device gets assigned at runtime. - apply_device_function = False - - else: - new_class_values.append(class_value) - value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( - s=new_class_values)) - - # NOTE(mrry): We cannot use zip here because control inputs do not - # appear in the list of input_types. - for i, input_name in enumerate( - [_CanonicalInputName(x) for x in node.input]): - - if _IsControlInput(input_name): - # (a) Input is a control input that should be taken from an op - # in "graph_def". - try: - source_op = name_to_op[input_name[1:]] - except KeyError: - raise ValueError( - _InvalidNodeMessage( - node, - 'Control input %r not found in graph_def.' - % (input_name,))) - # pylint: disable=protected-access - op._add_control_input(source_op) - # pylint: enable=protected-access - - else: - try: - input_type = input_types[i] - except IndexError: - raise ValueError(_InvalidNodeMessage( - node, 'More inputs specified (%r) than the op expects.' - % (input_name,))) - - if input_name in input_map: - # (b) Input should be replaced by a tensor from the caller. - source_tensor = input_map[input_name] - used_input_keys.add(input_name) - - else: - # (c) Input should be taken from an op in `graph_def`. - operation_name, output_index = _ParseTensorName(input_name) - try: - source_op = name_to_op[operation_name] - source_tensor = list(source_op.values())[output_index] - except (KeyError, IndexError): - raise ValueError( - _InvalidNodeMessage( - node, - 'Input tensor %r not found in graph_def.' - % (input_name,))) - - try: - # pylint: disable=protected-access - op._add_input(source_tensor, dtype=input_type) - # pylint: enable=protected-access - except TypeError as te: - raise ValueError(_InvalidNodeMessage( - node, 'Input tensor %r %s' % (input_name, te))) - - # pylint: disable=protected-access - if op._input_types != input_types: - raise ValueError( - _InvalidNodeMessage( - node, - 'Input types mismatch (expected %r but got %r)' - % (', '.join(dtypes.as_dtype(x).name for x in input_types), - ', '.join(x.name for x in op._input_types)))) - # pylint: enable=protected-access - - # Execute shape inference for this op. - # NOTE(mrry): If the graph contains a cycle, the full shape - # information may not be available for this op's inputs. - ops.set_shape_and_handle_data_for_outputs(op) - # For nodes with _output_shapes set, set the output shapes. - if '_output_shapes' in op.node_def.attr: - for i, output in enumerate(op.outputs): - dims = op.node_def.attr['_output_shapes'].list.shape[i] - output_shape = tensor_shape.TensorShape( - None if dims.unknown_rank else - [dim.size if dim.size >= 0 else None for dim in dims.dim]) - - try: - output.set_shape(output_shape) - except ValueError as e: - # If the output shape is incompatible with what is inferred - # by the graph for a very specific whitelist of ops, then we - # ignore this output shape. This can happen if there is a - # bug in the shape function for some operation, and the - # serialized graph def has the incorrect shape set when - # running on a newer binary with the fixed shape function. - # This is an escape hatch that allows us to correct shape - # functions that are not critical to correct execution but - # would cause graphs to fail if imported after correcting. - # - # This can be removed after 2017/03/08. - if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', - 'FIFOQueue', 'PriorityQueue', 'QueueSize', - 'Stack', 'Barrier', 'BarrierReadySize', - 'BarrierIncompleteSize', 'HashTable', - 'MutableHashTable', - 'MutableHashTableOfTensors', 'Mutex', - 'CuckooTable', 'IndexTable', - 'WholeFileReader', 'TextLineReader', - 'FixedLengthRecordReader', - 'TFRecordReader', 'IdentityReader', - 'LMDBReader', - 'RefSwitch', 'RefEnter', 'RefNextIteration', - 'RefMerge', 'RefIdentity']: - pass - elif op.type in [ - 'ConditionalAccumulator', 'SparseConditionalAccumulator', - 'Table' - ]: - # This can be removed after 2017/04/24. - pass - else: - raise e - - del op.node_def.attr['_output_shapes'] - - # NOTE(mrry): We do this after configuring the inputs, because - # the result of the device functions may depend on the inputs. - if apply_device_function: - with _MaybeDevice(node.device): - g._apply_device_functions(op) # pylint: disable=protected-access - - # The following loop populates the device field of ops that are - # colocated with another op. This is implied by the colocation - # attribute, but we propagate the device field for completeness. - for op, coloc_op_list in colocation_pairs.items(): - coloc_device = None - # Find any device in the list of colocated ops that have a - # device, if it exists. We assume that if multiple ops - # have devices, they refer to the same device. Otherwise, a - # runtime error will occur since the colocation property - # cannot be guaranteed. - # - # One possible improvement is to try to check for compatibility - # of all devices in this list at import time here, which would - # require implementing a compatibility function for device specs - # in python. - for coloc_op in coloc_op_list: - if coloc_op.device: - coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) - break - if coloc_device: - op._set_device(coloc_device) # pylint: disable=protected-access - - # Treat input mappings that don't appear in the graph as an error, - # because they are likely to be due to a typo. - def _IsImportedNodeOutput(tensor_name): - operation_name, output_index = _ParseTensorName(tensor_name) - try: - return output_index < len(name_to_op[operation_name].outputs) - except KeyError: - return False - absent_input_keys = [ - k for k in frozenset(input_map.keys()).difference(used_input_keys) - if not _IsImportedNodeOutput(k)] - if absent_input_keys: - raise ValueError( - 'Attempted to map inputs that were not found in graph_def: [%s]' - % ', '.join(absent_input_keys)) - - if return_elements is None: - return None - else: - ret = [] - for name in return_elements: - name = compat.as_str(name) - if ':' in name: - try: - operation_name, output_index = _ParseTensorName(name) - ret.append(name_to_op[operation_name].outputs[output_index]) - except (ValueError, KeyError, IndexError): - raise ValueError( - 'Requested return_element %r not found in graph_def.' % name) - else: - try: - ret.append(name_to_op[name]) - except KeyError: - raise ValueError( - 'Requested return_element %r not found in graph_def.' % name) - return ret - # LINT.ThenChange(//tensorflow/core/graph/graph_constructor.cc) + _ProcessNewOps(graph) + + # Treat input mappings that don't appear in the graph as an error, because + # they are likely to be due to a typo. + missing_unused_input_keys = ( + c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( + results.results)) + if missing_unused_input_keys: + missing_unused_input_keys = [ + compat.as_str(s) for s in missing_unused_input_keys + ] + raise ValueError( + 'Attempted to map inputs that were not found in graph_def: [%s]' % + ', '.join(missing_unused_input_keys)) + + if return_elements is None: + return None + else: + return _GatherReturnElements(return_elements, graph, results.results) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 1b34bf3ceb469df58d6d0e423f5dc8b49f11516a..cf0b1e36fb3f02c85873a0da81dc056d2fbd5f6a 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import copy -import functools import linecache import os import re @@ -56,16 +55,16 @@ from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export -# Temporary global switch determining if we should enable the work-in-progress -# calls to the C API. Currently disabled by default but can be manually enabled -# in code or via the environment variable. This will be removed once all -# functionality is supported and there's no performance penalty with it enabled. -_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "1") is not "0" -_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0" +# Temporary global switches determining if we should enable the work-in-progress +# calls to the C API. These will be removed once all functionality is supported. +_USE_C_API = True +_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0" def tensor_id(tensor): @@ -291,15 +290,8 @@ class Tensor(_TensorLike): self._value_index = value_index self._dtype = dtypes.as_dtype(dtype) - if _USE_C_API: - # This will be set by set_shape_and_handle_data_for_outputs. - self._shape_val = None - else: - # The Python code requires all tensors start with a shape to support shape - # inference on imported while loops. This isn't necessary with the C API - # enabled because the C API provides the shapes for imported nodes. - # TODO(skyewm): remove when _USE_C_API is removed. - self._shape_val = tensor_shape.unknown_shape() + # This will be set by self.shape(). + self._shape_val = None # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. @@ -387,7 +379,6 @@ class Tensor(_TensorLike): if _USE_C_SHAPES: self._shape_val = self._c_api_shape() else: - assert _USE_C_API # Call set_shape_and_handle_data_for_outputs in topological order on all # ops that are needed to compute self.op's shape. We do this instead of # having set_shape_and_handle_data_for_outputs recursively call @@ -511,8 +502,6 @@ class Tensor(_TensorLike): else: self._shape_val = self.shape.merge_with(shape) - if not self._op._graph._c_graph: return - # Update C shape even if _USE_C_SHAPES = False, since we still want # set_shape to be reflected in the C API graph for when we run it. if not isinstance(shape, tensor_shape.TensorShape): @@ -548,33 +537,14 @@ class Tensor(_TensorLike): Returns: A list of `Operation`s. """ - if self._op._c_op: # pylint: disable=protected-access - consumer_names = c_api.TF_OperationOutputConsumers_wrapper( - self._as_tf_output()) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe(name) - for name in consumer_names - ] - # pylint: enable=protected-access - else: - return self._consumers - - def _add_consumer(self, consumer): - """Add a consumer to this tensor. - - Args: - consumer: an Operation. - - Raises: - TypeError: if the consumer is not an Operation. - """ + consumer_names = c_api.TF_OperationOutputConsumers_wrapper( + self._as_tf_output()) # pylint: disable=protected-access - assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API" + return [ + self.graph._get_operation_by_name_unsafe(name) + for name in consumer_names + ] # pylint: enable=protected-access - if not isinstance(consumer, Operation): - raise TypeError("Consumer must be an Operation: %s" % consumer) - self._consumers.append(consumer) def _as_node_def_input(self): """Return a value to use for the NodeDef "input" attribute. @@ -597,7 +567,6 @@ class Tensor(_TensorLike): def _as_tf_output(self): # pylint: disable=protected-access - assert self.op._c_op return c_api_util.tf_output(self.op._c_op, self.value_index) # pylint: enable=protected-access @@ -1725,18 +1694,8 @@ class Operation(object): "a Tensor, or IndexedSlices: %s" % c) control_input_ops.append(control_op) - # Don't set private fields with C API enabled to catch users who need to - # switch to public API. - # TODO(skyewm): delete these fields once we remove _USE_C_API - if not self._graph._c_graph: - self._inputs_val = list(inputs) # Defensive copy. - self._input_types_val = input_types - self._control_inputs_val = control_input_ops - self._node_def_val = copy.deepcopy(node_def) - self._op_def_val = op_def - else: - # This will be set by self.inputs. - self._inputs_val = None + # This will be set by self.inputs. + self._inputs_val = None self._id_value = self._graph._next_id() # pylint: disable=protected-access self._original_op = original_op @@ -1745,10 +1704,8 @@ class Operation(object): # Initialize self._c_op. if c_op: - # TODO(skyewm): remove this assert when we remove USE_C_API - assert self._graph._c_graph # pylint: disable=protected-access self._c_op = c_op - elif self._graph._c_graph: # pylint: disable=protected-access + else: if op_def is None: op_def = self._graph._get_op_def(node_def.op) # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. @@ -1757,30 +1714,19 @@ class Operation(object): op_def, inputs, node_def.attr) self._c_op = _create_c_op(self._graph, node_def, grouped_inputs, control_input_ops) - else: - self._c_op = None - - # Mark that we consume the inputs. This is unnecessary and unsupported with - # the C API enabled, since the C API tracks the tensor consumers instead. - if not self._c_op: - for input_tensor in self._inputs_val: - input_tensor._add_consumer(self) # pylint: disable=protected-access # Initialize self._outputs. - if self._c_op: - num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [ - c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) - for i in range(num_outputs)] - assert output_types is not None - elif output_types is None: - output_types = [] - self._output_types_val = output_types + num_outputs = c_api.TF_OperationNumOutputs(self._c_op) + output_types = [ + c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) + for i in range(num_outputs)] self._outputs = [ Tensor(self, i, output_type) for i, output_type in enumerate(output_types) ] + self._graph._add_op(self) # pylint: disable=protected-access + if not c_op: self._control_flow_post_processing() @@ -1794,7 +1740,6 @@ class Operation(object): control_flow_util.CheckInputFromValidContext(self, input_tensor.op) if self._control_flow_context is not None: self._control_flow_context.AddOp(self) - self._recompute_node_def() def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): """Regroups a flat list of input tensors into scalar and sequence inputs. @@ -1875,10 +1820,7 @@ class Operation(object): @property def name(self): """The full name of this operation.""" - if self._c_op: - return c_api.TF_OperationName(self._c_op) - else: - return self._node_def_val.name + return c_api.TF_OperationName(self._c_op) @property def _id(self): @@ -1894,10 +1836,7 @@ class Operation(object): assigned, or an empty string if it has not been assigned to a device. """ - if self._c_op: - return c_api.TF_OperationDevice(self._c_op) - else: - return self._node_def_val.device + return c_api.TF_OperationDevice(self._c_op) @property def _output_types(self): @@ -1910,28 +1849,21 @@ class Operation(object): The length of this list indicates the number of output endpoints of the operation. """ - if self._c_op: - num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [ - c_api.TF_OperationOutputType(self._tf_output(i)) - for i in xrange(num_outputs) - ] - # TODO(iga): Remove this assert after converting to C API by default. - # Just being a bit paranoid here. - assert self._output_types_val == output_types - # In all the tests we have output_types that are passed into - # Operation.__init__ are a list of ints (which is illegal according - # to the docstring), but input_types are instances of DType. - # This extra assert is to catch if we ever use DType for output_types. - if output_types: - assert isinstance(output_types[0], int) - return output_types - else: - return self._output_types_val + num_outputs = c_api.TF_OperationNumOutputs(self._c_op) + output_types = [ + c_api.TF_OperationOutputType(self._tf_output(i)) + for i in xrange(num_outputs) + ] + # In all the tests we have output_types that are passed into + # Operation.__init__ are a list of ints (which is illegal according + # to the docstring), but input_types are instances of DType. + # This extra assert is to catch if we ever use DType for output_types. + if output_types: + assert isinstance(output_types[0], int) + return output_types def _tf_output(self, output_idx): """Create and return a new TF_Output for output_idx'th output of this op.""" - assert self._c_op tf_output = c_api.TF_Output() tf_output.oper = self._c_op tf_output.index = output_idx @@ -1939,7 +1871,6 @@ class Operation(object): def _tf_input(self, input_idx): """Create and return a new TF_Input for input_idx'th input of this op.""" - assert self._c_op tf_input = c_api.TF_Input() tf_input.oper = self._c_op tf_input.index = input_idx @@ -1951,47 +1882,12 @@ class Operation(object): Args: device: string or device.. The device to set. """ - if self._c_op: - c_api.SetRequestedDevice( - self._graph._c_graph, # pylint: disable=protected-access - self._c_op, # pylint: disable=protected-access - compat.as_str(_device_string(device))) - else: - self._node_def_val.device = _device_string(device) - - def _add_input(self, tensor, dtype=None): - """Add a new input to this operation. - - Args: - tensor: the Tensor to add as an input. - dtype: tf.DType: type of the input; defaults to - the tensor's dtype. + c_api.SetRequestedDevice( + self._graph._c_graph, # pylint: disable=protected-access + self._c_op, # pylint: disable=protected-access + compat.as_str(_device_string(device))) - Raises: - TypeError: if tensor is not a Tensor, - or if input tensor type is not convertible to dtype. - ValueError: if the Tensor is from a different graph. - """ - assert not self._c_op, ( - "Operation._add_input doesn't work with C API") - if not isinstance(tensor, Tensor): - raise TypeError("tensor must be a Tensor: %s" % tensor) - _assert_same_graph(self, tensor) - if dtype is None: - dtype = tensor.dtype - else: - dtype = dtypes.as_dtype(dtype) - if not dtype.is_compatible_with(tensor.dtype): - raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" % - (tensor.dtype.name, dtype.name)) - self._inputs_val.append(tensor) - self._input_types_val.append(dtype) - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() - - # TODO(skyewm): Remove `update_dtype` when we enable the C API. - def _update_input(self, index, tensor, update_dtype=True): + def _update_input(self, index, tensor): """Update the input to this operation at the given index. NOTE: This is for TF internal use only. Please don't use it. @@ -1999,7 +1895,6 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. - update_dtype: If `False`, the type for this input is not updated. Raises: TypeError: if tensor is not a Tensor, @@ -2016,20 +1911,12 @@ class Operation(object): if not _USE_C_SHAPES: set_shape_and_handle_data_for_outputs(self) - if self._c_op: - # Reset cached inputs. - self._inputs_val = None - c_api.UpdateEdge( - self._graph._c_graph, # pylint: disable=protected-access - tensor._as_tf_output(), # pylint: disable=protected-access - self._tf_input(index)) - else: - self._inputs_val[index].consumers().remove(self) - self._inputs_val[index] = tensor - if update_dtype: - self._input_types_val[index] = tensor.dtype - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() + # Reset cached inputs. + self._inputs_val = None + c_api.UpdateEdge( + self._graph._c_graph, # pylint: disable=protected-access + tensor._as_tf_output(), # pylint: disable=protected-access + self._tf_input(index)) def _add_control_inputs(self, ops): """Add a list of new control inputs to this operation. @@ -2041,19 +1928,10 @@ class Operation(object): TypeError: if ops is not a list of Operations. ValueError: if any op in ops is from a different graph. """ - if self._c_op: - for op in ops: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access - else: - if ops: - for op in ops: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - _assert_same_graph(self, op) - self._control_inputs_val.append(op) - self._recompute_node_def() + for op in ops: + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access def _add_control_input(self, op): """Add a new control input to this operation. @@ -2065,33 +1943,13 @@ class Operation(object): TypeError: if op is not an Operation. ValueError: if op is from a different graph. """ - if self._c_op: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access - else: - self._add_control_inputs([op]) + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access def _remove_all_control_inputs(self): """Removes any control inputs to this operation.""" - if self._c_op: - c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access - else: - del self.control_inputs[:] - - # Methods below are used when building the NodeDef and Graph proto. - def _recompute_node_def(self): - # TODO(skyewm): remove this function when we switch to C API - if self._c_op: return - - del self._node_def_val.input[:] - # pylint: disable=protected-access - self._node_def_val.input.extend( - [t._as_node_def_input() for t in self._inputs_val]) - # pylint: enable=protected-access - if self._control_inputs_val: - self._node_def_val.input.extend( - ["^%s" % op.name for op in self._control_inputs_val]) + c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access def __str__(self): return str(self.node_def) @@ -2132,19 +1990,16 @@ class Operation(object): @property def inputs(self): """The list of `Tensor` objects representing the data inputs of this op.""" - if self._c_op: - if self._inputs_val is None: - tf_outputs = c_api.GetOperationInputs(self._c_op) - # pylint: disable=protected-access - retval = [ - self.graph._get_tensor_by_tf_output(tf_output) - for tf_output in tf_outputs - ] - # pylint: enable=protected-access - self._inputs_val = Operation._InputList(retval) - return self._inputs_val - else: - return Operation._InputList(self._inputs_val) + if self._inputs_val is None: + tf_outputs = c_api.GetOperationInputs(self._c_op) + # pylint: disable=protected-access + retval = [ + self.graph._get_tensor_by_tf_output(tf_output) + for tf_output in tf_outputs + ] + # pylint: enable=protected-access + self._inputs_val = Operation._InputList(retval) + return self._inputs_val @property def _inputs(self): @@ -2158,15 +2013,12 @@ class Operation(object): @property def _input_types(self): - if self._c_op: - num_inputs = c_api.TF_OperationNumInputs(self._c_op) - input_types = [ - dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) - for i in xrange(num_inputs) - ] - return input_types - else: - return self._input_types_val + num_inputs = c_api.TF_OperationNumInputs(self._c_op) + input_types = [ + dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) + for i in xrange(num_inputs) + ] + return input_types @_input_types.setter def _input_types(self, value): @@ -2186,16 +2038,13 @@ class Operation(object): A list of `Operation` objects. """ - if self._c_op: - control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops - ] - # pylint: enable=protected-access - else: - return self._control_inputs_val + control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access @property def _control_outputs(self): @@ -2208,18 +2057,13 @@ class Operation(object): A list of `Operation` objects. """ - if self._c_op: - control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops - ] - # pylint: enable=protected-access - else: - # TODO(apassos) this should be less inefficient. - return [o for o in self._graph.get_operations() - if self in o.control_inputs] + control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access @property def _control_inputs(self): @@ -2243,11 +2087,7 @@ class Operation(object): @property def type(self): """The type of the op (e.g. `"MatMul"`).""" - if self._c_op: - op_type = c_api.TF_OperationOpType(self._c_op) - return op_type - else: - return self._node_def_val.op + return c_api.TF_OperationOpType(self._c_op) @property def graph(self): @@ -2265,15 +2105,12 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - if self._c_op: - with c_api_util.tf_buffer() as buf: - c_api.TF_OperationToNodeDef(self._c_op, buf) - data = c_api.TF_GetBuffer(buf) - node_def = node_def_pb2.NodeDef() - node_def.ParseFromString(compat.as_bytes(data)) - return node_def - else: - return self._node_def_val + with c_api_util.tf_buffer() as buf: + c_api.TF_OperationToNodeDef(self._c_op, buf) + data = c_api.TF_GetBuffer(buf) + node_def = node_def_pb2.NodeDef() + node_def.ParseFromString(compat.as_bytes(data)) + return node_def @property def _node_def(self): @@ -2292,10 +2129,7 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - if self._c_op: - return self._graph._get_op_def(self.type) - else: - return self._op_def_val + return self._graph._get_op_def(self.type) @property def _op_def(self): @@ -2321,17 +2155,14 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if self._c_op: - buf = c_api.TF_NewBufferFromString( - compat.as_bytes(attr_value.SerializeToString())) - try: - # pylint: disable=protected-access - c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) - # pylint: enable=protected-access - finally: - c_api.TF_DeleteBuffer(buf) - else: - self._node_def_val.attr[attr_name].CopyFrom(attr_value) + buf = c_api.TF_NewBufferFromString( + compat.as_bytes(attr_value.SerializeToString())) + try: + # pylint: disable=protected-access + c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) + # pylint: enable=protected-access + finally: + c_api.TF_DeleteBuffer(buf) def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. @@ -2346,21 +2177,15 @@ class Operation(object): ValueError: If this op does not have an attr with the given `name`. """ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] - if self._c_op: - try: - with c_api_util.tf_buffer() as buf: - c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) - data = c_api.TF_GetBuffer(buf) - except errors.InvalidArgumentError as e: - # Convert to ValueError for backwards compatibility. - raise ValueError(str(e)) - x = attr_value_pb2.AttrValue() - x.ParseFromString(data) - else: - if name not in self._node_def_val.attr: - raise ValueError( - "No attr named '" + name + "' in " + str(self._node_def_val)) - x = self._node_def_val.attr[name] + try: + with c_api_util.tf_buffer() as buf: + c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) + data = c_api.TF_GetBuffer(buf) + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) + x = attr_value_pb2.AttrValue() + x.ParseFromString(data) # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): @@ -2580,9 +2405,9 @@ def _set_shape_and_handle_data_for_outputs_c_api(op): def set_shape_and_handle_data_for_outputs(op): """Set the shapes and resource handle data for op's outputs. - When _USE_C_API = True, this is lazily called when a tensor's shape is first - requested. Usually this should work automatically, but some edge cases may - require manually calling this first to make sure Tensor._shape_val and + When _USE_C_SHAPES = False, this is lazily called when a tensor's shape is + first requested. Usually this should work automatically, but some edge cases + may require manually calling this first to make sure Tensor._shape_val and Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a Tensor). """ @@ -2775,6 +2600,10 @@ def _name_from_scope_name(name): return name[:-1] if (name and name[-1] == "/") else name +_MUTATION_LOCK_GROUP = 0 +_SESSION_RUN_LOCK_GROUP = 1 + + @tf_export("Graph") class Graph(object): """A TensorFlow computation, represented as a dataflow graph. @@ -2824,20 +2653,21 @@ class Graph(object): def __init__(self): """Creates a new, empty Graph.""" - # Protects core state that can be returned via public accessors, as well as - # synchronizes Session.run calls with methods that create and mutate ops - # (e.g. Graph.create_op()). This synchronization is necessary because it's - # illegal to modify an operation after it's been run. Thread-safety is - # provided on a best-effort basis to support buggy programs, and is not - # guaranteed by the public `tf.Graph` API. - # - # The lock must be reentrant because create_op can be called recursively due - # to control flow. Without a reentrant lock, many methods would also need a - # "locked" version or parameter (including generated code). + # Protects core state that can be returned via public accessors. + # Thread-safety is provided on a best-effort basis to support buggy + # programs, and is not guaranteed by the public `tf.Graph` API. # # NOTE(mrry): This does not protect the various stacks. A warning will # be reported if these are used from multiple threads self._lock = threading.RLock() + # The group lock synchronizes Session.run calls with methods that create + # and mutate ops (e.g. Graph.create_op()). This synchronization is + # necessary because it's illegal to modify an operation after it's been run. + # The group lock allows any number of threads to mutate ops at the same time + # but if any modification is going on, all Session.run calls have to wait. + # Similarly, if one or more Session.run calls are going on, all mutate ops + # have to wait until all Session.run calls have finished. + self._group_lock = lock_util.GroupLock(num_groups=2) self._nodes_by_id = dict() # GUARDED_BY(self._lock) self._next_id_counter = 0 # GUARDED_BY(self._lock) self._nodes_by_name = dict() # GUARDED_BY(self._lock) @@ -3086,15 +2916,12 @@ class Graph(object): A `VersionDef`. """ # pylint: enable=line-too-long - if self._c_graph: - with c_api_util.tf_buffer() as buf: - c_api.TF_GraphVersions(self._c_graph, buf) - data = c_api.TF_GetBuffer(buf) - version_def = versions_pb2.VersionDef() - version_def.ParseFromString(compat.as_bytes(data)) - return version_def - else: - return self._graph_def_versions + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphVersions(self._c_graph, buf) + data = c_api.TF_GetBuffer(buf) + version_def = versions_pb2.VersionDef() + version_def.ParseFromString(compat.as_bytes(data)) + return version_def @property def seed(self): @@ -3188,40 +3015,22 @@ class Graph(object): """ # pylint: enable=line-too-long - if self._c_graph: - with self._lock: - with c_api_util.tf_buffer() as buf: - c_api.TF_GraphToGraphDef(self._c_graph, buf) - data = c_api.TF_GetBuffer(buf) - graph = graph_pb2.GraphDef() - graph.ParseFromString(compat.as_bytes(data)) - # Strip the experimental library field iff it's empty. - if not graph.library.function: - graph.ClearField("library") - - if add_shapes: - for node in graph.node: - op = self._nodes_by_name[node.name] - if op.outputs: - node.attr["_output_shapes"].list.shape.extend( - [output.get_shape().as_proto() for output in op.outputs]) - else: - with self._lock: - graph = graph_pb2.GraphDef() - graph.versions.CopyFrom(self._graph_def_versions) - bytesize = 0 - for op_id in sorted(self._nodes_by_id): - op = self._nodes_by_id[op_id] - if from_version is None or op_id > from_version: - graph.node.extend([op.node_def]) - if op.outputs and add_shapes: - assert "_output_shapes" not in graph.node[-1].attr - graph.node[-1].attr["_output_shapes"].list.shape.extend( - [output.get_shape().as_proto() for output in op.outputs]) - bytesize += op.node_def.ByteSize() - if bytesize >= (1 << 31) or bytesize < 0: - raise ValueError("GraphDef cannot be larger than 2GB.") - self._copy_functions_to_graph_def(graph, bytesize) + with self._lock: + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphToGraphDef(self._c_graph, buf) + data = c_api.TF_GetBuffer(buf) + graph = graph_pb2.GraphDef() + graph.ParseFromString(compat.as_bytes(data)) + # Strip the experimental library field iff it's empty. + if not graph.library.function: + graph.ClearField("library") + + if add_shapes: + for node in graph.node: + op = self._nodes_by_name[node.name] + if op.outputs: + node.attr["_output_shapes"].list.shape.extend( + [output.get_shape().as_proto() for output in op.outputs]) return graph, self._version def as_graph_def(self, from_version=None, add_shapes=False): @@ -3295,34 +3104,16 @@ class Graph(object): # Add function to graph # pylint: disable=protected-access - if self._c_graph: - # Handle functions created without using the C API. TODO(apassos,skyewm) - # remove this when all functions are generated using the C API by default - # as this will be unnecessary. - if not function._c_func: - serialized = function.definition.SerializeToString() - c_func = c_api.TF_FunctionImportFunctionDef(serialized) - function._c_func = c_api_util.ScopedTFFunction(c_func) - gradient = (function._grad_func._c_func.func if function._grad_func - else None) - c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) - else: - # If there is already a function with the same name, raise an error - # if bodies are different. Else, do nothing. The C API version above - # has the same behavior. - previous = self._functions.get(name, None) - if previous: - # This check is not ideal as we can have a hash collision with only - # 32 bits in the hash, but the non C API mode is being deprecated. - # Don't bother changing it now. - if previous._hash_str == function._hash_str: - return - else: - raise ValueError("Cannot add function (%s, hash %s) to graph (%s). " - "Another function (%s, hash %s) is already defined " - "with that name (%s)" % ( - function, function._hash_str, self, - previous, previous._hash_str, name)) + # Handle functions created without using the C API. TODO(apassos,skyewm) + # remove this when all functions are generated using the C API by default + # as this will be unnecessary. + if not function._c_func: + serialized = function.definition.SerializeToString() + c_func = c_api.TF_FunctionImportFunctionDef(serialized) + function._c_func = c_api_util.ScopedTFFunction(c_func) + gradient = (function._grad_func._c_func.func if function._grad_func + else None) + c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) # pylint: enable=protected-access self._functions[name] = function @@ -3337,6 +3128,9 @@ class Graph(object): return self._building_function # Helper functions to create operations. + @deprecated_args(None, + "Shapes are always computed; don't use the compute_shapes " + "as it has no effect.", "compute_shapes") def create_op( self, op_type, @@ -3373,8 +3167,8 @@ class Graph(object): proto). op_def: (Optional.) The `OpDef` proto that describes the `op_type` that the operation will have. - compute_shapes: (Optional.) If True, shape inference will be performed - to compute the shapes of the outputs. + compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always + computed). compute_device: (Optional.) If True, device functions will be executed to compute the device property of the Operation. @@ -3384,8 +3178,9 @@ class Graph(object): Returns: An `Operation` object. - """ + del compute_shapes + self._check_not_finalized() for idx, a in enumerate(inputs): if not isinstance(a, Tensor): @@ -3403,9 +3198,9 @@ class Graph(object): input_ops = set([t.op for t in inputs]) control_inputs = self._control_dependencies_for_inputs(input_ops) - # _create_op_helper mutates the new Operation. _lock ensures a Session.run - # call cannot occur between creating and mutating the op. - with self._lock: + # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a + # Session.run call cannot occur between creating and mutating the op. + with self._mutation_lock(): ret = Operation( node_def, self, @@ -3415,18 +3210,7 @@ class Graph(object): input_types=input_types, original_op=self._default_original_op, op_def=op_def) - - # Note: shapes are lazily computed with the C API enabled. - # - # TODO(skyewm): unlike in the original Python implementation, the C API - # always computes shape information (even for function calls, which the - # original Python shape inference code doesn't handle). Deprecate the - # compute_shapes argument. - if not _USE_C_API and compute_shapes: - set_shape_and_handle_data_for_outputs(ret) - - self._create_op_helper(ret, compute_shapes=compute_shapes, - compute_device=compute_device) + self._create_op_helper(ret, compute_device=compute_device) return ret def _create_op_from_tf_operation(self, c_op, compute_device=True): @@ -3461,11 +3245,8 @@ class Graph(object): self._create_op_helper(ret, compute_device=compute_device) return ret - def _create_op_helper(self, op, compute_shapes=True, compute_device=True): + def _create_op_helper(self, op, compute_device=True): """Common logic for creating an op in this graph.""" - # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. - self._add_op(op) - # Apply any additional attributes requested. Do not overwrite any existing # attributes. for key, value in self._attr_scope_map.items(): @@ -3532,8 +3313,7 @@ class Graph(object): # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None - # TODO(skyewm): remove op.op_def check when _USE_C_API is removed. - if self._container and op.op_def and op.op_def.is_stateful: + if self._container and op.op_def.is_stateful: try: container_attr = op.get_attr("container") except ValueError: @@ -3820,17 +3600,14 @@ class Graph(object): def _get_op_def(self, type): # pylint: disable=redefined-builtin """Returns the `OpDef` proto for `type`. `type` is a string.""" - if self._c_graph: - with c_api_util.tf_buffer() as buf: - # pylint: disable=protected-access - c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) - # pylint: enable=protected-access - data = c_api.TF_GetBuffer(buf) - op_def = op_def_pb2.OpDef() - op_def.ParseFromString(compat.as_bytes(data)) - return op_def - else: - return self._registered_ops[type] + with c_api_util.tf_buffer() as buf: + # pylint: disable=protected-access + c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) + # pylint: enable=protected-access + data = c_api.TF_GetBuffer(buf) + op_def = op_def_pb2.OpDef() + op_def.ParseFromString(compat.as_bytes(data)) + return op_def def as_default(self): """Returns a context manager that makes this `Graph` the default graph. @@ -3862,6 +3639,9 @@ class Graph(object): assert c.graph is g ``` + If eager execution is enabled ops created under this context manager will be + added to the graph instead of executed eagerly. + Returns: A context manager for using this graph as the default graph. """ @@ -3883,7 +3663,6 @@ class Graph(object): contains many standard names for collections. value: The value to add to the collection. """ # pylint: disable=g-doc-exception - _assert_collection_is_ok(name) self._check_not_finalized() with self._lock: if name not in self._collections: @@ -3930,7 +3709,6 @@ class Graph(object): The list of values in the collection with the given `name`, or an empty list if no value has been added to that collection. """ # pylint: disable=g-doc-exception - _assert_collection_is_ok(name) with self._lock: coll_list = self._collections.get(name, None) if coll_list is None: @@ -3960,7 +3738,6 @@ class Graph(object): list contains the values in the order under which they were collected. """ # pylint: disable=g-doc-exception - _assert_collection_is_ok(name) with self._lock: collection = self._collections.get(name, None) if collection is None: @@ -4956,6 +4733,20 @@ class Graph(object): else: self._graph_control_dependencies_stack = control_dependencies + def _mutation_lock(self): + """Returns a lock to guard code that creates & mutates ops. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_MUTATION_LOCK_GROUP) + + def _session_run_lock(self): + """Returns a lock to guard code for Session.run. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) + # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. @@ -5278,35 +5069,15 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access @tf_contextlib.contextmanager def get_controller(self, default): try: - if context.executing_eagerly(): - # A Graph alone on the context stack would keep init_scope-wrapped - # operations graph building when entered (assuming init_scope is called - # in a graph building context). Instead, we push a context which first - # enables eager execution and then re-enters the Graph. - context.context().context_switches.push( - default.building_function, - functools.partial( - _enter_context_and_graph, - context.eager_mode, - default.as_default)) - else: - # This Graph is being used from a graph building context. A lack of - # context switch implies that the context is graph building. - context.context().context_switches.push(default.building_function, - default.as_default) - with super(_DefaultGraphStack, self).get_controller(default) as g: + context.context().context_switches.push( + default.building_function, default.as_default) + with super(_DefaultGraphStack, self).get_controller( + default) as g, context.graph_mode(): yield g finally: context.context().context_switches.pop() -@tf_contextlib.contextmanager -def _enter_context_and_graph(context_fn, graph_fn): - """Combines two context managers.""" - with context_fn(), graph_fn(): - yield - - _default_graph_stack = _DefaultGraphStack() @@ -5355,6 +5126,7 @@ def init_scope(): # Names that end with trailing slashes are treated by `name_scope` as # absolute. scope = scope + '/' + inner_device_stack = default_graph._device_function_stack # pylint: disable=protected-access outer_context = None if not _default_graph_stack.stack: @@ -5383,13 +5155,28 @@ def init_scope(): raise RuntimeError("All graphs are building functions, and no " "eager context was previously active.") - with outer_context(), name_scope(scope), control_dependencies( - None), tape.stop_recording(): - yield + outer_graph = None + outer_device_stack = None + try: + with outer_context(), name_scope(scope), control_dependencies( + None), tape.stop_recording(): + if not context.executing_eagerly(): + # The device stack is preserved when lifting into a graph. Eager + # execution doesn't implement device stacks and in particular it + # doesn't support device functions, so in general it's not possible + # to do the same when lifting into the eager context. + outer_graph = get_default_graph() + outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access + outer_graph._device_function_stack = inner_device_stack # pylint: disable=protected-access + yield + finally: + if outer_graph is not None: + outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access @tf_export("enable_eager_execution") -def enable_eager_execution(config=None, device_policy=None, +def enable_eager_execution(config=None, + device_policy=None, execution_mode=None): """Enables eager execution for the lifetime of this program. @@ -5449,6 +5236,31 @@ def enable_eager_execution(config=None, device_policy=None, TensorFlow graph, or if options provided conflict with a previous call to this function. """ + return enable_eager_execution_internal( + config, device_policy, execution_mode, None) + + +def enable_eager_execution_internal(config=None, + device_policy=None, + execution_mode=None, + server_def=None): + """Enables eager execution for the lifetime of this program. + + Most of the doc string for enable_eager_execution is relevant here as well. + Args: + config: See enable_eager_execution doc string + device_policy: See enable_eager_execution doc string + execution_mode: See enable_eager_execution doc string + server_def: (Optional.) A tensorflow::ServerDef proto. + Enables execution on remote devices. GrpcServers need to be started by + creating an identical server_def to this, and setting the appropriate + task_indexes, so that the servers can communicate. It will then be + possible to execute operations on remote devices. + + Raises: + ValueError + + """ if config is not None and not isinstance(config, config_pb2.ConfigProto): raise TypeError( "config must be a tf.ConfigProto, but got %s" % type(config)) @@ -5476,7 +5288,8 @@ def enable_eager_execution(config=None, device_policy=None, context._context = context.Context( config=config, device_policy=device_policy, - execution_mode=execution_mode) + execution_mode=execution_mode, + server_def=server_def) elif ((config is not None and config is not context._context._config) or (device_policy is not None and device_policy is not context._context._device_policy) or @@ -5564,6 +5377,10 @@ def get_default_graph(): """ return _default_graph_stack.get_default() +def has_default_graph(): + """Returns True if there is a default graph.""" + return len(_default_graph_stack.stack) >= 1 + def get_name_scope(): """Returns the current name scope in the default_graph. @@ -5831,7 +5648,8 @@ def add_to_collection(name, value): value: The value to add to the collection. @compatibility(eager) - Collections are not supported when eager execution is enabled. + Collections are only supported in eager when variables are created inside an + EagerVariableStore (e.g. as part of a layer or template). @end_compatibility """ get_default_graph().add_to_collection(name, value) @@ -5849,7 +5667,8 @@ def add_to_collections(names, value): value: The value to add to the collections. @compatibility(eager) - Collections are not supported when eager execution is enabled. + Collections are only supported in eager when variables are created inside an + EagerVariableStore (e.g. as part of a layer or template). @end_compatibility """ get_default_graph().add_to_collections(names, value) @@ -6142,14 +5961,6 @@ def get_from_proto_function(collection_name): return None -def _assert_collection_is_ok(collection_name): - if context.executing_eagerly(): - if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access - raise ValueError( - "variable collections are not supported when eager execution is enabled." - ) - - def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): """Produce a nice error if someone converts an Operation to a Tensor.""" raise TypeError(("Can't convert Operation '%s' to Tensor " diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 6321e99671403f47b39b9816591337545139089c..150100d771bb41d3693d39dc6fa19baa40da4c04 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -270,7 +270,6 @@ class OperationTest(test_util.TensorFlowTestCase): op1 = ops.Operation( ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]) - g._add_op(op1) self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) self.assertEquals([], list(op1.inputs)) ref_t, nonref_t = op1.values() @@ -279,14 +278,12 @@ class OperationTest(test_util.TensorFlowTestCase): ops._NodeDef("RefInputFloatInput", "op2"), g, [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32]) - g._add_op(op2) self.assertProtoEquals( "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", op2.node_def) self.assertEquals([ref_t, nonref_t], list(op2.inputs)) op3 = ops.Operation( ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) - g._add_op(op3) self.assertProtoEquals( "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", op3.node_def) @@ -1693,7 +1690,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEager(self): def future(): future.calls += 1 @@ -1878,7 +1875,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): class OpScopeTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNames(self): with ops.name_scope("foo") as foo: self.assertEqual("foo/", foo) @@ -1909,7 +1906,7 @@ class OpScopeTest(test_util.TensorFlowTestCase): with ops.name_scope("a//b/c") as foo10: self.assertEqual("a//b/c/", foo10) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerDefaultScopeName(self): with ops.name_scope(None, "default") as scope: self.assertEqual(scope, "default/") @@ -2042,6 +2039,21 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertEqual(len(g1.get_operations()), 0) self.assertEqual(len(g0.get_operations()), 1) + def testPreservesDevices(self): + g0 = ops.Graph() + with g0.as_default(), ops.device("CPU:0"): + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + with g1.as_default(), ops.device("GPU:0"): + with ops.init_scope(): + # init_scope should preserve device set under `g1`. + on_gpu = constant_op.constant(1.0) + self.assertEqual(on_gpu.device, "/device:GPU:0") + still_on_gpu = constant_op.constant(1.0) + self.assertEqual(still_on_gpu.device, "/device:GPU:0") + on_cpu = constant_op.constant(1.0) + self.assertEqual(on_cpu.device, "/device:CPU:0") + def testComposes(self): g0 = ops.Graph() g1 = ops.Graph() @@ -2209,12 +2221,25 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertEqual(ops.get_name_scope(), "inner") self.assertEqual(ops.get_name_scope(), "") - def testEagerGraphContextsExecuteEagerly(self): + def testEnteringGraphFromEagerIsSticky(self): with context.eager_mode(): + g = ops.Graph() + with g.as_default(): + with ops.init_scope(): + self.assertFalse(context.executing_eagerly()) + self.assertEqual(g, ops.get_default_graph()) + + def testMixGraphEager(self): + with context.eager_mode(): + c = constant_op.constant(1.0) with ops.Graph().as_default(): - with context.graph_mode(): - with ops.init_scope(): - self.assertTrue(context.executing_eagerly()) + with self.assertRaisesRegexp( + RuntimeError, "Attempting to capture an EagerTensor"): + math_ops.add(c, c) + c2 = constant_op.constant(2.0) + with self.assertRaisesRegexp( + TypeError, "contains objects other than 'EagerTensor'"): + math_ops.add(c2, c2) def testPreservesNameScopeInEagerExecution(self): with context.eager_mode(): @@ -2248,6 +2273,11 @@ class GraphTest(test_util.TensorFlowTestCase): with g0.as_default(): ops.reset_default_graph() + def testGraphContextManagerCancelsEager(self): + with context.eager_mode(): + with ops.Graph().as_default(): + self.assertFalse(context.executing_eagerly()) + def testGraphContextManager(self): g0 = ops.Graph() with g0.as_default() as g1: diff --git a/tensorflow/python/framework/random_seed_test.py b/tensorflow/python/framework/random_seed_test.py index 194492268631abfa911bd45f13a302c09a2c8bda..6696bffc6c553f3fcf458f52cb9cd386e2711ff4 100644 --- a/tensorflow/python/framework/random_seed_test.py +++ b/tensorflow/python/framework/random_seed_test.py @@ -26,7 +26,7 @@ from tensorflow.python.platform import test class RandomSeedTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandomSeed(self): test_cases = [ # Each test case is a tuple with input to get_seed: diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 1fe81e5f17a7de0a113596d920d63e5d9474c7c1..6a5c6468f77382b2b7e62a6a49d4fb637fed4dc0 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections +from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -225,6 +226,7 @@ class SparseTensor(_TensorLike): SparseTensorValue = collections.namedtuple( "SparseTensorValue", ["indices", "values", "dense_shape"]) tf_export("SparseTensorValue")(SparseTensorValue) +pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue) @tf_export("convert_to_tensor_or_sparse_tensor") diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 0dd29460ed93aadf61ef1f1b2dbf1d7802ca4877..c9be3d50056b2838e8cf39c3a17e1cff14e67ea0 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -961,9 +961,12 @@ def unknown_shape(ndims=None): return TensorShape([Dimension(None)] * ndims) +_SCALAR_SHAPE = TensorShape([]) + + def scalar(): """Returns a shape representing a scalar.""" - return TensorShape([]) + return _SCALAR_SHAPE def vector(length): diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 35fff80c61b98e7603d3b7b5df3cabdb59059a72..d6edc1364369e1b4d06093879571cdb4e9ffe409 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -941,7 +941,7 @@ class ConstantValueTest(test.TestCase): class ConstantValueAsShapeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstant(self): np_val = np.random.rand(3).astype(np.int32) tf_val = constant_op.constant(np_val) @@ -954,13 +954,13 @@ class ConstantValueAsShapeTest(test.TestCase): tensor_shape.TensorShape([]), tensor_util.constant_value_as_shape(tf_val)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) c_val = tensor_util.constant_value_as_shape(tf_val) self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMinusOneBecomesNone(self): tf_val = constant_op.constant([-1, 1, -1], shape=[3]) c_val = tensor_util.constant_value_as_shape(tf_val) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5e02e7e3ec605e41deffad5e3faf1c9f8a2679ef..2bc2a189fa8e825613ca834e2c06ea916074d455 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -27,6 +27,7 @@ import random import re import tempfile import threading +import unittest import numpy as np import six @@ -61,13 +62,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -321,32 +322,6 @@ def NCHWToNHWC(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] -# TODO(skyewm): remove this eventually -# pylint: disable=protected-access -def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): - prev_value = ops._USE_C_API - ops._USE_C_API = use_c_api - try: - # Reset the default graph so it has the C API enabled. We call - # reset_default_graph() instead of creating a new default Graph context to - # make this robust to tests that call reset_default_graph(), which requires - # that the current default graph isn't nested. - ops.reset_default_graph() - fn(*args, **kwargs) - finally: - ops._USE_C_API = prev_value - # Make sure default graph reflects prev_value in case next test doesn't call - # reset_default_graph(). - ops.reset_default_graph() - - -# pylint: disable=protected-access - - -def c_api_and_cuda_enabled(): - return ops._USE_C_API and IsGoogleCudaEnabled() - - def skip_if(condition): """Skips the decorated function if condition is or evaluates to True. @@ -372,46 +347,6 @@ def skip_if(condition): return real_skip_if -# TODO(skyewm): remove this eventually -def disable_c_api(fn): - """Decorator for disabling the C API on a test. - - Note this disables the C API after running the test class's setup/teardown - methods. - - Args: - fn: the function to be wrapped - - Returns: - The wrapped function - """ - - def wrapper(*args, **kwargs): - _use_c_api_wrapper(fn, False, *args, **kwargs) - - return wrapper - - -# TODO(skyewm): remove this eventually -def enable_c_api(fn): - """Decorator for enabling the C API on a test. - - Note this enables the C API after running the test class's setup/teardown - methods. - - Args: - fn: the function to be wrapped - - Returns: - The wrapped function - """ - - def wrapper(*args, **kwargs): - _use_c_api_wrapper(fn, True, *args, **kwargs) - - return wrapper - - def enable_c_shapes(fn): """Decorator for enabling C shapes on a test. @@ -425,46 +360,19 @@ def enable_c_shapes(fn): The wrapped function """ + # pylint: disable=protected-access def wrapper(*args, **kwargs): prev_value = ops._USE_C_SHAPES - # Only use C shapes if the C API is already enabled. - ops._USE_C_SHAPES = ops._USE_C_API + ops._USE_C_SHAPES = True try: fn(*args, **kwargs) finally: ops._USE_C_SHAPES = prev_value + # pylint: enable=protected-access return wrapper -# This decorator is a hacky way to run all the test methods in a decorated -# class with and without C API enabled. -# TODO(iga): Remove this and its uses once we switch to using C API by default. -def with_c_api(cls): - """Adds methods that call original methods but with C API enabled. - - Note this enables the C API in new methods after running the test class's - setup method. This can be a problem if some objects are created in it - before the C API is enabled. - - Args: - cls: class to decorate - - Returns: - cls with new test methods added - """ - # If the C API is already enabled, don't do anything. Some tests break if the - # same test is run twice, so this allows us to turn on the C API by default - # without breaking these tests. - if ops._USE_C_API: - return cls - - for name, value in cls.__dict__.copy().items(): - if callable(value) and name.startswith("test"): - setattr(cls, name + "WithCApi", enable_c_api(value)) - return cls - - def with_c_shapes(cls): """Adds methods that call original methods but with C API shapes enabled. @@ -507,8 +415,28 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections} for _ in range(3): f(self, **kwargs) + # Note that gc.get_objects misses anything that isn't subject to garbage + # collection (C types). Collections are a common source of leaks, so we + # test for collection sizes explicitly. + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") + % (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) @@ -556,12 +484,16 @@ def assert_no_new_tensors(f): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - outside_graph_key = ops.get_default_graph()._graph_key - with ops.Graph().as_default(): + if context.executing_eagerly(): + f(self, **kwargs) + ops.reset_default_graph() + else: # Run the test in a new graph so that collections get cleared when it's # done, but inherit the graph key so optimizers behave. - ops.get_default_graph()._graph_key = outside_graph_key - f(self, **kwargs) + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): + ops.get_default_graph()._graph_key = outside_graph_key + f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. backprop._zeros_cache.flush() @@ -639,7 +571,16 @@ def assert_no_garbage_created(f): return decorator -def run_in_graph_and_eager_modes(__unused__=None, +def run_all_in_graph_and_eager_modes(cls): + """Execute all test methods in the given class with and without eager.""" + base_decorator = run_in_graph_and_eager_modes + for name, value in cls.__dict__.copy().items(): + if callable(value) and name.startswith("test"): + setattr(cls, name, base_decorator(value)) + return cls + + +def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, reset_test=True, @@ -657,7 +598,7 @@ def run_in_graph_and_eager_modes(__unused__=None, ```python class MyTests(tf.test.TestCase): - @run_in_graph_and_eager_modes() + @run_in_graph_and_eager_modes def test_foo(self): x = tf.constant([1, 2]) y = tf.constant([3, 4]) @@ -674,7 +615,9 @@ def run_in_graph_and_eager_modes(__unused__=None, Args: - __unused__: Prevents silently skipping tests. + func: function to be annotated. If `func` is None, this method returns a + decorator the can be applied to a function. If `func` is not None this + returns the decorator applied to `func`. config: An optional config_pb2.ConfigProto to use to configure the session when executing graphs. use_gpu: If True, attempt to run as many operations as possible on GPU. @@ -696,20 +639,19 @@ def run_in_graph_and_eager_modes(__unused__=None, eager execution enabled. """ - assert not __unused__, "Add () after run_in_graph_and_eager_modes." - def decorator(f): - def decorated(self, **kwargs): - with context.graph_mode(): - with self.test_session(use_gpu=use_gpu): - f(self, **kwargs) + if tf_inspect.isclass(f): + raise ValueError( + "`run_test_in_graph_and_eager_modes` only supports test methods. " + "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?") - if reset_test: - # This decorator runs the wrapped test twice. - # Reset the test environment between runs. - self.tearDown() - self._tempdir = None - self.setUp() + def decorated(self, **kwargs): + try: + with context.graph_mode(): + with self.test_session(use_gpu=use_gpu, config=config): + f(self, **kwargs) + except unittest.case.SkipTest: + pass def run_eagerly(self, **kwargs): if not use_gpu: @@ -719,15 +661,25 @@ def run_in_graph_and_eager_modes(__unused__=None, f(self, **kwargs) if assert_no_eager_garbage: + ops.reset_default_graph() run_eagerly = assert_no_new_tensors( assert_no_garbage_created(run_eagerly)) with context.eager_mode(): - with ops.Graph().as_default(): - run_eagerly(self, **kwargs) + if reset_test: + # This decorator runs the wrapped test twice. + # Reset the test environment between runs. + self.tearDown() + self._tempdir = None + self.setUp() + + run_eagerly(self, **kwargs) return decorated + if func is not None: + return decorator(func) + return decorator @@ -910,14 +862,13 @@ class TensorFlowTestCase(googletest.TestCase): def _eval_tensor(self, tensor): if tensor is None: return None - elif isinstance(tensor, ops.EagerTensor): - return tensor.numpy() - elif isinstance(tensor, resource_variable_ops.ResourceVariable): - return tensor.read_value().numpy() elif callable(tensor): return self._eval_helper(tensor()) else: - raise ValueError("Unsupported type %s." % type(tensor)) + try: + return tensor.numpy() + except AttributeError as e: + six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) def _eval_helper(self, tensors): if tensors is None: @@ -1019,7 +970,9 @@ class TensorFlowTestCase(googletest.TestCase): rewriter_config_pb2.RewriterConfig.OFF) return config - if graph is None: + if context.executing_eagerly(): + yield None + elif graph is None: if self._cached_session is None: self._cached_session = session.Session( graph=None, config=prepare_config(config)) @@ -1320,11 +1273,11 @@ class TensorFlowTestCase(googletest.TestCase): b, rtol=rtol, atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) + msg=("Mismatched value: a%s is different from b%s. %s" % + (path_str, path_str, msg))) except TypeError as e: - msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a), - path_str, type(b)) + msg = ("Error: a%s has %s, but b%s has %s. %s" % + (path_str, type(a), path_str, type(b), msg)) e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 0f53762f6fab16f3fef2f5511a11acf53467e250..122c14c8473f133f6a3bed1e6297394eaa1b845c 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -569,7 +569,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_np_rand, b_np_rand) self.assertEqual(a_rand, b_rand) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_callable_evaluate(self): def model(): return resource_variable_ops.ResourceVariable( @@ -578,7 +578,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): with context.eager_mode(): self.assertEqual(2, self.evaluate(model)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_tensors_evaluate(self): expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} nested = {"a": constant_op.constant(1), @@ -588,6 +588,27 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(expected, self.evaluate(nested)) + def test_run_in_graph_and_eager_modes(self): + l = [] + def inc(self, with_brackets): + del self # self argument is required by run_in_graph_and_eager_modes. + mode = "eager" if context.executing_eagerly() else "graph" + with_brackets = "with_brackets" if with_brackets else "without_brackets" + l.append((with_brackets, mode)) + + f = test_util.run_in_graph_and_eager_modes(inc) + f(self, with_brackets=False) + f = test_util.run_in_graph_and_eager_modes()(inc) + f(self, with_brackets=True) + + self.assertEqual(len(l), 4) + self.assertEqual(set(l), { + ("with_brackets", "graph"), + ("with_brackets", "eager"), + ("without_brackets", "graph"), + ("without_brackets", "eager"), + }) + def test_get_node_def_from_graph(self): graph_def = graph_pb2.GraphDef() node_foo = graph_def.node.add() @@ -595,6 +616,55 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def)) + def test_run_in_eager_and_graph_modes_test_class(self): + msg = "`run_test_in_graph_and_eager_modes` only supports test methods.*" + with self.assertRaisesRegexp(ValueError, msg): + @test_util.run_in_graph_and_eager_modes() + class Foo(object): + pass + del Foo # Make pylint unused happy. + + def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self): + modes = [] + def _test(self): + if not context.executing_eagerly(): + self.skipTest("Skipping in graph mode") + modes.append("eager" if context.executing_eagerly() else "graph") + test_util.run_in_graph_and_eager_modes(_test)(self) + self.assertEqual(modes, ["eager"]) + + def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self): + modes = [] + def _test(self): + if context.executing_eagerly(): + self.skipTest("Skipping in eager mode") + modes.append("eager" if context.executing_eagerly() else "graph") + test_util.run_in_graph_and_eager_modes(_test)(self) + self.assertEqual(modes, ["graph"]) + + def test_run_in_graph_and_eager_modes_setup_in_same_mode(self): + modes = [] + mode_name = lambda: "eager" if context.executing_eagerly() else "graph" + + class ExampleTest(test_util.TensorFlowTestCase): + + def runTest(self): + pass + + def setUp(self): + modes.append("setup_" + mode_name()) + + @test_util.run_in_graph_and_eager_modes + def testBody(self): + modes.append("run_" + mode_name()) + + e = ExampleTest() + e.setUp() + e.testBody() + + self.assertEqual(modes[0:2], ["setup_graph", "run_graph"]) + self.assertEqual(modes[2:], ["setup_eager", "run_eager"]) + class GarbageCollectionTest(test_util.TensorFlowTestCase): @@ -619,6 +689,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): ReferenceCycleTest().test_has_no_cycle() + @test_util.run_in_graph_and_eager_modes def test_no_leaked_tensor_decorator(self): class LeakedTensorTest(object): @@ -628,11 +699,11 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): @test_util.assert_no_new_tensors def test_has_leak(self): - self.a = constant_op.constant([3.]) + self.a = constant_op.constant([3.], name="leak") @test_util.assert_no_new_tensors def test_has_no_leak(self): - constant_op.constant([3.]) + constant_op.constant([3.], name="no-leak") with self.assertRaisesRegexp(AssertionError, "Tensors not deallocated"): LeakedTensorTest().test_has_leak() diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 2d6925d1a825808ce133eb0404b5bd4925861723..7d07c77c797668c858014cc31cf713050627d72f 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -158,6 +158,7 @@ def _get_config(layout_optimizer=True): layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, # do not remove duplicated nodes arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + rewrite_options.min_graph_nodes = -1 graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -1389,7 +1390,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 3 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) + self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testLoopWithVecAnd4D(self): @@ -1413,7 +1414,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 2 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) + self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testBinaryOpSecondPort(self): @@ -1443,7 +1444,8 @@ class LayoutOptimizerTest(test.TestCase): def testGradient(self): meta_graph = _simple_metagraph() rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + min_graph_nodes=-1) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) @@ -1457,7 +1459,8 @@ class LayoutOptimizerTest(test.TestCase): def testDepthwise(self): meta_graph = _simple_metagraph(depthwise=True) rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + min_graph_nodes=-1) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 7ed4b128e495c484d294ece40541427f21856cf1..b658edff2dffac9856432c575b9af0d2f0b1986b 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -76,7 +76,8 @@ class MemoryOptimizerSwapTest(test.TestCase): disable_model_pruning=True, meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, - memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) + memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL, + min_graph_nodes=-1) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), graph_size + 2) @@ -133,6 +134,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS), original_metagraph) self.assertGreater( @@ -158,6 +160,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS, # Checks that name scope "gradients/" also match sub-scope. @@ -297,6 +300,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): if 'Recomputed/' in node.name])) rewritten_graph_def = tf_optimizer.OptimizeGraph( rewriter_config_pb2.RewriterConfig( + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL), metagraph) self.assertEqual( diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 1c0f072dd32d38f048cfa48d38b45264951d095e..5a9afe725753749ea42d53382731ab14a3cf24f5 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -47,6 +47,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers.append('constfold') + rewriter_config.min_graph_nodes = -1 graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) @@ -68,6 +69,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) # Check that the nodes referenced in various collections have been preserved @@ -109,6 +111,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) mg.graph_def.CopyFrom(optimized_graph) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index b4213f0836e25d72cd61d9c94ebdb63649a4bc38..8b6b28bc776fa500a93d0a3fb3bf91081ba86967 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -19,74 +19,38 @@ py_library( name = "keras", srcs = [ "__init__.py", - "_impl/keras/__init__.py", - "_impl/keras/applications/__init__.py", - "_impl/keras/applications/densenet.py", - "_impl/keras/applications/imagenet_utils.py", - "_impl/keras/applications/inception_resnet_v2.py", - "_impl/keras/applications/inception_v3.py", - "_impl/keras/applications/mobilenet.py", - "_impl/keras/applications/nasnet.py", - "_impl/keras/applications/resnet50.py", - "_impl/keras/applications/vgg16.py", - "_impl/keras/applications/vgg19.py", - "_impl/keras/applications/xception.py", - "_impl/keras/datasets/__init__.py", - "_impl/keras/datasets/boston_housing.py", - "_impl/keras/datasets/cifar.py", - "_impl/keras/datasets/cifar10.py", - "_impl/keras/datasets/cifar100.py", - "_impl/keras/datasets/fashion_mnist.py", - "_impl/keras/datasets/imdb.py", - "_impl/keras/datasets/mnist.py", - "_impl/keras/datasets/reuters.py", - "_impl/keras/preprocessing/__init__.py", - "_impl/keras/preprocessing/image.py", - "_impl/keras/preprocessing/sequence.py", - "_impl/keras/preprocessing/text.py", - "_impl/keras/testing_utils.py", - "_impl/keras/utils/__init__.py", - "_impl/keras/utils/multi_gpu_utils.py", - "_impl/keras/utils/np_utils.py", - "_impl/keras/utils/vis_utils.py", - "_impl/keras/wrappers/__init__.py", - "_impl/keras/wrappers/scikit_learn.py", - "activations/__init__.py", "applications/__init__.py", - "applications/densenet/__init__.py", - "applications/inception_resnet_v2/__init__.py", - "applications/inception_v3/__init__.py", - "applications/mobilenet/__init__.py", - "applications/nasnet/__init__.py", - "applications/resnet50/__init__.py", - "applications/vgg16/__init__.py", - "applications/vgg19/__init__.py", - "applications/xception/__init__.py", - "backend/__init__.py", - "callbacks/__init__.py", - "constraints/__init__.py", + "applications/densenet.py", + "applications/imagenet_utils.py", + "applications/inception_resnet_v2.py", + "applications/inception_v3.py", + "applications/mobilenet.py", + "applications/nasnet.py", + "applications/resnet50.py", + "applications/vgg16.py", + "applications/vgg19.py", + "applications/xception.py", "datasets/__init__.py", - "datasets/boston_housing/__init__.py", - "datasets/cifar10/__init__.py", - "datasets/cifar100/__init__.py", - "datasets/fashion_mnist/__init__.py", - "datasets/imdb/__init__.py", - "datasets/mnist/__init__.py", - "datasets/reuters/__init__.py", - "initializers/__init__.py", - "layers/__init__.py", - "losses/__init__.py", - "metrics/__init__.py", - "models/__init__.py", - "optimizers/__init__.py", + "datasets/boston_housing.py", + "datasets/cifar.py", + "datasets/cifar10.py", + "datasets/cifar100.py", + "datasets/fashion_mnist.py", + "datasets/imdb.py", + "datasets/mnist.py", + "datasets/reuters.py", + "estimator/__init__.py", "preprocessing/__init__.py", - "preprocessing/image/__init__.py", - "preprocessing/sequence/__init__.py", - "preprocessing/text/__init__.py", - "regularizers/__init__.py", + "preprocessing/image.py", + "preprocessing/sequence.py", + "preprocessing/text.py", + "testing_utils.py", "utils/__init__.py", + "utils/multi_gpu_utils.py", + "utils/np_utils.py", + "utils/vis_utils.py", "wrappers/__init__.py", - "wrappers/scikit_learn/__init__.py", + "wrappers/scikit_learn.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], @@ -104,7 +68,7 @@ py_library( py_library( name = "backend", - srcs = ["_impl/keras/backend.py"], + srcs = ["backend.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", @@ -145,33 +109,34 @@ py_library( py_library( name = "engine", srcs = [ - "_impl/keras/activations.py", - "_impl/keras/callbacks.py", - "_impl/keras/constraints.py", - "_impl/keras/engine/__init__.py", - "_impl/keras/engine/base_layer.py", - "_impl/keras/engine/input_layer.py", - "_impl/keras/engine/network.py", - "_impl/keras/engine/saving.py", - "_impl/keras/engine/sequential.py", - "_impl/keras/engine/training.py", - "_impl/keras/engine/training_arrays.py", - "_impl/keras/engine/training_eager.py", - "_impl/keras/engine/training_generator.py", - "_impl/keras/engine/training_utils.py", - "_impl/keras/initializers.py", - "_impl/keras/losses.py", - "_impl/keras/metrics.py", - "_impl/keras/models.py", - "_impl/keras/optimizers.py", - "_impl/keras/regularizers.py", - "_impl/keras/utils/data_utils.py", - "_impl/keras/utils/io_utils.py", + "activations.py", + "callbacks.py", + "constraints.py", + "engine/__init__.py", + "engine/base_layer.py", + "engine/input_layer.py", + "engine/network.py", + "engine/saving.py", + "engine/sequential.py", + "engine/training.py", + "engine/training_arrays.py", + "engine/training_eager.py", + "engine/training_generator.py", + "engine/training_utils.py", + "initializers.py", + "losses.py", + "metrics.py", + "models.py", + "optimizers.py", + "regularizers.py", + "utils/data_utils.py", + "utils/io_utils.py", ], srcs_version = "PY2AND3", deps = [ ":backend", "//tensorflow/python/data", + "//tensorflow/python/training/checkpointable:data_structures", "@six_archive//:six", ], ) @@ -179,25 +144,25 @@ py_library( py_library( name = "layers", srcs = [ - "_impl/keras/layers/__init__.py", - "_impl/keras/layers/advanced_activations.py", - "_impl/keras/layers/convolutional.py", - "_impl/keras/layers/convolutional_recurrent.py", - "_impl/keras/layers/core.py", - "_impl/keras/layers/cudnn_recurrent.py", - "_impl/keras/layers/embeddings.py", - "_impl/keras/layers/local.py", - "_impl/keras/layers/merge.py", - "_impl/keras/layers/noise.py", - "_impl/keras/layers/normalization.py", - "_impl/keras/layers/pooling.py", - "_impl/keras/layers/recurrent.py", - "_impl/keras/layers/serialization.py", - "_impl/keras/layers/wrappers.py", - "_impl/keras/utils/conv_utils.py", - "_impl/keras/utils/generic_utils.py", - "_impl/keras/utils/layer_utils.py", - "_impl/keras/utils/tf_utils.py", + "layers/__init__.py", + "layers/advanced_activations.py", + "layers/convolutional.py", + "layers/convolutional_recurrent.py", + "layers/core.py", + "layers/cudnn_recurrent.py", + "layers/embeddings.py", + "layers/local.py", + "layers/merge.py", + "layers/noise.py", + "layers/normalization.py", + "layers/pooling.py", + "layers/recurrent.py", + "layers/serialization.py", + "layers/wrappers.py", + "utils/conv_utils.py", + "utils/generic_utils.py", + "utils/layer_utils.py", + "utils/tf_utils.py", ], srcs_version = "PY2AND3", deps = [ @@ -224,7 +189,7 @@ py_library( py_test( name = "integration_test", size = "medium", - srcs = ["_impl/keras/integration_test.py"], + srcs = ["integration_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -239,7 +204,7 @@ py_test( py_test( name = "activations_test", size = "small", - srcs = ["_impl/keras/activations_test.py"], + srcs = ["activations_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -251,7 +216,7 @@ py_test( py_test( name = "constraints_test", size = "small", - srcs = ["_impl/keras/constraints_test.py"], + srcs = ["constraints_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -263,7 +228,7 @@ py_test( py_test( name = "initializers_test", size = "small", - srcs = ["_impl/keras/initializers_test.py"], + srcs = ["initializers_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -276,7 +241,7 @@ py_test( py_test( name = "regularizers_test", size = "small", - srcs = ["_impl/keras/regularizers_test.py"], + srcs = ["regularizers_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -287,7 +252,7 @@ py_test( py_test( name = "optimizers_test", size = "medium", - srcs = ["_impl/keras/optimizers_test.py"], + srcs = ["optimizers_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -301,7 +266,7 @@ py_test( py_test( name = "losses_test", size = "small", - srcs = ["_impl/keras/losses_test.py"], + srcs = ["losses_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -313,7 +278,7 @@ py_test( py_test( name = "metrics_test", size = "medium", - srcs = ["_impl/keras/metrics_test.py"], + srcs = ["metrics_test.py"], srcs_version = "PY2AND3", tags = [ "manual", @@ -330,7 +295,7 @@ py_test( py_test( name = "densenet_test", size = "large", - srcs = ["_impl/keras/applications/densenet_test.py"], + srcs = ["applications/densenet_test.py"], srcs_version = "PY2AND3", tags = ["nomsan"], # times out, http://b/78650237 deps = [ @@ -343,7 +308,7 @@ py_test( py_test( name = "inception_resnet_v2_test", size = "medium", - srcs = ["_impl/keras/applications/inception_resnet_v2_test.py"], + srcs = ["applications/inception_resnet_v2_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -355,7 +320,7 @@ py_test( py_test( name = "inception_v3_test", size = "medium", - srcs = ["_impl/keras/applications/inception_v3_test.py"], + srcs = ["applications/inception_v3_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -367,7 +332,7 @@ py_test( py_test( name = "mobilenet_test", size = "medium", - srcs = ["_impl/keras/applications/mobilenet_test.py"], + srcs = ["applications/mobilenet_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -379,7 +344,7 @@ py_test( py_test( name = "nasnet_test", size = "large", - srcs = ["_impl/keras/applications/nasnet_test.py"], + srcs = ["applications/nasnet_test.py"], srcs_version = "PY2AND3", tags = ["nomsan"], # times out, http://b/78573625 deps = [ @@ -392,7 +357,7 @@ py_test( py_test( name = "resnet50_test", size = "medium", - srcs = ["_impl/keras/applications/resnet50_test.py"], + srcs = ["applications/resnet50_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -403,7 +368,7 @@ py_test( py_test( name = "vgg16_test", size = "small", - srcs = ["_impl/keras/applications/vgg16_test.py"], + srcs = ["applications/vgg16_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -414,7 +379,7 @@ py_test( py_test( name = "vgg19_test", size = "small", - srcs = ["_impl/keras/applications/vgg19_test.py"], + srcs = ["applications/vgg19_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -425,7 +390,7 @@ py_test( py_test( name = "xception_test", size = "medium", - srcs = ["_impl/keras/applications/xception_test.py"], + srcs = ["applications/xception_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -437,7 +402,7 @@ py_test( py_test( name = "advanced_activations_test", size = "small", - srcs = ["_impl/keras/layers/advanced_activations_test.py"], + srcs = ["layers/advanced_activations_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -448,7 +413,7 @@ py_test( py_test( name = "convolutional_recurrent_test", size = "large", - srcs = ["_impl/keras/layers/convolutional_recurrent_test.py"], + srcs = ["layers/convolutional_recurrent_test.py"], shard_count = 2, srcs_version = "PY2AND3", deps = [ @@ -461,7 +426,7 @@ py_test( py_test( name = "convolutional_test", size = "large", - srcs = ["_impl/keras/layers/convolutional_test.py"], + srcs = ["layers/convolutional_test.py"], srcs_version = "PY2AND3", tags = [ "manual", @@ -478,7 +443,7 @@ py_test( cuda_py_test( name = "cudnn_recurrent_test", size = "large", - srcs = ["_impl/keras/layers/cudnn_recurrent_test.py"], + srcs = ["layers/cudnn_recurrent_test.py"], additional_deps = [ ":keras", "@absl_py//absl/testing:parameterized", @@ -491,7 +456,7 @@ cuda_py_test( py_test( name = "pooling_test", size = "small", - srcs = ["_impl/keras/layers/pooling_test.py"], + srcs = ["layers/pooling_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -502,7 +467,7 @@ py_test( py_test( name = "core_test", size = "medium", - srcs = ["_impl/keras/layers/core_test.py"], + srcs = ["layers/core_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -514,7 +479,7 @@ py_test( py_test( name = "embeddings_test", size = "small", - srcs = ["_impl/keras/layers/embeddings_test.py"], + srcs = ["layers/embeddings_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -525,7 +490,7 @@ py_test( py_test( name = "local_test", size = "medium", - srcs = ["_impl/keras/layers/local_test.py"], + srcs = ["layers/local_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -537,7 +502,7 @@ py_test( py_test( name = "merge_test", size = "small", - srcs = ["_impl/keras/layers/merge_test.py"], + srcs = ["layers/merge_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -549,7 +514,7 @@ py_test( py_test( name = "noise_test", size = "small", - srcs = ["_impl/keras/layers/noise_test.py"], + srcs = ["layers/noise_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -560,7 +525,7 @@ py_test( py_test( name = "normalization_test", size = "medium", - srcs = ["_impl/keras/layers/normalization_test.py"], + srcs = ["layers/normalization_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -573,7 +538,7 @@ py_test( py_test( name = "simplernn_test", size = "medium", - srcs = ["_impl/keras/layers/simplernn_test.py"], + srcs = ["layers/simplernn_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -585,8 +550,8 @@ py_test( py_test( name = "gru_test", - size = "medium", - srcs = ["_impl/keras/layers/gru_test.py"], + size = "large", + srcs = ["layers/gru_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], # http://b/62136390 deps = [ @@ -599,7 +564,7 @@ py_test( py_test( name = "lstm_test", size = "medium", - srcs = ["_impl/keras/layers/lstm_test.py"], + srcs = ["layers/lstm_test.py"], shard_count = 4, srcs_version = "PY2AND3", tags = [ @@ -616,7 +581,7 @@ py_test( py_test( name = "recurrent_test", size = "medium", - srcs = ["_impl/keras/layers/recurrent_test.py"], + srcs = ["layers/recurrent_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -628,7 +593,7 @@ py_test( py_test( name = "serialization_test", size = "small", - srcs = ["_impl/keras/layers/serialization_test.py"], + srcs = ["layers/serialization_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -639,7 +604,7 @@ py_test( py_test( name = "wrappers_test", size = "medium", - srcs = ["_impl/keras/layers/wrappers_test.py"], + srcs = ["layers/wrappers_test.py"], shard_count = 4, srcs_version = "PY2AND3", tags = [ @@ -656,7 +621,7 @@ py_test( py_test( name = "scikit_learn_test", size = "small", - srcs = ["_impl/keras/wrappers/scikit_learn_test.py"], + srcs = ["wrappers/scikit_learn_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -669,7 +634,7 @@ py_test( py_test( name = "data_utils_test", size = "large", - srcs = ["_impl/keras/utils/data_utils_test.py"], + srcs = ["utils/data_utils_test.py"], srcs_version = "PY2AND3", tags = [ "no_oss", @@ -688,7 +653,7 @@ py_test( py_test( name = "generic_utils_test", size = "small", - srcs = ["_impl/keras/utils/generic_utils_test.py"], + srcs = ["utils/generic_utils_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -699,7 +664,7 @@ py_test( py_test( name = "io_utils_test", size = "small", - srcs = ["_impl/keras/utils/io_utils_test.py"], + srcs = ["utils/io_utils_test.py"], srcs_version = "PY2AND3", tags = [ "no_windows", # TODO: needs investigation on Windows @@ -715,7 +680,7 @@ py_test( py_test( name = "np_utils_test", size = "small", - srcs = ["_impl/keras/utils/np_utils_test.py"], + srcs = ["utils/np_utils_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -726,7 +691,7 @@ py_test( cuda_py_test( name = "multi_gpu_utils_test", - srcs = ["_impl/keras/utils/multi_gpu_utils_test.py"], + srcs = ["utils/multi_gpu_utils_test.py"], additional_deps = [ ":keras", "//third_party/py/numpy", @@ -741,7 +706,7 @@ cuda_py_test( py_test( name = "imagenet_utils_test", size = "small", - srcs = ["_impl/keras/applications/imagenet_utils_test.py"], + srcs = ["applications/imagenet_utils_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -753,7 +718,7 @@ py_test( py_test( name = "image_test", size = "medium", - srcs = ["_impl/keras/preprocessing/image_test.py"], + srcs = ["preprocessing/image_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -765,7 +730,7 @@ py_test( py_test( name = "sequence_test", size = "small", - srcs = ["_impl/keras/preprocessing/sequence_test.py"], + srcs = ["preprocessing/sequence_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -777,7 +742,7 @@ py_test( py_test( name = "text_test", size = "small", - srcs = ["_impl/keras/preprocessing/text_test.py"], + srcs = ["preprocessing/text_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -789,7 +754,7 @@ py_test( py_test( name = "callbacks_test", size = "medium", - srcs = ["_impl/keras/callbacks_test.py"], + srcs = ["callbacks_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -802,7 +767,7 @@ py_test( py_test( name = "training_test", size = "medium", - srcs = ["_impl/keras/engine/training_test.py"], + srcs = ["engine/training_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -815,7 +780,7 @@ py_test( py_test( name = "training_eager_test", size = "medium", - srcs = ["_impl/keras/engine/training_eager_test.py"], + srcs = ["engine/training_eager_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -828,7 +793,7 @@ py_test( py_test( name = "model_subclassing_test", size = "medium", - srcs = ["_impl/keras/model_subclassing_test.py"], + srcs = ["model_subclassing_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -841,7 +806,7 @@ py_test( py_test( name = "topology_test", size = "small", - srcs = ["_impl/keras/engine/topology_test.py"], + srcs = ["engine/topology_test.py"], srcs_version = "PY2AND3", tags = [ "no-internal-py3", @@ -856,7 +821,7 @@ py_test( py_test( name = "saving_test", size = "medium", - srcs = ["_impl/keras/engine/saving_test.py"], + srcs = ["engine/saving_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -869,7 +834,7 @@ py_test( py_test( name = "sequential_test", size = "small", - srcs = ["_impl/keras/engine/sequential_test.py"], + srcs = ["engine/sequential_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", @@ -881,7 +846,7 @@ py_test( py_test( name = "models_test", size = "small", - srcs = ["_impl/keras/models_test.py"], + srcs = ["models_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], # b/67509773 deps = [ @@ -894,21 +859,22 @@ py_test( py_test( name = "backend_test", - size = "small", - srcs = ["_impl/keras/backend_test.py"], + size = "medium", + srcs = ["backend_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", "//tensorflow/python:client_testlib", "//tensorflow/python:util", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) py_library( name = "testing_utils", srcs = [ - "_impl/keras/testing_utils.py", + "testing_utils.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index f76cfa66082c3a323f3ed8f22684de277f273a67..198c66d9e184c82423e529540b92ad447b947cf8 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,13 +21,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=wildcard-import from tensorflow.python.keras import activations from tensorflow.python.keras import applications from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks from tensorflow.python.keras import constraints from tensorflow.python.keras import datasets +from tensorflow.python.keras import estimator from tensorflow.python.keras import initializers from tensorflow.python.keras import layers from tensorflow.python.keras import losses @@ -39,11 +38,16 @@ from tensorflow.python.keras import preprocessing from tensorflow.python.keras import regularizers from tensorflow.python.keras import utils from tensorflow.python.keras import wrappers -from tensorflow.python.keras._impl.keras import __version__ from tensorflow.python.keras.layers import Input from tensorflow.python.keras.models import Model from tensorflow.python.keras.models import Sequential +from tensorflow.python.util.tf_export import tf_export + +__version__ = '2.1.6-tf' + +tf_export('keras.__version__').export_constant(__name__, '__version__') + del absolute_import del division del print_function diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py deleted file mode 100644 index 9bb140bfb86649c7b3b3263478033e9632776480..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The Keras API. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import applications -from tensorflow.python.keras._impl.keras import backend -from tensorflow.python.keras._impl.keras import callbacks -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import datasets -from tensorflow.python.keras._impl.keras import engine -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import layers -from tensorflow.python.keras._impl.keras import losses -from tensorflow.python.keras._impl.keras import metrics -from tensorflow.python.keras._impl.keras import models -from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras import preprocessing -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras import utils -from tensorflow.python.keras._impl.keras import wrappers -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import Sequential - -__version__ = '2.1.6-tf' diff --git a/tensorflow/python/keras/_impl/keras/applications/__init__.py b/tensorflow/python/keras/_impl/keras/applications/__init__.py deleted file mode 100644 index 206a769b377483c65a78b76fe44055eb50bdc7c4..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/_impl/keras/applications/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras Applications: models with automatic loading of pre-trained weights. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import InceptionResNetV2 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge -from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 -from tensorflow.python.keras._impl.keras.applications.xception import Xception diff --git a/tensorflow/python/keras/_impl/keras/layers/__init__.py b/tensorflow/python/keras/_impl/keras/layers/__init__.py deleted file mode 100644 index d7bc859280eeedfb41d2c78d4042a181484f3d20..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/_impl/keras/layers/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras layers module. -""" -# pylint: disable=wildcard-import -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.layers.advanced_activations import * -from tensorflow.python.keras._impl.keras.layers.convolutional import * -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import * -from tensorflow.python.keras._impl.keras.layers.core import * -from tensorflow.python.keras._impl.keras.layers.cudnn_recurrent import * -from tensorflow.python.keras._impl.keras.layers.embeddings import * -from tensorflow.python.keras._impl.keras.layers.local import * -from tensorflow.python.keras._impl.keras.layers.merge import * -from tensorflow.python.keras._impl.keras.layers.noise import * -from tensorflow.python.keras._impl.keras.layers.normalization import * -from tensorflow.python.keras._impl.keras.layers.pooling import * -from tensorflow.python.keras._impl.keras.layers.recurrent import * -from tensorflow.python.keras._impl.keras.layers.serialization import deserialize -from tensorflow.python.keras._impl.keras.layers.serialization import serialize -from tensorflow.python.keras._impl.keras.layers.wrappers import * diff --git a/tensorflow/python/keras/_impl/keras/utils/__init__.py b/tensorflow/python/keras/_impl/keras/utils/__init__.py deleted file mode 100644 index 0c9f19a0c8dcf3bf929e102b31679a03b27728f7..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/_impl/keras/utils/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras utilities. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary -from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model - diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/activations.py similarity index 62% rename from tensorflow/python/keras/_impl/keras/activations.py rename to tensorflow/python/keras/activations.py index 8def7ec49375c7ce23e8f2a24a4c3615d05ca9bb..f608dea430f0573503713f0cbc60f8921e6df51e 100644 --- a/tensorflow/python/keras/_impl/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -20,8 +20,8 @@ from __future__ import print_function import six -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @@ -32,7 +32,7 @@ def softmax(x, axis=-1): """Softmax activation function. Arguments: - x : Tensor. + x : Input tensor. axis: Integer, axis along which the softmax normalization is applied. Returns: @@ -49,28 +49,52 @@ def softmax(x, axis=-1): s = math_ops.reduce_sum(e, axis=axis, keepdims=True) return e / s else: - raise ValueError('Cannot apply softmax to a tensor that is 1D') + raise ValueError('Cannot apply softmax to a tensor that is 1D. ' + 'Received input: %s' % (x,)) @tf_export('keras.activations.elu') def elu(x, alpha=1.0): + """Exponential linear unit. + + Arguments: + x: Input tensor. + alpha: A scalar, slope of negative section. + + Returns: + The exponential linear activation: `x` if `x > 0` and + `alpha * (exp(x)-1)` if `x < 0`. + + Reference: + - [Fast and Accurate Deep Network Learning by Exponential + Linear Units (ELUs)](https://arxiv.org/abs/1511.07289) + """ return K.elu(x, alpha) @tf_export('keras.activations.selu') def selu(x): - """Scaled Exponential Linear Unit. (Klambauer et al., 2017). + """Scaled Exponential Linear Unit (SELU). + + SELU is equal to: `scale * elu(x, alpha)`, where alpha and scale + are pre-defined constants. The values of `alpha` and `scale` are + chosen so that the mean and variance of the inputs are preserved + between two consecutive layers as long as the weights are initialized + correctly (see `lecun_normal` initialization) and the number of inputs + is "large enough" (see references for more information). Arguments: x: A tensor or variable to compute the activation function for. Returns: - Tensor with the same shape and dtype as `x`. + The scaled exponential unit activation: `scale * elu(x, alpha)`. # Note - To be used together with the initialization "lecun_normal". - To be used together with the dropout variant "AlphaDropout". + References: + - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) """ alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 @@ -79,16 +103,44 @@ def selu(x): @tf_export('keras.activations.softplus') def softplus(x): + """Softplus activation function. + + Arguments: + x: Input tensor. + + Returns: + The softplus activation: `log(exp(x) + 1)`. + """ return nn.softplus(x) @tf_export('keras.activations.softsign') def softsign(x): + """Softsign activation function. + + Arguments: + x: Input tensor. + + Returns: + The softplus activation: `x / (abs(x) + 1)`. + """ return nn.softsign(x) @tf_export('keras.activations.relu') def relu(x, alpha=0., max_value=None): + """Rectified Linear Unit. + + Arguments: + x: Input tensor. + alpha: Slope of the negative part. Defaults to zero. + max_value: Maximum value for the output. + + Returns: + The (leaky) rectified linear unit activation: `x` if `x > 0`, + `alpha * x` if `x < 0`. If `max_value` is defined, the result + is truncated to this value. + """ return K.relu(x, alpha=alpha, max_value=max_value) @@ -104,6 +156,19 @@ def sigmoid(x): @tf_export('keras.activations.hard_sigmoid') def hard_sigmoid(x): + """Hard sigmoid activation function. + + Faster to compute than sigmoid activation. + + Arguments: + x: Input tensor. + + Returns: + Hard sigmoid activation: + - `0` if `x < -2.5` + - `1` if `x > 2.5` + - `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`. + """ return K.hard_sigmoid(x) diff --git a/tensorflow/python/keras/activations/__init__.py b/tensorflow/python/keras/activations/__init__.py deleted file mode 100644 index d04838c218d6643a703723a1d163c88547c14da7..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/activations/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in activation functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh - -# Auxiliary utils. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/activations_test.py b/tensorflow/python/keras/activations_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/activations_test.py rename to tensorflow/python/keras/activations_test.py index fb0bb5f1269d112e3f268ce211a2ddeb24b417bf..5cff1f8f9cb06569029150e44a4c2adfb370229d 100644 --- a/tensorflow/python/keras/_impl/keras/activations_test.py +++ b/tensorflow/python/keras/activations_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py index fccedf919a7b261bb30f332172b1388db9da1939..062135266dd8b11c489b7dff83b46ae29a0d21e6 100644 --- a/tensorflow/python/keras/applications/__init__.py +++ b/tensorflow/python/keras/applications/__init__.py @@ -18,15 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.applications import densenet -from tensorflow.python.keras.applications import inception_resnet_v2 -from tensorflow.python.keras.applications import inception_v3 -from tensorflow.python.keras.applications import mobilenet -from tensorflow.python.keras.applications import nasnet -from tensorflow.python.keras.applications import resnet50 -from tensorflow.python.keras.applications import vgg16 -from tensorflow.python.keras.applications import vgg19 -from tensorflow.python.keras.applications import xception from tensorflow.python.keras.applications.densenet import DenseNet121 from tensorflow.python.keras.applications.densenet import DenseNet169 from tensorflow.python.keras.applications.densenet import DenseNet201 diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py similarity index 90% rename from tensorflow/python/keras/_impl/keras/applications/densenet.py rename to tensorflow/python/keras/applications/densenet.py index ca83e8691237216e799f2ca738dcb6822506e2cb..8df6d086111c4b179d2f0c7b5c1130a6cd95aaab 100644 --- a/tensorflow/python/keras/_impl/keras/applications/densenet.py +++ b/tensorflow/python/keras/applications/densenet.py @@ -27,24 +27,24 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.applications import imagenet_utils -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import Concatenate -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.applications import imagenet_utils +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import AveragePooling2D +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import Concatenate +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.layers import ZeroPadding2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export @@ -238,7 +238,7 @@ def DenseNet(blocks, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet_test.py b/tensorflow/python/keras/applications/densenet_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/applications/densenet_test.py rename to tensorflow/python/keras/applications/densenet_test.py index 3b92287a1e77a944c069a6c234e11e4a79ad7d32..8b6aa281ad0e2d0798952b7489c89892709cda29 100644 --- a/tensorflow/python/keras/_impl/keras/applications/densenet_test.py +++ b/tensorflow/python/keras/applications/densenet_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/applications/imagenet_utils.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py rename to tensorflow/python/keras/applications/imagenet_utils.py index d928a7afdc639485d443be382420cac09ba9abd6..0d8ccca1b5c2a6c05f0d933a8f0fe176ea62c2a3 100644 --- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py +++ b/tensorflow/python/keras/applications/imagenet_utils.py @@ -23,8 +23,8 @@ import json import numpy as np from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils_test.py b/tensorflow/python/keras/applications/imagenet_utils_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/applications/imagenet_utils_test.py rename to tensorflow/python/keras/applications/imagenet_utils_test.py index d843dace59f1c88744217fbaee605d2ac859ec55..349339309017f3e9e3a9922d95188f1954ed8634 100644 --- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils_test.py +++ b/tensorflow/python/keras/applications/imagenet_utils_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input +from tensorflow.python import keras +from tensorflow.python.keras.applications.imagenet_utils import preprocess_input from tensorflow.python.platform import test @@ -197,4 +197,3 @@ class ImageNetUtilsTest(test.TestCase): if __name__ == '__main__': test.main() - diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py similarity index 91% rename from tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py rename to tensorflow/python/keras/applications/inception_resnet_v2.py index 17e407dd58460e6d6802a3e137a96faf38a6f576..14e3b6aa60dbfa7e62e04849d35633eed162a416 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -27,24 +27,24 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.applications import imagenet_utils -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import Concatenate -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import Lambda -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.applications import imagenet_utils +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import AveragePooling2D +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import Concatenate +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import Lambda +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -354,7 +354,7 @@ def InceptionResNetV2(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor` if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2_test.py b/tensorflow/python/keras/applications/inception_resnet_v2_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2_test.py rename to tensorflow/python/keras/applications/inception_resnet_v2_test.py index de71e9615a09ecdf07a51fff0b3ee3b1d8ca50ca..0a12f885052ae9530e82190f7580c8288860c9a8 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2_test.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py similarity index 91% rename from tensorflow/python/keras/_impl/keras/applications/inception_v3.py rename to tensorflow/python/keras/applications/inception_v3.py index 2897c6058eb445ceacc34084b53dc89f556e3e9c..b5e28c781f71e67b8d835b50070b49add2d7930a 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/applications/inception_v3.py @@ -32,23 +32,23 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers -from tensorflow.python.keras._impl.keras.applications import imagenet_utils -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras.applications import imagenet_utils +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import AveragePooling2D +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -375,7 +375,7 @@ def InceptionV3(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3_test.py b/tensorflow/python/keras/applications/inception_v3_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/applications/inception_v3_test.py rename to tensorflow/python/keras/applications/inception_v3_test.py index 20e11fa019134423cc7c0499e7507680e13cb86d..a3fcdd55644af5a2211b58169d87ab4fba996b19 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_v3_test.py +++ b/tensorflow/python/keras/applications/inception_v3_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py similarity index 92% rename from tensorflow/python/keras/_impl/keras/applications/mobilenet.py rename to tensorflow/python/keras/applications/mobilenet.py index 18a0612e13838b77f43a9eb39b1b1ad0ee7e9359..e56c695a288026d12de6bc0bdb65706c71eefe14 100644 --- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -71,28 +71,28 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.applications import imagenet_utils -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import DepthwiseConv2D -from tensorflow.python.keras._impl.keras.layers import Dropout -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import Reshape -from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.applications import imagenet_utils +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import DepthwiseConv2D +from tensorflow.python.keras.layers import Dropout +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import Reshape +from tensorflow.python.keras.layers import ZeroPadding2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -317,7 +317,7 @@ def MobileNet(input_shape=None, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet_test.py b/tensorflow/python/keras/applications/mobilenet_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/applications/mobilenet_test.py rename to tensorflow/python/keras/applications/mobilenet_test.py index 601d417e496b8230a2ad846eab204763ff5564b8..5661ed7856ad6e307cf3e388ea3db98c69db983f 100644 --- a/tensorflow/python/keras/_impl/keras/applications/mobilenet_test.py +++ b/tensorflow/python/keras/applications/mobilenet_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/applications/nasnet.py rename to tensorflow/python/keras/applications/nasnet.py index f3412d71be525d704e8e0d5f21f3c3941f59a066..ff79b3a057b8fd6ab3b0edf652a5bede0e2d7b87 100644 --- a/tensorflow/python/keras/_impl/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -45,27 +45,27 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import add -from tensorflow.python.keras._impl.keras.layers import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import concatenate -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Cropping2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers import SeparableConv2D -from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import add +from tensorflow.python.keras.layers import AveragePooling2D +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import concatenate +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Cropping2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.layers import SeparableConv2D +from tensorflow.python.keras.layers import ZeroPadding2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -290,7 +290,7 @@ def NASNet(input_shape=None, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py b/tensorflow/python/keras/applications/nasnet_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/applications/nasnet_test.py rename to tensorflow/python/keras/applications/nasnet_test.py index aa1dec670cb995e47bdcf88bd69594c532781b18..f96c3aa51c17ff3a123ad1a22ceff6c23f69d311 100644 --- a/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py +++ b/tensorflow/python/keras/applications/nasnet_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/applications/resnet50.py rename to tensorflow/python/keras/applications/resnet50.py index c3a92bea8920cad3297fee3efc50158813e72361..6afc08681214c5dbb0577623d30e27e9988c6a57 100644 --- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py +++ b/tensorflow/python/keras/applications/resnet50.py @@ -29,26 +29,25 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import Flatten -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils import layer_utils -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.applications.imagenet_utils import preprocess_input +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import AveragePooling2D +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import Flatten +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.layers import ZeroPadding2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -277,7 +276,7 @@ def ResNet50(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/resnet50/__init__.py b/tensorflow/python/keras/applications/resnet50/__init__.py deleted file mode 100644 index 530805d150bfe32c5b81d7d7d3f92e203b83b602..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/resnet50/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ResNet50 Keras application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50_test.py b/tensorflow/python/keras/applications/resnet50_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/applications/resnet50_test.py rename to tensorflow/python/keras/applications/resnet50_test.py index 07f9ffd73f55ee39351af71223e7919b08ca66e1..22a3f055805f48bb27ad75db664b142d7916b654 100644 --- a/tensorflow/python/keras/_impl/keras/applications/resnet50_test.py +++ b/tensorflow/python/keras/applications/resnet50_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/applications/vgg16.py rename to tensorflow/python/keras/applications/vgg16.py index 25a15475eaa4038fcf7364d519e13a0d5d7839da..cef0230da96ed4b9c992e57839ebb2071383e3b1 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py +++ b/tensorflow/python/keras/applications/vgg16.py @@ -28,21 +28,20 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import Flatten -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils import layer_utils -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.applications.imagenet_utils import preprocess_input +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import Flatten +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -202,7 +201,7 @@ def VGG16(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/vgg16/__init__.py b/tensorflow/python/keras/applications/vgg16/__init__.py deleted file mode 100644 index 118361604bbc7e0a88ed34243c0d5ea98856a301..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/vgg16/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""VGG16 Keras application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16_test.py b/tensorflow/python/keras/applications/vgg16_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/applications/vgg16_test.py rename to tensorflow/python/keras/applications/vgg16_test.py index e6eba83678def582c1a9fb477399790dbded8a15..cad65765f3d18c5a458c802a6b1aed688468d444 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg16_test.py +++ b/tensorflow/python/keras/applications/vgg16_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/applications/vgg19.py rename to tensorflow/python/keras/applications/vgg19.py index b09d0068b79738ffe157c486560d7a4fe90dc0a6..c4031f551003eda076380d1ae5208ee0876f5750 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py +++ b/tensorflow/python/keras/applications/vgg19.py @@ -28,21 +28,20 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import Flatten -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils import layer_utils -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.applications.imagenet_utils import preprocess_input +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import Flatten +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -211,7 +210,7 @@ def VGG19(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/vgg19/__init__.py b/tensorflow/python/keras/applications/vgg19/__init__.py deleted file mode 100644 index cda52628f3c10d65fdbe70b2f86cc12c771870a9..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/vgg19/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""VGG19 Keras application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19_test.py b/tensorflow/python/keras/applications/vgg19_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/applications/vgg19_test.py rename to tensorflow/python/keras/applications/vgg19_test.py index 25100a2993f8a650b9ec441bf0c2c528f13364a4..61dccc0c5cc315cc0e5c0284cf829ac2034c69d2 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg19_test.py +++ b/tensorflow/python/keras/applications/vgg19_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py similarity index 90% rename from tensorflow/python/keras/_impl/keras/applications/xception.py rename to tensorflow/python/keras/applications/xception.py index 971063a16d1f5ba0e25189f1ef2f6c24eb5f8d61..01397cfac2563273ba1215003df1afab293b6b20 100644 --- a/tensorflow/python/keras/_impl/keras/applications/xception.py +++ b/tensorflow/python/keras/applications/xception.py @@ -39,23 +39,23 @@ from __future__ import print_function import os -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers -from tensorflow.python.keras._impl.keras.applications import imagenet_utils -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.layers import Activation -from tensorflow.python.keras._impl.keras.layers import BatchNormalization -from tensorflow.python.keras._impl.keras.layers import Conv2D -from tensorflow.python.keras._impl.keras.layers import Dense -from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers import Input -from tensorflow.python.keras._impl.keras.layers import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers import SeparableConv2D -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras.applications import imagenet_utils +from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import BatchNormalization +from tensorflow.python.keras.layers import Conv2D +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import MaxPooling2D +from tensorflow.python.keras.layers import SeparableConv2D +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -302,7 +302,7 @@ def Xception(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/xception/__init__.py b/tensorflow/python/keras/applications/xception/__init__.py deleted file mode 100644 index ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/xception/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Xception Keras application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/applications/xception_test.py b/tensorflow/python/keras/applications/xception_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/applications/xception_test.py rename to tensorflow/python/keras/applications/xception_test.py index 7ebdc30010aa48362046b3c0c281fe1f2be64a84..7e2efd0017836ae671d88b561385b6e61be9fa0b 100644 --- a/tensorflow/python/keras/_impl/keras/applications/xception_test.py +++ b/tensorflow/python/keras/applications/xception_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/backend.py similarity index 95% rename from tensorflow/python/keras/_impl/keras/backend.py rename to tensorflow/python/keras/backend.py index af3d1fa33d3431e7b13d1910a8581393e7b912c6..824513dce07fc31edc6f8eca512efd99a1a258cc 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import collections +import itertools import json import os import weakref @@ -2794,10 +2795,15 @@ class Function(object): if not isinstance(self.fetches, list): self.fetches = [self.fetches] # The main use case of `fetches` being passed to a model is the ability - # to run custom updates (since the outputs of fetches are never returned). + # to run custom updates # This requires us to wrap fetches in `identity` ops. self.fetches = [array_ops.identity(x) for x in self.fetches] self.session_kwargs = session_kwargs + # This mapping keeps track of the function that should receive the + # output from a fetch in `fetches`: { fetch: function(fetch_output) } + # A Callback can use this to register a function with access to the + # output values for a fetch it added. + self.fetch_callbacks = dict() if session_kwargs: raise ValueError('Some keys in session_kwargs are not supported at this ' @@ -2807,6 +2813,7 @@ class Function(object): self._feed_arrays = None self._feed_symbols = None self._symbol_vals = None + self._fetches = None self._session = None def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session): @@ -2852,8 +2859,14 @@ class Function(object): self._feed_arrays = feed_arrays self._feed_symbols = feed_symbols self._symbol_vals = symbol_vals + self._fetches = list(self.fetches) self._session = session + def _call_fetch_callbacks(self, fetches_output): + for fetch, output in zip(self._fetches, fetches_output): + if fetch in self.fetch_callbacks: + self.fetch_callbacks[fetch](output) + def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') @@ -2880,21 +2893,24 @@ class Function(object): feed_arrays.append(tensor) # We need to do array conversion and type casting at this level, since # `callable_fn` only supports exact matches. - array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name)) + tensor_type = dtypes_module.as_dtype(tensor.dtype) + array_vals.append(np.asarray(value, + dtype=tensor_type.as_numpy_dtype)) + if self.feed_dict: for key in sorted(self.feed_dict.keys()): array_vals.append( np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name)) # Refresh callable if anything has changed. - if (self._callable_fn is None or - feed_arrays != self._feed_arrays or + if (self._callable_fn is None or feed_arrays != self._feed_arrays or symbol_vals != self._symbol_vals or - feed_symbols != self._feed_symbols or + feed_symbols != self._feed_symbols or self.fetches != self._fetches or session != self._session): self._make_callable(feed_arrays, feed_symbols, symbol_vals, session) fetched = self._callable_fn(*array_vals) + self._call_fetch_callbacks(fetched[-len(self._fetches):]) return fetched[:len(self.outputs)] @@ -2973,30 +2989,29 @@ def rnn(step_function, Arguments: step_function: RNN step function. - Parameters; - input; tensor with shape `(samples, ...)` (no time dimension), + Args; + input; Tensor with shape `(samples, ...)` (no time dimension), representing input for the batch of samples at a certain time step. - states; list of tensors. + states; List of tensors. Returns; - output; tensor with shape `(samples, output_dim)` + output; Tensor with shape `(samples, output_dim)` (no time dimension). - new_states; list of tensors, same length and shapes + new_states; List of tensors, same length and shapes as 'states'. The first state in the list must be the output tensor at the previous timestep. - inputs: tensor of temporal data of shape `(samples, time, ...)` + inputs: Tensor of temporal data of shape `(samples, time, ...)` (at least 3D). - initial_states: tensor with shape (samples, output_dim) + initial_states: Tensor with shape `(samples, output_dim)` (no time dimension), containing the initial values for the states used in the step function. - go_backwards: boolean. If True, do the iteration over the time + go_backwards: Boolean. If True, do the iteration over the time dimension in reverse order and return the reversed sequence. - mask: binary tensor with shape `(samples, time, 1)`, + mask: Binary tensor with shape `(samples, time, 1)`, with a zero for every element that is masked. - constants: a list of constant values passed at each step. - unroll: whether to unroll the RNN or to use a symbolic loop - (`while_loop` or `scan` depending on backend). + constants: List of constant values passed at each step. + unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. input_length: If specified, assume time dimension is of this length. Returns: @@ -3158,10 +3173,16 @@ def rnn(step_function, array_ops.stack( [1, array_ops.shape(output)[1]])) output = array_ops.where(tiled_mask_t, output, states[0]) - new_states = [ - array_ops.where(tiled_mask_t, new_states[i], states[i]) - for i in range(len(states)) - ] + + masked_states = [] + for i in range(len(states)): + states_dim = array_ops.shape(new_states[i])[1] + stacked_states_dim = array_ops.stack([1, states_dim]) + tiled_mask = array_ops.tile(mask_t, stacked_states_dim) + masked_state = array_ops.where(tiled_mask, new_states[i], states[i]) + masked_states.append(masked_state) + new_states = masked_states + output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states) else: @@ -3637,12 +3658,12 @@ def _preprocess_conv1d_input(x, data_format): Returns: A tensor. """ - tf_data_format = 'NHWC' # to pass TF Conv2dNative operations + tf_data_format = 'NWC' # to pass TF Conv2dNative operations if data_format == 'channels_first': if not _has_nchw_support(): x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC else: - tf_data_format = 'NCHW' + tf_data_format = 'NCW' return x, tf_data_format @@ -3741,10 +3762,8 @@ def conv1d(x, x = temporal_padding(x, (left_pad, 0)) padding = 'valid' padding = _preprocess_padding(padding) - if data_format == 'channels_last': - tf_data_format = 'NWC' - else: - tf_data_format = 'NCW' + + x, tf_data_format = _preprocess_conv1d_input(x, data_format) x = nn.convolution( input=x, filter=kernel, @@ -3752,6 +3771,8 @@ def conv1d(x, strides=(strides,), padding=padding, data_format=tf_data_format) + if data_format == 'channels_first' and tf_data_format == 'NWC': + x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW return x @@ -3892,11 +3913,16 @@ def separable_conv1d(x, if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) + if isinstance(strides, int): + strides = (strides,) + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) + x, tf_data_format = _preprocess_conv1d_input(x, data_format) padding = _preprocess_padding(padding) if not isinstance(strides, tuple): strides = tuple(strides) - if tf_data_format == 'NHWC': + if tf_data_format == 'NWC': spatial_start_dim = 1 strides = (1,) + strides * 2 + (1,) else: @@ -3918,7 +3944,7 @@ def separable_conv1d(x, x = array_ops.squeeze(x, [spatial_start_dim]) - if data_format == 'channels_first' and tf_data_format == 'NHWC': + if data_format == 'channels_first' and tf_data_format == 'NWC': x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW return x @@ -4238,45 +4264,115 @@ def pool3d(x, return x -def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): - """Apply 1D conv with un-shared weights. - - Arguments: - inputs: 3D tensor with shape: (batch_size, steps, input_dim) - kernel: the unshared weight for convolution, - with shape (output_length, feature_dim, filters) - kernel_size: a tuple of a single integer, - specifying the length of the 1D convolution window - strides: a tuple of a single integer, - specifying the stride length of the convolution - data_format: the data format, channels_first or channels_last - - Returns: - the tensor after 1d conv with un-shared weights, with shape (batch_size, - output_length, filters) +def local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format=None): + """Apply N-D convolution with un-shared weights. + + Arguments: + inputs: (N+2)-D tensor with shape + (batch_size, channels_in, d_in1, ..., d_inN) + if data_format='channels_first', or + (batch_size, d_in1, ..., d_inN, channels_in) + if data_format='channels_last'. + kernel: the unshared weight for N-D convolution, + with shape (output_items, feature_dim, channels_out), where + feature_dim = np.prod(kernel_size) * channels_in, + output_items = np.prod(output_shape). + kernel_size: a tuple of N integers, specifying the + spatial dimensions of the N-D convolution window. + strides: a tuple of N integers, specifying the strides + of the convolution along the spatial dimensions. + output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial + dimensionality of the output. + data_format: string, "channels_first" or "channels_last". + + Returns: + An (N+2)-D tensor with shape: + (batch_size, channels_out) + output_shape + if data_format='channels_first', or: + (batch_size,) + output_shape + (channels_out,) + if data_format='channels_last'. Raises: - ValueError: if `data_format` is neither `channels_last` or - `channels_first`. + ValueError: if `data_format` is neither + `channels_last` nor `channels_first`. """ if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) - stride = strides[0] kernel_shape = int_shape(kernel) - output_length = kernel_shape[0] feature_dim = kernel_shape[1] + channels_out = kernel_shape[-1] + ndims = len(output_shape) + spatial_dimensions = list(range(ndims)) xs = [] - for i in range(output_length): - slice_length = slice(i * stride, i * stride + kernel_size[0]) - xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim))) + output_axes_ticks = [range(axis_max) for axis_max in output_shape] + for position in itertools.product(*output_axes_ticks): + slices = [slice(None)] + + if data_format == 'channels_first': + slices.append(slice(None)) + + slices.extend([slice(position[d] * strides[d], + position[d] * strides[d] + kernel_size[d]) + for d in spatial_dimensions]) + + if data_format == 'channels_last': + slices.append(slice(None)) + + xs.append(reshape(inputs[slices], (1, -1, feature_dim))) + x_aggregate = concatenate(xs, axis=0) - # Shape: `(output_length, batch_size, filters)`. output = batch_dot(x_aggregate, kernel) - return permute_dimensions(output, (1, 0, 2)) + output = reshape(output, output_shape + (-1, channels_out)) + + if data_format == 'channels_first': + permutation = [ndims, ndims + 1] + spatial_dimensions + else: + permutation = [ndims] + spatial_dimensions + [ndims + 1] + + return permute_dimensions(output, permutation) + + +def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): + """Apply 1D conv with un-shared weights. + + Arguments: + inputs: 3D tensor with shape: + (batch_size, steps, input_dim) + if data_format is "channels_last" or + (batch_size, input_dim, steps) + if data_format is "channels_first". + kernel: the unshared weight for convolution, + with shape (output_length, feature_dim, filters). + kernel_size: a tuple of a single integer, + specifying the length of the 1D convolution window. + strides: a tuple of a single integer, + specifying the stride length of the convolution. + data_format: the data format, channels_first or channels_last. + + Returns: + A 3d tensor with shape: + (batch_size, output_length, filters) + if data_format='channels_first' + or 3D tensor with shape: + (batch_size, filters, output_length) + if data_format='channels_last'. + """ + output_shape = (kernel.shape[0],) + return local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format) def local_conv2d(inputs, @@ -4289,64 +4385,34 @@ def local_conv2d(inputs, Arguments: inputs: 4D tensor with shape: - (batch_size, filters, new_rows, new_cols) - if data_format='channels_first' - or 4D tensor with shape: - (batch_size, new_rows, new_cols, filters) - if data_format='channels_last'. + (batch_size, filters, new_rows, new_cols) + if data_format='channels_first' + or 4D tensor with shape: + (batch_size, new_rows, new_cols, filters) + if data_format='channels_last'. kernel: the unshared weight for convolution, - with shape (output_items, feature_dim, filters) + with shape (output_items, feature_dim, filters). kernel_size: a tuple of 2 integers, specifying the - width and height of the 2D convolution window. + width and height of the 2D convolution window. strides: a tuple of 2 integers, specifying the strides - of the convolution along the width and height. - output_shape: a tuple with (output_row, output_col) - data_format: the data format, channels_first or channels_last + of the convolution along the width and height. + output_shape: a tuple with (output_row, output_col). + data_format: the data format, channels_first or channels_last. Returns: - A 4d tensor with shape: + A 4D tensor with shape: (batch_size, filters, new_rows, new_cols) if data_format='channels_first' or 4D tensor with shape: (batch_size, new_rows, new_cols, filters) if data_format='channels_last'. - - Raises: - ValueError: if `data_format` is neither - `channels_last` or `channels_first`. """ - if data_format is None: - data_format = image_data_format() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format: ' + str(data_format)) - - stride_row, stride_col = strides - output_row, output_col = output_shape - kernel_shape = int_shape(kernel) - feature_dim = kernel_shape[1] - filters = kernel_shape[2] - - xs = [] - for i in range(output_row): - for j in range(output_col): - slice_row = slice(i * stride_row, i * stride_row + kernel_size[0]) - slice_col = slice(j * stride_col, j * stride_col + kernel_size[1]) - if data_format == 'channels_first': - xs.append( - reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim))) - else: - xs.append( - reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim))) - - x_aggregate = concatenate(xs, axis=0) - output = batch_dot(x_aggregate, kernel) - output = reshape(output, (output_row, output_col, -1, filters)) - - if data_format == 'channels_first': - output = permute_dimensions(output, (2, 3, 0, 1)) - else: - output = permute_dimensions(output, (2, 0, 1, 3)) - return output + return local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format) @tf_export('keras.backend.bias_add') @@ -4704,8 +4770,13 @@ def foldr(fn, elems, initializer=None, name=None): # Load Keras default configuration from config file if present. -_keras_base_dir = os.path.expanduser('~') -_keras_dir = os.path.join(_keras_base_dir, '.keras') +# Set Keras base dir path given KERAS_HOME env variable, if applicable. +# Otherwise either ~/.keras or /tmp. +if 'KERAS_HOME' in os.environ: + _keras_dir = os.environ.get('KERAS_HOME') +else: + _keras_base_dir = os.path.expanduser('~') + _keras_dir = os.path.join(_keras_base_dir, '.keras') _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json')) if os.path.exists(_config_path): try: diff --git a/tensorflow/python/keras/backend/__init__.py b/tensorflow/python/keras/backend/__init__.py deleted file mode 100644 index 10ef5a75852deb6595bced2703d7c5f29b0efac3..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/backend/__init__.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras backend API.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/backend_test.py similarity index 81% rename from tensorflow/python/keras/_impl/keras/backend_test.py rename to tensorflow/python/keras/backend_test.py index b2243473aa823d08f6510a15232894530670c81b..36478ea089a871667908d70e33422aef8444a3e4 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -17,11 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np import scipy.sparse +from tensorflow.python import keras +from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor -from tensorflow.python.keras._impl import keras from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.util import tf_inspect @@ -274,6 +276,36 @@ class BackendUtilsTest(test.TestCase): self.assertEqual( keras.backend.get_session().run(fetches=[x, y]), [30., 40.]) + def test_function_fetch_callbacks(self): + + class CallbackStub(object): + + def __init__(self): + self.times_called = 0 + self.callback_result = 0 + + def _fetch_callback(self, result): + self.times_called += 1 + self.callback_result = result + + with self.test_session(): + callback = CallbackStub() + x_placeholder = keras.backend.placeholder(shape=()) + y_placeholder = keras.backend.placeholder(shape=()) + + callback_op = x_placeholder * y_placeholder + + f = keras.backend.function( + inputs=[x_placeholder, y_placeholder], + outputs=[x_placeholder + y_placeholder]) + f.fetches.append(callback_op) + f.fetch_callbacks[callback_op] = callback._fetch_callback + + _ = f([10., 20.]) + + self.assertEqual(callback.times_called, 1) + self.assertEqual(callback.callback_result, 200) + class BackendVariableTest(test.TestCase): @@ -661,7 +693,7 @@ class BackendShapeOpsTest(test.TestCase): np_kwargs={'data_format': 'channels_first'}) -class BackendNNOpsTest(test.TestCase): +class BackendNNOpsTest(test.TestCase, parameterized.TestCase): def test_bias_add(self): with self.test_session(): @@ -810,6 +842,118 @@ class BackendNNOpsTest(test.TestCase): padding='same', data_format='channels_last') self.assertEqual(y.get_shape().as_list(), [10, 5, 5]) + def test_local_conv_channels_dim(self): + filters = 3 + batch_size = 2 + + for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]: + channels_in = input_shape[0] + input_spatial_shape = input_shape[1:] + dim = len(input_spatial_shape) + + inputs = np.random.normal(0, 1, (batch_size,) + input_shape) + inputs_cf = keras.backend.variable(inputs) + + for kernel_size in [1, 2]: + for stride in [1, 2]: + kernel_sizes = (kernel_size,) * dim + strides = (stride,) * dim + + output_shape = tuple([(i - kernel_size + stride) // stride + for i in input_spatial_shape]) + + kernel_shape = (np.prod(output_shape), + np.prod(kernel_sizes) * channels_in, + filters) + + kernel = np.random.normal( + 0, + 1, + output_shape + (channels_in, np.prod(kernel_sizes), filters) + ) + + kernel_cf = np.reshape(kernel, kernel_shape) + kernel_cf = keras.backend.variable(kernel_cf) + + conv_cf = keras.backend.local_conv(inputs_cf, + kernel_cf, + kernel_sizes, + strides, + output_shape, + 'channels_first') + + inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) + + [1]) + inputs_cl = keras.backend.variable(inputs_cl) + + kernel_cl = np.reshape( + np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]), + kernel_shape + ) + kernel_cl = keras.backend.variable(kernel_cl) + + conv_cl = keras.backend.local_conv(inputs_cl, + kernel_cl, + kernel_sizes, + strides, + output_shape, + 'channels_last') + with self.test_session(): + conv_cf = keras.backend.eval(conv_cf) + conv_cl = keras.backend.eval(conv_cl) + + self.assertAllCloseAccordingToType( + conv_cf, + np.transpose(conv_cl, + [0, dim + 1] + list(range(1, dim + 1))), + atol=1e-5 + ) + + @parameterized.named_parameters( + ('local_conv1d', (5, 6), (3,), (1,), (3,)), + ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3))) + def test_local_conv_1d_and_2d(self, + input_shape, + kernel_sizes, + strides, + output_shape): + filters = 3 + batch_size = 2 + + inputs = np.random.normal(0, 1, (batch_size,) + input_shape) + inputs = keras.backend.variable(inputs) + + kernel = np.random.normal(0, 1, (np.prod(output_shape), + np.prod(kernel_sizes) * input_shape[-1], + filters)) + kernel = keras.backend.variable(kernel) + + local_conv = keras.backend.local_conv(inputs, + kernel, + kernel_sizes, + strides, + output_shape, + 'channels_last') + if len(output_shape) == 1: + local_conv_dim = keras.backend.local_conv1d(inputs, + kernel, + kernel_sizes, + strides, + 'channels_last') + else: + local_conv_dim = keras.backend.local_conv2d(inputs, + kernel, + kernel_sizes, + strides, + output_shape, + 'channels_last') + + with self.test_session(): + local_conv = keras.backend.eval(local_conv) + local_conv_dim = keras.backend.eval(local_conv_dim) + + self.assertAllCloseAccordingToType(local_conv, local_conv_dim) + def test_conv2d(self): val = np.random.random((10, 4, 10, 10)) x = keras.backend.variable(val) @@ -963,7 +1107,7 @@ class BackendNNOpsTest(test.TestCase): {'go_backwards': False, 'mask': mask, 'unroll': True}, ] with self.test_session(): - for (i, kwargs) in enumerate(kwargs_list): + for i, kwargs in enumerate(kwargs_list): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) @@ -1010,6 +1154,115 @@ class BackendNNOpsTest(test.TestCase): for b_s, b_u_s in zip(state_list[2], state_list[3]): self.assertAllClose(b_s, b_u_s, atol=1e-04) + def test_rnn_additional_states(self): + # implement a simple RNN + num_samples = 4 + input_dim = 5 + output_dim = 3 + timesteps = 6 + + input_val = np.random.random( + (num_samples, timesteps, input_dim)).astype(np.float32) + init_state_val = np.random.random( + (num_samples, output_dim)).astype(np.float32) + w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32) + w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32) + np_mask = np.random.randint(2, size=(num_samples, timesteps)) + + def rnn_step_fn(): + w_i = keras.backend.variable(w_i_val) + w_o = keras.backend.variable(w_o_val) + + def step_function(x, states): + assert len(states) == 2 + prev_output = states[0] + output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o) + return output, [output, + keras.backend.concatenate([output, output], axis=-1)] + + return step_function + + # test default setup + last_output_list = [[], [], [], [], [], []] + outputs_list = [[], [], [], [], [], []] + state_list = [[], [], [], [], [], []] + additional_state_list = [[], [], [], [], [], []] + + rnn_fn = rnn_step_fn() + inputs = keras.backend.variable(input_val) + initial_states = [keras.backend.variable(init_state_val), + np.concatenate([init_state_val, init_state_val], axis=-1)] + mask = keras.backend.variable(np_mask) + + kwargs_list = [ + {'go_backwards': False, 'mask': None}, + {'go_backwards': False, 'mask': None, 'unroll': True}, + {'go_backwards': True, 'mask': None}, + {'go_backwards': True, 'mask': None, 'unroll': True}, + {'go_backwards': False, 'mask': mask}, + {'go_backwards': False, 'mask': mask, 'unroll': True}, + ] + with self.test_session(): + for i, kwargs in enumerate(kwargs_list): + last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, + initial_states, + **kwargs) + # check static shape inference + self.assertEqual(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + # for state in new_states: + # self.assertEquals(state.get_shape().as_list(), + # [num_samples, output_dim]) + self.assertEqual(new_states[0].get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(new_states[1].get_shape().as_list(), + [num_samples, 2 * output_dim]) + + last_output_list[i].append(keras.backend.eval(last_output)) + outputs_list[i].append(keras.backend.eval(outputs)) + self.assertEqual(len(new_states), 2) + state_list[i].append(keras.backend.eval(new_states[0])) + additional_state_list[i].append(keras.backend.eval(new_states[1])) + + def assert_list_pairwise(z_list, atol=1e-05): + for (z1, z2) in zip(z_list[1:], z_list[:-1]): + self.assertAllClose(z1, z2, atol=atol) + + assert_list_pairwise(last_output_list[0], atol=1e-04) + assert_list_pairwise(outputs_list[0], atol=1e-04) + assert_list_pairwise(state_list[0], atol=1e-04) + assert_list_pairwise(additional_state_list[0], atol=1e-04) + assert_list_pairwise(last_output_list[2], atol=1e-04) + assert_list_pairwise(outputs_list[2], atol=1e-04) + assert_list_pairwise(state_list[2], atol=1e-04) + assert_list_pairwise(additional_state_list[2], atol=1e-04) + + for l, u_l in zip(last_output_list[0], last_output_list[1]): + self.assertAllClose(l, u_l, atol=1e-04) + + for o, u_o in zip(outputs_list[0], outputs_list[1]): + self.assertAllClose(o, u_o, atol=1e-04) + + for s, u_s in zip(state_list[0], state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[0], additional_state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]): + self.assertAllClose(b_l, b_u_l, atol=1e-04) + + for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]): + self.assertAllClose(b_o, b_u_o, atol=1e-04) + + for b_s, b_u_s in zip(state_list[2], state_list[3]): + self.assertAllClose(b_s, b_u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[2], additional_state_list[3]): + self.assertAllClose(s, u_s, atol=1e-04) + def test_normalize_batch_in_training(self): val = np.random.random((10, 3, 10, 10)) x = keras.backend.variable(val) @@ -1165,6 +1418,13 @@ class TestRandomOps(test.TestCase): self.assertAllClose(np.max(y), 2., atol=0.1) self.assertAllClose(np.min(y), -2., atol=0.1) + def test_string_input(self): + seq = keras.Sequential([ + keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string), + keras.layers.Lambda(lambda x: x[0]) + ]) + preds = seq.predict([['tensorflow eager']]) + self.assertEqual(preds.shape, (1,)) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/callbacks.py similarity index 93% rename from tensorflow/python/keras/_impl/keras/callbacks.py rename to tensorflow/python/keras/callbacks.py index 79864a5c67819492011e1f3b2dcbe50bcca82ac0..3ae06d7ab870f7125a123de51fab95d543efe56c 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -24,14 +24,15 @@ from collections import Iterable from collections import OrderedDict import csv import json +import math import os import time import numpy as np import six -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary @@ -424,7 +425,7 @@ class ModelCheckpoint(Callback): if mode not in ['auto', 'min', 'max']: logging.warning('ModelCheckpoint mode %s is unknown, ' - 'fallback to auto mode.', (mode), RuntimeWarning) + 'fallback to auto mode.', mode) mode = 'auto' if mode == 'min': @@ -451,7 +452,7 @@ class ModelCheckpoint(Callback): current = logs.get(self.monitor) if current is None: logging.warning('Can save best model only with %s available, ' - 'skipping.', self.monitor, RuntimeWarning) + 'skipping.', self.monitor) else: if self.monitor_op(current, self.best): if self.verbose > 0: @@ -496,6 +497,9 @@ class EarlyStopping(Callback): monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. + baseline: baseline value for the monitored quantity. + Training will stop if the model doesn't show improvement over the + baseline. """ def __init__(self, @@ -503,19 +507,21 @@ class EarlyStopping(Callback): min_delta=0, patience=0, verbose=0, - mode='auto'): + mode='auto', + baseline=None): super(EarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.verbose = verbose - self.min_delta = min_delta + self.baseline = baseline + self.min_delta = abs(min_delta) self.wait = 0 self.stopped_epoch = 0 if mode not in ['auto', 'min', 'max']: logging.warning('EarlyStopping mode %s is unknown, ' - 'fallback to auto mode.', mode, RuntimeWarning) + 'fallback to auto mode.', mode) mode = 'auto' if mode == 'min': @@ -537,14 +543,17 @@ class EarlyStopping(Callback): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) if current is None: logging.warning('Early stopping conditioned on metric `%s` ' 'which is not available. Available metrics are: %s', - self.monitor, ','.join(list(logs.keys())), RuntimeWarning) + self.monitor, ','.join(list(logs.keys()))) return if self.monitor_op(current - self.min_delta, self.best): self.best = current @@ -635,7 +644,11 @@ class LearningRateScheduler(Callback): def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, 'lr'): raise ValueError('Optimizer must have a "lr" attribute.') - lr = self.schedule(epoch) + try: # new API + lr = float(K.get_value(self.model.optimizer.lr)) + lr = self.schedule(epoch, lr) + except TypeError: # Support for old API for backward compatibility + lr = self.schedule(epoch) if not isinstance(lr, (float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') @@ -711,8 +724,13 @@ class TensorBoard(Callback): self.write_grads = write_grads self.write_images = write_images self.batch_size = batch_size + self._current_batch = 0 + # abstracted writer class to be able to stub for testing + self._writer_class = tf_summary.FileWriter def set_model(self, model): + """Sets Keras model and creates summary ops.""" + self.model = model self.sess = K.get_session() if self.histogram_freq and self.merged is None: @@ -720,15 +738,6 @@ class TensorBoard(Callback): for weight in layer.weights: mapped_weight_name = weight.name.replace(':', '_') tf_summary.histogram(mapped_weight_name, weight) - if self.write_grads: - grads = model.optimizer.get_gradients(model.total_loss, weight) - - def is_indexed_slices(grad): - return type(grad).__name__ == 'IndexedSlices' - - grads = [grad.values if is_indexed_slices(grad) else grad - for grad in grads] - tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if self.write_images: w_img = array_ops.squeeze(weight) shape = K.int_shape(w_img) @@ -755,59 +764,58 @@ class TensorBoard(Callback): assert len(shape) == 4 and shape[-1] in [1, 3, 4] tf_summary.image(mapped_weight_name, w_img) + if self.write_grads: + for weight in layer.trainable_weights: + mapped_weight_name = weight.name.replace(':', '_') + grads = model.optimizer.get_gradients(model.total_loss, weight) + + def is_indexed_slices(grad): + return type(grad).__name__ == 'IndexedSlices' + + grads = [grad.values if is_indexed_slices(grad) else grad + for grad in grads] + tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) + if hasattr(layer, 'output'): tf_summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf_summary.merge_all() if self.write_graph: - self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph) + self.writer = self._writer_class(self.log_dir, self.sess.graph) else: - self.writer = tf_summary.FileWriter(self.log_dir) + self.writer = self._writer_class(self.log_dir) - def on_epoch_end(self, epoch, logs=None): - logs = logs or {} + def _fetch_callback(self, summary): + self.writer.add_summary( + summary, self._epoch + self._current_batch / self._batches_per_epoch) + self._current_batch += 1 + + def on_epoch_begin(self, epoch, logs=None): + """Add histogram op to Model test_function callbacks, reset batch count.""" if not self.validation_data and self.histogram_freq: raise ValueError('If printing histograms, validation_data must be ' 'provided, and cannot be a generator.') - if self.validation_data and self.histogram_freq: - if epoch % self.histogram_freq == 0: - - val_data = self.validation_data - tensors = ( - self.model.inputs + self.model.targets + self.model.sample_weights) - - if self.model.uses_learning_phase: - tensors += [K.learning_phase()] - - assert len(val_data) == len(tensors) - val_size = val_data[0].shape[0] - i = 0 - while i < val_size: - step = min(self.batch_size, val_size - i) - batch_val = [] - batch_val.append(val_data[0][i:i + step] - if val_data[0] is not None else None) - batch_val.append(val_data[1][i:i + step] - if val_data[1] is not None else None) - batch_val.append(val_data[2][i:i + step] - if val_data[2] is not None else None) - if self.model.uses_learning_phase: - # do not slice the learning phase - batch_val = [x[i:i + step] if x is not None else None - for x in val_data[:-1]] - batch_val.append(val_data[-1]) - else: - batch_val = [x[i:i + step] if x is not None else None - for x in val_data] - feed_dict = {} - for key, val in zip(tensors, batch_val): - if val is not None: - feed_dict[key] = val - result = self.sess.run([self.merged], feed_dict=feed_dict) - summary_str = result[0] - self.writer.add_summary(summary_str, epoch) - i += self.batch_size + if self.histogram_freq and epoch % self.histogram_freq == 0: + self._epoch = epoch + self._current_batch = 0 + self._batches_per_epoch = math.ceil( + self.validation_data[0].shape[0] / self.batch_size) + if self.merged not in self.model.test_function.fetches: + self.model.test_function.fetches.append(self.merged) + self.model.test_function.fetch_callbacks[ + self.merged] = self._fetch_callback + + def on_epoch_end(self, epoch, logs=None): + """Checks if summary ops should run next epoch, logs scalar summaries.""" + + logs = logs or {} + + if self.histogram_freq and self.histogram_freq > 1: + if self.merged in self.model.test_function.fetches: + self.model.test_function.fetches.remove(self.merged) + if self.merged in self.model.test_function.fetch_callbacks: + self.model.test_function.fetch_callbacks.pop(self.merged) for name, value in logs.items(): if name in ['batch', 'size']: @@ -898,7 +906,7 @@ class ReduceLROnPlateau(Callback): """ if self.mode not in ['auto', 'min', 'max']: logging.warning('Learning Rate Plateau Reducing mode %s is unknown, ' - 'fallback to auto mode.', self.mode, RuntimeWarning) + 'fallback to auto mode.', self.mode) self.mode = 'auto' if (self.mode == 'min' or (self.mode == 'auto' and 'acc' not in self.monitor)): @@ -920,7 +928,7 @@ class ReduceLROnPlateau(Callback): if current is None: logging.warning('Reduce LR on plateau conditioned on metric `%s` ' 'which is not available. Available metrics are: %s', - self.monitor, ','.join(list(logs.keys())), RuntimeWarning) + self.monitor, ','.join(list(logs.keys()))) else: if self.in_cooldown(): diff --git a/tensorflow/python/keras/callbacks/__init__.py b/tensorflow/python/keras/callbacks/__init__.py deleted file mode 100644 index 2d884790ddb9ccf49649c6af4cfd40cddbc38cb3..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/callbacks/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras callback classes.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/callbacks_test.py rename to tensorflow/python/keras/callbacks_test.py index 1d9d48dd2d27c08660302af231de31a0724e0664..d56f2f5bfc7d7045a4c1d2bde764fe1143764922 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -27,8 +27,9 @@ import unittest import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.core.framework import summary_pb2 +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary.writer import writer_cache @@ -273,16 +274,43 @@ class KerasCallbacksTest(test.TestCase): 1, activation='sigmoid'),)) model.compile( optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) - stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) weights = model.get_weights() + stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience # This should allow training to go for at least `patience` epochs model.set_weights(weights) hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) - assert len(hist.epoch) >= patience + assert len(hist.epoch) >= patience + + def test_EarlyStopping_with_baseline(self): + with self.test_session(): + np.random.seed(1337) + baseline = 0.5 + (data, labels), _ = testing_utils.get_test_data( + train_samples=100, + test_samples=50, + input_shape=(1,), + num_classes=NUM_CLASSES) + model = keras.models.Sequential((keras.layers.Dense( + 1, input_dim=1, activation='relu'), keras.layers.Dense( + 1, activation='sigmoid'),)) + model.compile( + optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) + + stopper = keras.callbacks.EarlyStopping(monitor='acc', + baseline=baseline) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) + assert len(hist.epoch) == 1 + + patience = 3 + stopper = keras.callbacks.EarlyStopping(monitor='acc', + patience=patience, + baseline=baseline) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) + assert len(hist.epoch) >= patience def test_RemoteMonitor(self): if requests is None: @@ -321,8 +349,26 @@ class KerasCallbacksTest(test.TestCase): callbacks=cbks, epochs=5, verbose=0) - assert (float(keras.backend.get_value(model.optimizer.lr)) - 0.2 - ) < keras.backend.epsilon() + assert ( + float(keras.backend.get_value( + model.optimizer.lr)) - 0.2) < keras.backend.epsilon() + + cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)] + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=2, + verbose=0) + assert ( + float(keras.backend.get_value( + model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon() def test_ReduceLROnPlateau(self): with self.test_session(): @@ -635,6 +681,8 @@ class KerasCallbacksTest(test.TestCase): model.add( keras.layers.Dense( NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) + # non_trainable_weights: moving_variance, moving_mean + model.add(keras.layers.BatchNormalization()) model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) model.compile( loss='categorical_crossentropy', @@ -854,6 +902,80 @@ class KerasCallbacksTest(test.TestCase): callbacks=callbacks_factory(histogram_freq=1)) assert os.path.isdir(filepath) + def test_Tensorboard_histogram_summaries_in_test_function(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + self.steps_seen = [] + + def add_summary(self, summary, global_step): + summary_obj = summary_pb2.Summary() + + # ensure a valid Summary proto is being sent + if isinstance(summary, bytes): + summary_obj.ParseFromString(summary) + else: + assert isinstance(summary, summary_pb2.Summary) + summary_obj = summary + + # keep track of steps seen for the merged_summary op, + # which contains the histogram summaries + if len(summary_obj.value) > 1: + self.steps_seen.append(global_step) + + def flush(self): + pass + + def close(self): + pass + + np.random.seed(1337) + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=TRAIN_SAMPLES, + test_samples=TEST_SAMPLES, + input_shape=(INPUT_DIM,), + num_classes=NUM_CLASSES) + y_test = keras.utils.to_categorical(y_test) + y_train = keras.utils.to_categorical(y_train) + + with self.test_session(): + model = keras.models.Sequential() + model.add( + keras.layers.Dense( + NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) + # non_trainable_weights: moving_variance, moving_mean + model.add(keras.layers.BatchNormalization()) + model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + tsb = keras.callbacks.TensorBoard( + log_dir=tmpdir, + histogram_freq=1, + write_images=True, + write_grads=True, + batch_size=5) + tsb._writer_class = FileWriterStub + cbks = [tsb] + + # fit with validation data + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=3, + verbose=0) + + self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5]) + @unittest.skipIf( os.name == 'nt', 'use_multiprocessing=True does not work on windows properly.') diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/constraints.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/constraints.py rename to tensorflow/python/keras/constraints.py index abe95d8e0ca68b2e62f9574fba9ae912a9179fff..bf3a3a728aafc8071d8ddb7e3acf4f7282ed4c16 100644 --- a/tensorflow/python/keras/_impl/keras/constraints.py +++ b/tensorflow/python/keras/constraints.py @@ -21,9 +21,9 @@ from __future__ import print_function import six -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/constraints/__init__.py b/tensorflow/python/keras/constraints/__init__.py deleted file mode 100644 index 152606d8ebbcadf57d971d508e15283da65e4aa3..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/constraints/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in constraints functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm - -# Auxiliary utils. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/constraints_test.py b/tensorflow/python/keras/constraints_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/constraints_test.py rename to tensorflow/python/keras/constraints_test.py index 87905693caa900a2cc565cef4bcea3fa30a4bc6c..84e2db10332c82f566a35d5ebba0c340e502fcd5 100644 --- a/tensorflow/python/keras/_impl/keras/constraints_test.py +++ b/tensorflow/python/keras/constraints_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/datasets/boston_housing.py similarity index 88% rename from tensorflow/python/keras/_impl/keras/datasets/boston_housing.py rename to tensorflow/python/keras/datasets/boston_housing.py index 13fa9aed2b8da124af4e9f68c779e08d3094cb5d..eeb7cbc44a72a5c624f8d1d1d9dbfab1fcd1b225 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/datasets/boston_housing.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export @@ -39,15 +39,15 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ assert 0 <= test_split < 1 + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, - origin='https://s3.amazonaws.com/keras-datasets/boston_housing.npz', + origin=origin_folder + 'boston_housing.npz', file_hash= 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5') - f = np.load(path) - x = f['x'] - y = f['y'] - f.close() + with np.load(path) as f: + x = f['x'] + y = f['y'] np.random.seed(seed) indices = np.arange(len(x)) diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar.py b/tensorflow/python/keras/datasets/cifar.py similarity index 100% rename from tensorflow/python/keras/_impl/keras/datasets/cifar.py rename to tensorflow/python/keras/datasets/cifar.py diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/datasets/cifar10.py similarity index 90% rename from tensorflow/python/keras/_impl/keras/datasets/cifar10.py rename to tensorflow/python/keras/datasets/cifar10.py index 6b772433822474c06efcce1701226a4a67abe361..d627160875c007971c695891d1dab34b8bf1ba39 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py +++ b/tensorflow/python/keras/datasets/cifar10.py @@ -22,9 +22,9 @@ import os import numpy as np -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.datasets.cifar import load_batch +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/datasets/cifar100.py similarity index 90% rename from tensorflow/python/keras/_impl/keras/datasets/cifar100.py rename to tensorflow/python/keras/datasets/cifar100.py index 28d74116a50979abab207dbec88e384210dfc070..e9a6d634a5308ab8c749e8861e0e4a33ac56d464 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py +++ b/tensorflow/python/keras/datasets/cifar100.py @@ -22,9 +22,9 @@ import os import numpy as np -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.datasets.cifar import load_batch +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/datasets/fashion_mnist.py similarity index 85% rename from tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py rename to tensorflow/python/keras/datasets/fashion_mnist.py index 508e95f719a02977960b80c283495ced642293c5..3f4c6c7413e01313fda051a5603f223f9f7c4d27 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py +++ b/tensorflow/python/keras/datasets/fashion_mnist.py @@ -23,7 +23,7 @@ import os import numpy as np -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export @@ -33,9 +33,15 @@ def load_data(): Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + + License: + The copyright for Fashion-MNIST is held by Zalando SE. + Fashion-MNIST is licensed under the [MIT license]( + https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE). + """ dirname = os.path.join('datasets', 'fashion-mnist') - base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' files = [ 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz' diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/datasets/imdb.py similarity index 92% rename from tensorflow/python/keras/_impl/keras/datasets/imdb.py rename to tensorflow/python/keras/datasets/imdb.py index 7467bb24646227705972262381aa5cf1de809f1c..b73b024162ac3fde4c430c34ff4f0f7b1174abe6 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py +++ b/tensorflow/python/keras/datasets/imdb.py @@ -22,8 +22,8 @@ import json import numpy as np -from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras.preprocessing.sequence import _remove_long_seq +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -77,9 +77,10 @@ def load_data(path='imdb.npz', if kwargs: raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/imdb.npz', + origin=origin_folder + 'imdb.npz', file_hash='599dadb1135973df5b59232a0e9a887c') with np.load(path) as f: x_train, labels_train = f['x_train'], f['y_train'] @@ -140,9 +141,10 @@ def get_word_index(path='imdb_word_index.json'): Returns: The word index dictionary. """ + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json', + origin=origin_folder + 'imdb_word_index.json', file_hash='bfafd718b763782e994055a2d397834f') with open(path) as f: return json.load(f) diff --git a/tensorflow/python/keras/datasets/imdb/__init__.py b/tensorflow/python/keras/datasets/imdb/__init__.py deleted file mode 100644 index 1c6396d2d32b88eaa900a5af4e62c7484fceab63..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/datasets/imdb/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""IMDB movie review sentiment classification dataset.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py similarity index 65% rename from tensorflow/python/keras/_impl/keras/datasets/mnist.py rename to tensorflow/python/keras/datasets/mnist.py index e30691373e9aafad61b101476e21d6860527ce98..a96b581960f3d5f60994fe92a1424e793d7e39c7 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py +++ b/tensorflow/python/keras/datasets/mnist.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export @@ -34,13 +34,21 @@ def load_data(path='mnist.npz'): Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + + License: + Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, + which is a derivative work from original NIST datasets. + MNIST dataset is made available under the terms of the + [Creative Commons Attribution-Share Alike 3.0 license.]( + https://creativecommons.org/licenses/by-sa/3.0/) """ + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, - origin='https://s3.amazonaws.com/img-datasets/mnist.npz', + origin=origin_folder + 'mnist.npz', file_hash='8a61469f7ea1b51cbae51d4f78837e45') - f = np.load(path) - x_train, y_train = f['x_train'], f['y_train'] - x_test, y_test = f['x_test'], f['y_test'] - f.close() - return (x_train, y_train), (x_test, y_test) + with np.load(path) as f: + x_train, y_train = f['x_train'], f['y_train'] + x_test, y_test = f['x_test'], f['y_test'] + + return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/datasets/mnist/__init__.py b/tensorflow/python/keras/datasets/mnist/__init__.py deleted file mode 100644 index 364255f3387b59a419c010db9b93cdfbcba36186..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/datasets/mnist/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MNIST handwritten digits classification dataset.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/datasets/reuters.py similarity index 90% rename from tensorflow/python/keras/_impl/keras/datasets/reuters.py rename to tensorflow/python/keras/datasets/reuters.py index b711696b5eecf9ba07a66cef25c1811c182b3b60..cb796bb06cf09157cc510b55e3981d518fd8b433 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py +++ b/tensorflow/python/keras/datasets/reuters.py @@ -22,8 +22,8 @@ import json import numpy as np -from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras.preprocessing.sequence import _remove_long_seq +from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -75,9 +75,10 @@ def load_data(path='reuters.npz', if kwargs: raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/reuters.npz', + origin=origin_folder + 'reuters.npz', file_hash='87aedbeb0cb229e378797a632c1997b6') with np.load(path) as f: xs, labels = f['x'], f['y'] @@ -124,11 +125,10 @@ def get_word_index(path='reuters_word_index.json'): Returns: The word index dictionary. """ + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json', + origin=origin_folder + 'reuters_word_index.json', file_hash='4d44cc38712099c9e383dc6e5f11a921') - f = open(path) - data = json.load(f) - f.close() - return data + with open(path) as f: + return json.load(f) diff --git a/tensorflow/python/keras/datasets/reuters/__init__.py b/tensorflow/python/keras/datasets/reuters/__init__.py deleted file mode 100644 index bb6791a344ad0c372ac60cd4a332f5632841dd46..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/datasets/reuters/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Reuters newswire topic classification dataset.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/engine/__init__.py b/tensorflow/python/keras/engine/__init__.py similarity index 62% rename from tensorflow/python/keras/_impl/keras/engine/__init__.py rename to tensorflow/python/keras/engine/__init__.py index 1bc533ab8f7ba37948d82bc69fe1c9bfe00d6834..26aed34766f9e1e2094db7a4c8b66ff057dacc4b 100644 --- a/tensorflow/python/keras/_impl/keras/engine/__init__.py +++ b/tensorflow/python/keras/engine/__init__.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.engine.base_layer import InputSpec -from tensorflow.python.keras._impl.keras.engine.base_layer import Layer -from tensorflow.python.keras._impl.keras.engine.input_layer import Input -from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer -from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs -from tensorflow.python.keras._impl.keras.engine.network import Network -from tensorflow.python.keras._impl.keras.engine.training import Model +# TODO(fchollet): Remove hourglass imports once external code is done importing +# non-public APIs. +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.utils.layer_utils import get_source_inputs + +del absolute_import +del division +del print_function diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/engine/base_layer.py rename to tensorflow/python/keras/engine/base_layer.py index e5e096d1f662547d47672bb7a8618e40be7ba8fe..361778570bc7e87bc0642a2c52d43762c6828eb4 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import enum # pylint: disable=g-bad-import-order import inspect # Necessary supplement to tf_inspect to deal with variadic args. import numpy as np @@ -29,15 +30,15 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.keras._impl.keras import backend -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import backend +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import tf_utils # A module that only depends on `keras.layers` import these from here. -from tensorflow.python.keras._impl.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import -from tensorflow.python.keras._impl.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import +from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import +from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs @@ -50,6 +51,20 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +class CallConvention(enum.Enum): + """Calling conventions for passing `Layer` inputs to `Layer.call`.""" + # The Layer takes inputs as its first argument, named "inputs" for + # compatibility with the signature of Layer.__call__. This is the mode assumed + # for Layers which are not subclassed Models. + EXPLICIT_INPUTS_ARGUMENT = 1 + # The Layer takes a single positional argument, not named "inputs". It's + # treated like an "inputs" argument. + SINGLE_POSITIONAL_ARGUMENT = 2 + # The Layer has multiple positional arguments to which its inputs should be + # bound. + POSITIONAL_ARGUMENTS_ARE_INPUTS = 3 + + @tf_export('keras.layers.Layer') class Layer(checkpointable.CheckpointableBase): """Base layer class. @@ -101,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase): constraints on inputs that can be accepted by the layer. """ + @checkpointable.no_automatic_dependency_tracking def __init__(self, trainable=True, name=None, dtype=None, **kwargs): # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' @@ -149,7 +165,7 @@ class Layer(checkpointable.CheckpointableBase): self._call_fn_args = function_utils.fn_args(self.call) self._compute_previous_mask = ('mask' in self._call_fn_args or hasattr(self, 'compute_mask')) - self._uses_inputs_arg = True + self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT # These lists will be filled via successive calls # to self._add_inbound_node(). @@ -202,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase): @activity_regularizer.setter def activity_regularizer(self, regularizer): """Optional regularizer function for the output of this layer.""" - self._activity_regularizer = regularizer + self._activity_regularizer = self._no_dependency(regularizer) @property def trainable_weights(self): @@ -436,7 +452,7 @@ class Layer(checkpointable.CheckpointableBase): def _name_scope(self): return self.name - def build(self, _): + def build(self, input_shape): """Creates the variables of the layer.""" self.built = True @@ -643,7 +659,8 @@ class Layer(checkpointable.CheckpointableBase): self._compute_previous_mask): previous_mask = collect_previous_mask(inputs) if not hasattr(self, '_call_fn_args'): - self._call_fn_args = function_utils.fn_args(self.call) + self._call_fn_args = self._no_dependency( + function_utils.fn_args(self.call)) if ('mask' in self._call_fn_args and 'mask' not in kwargs and not generic_utils.is_all_none(previous_mask)): # The previous layer generated a mask, and mask was not explicitly pass @@ -793,12 +810,22 @@ class Layer(checkpointable.CheckpointableBase): pass # C type such as dict. Masking not supported in this case. def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): - if args and getattr(self, '_uses_inputs_arg', True): - raise TypeError( - 'This Layer takes an `inputs` argument to call(), and only the ' - '`inputs` argument may be specified as a positional argument. ' - 'Pass everything else as a keyword argument (those arguments will' - ' not be tracked as inputs to the Layer).') + call_convention = getattr(self, '_call_convention', + CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if args: + if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT: + raise TypeError( + 'This Layer takes an `inputs` argument to call(), and only the ' + '`inputs` argument may be specified as a positional argument. ' + 'Pass everything else as a keyword argument (those arguments will' + ' not be tracked as inputs to the Layer).') + elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT: + raise TypeError( + 'This Layer takes a single positional argument to call(), which is ' + 'by convention the inputs argument, and only this argument may be ' + 'specified as a positional argument. Pass everything else as a ' + 'keyword argument (those arguments will not be tracked as inputs ' + 'to the Layer).') # If the layer returns tensors from its inputs, unmodified, # we copy them to avoid loss of tensor metadata. @@ -834,7 +861,11 @@ class Layer(checkpointable.CheckpointableBase): A tuple of (inputs, non_input_kwargs). These may be the same objects as were passed in (call_args and call_kwargs). """ - if getattr(self, '_uses_inputs_arg', True): + call_convention = getattr(self, '_call_convention', + CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if (call_convention in ( + CallConvention.EXPLICIT_INPUTS_ARGUMENT, + CallConvention.SINGLE_POSITIONAL_ARGUMENT)): assert len(call_args) == 1 # TypeError raised earlier in __call__. return call_args[0], call_kwargs else: diff --git a/tensorflow/python/keras/_impl/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/engine/input_layer.py rename to tensorflow/python/keras/engine/input_layer.py index bd9dcbe3c576851123dfcabe3e36379019627ac5..8a4018a0df50b8d4c9df5900ffddfcdc093f161f 100644 --- a/tensorflow/python/keras/_impl/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -21,8 +21,8 @@ from __future__ import print_function from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import tf_export @@ -119,6 +119,12 @@ class InputLayer(base_layer.Layer): self.is_placeholder = False self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) + if context.executing_eagerly(): + raise ValueError('You should not pass an input tensor when executing ' + 'in eager mode. For example, instead of creating an ' + 'InputLayer, you should instantiate your model and ' + 'directly call it on your input.') + # Create an input node to add to self.outbound_node # and set output_tensors' _keras_history. input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access @@ -209,7 +215,7 @@ def Input( # pylint: disable=invalid-name if dtype is None: dtype = K.floatx() - if not shape and tensor is None: + if shape is None and tensor is None: raise ValueError('Please provide to Input either a `shape`' ' or a `tensor` argument. Note that ' '`shape` does not include the batch ' diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/engine/network.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/engine/network.py rename to tensorflow/python/keras/engine/network.py index 87a670e501a5f6d2271b20d1108a9fc6aedd5b3f..a4d96de74fc90e31d52f9a67e845a84f9ceb5034 100644 --- a/tensorflow/python/keras/_impl/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import copy +import functools import json import os import weakref @@ -32,15 +33,18 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import backend -from tensorflow.python.keras._impl.keras.engine import base_layer -from tensorflow.python.keras._impl.keras.engine import saving -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils -from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary +from tensorflow.python.keras import backend +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import saving +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -77,6 +81,20 @@ class Network(base_layer.Layer): # Subclassed network self._init_subclassed_network(**kwargs) + # Several Network methods have "no_automatic_dependency_tracking" + # annotations. Since Network does automatic dependency tracking on attribute + # assignment, including for common data structures such as lists, by default + # we'd have quite a few empty dependencies which users don't care about (or + # would need some way to ignore dependencies automatically, which is confusing + # when applied to user code). Some attributes, such as _layers, would cause + # structural issues (_layers being the place where Layers assigned to tracked + # attributes are stored). + # + # Aside from these aesthetic and structural issues, useless dependencies on + # empty lists shouldn't cause issues; adding or removing them will not break + # checkpoints, but may cause "all Python objects matched" assertions to fail + # (in which case less strict assertions may be substituted if necessary). + @checkpointable.no_automatic_dependency_tracking def _base_init(self, name=None): # The following are implemented as property functions: # self.trainable_weights @@ -93,6 +111,11 @@ class Network(base_layer.Layer): self.trainable = True self._is_compiled = False self._expects_training_arg = False + # A list of "extra" variables assigned to attributes of this class, included + # in self.weights and self.variables. Always empty for graph networks (but + # included in base_init to avoid excessive special casing when retrieving + # the value). + self._extra_variables = [] self.supports_masking = False if not hasattr(self, 'optimizer'): @@ -126,8 +149,9 @@ class Network(base_layer.Layer): # restore operations when graph building. self._in_progress_restore_finalizer = None + @checkpointable.no_automatic_dependency_tracking def _init_graph_network(self, inputs, outputs, name=None): - self._uses_inputs_arg = True + self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT # Normalize and set self.inputs, self.outputs. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. @@ -146,14 +170,14 @@ class Network(base_layer.Layer): raise TypeError('When eager execution is enabled, ' 'inputs must come from a call to ' '`tf.keras.Input` (called after ' - 'tfe.enable_eager_execution()). ' + 'tf.enable_eager_execution()). ' 'Received invalid input: ' + str(tensor)) for tensor in self.outputs: if not isinstance(tensor, base_layer.DeferredTensor): # pylint: disable=protected-access raise TypeError('When eager execution is enabled, ' 'outputs must come from a call to ' 'a layer (called after ' - 'tfe.enable_eager_execution()). ' + 'tf.enable_eager_execution()). ' 'Received invalid output: ' + str(tensor)) # Check for redundancy in inputs. if len(set(self.inputs)) != len(self.inputs): @@ -284,22 +308,59 @@ class Network(base_layer.Layer): for layer in self._output_layers: self.output_names.append(layer.name) + @checkpointable.no_automatic_dependency_tracking def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False - call_args = tf_inspect.getargspec(self.call).args - if 'training' in call_args: + call_argspec = tf_inspect.getargspec(self.call) + if 'training' in call_argspec.args: self._expects_training_arg = True else: self._expects_training_arg = False - if 'inputs' in call_args: - self._uses_inputs_arg = True - else: - self._uses_inputs_arg = False + self._call_convention = self._determine_call_convention(call_argspec) self.outputs = None self.inputs = None self.built = False + def _determine_call_convention(self, call_argspec): + """Decides how `self.call()` is invoked. See base_layer.CallConvention.""" + if call_argspec.varargs: + may_take_single_argument = False + else: + try: + # Note: tf_inspect doesn't raise a TypeError when regular inspect would, + # so we need to keep in mind that "getcallargs" may have returned + # something even though we under-specified positional arguments. + all_args = tf_inspect.getcallargs(self.call, None) + self_args = set() + for arg_name, obj in all_args.items(): + if obj is self: + self_args.add(arg_name) + may_take_single_argument = True + except TypeError: + may_take_single_argument = False + if may_take_single_argument: + # A single positional argument (plus "self") is considered equivalent to + # an "inputs" argument. + all_positional_args = len(call_argspec.args) + if call_argspec.defaults is not None: + all_positional_args -= len(call_argspec.defaults) + non_self_positional_args = all_positional_args + for positional_arg_name in call_argspec.args[:all_positional_args]: + if positional_arg_name in self_args: + non_self_positional_args -= 1 + if non_self_positional_args == 1: + if 'inputs' in call_argspec.args[all_positional_args:]: + raise TypeError( + "Model.call() takes a single positional argument (to which " + "inputs are passed by convention) and a separate 'inputs' " + "argument. Unable to determine which arguments are inputs.") + return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT + if 'inputs' in call_argspec.args: + return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT + else: + return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS + def _track_layers(self, layers): """Add Checkpointable dependencies on a list of Layers.""" weight_layer_index = 0 @@ -317,11 +378,35 @@ class Network(base_layer.Layer): self._track_checkpointable( layer, name='layer-%d' % layer_index, overwrite=True) + def _no_dependency(self, value): + """Override to allow `Layer` to disable dependency tracking. + + `CheckpointableBase` defines this method, whose semantics are "if a subclass + does dependency tracking, this method exempts `value`." Layer uses + `_no_dependency` to exempt some of its attribute assignments (conditional on + attribute assignment causing tracking in the subclass). + + Args: + value: An object which will be assigned to an object attribute, whose + value should not be tracked. + + Returns: + A wrapped object which, when assigned to an attribute, will not be + tracked (`value` will be stored in the attribute). + """ + return data_structures.NoDependency(value) + def __setattr__(self, name, value): - no_dependency = isinstance(value, checkpointable.NoDependency) - if no_dependency: - value = value.value - if isinstance(value, (base_layer.Layer, Network)): + if not getattr(self, '_setattr_tracking', True): + super(Network, self).__setattr__(name, value) + return + no_dependency = isinstance(value, data_structures.NoDependency) + value = data_structures.sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + if isinstance(value, ( + base_layer.Layer, + Network, + data_structures.CheckpointableDataStructure)): try: is_graph_network = self._is_graph_network except AttributeError: @@ -329,7 +414,9 @@ class Network(base_layer.Layer): 'forgot to call `super(YourClass, self).__init__()`.' ' Always start with this line.') if not is_graph_network: - if value not in self._layers: + # We need to check object identity to avoid de-duplicating empty + # container types which compare equal. + if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, '_use_resource_variables'): # In subclassed models, legacy layers (tf.layers) must always use @@ -337,17 +424,22 @@ class Network(base_layer.Layer): value._use_resource_variables = True if (not no_dependency and isinstance(value, checkpointable.CheckpointableBase)): - # Layer (and therefore Network/Model) inherit from CheckpointableBase - # rather than Checkpointable, which means there is no Checkpointable - # __setattr__ override (it would be a performance issue for functional - # layers). Therefore Model tracks Checkpointable objects itself. - self._track_checkpointable( - checkpointable=value, name=name, overwrite=True) + if ( # For subclassed models only, users may add extra weights/variables + # simply by assigning them to attributes. + not self._is_graph_network + and isinstance(value, variables.Variable)): + self._extra_variables.append(value) super(Network, self).__setattr__(name, value) def add_variable(self, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, constraint=None): - raise NotImplementedError('`add_variable` is not supported on Networks.') + if self._is_graph_network: + raise NotImplementedError('`add_variable` is not supported on Networks.') + else: + raise NotImplementedError( + '`add_variable` is not supported on Networks. However, you may ' + 'assign variables to attributes and they will show up in the weights ' + 'and variables properties.') def add_loss(self, *args, **kwargs): if context.executing_eagerly(): @@ -434,7 +526,8 @@ class Network(base_layer.Layer): @property def layers(self): - return self._layers + return checkpointable_layer_utils.filter_empty_layer_containers( + self._layers) def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. @@ -469,6 +562,28 @@ class Network(base_layer.Layer): return layer raise ValueError('No such layer: ' + name) + @property + def _unfiltered_updates(self): + if context.executing_eagerly(): + return [] + updates = [] + for layer in self.layers: + if isinstance(layer, Network): + updates += layer._unfiltered_updates + else: + updates += layer.updates + return updates + + @property + def _unfiltered_losses(self): + losses = [] + for layer in self.layers: + if isinstance(layer, Network): + losses += layer._unfiltered_losses + else: + losses += layer.losses + return losses + @property def updates(self): """Retrieves the network's updates. @@ -478,6 +593,8 @@ class Network(base_layer.Layer): (e.g. will not include updates that were created by layers of this model outside of the model). + When the network has no registered inputs, all updates are returned. + Effectively, `network.updates` behaves like `layer.updates`. Concrete example: @@ -523,22 +640,20 @@ class Network(base_layer.Layer): if not self.trainable and not self.stateful: return [] - updates = [] - for layer in self.layers: - updates += layer.updates + updates = self._unfiltered_updates # `updates` might contain irrelevant updates, so it needs to be filtered # with respect to inputs the model has been called on. - if self.inputs: - relevant_inputs = self.inputs[:] - else: - relevant_inputs = [] - for i in range(1, len(self._inbound_nodes)): + relevant_inputs = [] + for i in range(0, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): relevant_inputs += inputs else: relevant_inputs.append(inputs) + if not relevant_inputs: + return updates + reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates) relevant_conditional_updates = [x for x in updates if x in reachable] unconditional_updates = [ @@ -557,25 +672,25 @@ class Network(base_layer.Layer): (e.g. will not include losses that depend on tensors that aren't inputs to this model). + When the network has no registered inputs, all losses are returned. + Returns: A list of loss tensors. """ - losses = [] - for layer in self.layers: - losses += layer.losses + losses = self._unfiltered_losses if context.executing_eagerly(): return losses - if self.inputs: - relevant_inputs = self.inputs[:] - else: - relevant_inputs = [] - for i in range(1, len(self._inbound_nodes)): + relevant_inputs = [] + for i in range(0, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): relevant_inputs += inputs else: relevant_inputs.append(inputs) + if not relevant_inputs: + return losses + reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses) relevant_conditional_losses = [x for x in losses if x in reachable] unconditional_losses = [ @@ -585,24 +700,17 @@ class Network(base_layer.Layer): @property def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights + return checkpointable_layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights + return checkpointable_layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def input_spec(self): @@ -1082,7 +1190,7 @@ class Network(base_layer.Layer): layer_name = layer_data['name'] # Instantiate layer. - from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top layer = deserialize_layer(layer_data, custom_objects=custom_objects) created_layers[layer_name] = layer @@ -1166,7 +1274,7 @@ class Network(base_layer.Layer): if not self._is_graph_network: raise NotImplementedError - from tensorflow.python.keras._impl.keras.models import save_model # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top save_model(self, filepath, overwrite, include_optimizer) def save_weights(self, filepath, overwrite=True, save_format=None): @@ -1250,7 +1358,11 @@ class Network(base_layer.Layer): with h5py.File(filepath, 'w') as f: saving.save_weights_to_hdf5_group(f, self.layers) else: - self._checkpointable_saver.save(filepath) + if context.executing_eagerly(): + session = None + else: + session = backend.get_session() + self._checkpointable_saver.save(filepath, session=session) def load_weights(self, filepath, by_name=False): """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. @@ -1310,7 +1422,8 @@ class Network(base_layer.Layer): 'loading TensorFlow-formatted weights (got by_name=True to ' 'load_weights).') if not context.executing_eagerly(): - finalizer = status.run_restore_ops + session = backend.get_session() + finalizer = functools.partial(status.run_restore_ops, session=session) if self.built: finalizer() else: @@ -1348,7 +1461,7 @@ class Network(base_layer.Layer): Returns: Model config with Keras version information added. """ - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top config = self.get_config() model_config = { @@ -1407,7 +1520,8 @@ class Network(base_layer.Layer): ImportError: if yaml module is not found. """ if yaml is None: - raise ImportError('Requires yaml module installed.') + raise ImportError( + 'Requires yaml module installed (`pip install pyyaml`).') return yaml.dump(self._updated_config(), **kwargs) def summary(self, line_length=None, positions=None, print_fn=None): @@ -1424,52 +1538,19 @@ class Network(base_layer.Layer): It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary. - """ - print_layer_summary(self, - line_length=line_length, - positions=positions, - print_fn=print_fn) - - -def get_source_inputs(tensor, layer=None, node_index=None): - """Returns the list of input tensors necessary to compute `tensor`. - - Output will always be a list of tensors - (potentially with 1 element). - Arguments: - tensor: The tensor to start from. - layer: Origin layer of the tensor. Will be - determined via tensor._keras_history if not provided. - node_index: Origin node index of the tensor. - - Returns: - List of input tensors. - """ - if not hasattr(tensor, '_keras_history'): - return tensor - - if layer is None or node_index: - layer, node_index, _ = tensor._keras_history - if not layer._inbound_nodes: - return [tensor] - else: - node = layer._inbound_nodes[node_index] - if not node.inbound_layers: - # Reached an Input layer, stop recursion. - return node.input_tensors - else: - source_tensors = [] - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - previous_sources = get_source_inputs(x, layer, node_index) - # Avoid input redundancy. - for x in previous_sources: - if x not in source_tensors: - source_tensors.append(x) - return source_tensors + Raises: + ValueError: if `summary()` is called before the model is built. + """ + if not self.built: + raise ValueError('This model has never been called, thus its weights ' + 'have not yet been created, so no summary can be ' + 'displayed. Build the model first ' + '(e.g. by calling it on some data).') + layer_utils.print_summary(self, + line_length=line_length, + positions=positions, + print_fn=print_fn) def _is_hdf5_filepath(filepath): diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/engine/saving.py rename to tensorflow/python/keras/engine/saving.py index 6a3ae3b20c11ebcd9d9454e0f65659cb5c8d6bdd..d5ccd44604b6b84ea0ceb4fa1c270b2c7dddc147 100644 --- a/tensorflow/python/keras/_impl/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -25,10 +25,10 @@ import os import numpy as np from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import optimizers +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import serialization from tensorflow.python.util.tf_export import tf_export @@ -77,7 +77,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): if h5py is None: raise ImportError('`save_model` requires h5py.') - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top if not isinstance(filepath, h5py.File): # If file exists and should not be overwritten. @@ -106,7 +106,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): model_layers = model.layers save_weights_to_hdf5_group(model_weights_group, model_layers) - if include_optimizer and hasattr(model, 'optimizer'): + if include_optimizer and model.optimizer: if isinstance(model.optimizer, optimizers.TFOptimizer): logging.warning( 'TensorFlow optimizers do not ' @@ -302,7 +302,7 @@ def model_from_config(config, custom_objects=None): raise TypeError('`model_from_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top return deserialize(config, custom_objects=custom_objects) @@ -323,9 +323,9 @@ def model_from_yaml(yaml_string, custom_objects=None): ImportError: if yaml module is not found. """ if yaml is None: - raise ImportError('Requires yaml module installed.') + raise ImportError('Requires yaml module installed (`pip install pyyaml`).') config = yaml.load(yaml_string) - from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top return deserialize(config, custom_objects=custom_objects) @@ -343,7 +343,7 @@ def model_from_json(json_string, custom_objects=None): A Keras model instance (uncompiled). """ config = json.loads(json_string) - from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top return deserialize(config, custom_objects=custom_objects) @@ -351,7 +351,10 @@ def preprocess_weights_for_loading(layer, weights, original_keras_version=None, original_backend=None): - """Converts layers weights from Keras 1 format to Keras 2. + """Preprocess layer weights between different Keras formats. + + Converts layers weights from Keras 1 format to Keras 2 and also weights of + CuDNN layers in Keras 2. Arguments: layer: Layer instance. @@ -363,7 +366,18 @@ def preprocess_weights_for_loading(layer, Returns: A list of weights values (Numpy arrays). """ - if layer.__class__.__name__ == 'Bidirectional': + def convert_nested_bidirectional(weights): + """Converts layers nested in `Bidirectional` wrapper. + + This function uses `preprocess_weights_for_loading()` for converting + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ num_weights_per_layer = len(weights) // 2 forward_weights = preprocess_weights_for_loading( layer.forward_layer, weights[:num_weights_per_layer], @@ -371,7 +385,69 @@ def preprocess_weights_for_loading(layer, backward_weights = preprocess_weights_for_loading( layer.backward_layer, weights[num_weights_per_layer:], original_keras_version, original_backend) - weights = forward_weights + backward_weights + return forward_weights + backward_weights + + def convert_nested_time_distributed(weights): + """Converts layers nested in `TimeDistributed` wrapper. + + This function uses `preprocess_weights_for_loading()` for converting nested + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ + return preprocess_weights_for_loading( + layer.layer, weights, original_keras_version, original_backend) + + def convert_nested_model(weights): + """Converts layers nested in `Model` or `Sequential`. + + This function uses `preprocess_weights_for_loading()` for converting nested + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ + new_weights = [] + # trainable weights + for sublayer in layer.layers: + num_weights = len(sublayer.trainable_weights) + if num_weights > 0: + new_weights.extend(preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + + # non-trainable weights + for sublayer in layer.layers: + num_weights = len([l for l in sublayer.weights + if l not in sublayer.trainable_weights]) + if num_weights > 0: + new_weights.extend(preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + return new_weights + + # Convert layers nested in Bidirectional/Model/Sequential. + # Both transformation should be ran for both Keras 1->2 conversion + # and for conversion of CuDNN layers. + if layer.__class__.__name__ == 'Bidirectional': + weights = convert_nested_bidirectional(weights) + if layer.__class__.__name__ == 'TimeDistributed': + weights = convert_nested_time_distributed(weights) + elif layer.__class__.__name__ in ['Model', 'Sequential']: + weights = convert_nested_model(weights) if original_keras_version == '1': if layer.__class__.__name__ == 'TimeDistributed': @@ -446,35 +522,6 @@ def preprocess_weights_for_loading(layer, recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) weights = [kernel, recurrent_kernel, bias] - if layer.__class__.__name__ in ['Model', 'Sequential']: - new_weights = [] - # trainable weights - for sublayer in layer.layers: - num_weights = len(sublayer.trainable_weights) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - - # non-trainable weights - for sublayer in layer.layers: - num_weights = len([ - l for l in sublayer.weights if l not in sublayer.trainable_weights - ]) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - weights = new_weights - conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] if layer.__class__.__name__ in conv_layers: if original_backend == 'theano': @@ -486,6 +533,7 @@ def preprocess_weights_for_loading(layer, if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + # convert CuDNN layers return _convert_rnn_weights(layer, weights) @@ -624,7 +672,7 @@ def _convert_rnn_weights(layer, weights): kernels = transform_kernels(weights[0], transpose_input(from_cudnn), n_gates) recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) - biases = weights[2].reshape((2, -1) if from_cudnn else -1) + biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1) return [kernels, recurrent_kernels, biases] if bias_shape == (2 * units * n_gates,): @@ -663,7 +711,7 @@ def save_weights_to_hdf5_group(f, layers): f: HDF5 group. layers: List of layer instances. """ - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top save_attributes_to_hdf5_group( f, 'layer_names', [layer.name.encode('utf8') for layer in layers]) @@ -806,7 +854,16 @@ def load_weights_from_hdf5_group_by_name(f, layers): str(len(weight_values)) + ' element(s).') # Set values. for i in range(len(weight_values)): - weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + if K.int_shape(symbolic_weights[i]) != weight_values[i].shape: + raise ValueError('Layer #' + str(k) +' (named "' + layer.name + + '"), weight ' + str(symbolic_weights[i]) + + ' has shape {}'.format(K.int_shape( + symbolic_weights[i])) + + ', but the saved weight has shape ' + + str(weight_values[i].shape) + '.') + + else: + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) K.batch_set_value(weight_value_tuples) diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py similarity index 80% rename from tensorflow/python/keras/_impl/keras/engine/saving_test.py rename to tensorflow/python/keras/engine/saving_test.py index acd104b4fb642c89d9d81795f024aa2e2fc8c6c6..030328f2a66f0ec406ac271aecfbf2dbebf22f5f 100644 --- a/tensorflow/python/keras/_impl/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -21,17 +21,17 @@ from __future__ import print_function import os import shutil import tempfile - from absl.testing import parameterized import numpy as np +from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras.engine import saving +from tensorflow.python.keras.engine import training from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -248,6 +248,82 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose(y, ref_y) + def test_sequential_weight_loading_group_name_with_incorrect_length(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, use_bias=False, + input_dim=input_dim, name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named \"d1\"\) expects 1 ' + r'weight\(s\), but the saved weights have 2 ' + r'element\(s\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + + def test_sequential_weight_loading_group_name_with_incorrect_shape(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim, + name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named "d1"\), weight ' + r' has ' + r'shape \(3, 10\), but the saved weight has ' + r'shape \(3, 5\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + class TestWholeModelSaving(test.TestCase): @@ -288,6 +364,30 @@ class TestWholeModelSaving(test.TestCase): out2 = new_model.predict(x) self.assertAllClose(out, out2, atol=1e-05) + def test_sequential_model_saving_without_compile(self): + if h5py is None: + self.skipTest('h5py required to run this test') + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + + x = np.random.random((1, 3)) + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + + # Save the model without any compilation or training. + keras.models.save_model(model, fname) + + new_model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + out2 = new_model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + def test_sequential_model_saving_2(self): if h5py is None: self.skipTest('h5py required to run this test') @@ -404,26 +504,27 @@ class TestWholeModelSaving(test.TestCase): os.remove(fname) def test_saving_lambda_numpy_array_arguments(self): - if h5py is None: - self.skipTest('h5py required to run this test') + with self.test_session(): + if h5py is None: + self.skipTest('h5py required to run this test') - mean = np.random.random((4, 2, 3)) - std = np.abs(np.random.random((4, 2, 3))) + 1e-5 - inputs = keras.layers.Input(shape=(4, 2, 3)) - output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, - arguments={'mu': mean, 'std': std})(inputs) - model = keras.models.Model(inputs, output) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + mean = np.random.random((4, 2, 3)) + std = np.abs(np.random.random((4, 2, 3))) + 1e-5 + inputs = keras.layers.Input(shape=(4, 2, 3)) + output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, + arguments={'mu': mean, 'std': std})(inputs) + model = keras.models.Model(inputs, output) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) - self.assertAllClose(mean, model.layers[1].arguments['mu']) - self.assertAllClose(std, model.layers[1].arguments['std']) + self.assertAllClose(mean, model.layers[1].arguments['mu']) + self.assertAllClose(std, model.layers[1].arguments['std']) def test_saving_model_with_long_layer_names(self): if h5py is None: @@ -562,7 +663,7 @@ class SubclassedModel(training.Model): class TestWeightSavingAndLoadingTFFormat(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensorflow_format_overwrite(self): with self.test_session() as session: model = SubclassedModel() @@ -580,6 +681,25 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): # Indirectly tests that the user is prompted model.save_weights(prefix, save_format='tensorflow', overwrite=False) + def test_no_default_session(self): + with ops.Graph().as_default(): + self.assertFalse(ops.get_default_session()) + data = np.random.random((1000, 32)).astype(np.float32) + labels = np.random.random((1000, 10)).astype(np.float32) + + model = keras.models.Sequential([ + keras.layers.Dense(10, activation='softmax'), + keras.layers.Dense(10, activation='softmax')]) + + model.compile(optimizer=training_module.RMSPropOptimizer(0.001), + loss='categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels) + fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt') + model.save_weights(fname) + model.load_weights(fname) + def test_no_graph_pollution(self): with context.graph_mode(): graph = ops.Graph() @@ -632,7 +752,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): restore_on_create_y = self.evaluate(restore_on_create_y_tensor) self.assertAllClose(ref_y, restore_on_create_y) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model(self): def _make_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -642,7 +762,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._weight_loading_test_template(_make_graph_model) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_subclassed_model(self): self._weight_loading_test_template(SubclassedModel) @@ -676,7 +796,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): y = self.evaluate(model(x)) self.assertAllClose(ref_y, y) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model_added_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -696,7 +816,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): _save_graph_model, _restore_graph_model, _restore_init_fn) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model_added_no_weight_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -717,7 +837,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): _save_graph_model, _restore_graph_model, _restore_init_fn) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_subclassed_model_added_layer(self): class SubclassedModelRestore(training.Model): diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py similarity index 91% rename from tensorflow/python/keras/_impl/keras/engine/sequential.py rename to tensorflow/python/keras/engine/sequential.py index 8626626ca1a232de175af355e317f7df704fe148..371504a503168e7443895bb22a57126b274da226 100644 --- a/tensorflow/python/keras/_impl/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -21,14 +21,15 @@ from __future__ import print_function import copy -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers as layer_module -from tensorflow.python.keras._impl.keras.engine import base_layer -from tensorflow.python.keras._impl.keras.engine import network -from tensorflow.python.keras._impl.keras.engine.input_layer import Input -from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer -from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers as layer_module +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.engine.training import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -108,6 +109,7 @@ class Sequential(Model): return self._layers[1:] return self._layers + @checkpointable.no_automatic_dependency_tracking def add(self, layer): """Adds a layer instance on top of the layer stack. @@ -146,8 +148,6 @@ class Sequential(Model): first_layer = layer.layers[0] while isinstance(first_layer, (Model, Sequential)): first_layer = first_layer.layers[0] - batch_shape = first_layer._batch_input_shape - dtype = first_layer.dtype if hasattr(first_layer, '_batch_input_shape'): batch_shape = first_layer._batch_input_shape @@ -179,7 +179,7 @@ class Sequential(Model): 'use the functional API.') self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] - self.inputs = network.get_source_inputs(self.outputs[0]) + self.inputs = layer_utils.get_source_inputs(self.outputs[0]) elif self.outputs: output_tensor = layer(self.outputs[0]) if isinstance(output_tensor, list): @@ -193,6 +193,7 @@ class Sequential(Model): else: self._layers.append(layer) + @checkpointable.no_automatic_dependency_tracking def pop(self): """Removes the last layer in the model. @@ -212,6 +213,7 @@ class Sequential(Model): self.outputs = [self.layers[-1].output] self.build() + @checkpointable.no_automatic_dependency_tracking def build(self, input_shape=None): if input_shape and not self.inputs: batch_shape = tuple(input_shape) @@ -222,11 +224,16 @@ class Sequential(Model): for layer in self._layers: x = layer(x) self.outputs = [x] + # Make sure that the model's input shape will be preserved during + # serialization. + if self._layers: + self._layers[0]._batch_input_shape = batch_shape if self.inputs: self._init_graph_network(self.inputs, self.outputs, name=self.name) self.built = True - self._track_layers(self._layers) + if self._layers: + self._track_layers(self._layers) def predict_proba(self, x, batch_size=32, verbose=0): """Generates class probability predictions for the input samples. diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py similarity index 86% rename from tensorflow/python/keras/_impl/keras/engine/sequential_test.py rename to tensorflow/python/keras/engine/sequential_test.py index a90ad131a51e3c9edfbe147f1910fd3d16f964d6..0f54e29cee38bd12d691b03ae98d3e578b7ff907 100644 --- a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -20,10 +20,10 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.training import rmsprop @@ -33,7 +33,7 @@ class TestSequential(test.TestCase): """Most Sequential model API tests are covered in `training_test.py`. """ - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_basic_methods(self): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_dim=2)) @@ -44,7 +44,7 @@ class TestSequential(test.TestCase): self.assertEqual(len(model.weights), 2 * 2) self.assertEqual(model.get_layer(name='dp').name, 'dp') - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_pop(self): num_hidden = 5 input_dim = 3 @@ -77,7 +77,7 @@ class TestSequential(test.TestCase): with self.assertRaises(TypeError): model.pop() - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_np_arrays(self): num_hidden = 5 input_dim = 3 @@ -102,7 +102,7 @@ class TestSequential(test.TestCase): [None, num_classes]) self.assertEqual(len(model.weights), 2 * 2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_dataset_iterators(self): if not context.executing_eagerly(): # TODO(psv/fchollet): Add support for this use case in graph mode. @@ -136,7 +136,7 @@ class TestSequential(test.TestCase): [None, num_classes]) self.assertEqual(len(model.weights), 2 * 2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_invalid_use_cases(self): # Added objects must be layer instances with self.assertRaises(TypeError): @@ -160,7 +160,7 @@ class TestSequential(test.TestCase): model.add(keras.layers.Dense(1, input_dim=1)) model.add(MyLayer()) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_nested_sequential_trainability(self): input_dim = 20 num_units = 10 @@ -209,6 +209,30 @@ class TestSequential(test.TestCase): x2 = model.predict(val_a) assert np.abs(np.sum(x1 - x2)) > 1e-5 + def test_sequential_deferred_build_serialization(self): + num_hidden = 5 + input_dim = 3 + batch_size = 5 + num_classes = 2 + + model = keras.models.Sequential() + # We don't specify the input shape. + model.add(keras.layers.Dense(num_hidden)) + model.add(keras.layers.Dense(num_classes)) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + self.assertFalse(model.built) + + x = np.random.random((batch_size, input_dim)) + y = np.random.random((batch_size, num_classes)) + model.train_on_batch(x, y) + self.assertTrue(model.built) + + config = model.get_config() + new_model = keras.models.Sequential.from_config(config) + self.assertTrue(new_model.built) + self.assertEqual(len(model.layers), 2) + self.assertEqual(len(model.weights), 4) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/engine/topology_test.py rename to tensorflow/python/keras/engine/topology_test.py index 635c446879a24a277cb4d3fc9c1c26850af78e85..3eb69bd7f3d42f5cd8d6cc6d2d32cc9eb808d9a4 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -20,12 +20,14 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import input_layer as input_layer_lib +from tensorflow.python.keras.engine import network as network_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -62,7 +64,7 @@ class TopologyConstructionTest(test.TestCase): inputs=True) return inputs + 1 - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = MyLayer() _ = layer.apply(x1) @@ -70,7 +72,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x1)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - x2 = keras.Input(shape=(1,)) + x2 = input_layer_lib.Input(shape=(1,)) y2 = layer.apply(x2) self.assertEqual(len(layer.updates), 3) @@ -78,17 +80,17 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x2)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - network = keras.engine.Network(x2, y2) + network = network_lib.Network(x2, y2) self.assertEqual(len(network.updates), 2) self.assertEqual(len(network.get_updates_for(x1)), 0) self.assertEqual(len(network.get_updates_for(x2)), 1) self.assertEqual(len(network.get_updates_for(None)), 1) - x3 = keras.Input(shape=(1,)) + x3 = input_layer_lib.Input(shape=(1,)) _ = layer.apply(x3) self.assertEqual(len(network.updates), 2) - x4 = keras.Input(shape=(1,)) + x4 = input_layer_lib.Input(shape=(1,)) _ = network(x4) self.assertEqual(len(network.updates), 3) self.assertEqual(len(network.get_updates_for(x2)), 1) @@ -104,7 +106,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(network.get_updates_for(x4)), 2) def test_get_updates_bn(self): - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = keras.layers.BatchNormalization() _ = layer.apply(x1) @@ -134,7 +136,7 @@ class TopologyConstructionTest(test.TestCase): inputs=True) return inputs + 1 - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = MyLayer() _ = layer.apply(x1) @@ -142,7 +144,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x1)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - x2 = keras.Input(shape=(1,)) + x2 = input_layer_lib.Input(shape=(1,)) y2 = layer.apply(x2) self.assertEqual(len(layer.losses), 3) @@ -150,17 +152,17 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x2)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - network = keras.engine.Network(x2, y2) + network = network_lib.Network(x2, y2) self.assertEqual(len(network.losses), 2) self.assertEqual(len(network.get_losses_for(x1)), 0) self.assertEqual(len(network.get_losses_for(x2)), 1) self.assertEqual(len(network.get_losses_for(None)), 1) - x3 = keras.Input(shape=(1,)) + x3 = input_layer_lib.Input(shape=(1,)) _ = layer.apply(x3) self.assertEqual(len(network.losses), 2) - x4 = keras.Input(shape=(1,)) + x4 = input_layer_lib.Input(shape=(1,)) _ = network(x4) self.assertEqual(len(network.losses), 3) self.assertEqual(len(network.get_losses_for(x2)), 1) @@ -177,8 +179,8 @@ class TopologyConstructionTest(test.TestCase): def testTopologicalAttributes(self): # test layer attributes / methods related to cross-layer connectivity. - a = keras.Input(shape=(32,), name='input_a') - b = keras.Input(shape=(32,), name='input_b') + a = input_layer_lib.Input(shape=(32,), name='input_a') + b = input_layer_lib.Input(shape=(32,), name='input_b') # test input, output, input_shape, output_shape test_layer = keras.layers.Dense(16, name='test_layer') @@ -219,15 +221,15 @@ class TopologyConstructionTest(test.TestCase): _ = new_dense.input_shape with self.assertRaises(AttributeError): new_dense = keras.layers.Dense(16) - a = keras.Input(shape=(3, 32)) - a = keras.Input(shape=(5, 32)) + a = input_layer_lib.Input(shape=(3, 32)) + a = input_layer_lib.Input(shape=(5, 32)) a_2 = dense(a) b_2 = dense(b) _ = new_dense.input_shape with self.assertRaises(AttributeError): new_dense = keras.layers.Dense(16) - a = keras.Input(shape=(3, 32)) - a = keras.Input(shape=(5, 32)) + a = input_layer_lib.Input(shape=(3, 32)) + a = input_layer_lib.Input(shape=(5, 32)) a_2 = dense(a) b_2 = dense(b) _ = new_dense.output_shape @@ -239,7 +241,7 @@ class TopologyConstructionTest(test.TestCase): def call(self, inputs): return [inputs**2, inputs**3] - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) test_layer = PowersLayer() p1, p2 = test_layer(x) # pylint: disable=not-callable @@ -256,8 +258,8 @@ class TopologyConstructionTest(test.TestCase): assert len(inputs) == 2 return inputs[0] + inputs[1] - a = keras.Input(shape=(32,)) - b = keras.Input(shape=(32,)) + a = input_layer_lib.Input(shape=(32,)) + b = input_layer_lib.Input(shape=(32,)) test_layer = AddLayer() y = test_layer([a, b]) # pylint: disable=not-callable @@ -268,10 +270,10 @@ class TopologyConstructionTest(test.TestCase): def testBasicNetwork(self): # minimum viable network - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) dense = keras.layers.Dense(2) y = dense(x) - network = keras.engine.Network(x, y, name='dense_network') + network = network_lib.Network(x, y, name='dense_network') # test basic attributes self.assertEqual(network.name, 'dense_network') @@ -282,7 +284,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) # test callability on Input - x_2 = keras.Input(shape=(32,)) + x_2 = input_layer_lib.Input(shape=(32,)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 2]) @@ -506,7 +508,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)]) # test get_source_inputs - self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b]) + self.assertListEqual(keras.engine.get_source_inputs(c), [a, b]) # serialization / deserialization json_config = model.to_json() @@ -778,12 +780,12 @@ class TopologyConstructionTest(test.TestCase): self.evaluate(getattr(b, '_keras_mask'))) self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) else: - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) y = MaskedLayer()(x) # pylint: disable=not-callable - network = keras.engine.Network(x, y) + network = network_lib.Network(x, y) # test callability on Input - x_2 = keras.Input(shape=(32,)) + x_2 = input_layer_lib.Input(shape=(32,)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 32]) @@ -797,14 +799,14 @@ class TopologyConstructionTest(test.TestCase): def reg(x): return math_ops.reduce_sum(x) - net_a_input = keras.Input((2,)) + net_a_input = input_layer_lib.Input((2,)) net_a = net_a_input net_a = keras.layers.Dense(2, kernel_initializer='ones', use_bias=False, activity_regularizer=reg)(net_a) model_a = keras.Model([net_a_input], [net_a]) - net_b_input = keras.Input((2,)) + net_b_input = input_layer_lib.Input((2,)) net_b = model_a(net_b_input) model_b = keras.Model([net_b_input], [net_b]) @@ -817,7 +819,7 @@ class TopologyConstructionTest(test.TestCase): with self.test_session(): x_val = np.random.random((10, 5)) - x = keras.Input(shape=(5,)) + x = input_layer_lib.Input(shape=(5,)) a = keras.layers.Dense(5, name='A') b = keras.layers.Dense(5, name='B') output = a(b(a(b(x)))) @@ -837,7 +839,7 @@ class TopologyConstructionTest(test.TestCase): def test_layer_sharing_at_heterogenous_depth_with_concat(self): with self.test_session(): input_shape = (16, 9, 3) - input_layer = keras.Input(shape=input_shape) + input_layer = input_layer_lib.Input(shape=input_shape) a = keras.layers.Dense(3, name='dense_A') b = keras.layers.Dense(3, name='dense_B') @@ -924,7 +926,7 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testSimpleNetworkBuilding(self): - inputs = keras.engine.Input(shape=(32,)) + inputs = input_layer_lib.Input(shape=(32,)) if context.executing_eagerly(): self.assertIsInstance(inputs, base_layer.DeferredTensor) self.assertEqual(inputs.dtype.name, 'float32') @@ -937,8 +939,8 @@ class DeferredModeTest(test.TestCase): self.assertEqual(x.shape.as_list(), [None, 2]) outputs = keras.layers.Dense(4)(x) - network = keras.engine.Network(inputs, outputs) - self.assertIsInstance(network, keras.engine.Network) + network = network_lib.Network(inputs, outputs) + self.assertIsInstance(network, network_lib.Network) if context.executing_eagerly(): # It should be possible to call such a network on EagerTensors. @@ -949,8 +951,8 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testMultiIONetworkbuilding(self): - input_a = keras.engine.Input(shape=(32,)) - input_b = keras.engine.Input(shape=(16,)) + input_a = input_layer_lib.Input(shape=(32,)) + input_b = input_layer_lib.Input(shape=(16,)) a = keras.layers.Dense(16)(input_a) class AddLayer(keras.layers.Layer): @@ -964,7 +966,7 @@ class DeferredModeTest(test.TestCase): c = AddLayer()([a, input_b]) # pylint: disable=not-callable c = keras.layers.Dense(2)(c) - network = keras.engine.Network([input_a, input_b], [a, c]) + network = network_lib.Network([input_a, input_b], [a, c]) if context.executing_eagerly(): a_val = constant_op.constant( np.random.random((10, 32)).astype('float32')) diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/engine/training.py similarity index 93% rename from tensorflow/python/keras/_impl/keras/engine/training.py rename to tensorflow/python/keras/engine/training.py index 16d1b160e43ff7383bc58533cf59fd5c5917d538..8e632651fa7553fbc7ce31aa42e9963b606d20f9 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -24,24 +24,25 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import losses -from tensorflow.python.keras._impl.keras import metrics as metrics_module -from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras.engine import training_arrays -from tensorflow.python.keras._impl.keras.engine import training_eager -from tensorflow.python.keras._impl.keras.engine import training_generator -from tensorflow.python.keras._impl.keras.engine import training_utils -from tensorflow.python.keras._impl.keras.engine.base_layer import DeferredTensor -from tensorflow.python.keras._impl.keras.engine.base_layer import Layer -from tensorflow.python.keras._impl.keras.engine.network import Network -from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import losses +from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras import optimizers +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import training_arrays +from tensorflow.python.keras.engine import training_eager +from tensorflow.python.keras.engine import training_generator +from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine.network import Network +from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -112,7 +113,10 @@ class Model(Network): super(Model, self).__init__(*args, **kwargs) # Create a cache for iterator get_next op. self._iterator_get_next = weakref.WeakKeyDictionary() + # Create a cache for dataset - uninitialized iterators + self._dataset_iterator_cache = weakref.WeakKeyDictionary() + @checkpointable.no_automatic_dependency_tracking def compile(self, optimizer, loss=None, @@ -176,6 +180,11 @@ class Model(Network): raise ValueError('Only TF native optimizers are supported in Eager mode.') self.optimizer = optimizers.get(optimizer) + # We've disabled automatic dependency tracking for this method, but do want + # to add a checkpoint dependency on the optimizer if it's checkpointable. + if isinstance(self.optimizer, checkpointable.CheckpointableBase): + self._track_checkpointable( + self.optimizer, name='optimizer', overwrite=True) self.loss = loss self.metrics = metrics or [] self.loss_weights = loss_weights @@ -408,11 +417,13 @@ class Model(Network): else: if sample_weight_mode == 'temporal': sample_weights.append(array_ops.placeholder_with_default( - [[1.]], shape=[None, None], name=name + '_sample_weights')) + constant_op.constant([[1.]], dtype=K.floatx()), + shape=[None, None], name=name + '_sample_weights')) sample_weight_modes.append('temporal') else: sample_weights.append(array_ops.placeholder_with_default( - [1.], shape=[None], name=name + '_sample_weights')) + constant_op.constant([1.], dtype=K.floatx()), + shape=[None], name=name + '_sample_weights')) sample_weight_modes.append(None) self.sample_weight_modes = sample_weight_modes self._feed_sample_weight_modes = [] @@ -521,7 +532,7 @@ class Model(Network): # Keep track of state updates created by # stateful metrics (i.e. metrics layers). - if isinstance(metric_fn, Layer) and metric_fn.stateful: + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: self.stateful_metric_names.append(metric_name) self.stateful_metric_functions.append(metric_fn) self.metrics_updates += metric_fn.updates @@ -670,12 +681,12 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or a + dataset iterator, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: An optional sample-weight array passed by the user to weight the importance of each sample in `x`. @@ -706,11 +717,16 @@ class Model(Network): RuntimeError: If the model was never compiled. """ if isinstance(x, dataset_ops.Dataset): - raise ValueError('You passed a `Dataset` instance to your model (%s), ' - 'which is not supported. Instead, pass an `Iterator`, ' - 'which you can obtain e.g. via ' - '`dataset.make_one_shot_iterator()` (the exact method ' - 'to use will depend on your specific dataset).' % x) + if context.executing_eagerly(): + x = x.make_one_shot_iterator() + else: + if x in self._dataset_iterator_cache: + x = self._dataset_iterator_cache[x] + else: + iterator = x.make_initializable_iterator() + self._dataset_iterator_cache[x] = iterator + x = iterator + K.get_session().run(x.initializer) # Validates `steps` argument based on x's type. if check_steps: @@ -719,7 +735,7 @@ class Model(Network): is_x_eager_iterator = isinstance(x, iterator_ops.EagerIterator) is_x_iterator = isinstance(x, iterator_ops.Iterator) - # Validate user inputs when data is given as a dataset iterator. + # Validate user inputs when data is given as a dataset or dataset iterator. if is_x_iterator or is_x_eager_iterator: training_utils.validate_iterator_input(x, y, sample_weight, validation_split) @@ -839,7 +855,8 @@ class Model(Network): # in the case where all inputs are value arrays. if context.executing_eagerly(): - # In eager mode, do not do shape validation. + # In eager mode, do not do shape validation + # since the network has no input nodes (placeholders) to be fed. feed_input_names = self.input_names feed_input_shapes = None elif not self._is_graph_network: @@ -931,6 +948,7 @@ class Model(Network): str(x[0].shape[0]) + ' samples') return x, y, sample_weights + @checkpointable.no_automatic_dependency_tracking def _set_inputs(self, inputs, training=None): """Set model's input and output specs based on the input data received. @@ -951,11 +969,17 @@ class Model(Network): whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). """ - if not getattr(self, '_uses_inputs_arg', True): + call_convention = getattr( + self, + '_call_convention', + base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if call_convention not in ( + base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT, + base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT): raise NotImplementedError( - 'Subclassed Models without "inputs" in their call() signatures do ' - 'not yet support shape inference. File a feature request if this ' - 'limitation bothers you.') + 'Subclassed Models without "inputs" (or single positional arguments) ' + 'in their call() signatures do not yet support shape inference. File ' + 'a feature request if this limitation bothers you.') if self.__class__.__name__ == 'Sequential': # Note: we can't test whether the model is `Sequential` via `isinstance` # since `Sequential` depends on `Model`. @@ -973,6 +997,7 @@ class Model(Network): else: self._symbolic_set_inputs(inputs, training=training) + @checkpointable.no_automatic_dependency_tracking def _eager_set_inputs(self, inputs): """Set model's input and output specs based on the input data received. @@ -995,14 +1020,16 @@ class Model(Network): # to keep track of number of inputs and outputs and their ndim. if isinstance(inputs, (list, tuple)): if tensor_util.is_tensor(inputs[0]): - dummy_output_values = self.call(inputs) + dummy_output_values = self.call( + training_utils.cast_if_floating_dtype(inputs)) else: dummy_output_values = self.call( [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs]) dummy_input_values = list(inputs) else: if tensor_util.is_tensor(inputs): - dummy_output_values = self.call(inputs) + dummy_output_values = self.call( + training_utils.cast_if_floating_dtype(inputs)) else: dummy_output_values = self.call( ops.convert_to_tensor(inputs, dtype=K.floatx())) @@ -1012,17 +1039,18 @@ class Model(Network): else: dummy_output_values = [dummy_output_values] self.outputs = [ - DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_output_values] + base_layer.DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_output_values] self.inputs = [ - DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_input_values] + base_layer.DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_input_values] self.input_names = [ 'input_%d' % (i + 1) for i in range(len(dummy_input_values))] self.output_names = [ 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] self.built = True + @checkpointable.no_automatic_dependency_tracking def _symbolic_set_inputs(self, inputs, outputs=None, training=None): """Set model's inputs and output specs based. @@ -1130,19 +1158,19 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or dataset + iterator, `y` should not be specified (since targets will be obtained from the iterator). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors or dataset iterators (since they generate - batches). + form of symbolic tensors, datasets, or dataset iterators + (since they generate batches). epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. @@ -1164,7 +1192,7 @@ class Model(Network): on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling. This argument is - not supported when `x` is a dataset iterator. + not supported when `x` is a dataset or a dataset iterator. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. @@ -1172,7 +1200,7 @@ class Model(Network): `validation_data` could be: - tuple `(x_val, y_val)` of Numpy arrays or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays - - dataset iterator + - dataset or a dataset iterator shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch'). 'batch' is a special option for dealing with the @@ -1195,7 +1223,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset iterator. + supported when `x` is a dataset or a dataset iterator. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). @@ -1252,7 +1280,8 @@ class Model(Network): # Prepare validation data. if validation_data: if (isinstance(validation_data, iterator_ops.Iterator) or - isinstance(validation_data, iterator_ops.EagerIterator)): + isinstance(validation_data, iterator_ops.EagerIterator) or + isinstance(validation_data, dataset_ops.Dataset)): val_x = validation_data val_y = None val_sample_weight = None @@ -1266,8 +1295,9 @@ class Model(Network): 'When passing a `validation_data` argument, ' 'it must contain either 2 items (x_val, y_val), ' 'or 3 items (x_val, y_val, val_sample_weights), ' - 'or alternatively it could be a dataset iterator. However we ' - 'received `validation_data=%s`' % validation_data) + 'or alternatively it could be a dataset or a ' + 'dataset or a dataset iterator. ' + 'However we received `validation_data=%s`' % validation_data) # Validate and standardize validation data. val_x, val_y, val_sample_weights = self._standardize_user_data( @@ -1351,19 +1381,19 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified - (since targets will be obtained from the iterator). + tensor targets, or inversely). + If `x` is a dataset or a dataset iterator, `y` should not be specified + (since targets will be obtained from the iterator/dataset). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors or dataset iterators (since they generate - batches). + form of symbolic tensors, datasets, or dataset iterators + (since they generate batches). verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. sample_weight: Optional Numpy array of weights for @@ -1377,7 +1407,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset iterator. + supported when `x` is a dataset or a dataset iterator. steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. @@ -1426,13 +1456,13 @@ class Model(Network): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors or dataset iterators (since they generate - batches). + form of symbolic tensors, dataset, or dataset iterators + (since they generate batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. @@ -1473,12 +1503,12 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or a + dataset iterator, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. @@ -1487,8 +1517,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset iterator. - + supported when `x` is a dataset or a dataset iterator. class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples @@ -1537,12 +1566,12 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or a + dataset iterator, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. @@ -1551,7 +1580,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset iterator. + supported when `x` is a dataset or a dataset iterator. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1590,7 +1619,7 @@ class Model(Network): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. Returns: Numpy array(s) of predictions. @@ -1602,7 +1631,10 @@ class Model(Network): # Validate and standardize user data. inputs, _, _ = self._standardize_user_data(x) if context.executing_eagerly(): - if not isinstance(inputs, iterator_ops.EagerIterator): + if (isinstance(x, iterator_ops.EagerIterator) or + (isinstance(x, dataset_ops.Dataset) and context.executing_eagerly())): + inputs = training_utils.cast_if_floating_dtype(inputs) + else: inputs = [ ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs ] diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py similarity index 95% rename from tensorflow/python/keras/_impl/keras/engine/training_arrays.py rename to tensorflow/python/keras/engine/training_arrays.py index 84f93da89839c35e24432c48f316626af0eab26d..e82f5c03320094348213ac3d22cc13709c6af08c 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -24,12 +24,12 @@ import copy import numpy as np from tensorflow.python.framework import errors -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import callbacks as cbks -from tensorflow.python.keras._impl.keras.engine import training_utils -from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.utils.generic_utils import make_batches +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import tf_logging as logging try: @@ -124,6 +124,12 @@ def fit_loop(model, callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] + if callbacks is not None and any( + [isinstance(callback, cbks.TensorBoard) for callback in callbacks]): + # need to create the test_function before start of the first epoch + # because TensorBoard callback on_epoch_begin adds summary to the + # list of fetches of the test_function + model._make_test_function() else: callback_metrics = copy.copy(out_labels) @@ -185,6 +191,7 @@ def fit_loop(model, callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: + # Step-wise fit loop. for step_index in range(steps_per_epoch): batch_logs = {} batch_logs['batch'] = step_index @@ -215,7 +222,6 @@ def fit_loop(model, val_inputs, val_targets, sample_weights=val_sample_weights, - batch_size=batch_size, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): @@ -224,6 +230,7 @@ def fit_loop(model, for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o else: + # Sample-wise fit loop. if shuffle == 'batch': index_array = training_utils.batch_shuffle(index_array, batch_size) elif shuffle: diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/engine/training_eager.py rename to tensorflow/python/keras/engine/training_eager.py index 0a98fc2452937771e5a23a7ad672b56c9a793618..e8838cd3bca7b3afba80504f9e705943474423c5 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -28,12 +28,12 @@ from tensorflow.python.eager.backprop import GradientTape from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util -from tensorflow.python.keras._impl.keras import backend -from tensorflow.python.keras._impl.keras import callbacks as cbks -from tensorflow.python.keras._impl.keras import losses -from tensorflow.python.keras._impl.keras import metrics as metrics_module -from tensorflow.python.keras._impl.keras.engine import training_utils -from tensorflow.python.keras._impl.keras.utils import generic_utils +from tensorflow.python.keras import backend +from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras import losses +from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -255,6 +255,8 @@ def iterator_fit_loop(model, # Validate and standardize data. x, y, sample_weights = model._standardize_user_data( x, y, class_weight=class_weight) + x = training_utils.cast_if_floating_dtype(x) + y = training_utils.cast_if_floating_dtype(y) if sample_weights: sample_weights = [ ops.convert_to_tensor(val, dtype=backend.floatx()) @@ -471,6 +473,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0): # Validate and standardize data. x, y, sample_weights = model._standardize_user_data(x, y) + x = training_utils.cast_if_floating_dtype(x) + y = training_utils.cast_if_floating_dtype(y) # Calculate model output, loss values. loss_outs, loss, loss_metrics = _model_loss( @@ -639,6 +643,7 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): # Validate and standardize data. x, _, _ = model._standardize_user_data(x) + x = training_utils.cast_if_floating_dtype(x) if model._expects_training_arg: batch_outs = model.call(x[0] if len(x) == 1 else x, training=False) @@ -814,7 +819,10 @@ def train_on_batch(model, inputs, targets, sample_weights=None): Returns: total loss and the loss associated with each output. """ - if len(inputs) and not tensor_util.is_tensor(inputs[0]): + if len(inputs) and tensor_util.is_tensor(inputs[0]): + inputs = training_utils.cast_if_floating_dtype(inputs) + targets = training_utils.cast_if_floating_dtype(targets) + else: inputs = [ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs ] @@ -849,7 +857,10 @@ def test_on_batch(model, inputs, targets, sample_weights=None): Returns: total loss, loss and metrics associated with each output. """ - if len(inputs) and not tensor_util.is_tensor(inputs[0]): + if len(inputs) and tensor_util.is_tensor(inputs[0]): + inputs = training_utils.cast_if_floating_dtype(inputs) + targets = training_utils.cast_if_floating_dtype(targets) + else: inputs = [ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs ] diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/engine/training_eager_test.py rename to tensorflow/python/keras/engine/training_eager_test.py index 2031a8a3dc9731ae8f56497c447b55905ff9af78..bdb30351290644e2f7e8135c047ef6732054a08a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -21,10 +21,10 @@ from __future__ import print_function import numpy as np from tensorflow.python.data.ops import dataset_ops +from tensorflow.python import keras from tensorflow.python.framework import ops from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -413,7 +413,7 @@ class TrainingTest(test.TestCase): y = np.random.random((10, 4)) def iterator(): - while 1: + while True: yield x, y model.fit_generator(iterator(), steps_per_epoch=3, epochs=1) @@ -647,7 +647,7 @@ class LossWeightingTest(test.TestCase): class CorrectnessTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -668,7 +668,7 @@ class CorrectnessTest(test.TestCase): self.assertEqual( np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness(self): model = keras.Sequential() model.add(keras.layers.Dense(3, @@ -689,7 +689,7 @@ class CorrectnessTest(test.TestCase): outs = model.evaluate(x, y) self.assertEqual(outs[1], 0.) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness_with_iterator(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -712,7 +712,7 @@ class CorrectnessTest(test.TestCase): history = model.fit(iterator, epochs=1, steps_per_epoch=10) self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness_with_iterator(self): model = keras.Sequential() model.add( diff --git a/tensorflow/python/keras/_impl/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/engine/training_generator.py rename to tensorflow/python/keras/engine/training_generator.py index 0de8297795877ca50565ef7ee4a2a346750312f9..d81b384f0e1810614bd98e3861b4324f0f8a4dca 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -21,12 +21,12 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import callbacks as cbks -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.platform import tf_logging as logging diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py similarity index 91% rename from tensorflow/python/keras/_impl/keras/engine/training_test.py rename to tensorflow/python/keras/engine/training_test.py index 4b01fbb165ace0a9dbc562ec3ff4ad4938484d1b..d9e548f01f86fd96c3abd7b3cdaf5106653393fd 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -23,14 +23,14 @@ import unittest import numpy as np +from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.keras._impl.keras.engine.training_utils import weighted_masked_objective -from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine.training_utils import weighted_masked_objective +from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -129,8 +129,10 @@ class TrainingTest(test.TestCase): { 'input_a': input_a_np, 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, + }, { + 'dense': output_d_np, + 'dropout': output_e_np + }, epochs=1, batch_size=5, verbose=0) @@ -138,8 +140,10 @@ class TrainingTest(test.TestCase): { 'input_a': input_a_np, 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, + }, { + 'dense': output_d_np, + 'dropout': output_e_np + }, epochs=1, batch_size=5, verbose=1) @@ -147,8 +151,10 @@ class TrainingTest(test.TestCase): { 'input_a': input_a_np, 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, + }, { + 'dense': output_d_np, + 'dropout': output_e_np + }, validation_data=({ 'input_a': input_a_np, 'input_b': input_b_np @@ -162,8 +168,10 @@ class TrainingTest(test.TestCase): model.train_on_batch({ 'input_a': input_a_np, 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}) + }, { + 'dense': output_d_np, + 'dropout': output_e_np + }) # Test with lists for loss, metrics loss = ['mae', 'mse'] @@ -285,16 +293,20 @@ class TrainingTest(test.TestCase): { 'input_a': input_a_np, 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, + }, { + 'dense': output_d_np, + 'dropout': output_e_np + }, batch_size=5, verbose=0) model.evaluate( { 'input_a': input_a_np, 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, + }, { + 'dense': output_d_np, + 'dropout': output_e_np + }, batch_size=5, verbose=1) @@ -349,9 +361,11 @@ class TrainingTest(test.TestCase): with self.test_session(): test_inputs = [ - scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)] + scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2) + ] test_outputs = [ - scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)] + scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5) + ] in1 = keras.layers.Input(shape=(3,)) in2 = keras.layers.Input(shape=(3,)) out1 = keras.layers.Dropout(0.5, name='dropout')(in1) @@ -1682,7 +1696,7 @@ class TestTrainingWithDataTensors(test.TestCase): model.train_on_batch([input_a_np, input_b_np], [output_a_np, output_b_np]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metric_names_are_identical_in_graph_and_eager(self): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b') @@ -1709,7 +1723,7 @@ class TestTrainingWithDataTensors(test.TestCase): class TestTrainingWithDatasetIterators(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_training_and_eval_methods_on_iterators_single_io(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1721,8 +1735,8 @@ class TestTrainingWithDatasetIterators(test.TestCase): metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) + inputs = np.zeros((10, 3)) + targets = np.zeros((10, 4)) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) dataset = dataset.batch(10) @@ -1742,7 +1756,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): # Test with validation split with self.assertRaisesRegexp( ValueError, '`validation_split` argument is not supported ' - 'when input `x` is a dataset iterator'): + 'when input `x` is a dataset or a dataset iterator'): model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=0, validation_split=0.5, validation_steps=2) @@ -1751,7 +1765,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): sample_weight = np.random.random((10,)) with self.assertRaisesRegexp( ValueError, '`sample_weight` argument is not supported ' - 'when input `x` is a dataset iterator'): + 'when input `x` is a dataset or a dataset iterator'): model.fit( iterator, epochs=1, @@ -1760,10 +1774,6 @@ class TestTrainingWithDatasetIterators(test.TestCase): sample_weight=sample_weight) # Test invalid usage - with self.assertRaisesRegexp(ValueError, - 'Instead, pass an `Iterator`'): - model.fit(dataset, - epochs=1, steps_per_epoch=2, verbose=0) with self.assertRaisesRegexp(ValueError, 'you should not specify a target'): model.fit(iterator, iterator, @@ -1790,8 +1800,8 @@ class TestTrainingWithDatasetIterators(test.TestCase): metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) + inputs = np.zeros((10, 3)) + targets = np.zeros((10, 4)) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) dataset = dataset.batch(10) @@ -1803,7 +1813,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): ops.get_default_graph().finalize() model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_iterators_running_out_of_data(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1815,8 +1825,8 @@ class TestTrainingWithDatasetIterators(test.TestCase): metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) + inputs = np.zeros((10, 3)) + targets = np.zeros((10, 4)) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(2) dataset = dataset.batch(10) @@ -1829,5 +1839,129 @@ class TestTrainingWithDatasetIterators(test.TestCase): 'dataset iterator ran out of data') +class TestTrainingWithDataset(test.TestCase): + + def test_calling_model_on_same_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3)) + targets = np.zeros((10, 4)) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + # Finalize the graph to make sure new ops aren't added when calling on the + # same dataset + ops.get_default_graph().finalize() + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + @tf_test_util.run_in_graph_and_eager_modes + def test_training_and_eval_methods_on_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3)) + targets = np.zeros((10, 4)) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + model.train_on_batch(dataset) + model.predict_on_batch(dataset) + + # Test with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not supported ' + 'when input `x` is a dataset or a dataset iterator'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported ' + 'when input `x` is a dataset or a dataset iterator'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test invalid usage + with self.assertRaisesRegexp(ValueError, + 'you should not specify a target'): + model.fit(dataset, dataset, + epochs=1, steps_per_epoch=2, verbose=0) + + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + def test_dataset_input_shape_validation(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3)) + targets = np.zeros((10, 4)) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have 2 dimensions'): + model.train_on_batch(dataset) + + # Wrong input shape + inputs = np.zeros((10, 5)) + targets = np.zeros((10, 4)) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.train_on_batch(dataset) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py similarity index 93% rename from tensorflow/python/keras/_impl/keras/engine/training_utils.py rename to tensorflow/python/keras/engine/training_utils.py index 04d80c891ff2a145b39fd52e23a1bc027f1ff722..728a2b493b9f076cc2942766d2677c1f24fb3c15 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -25,9 +25,9 @@ import numpy as np from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import tensor_util -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import losses -from tensorflow.python.keras._impl.keras import metrics as metrics_module +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import losses +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.ops import math_ops @@ -166,10 +166,16 @@ def standardize_input_data(data, # Check shapes compatibility. if shapes: for i in range(len(names)): - if shapes[i] is not None and not tensor_util.is_tensor(data[i]): - data_shape = data[i].shape + if shapes[i] is not None: + if tensor_util.is_tensor(data[i]): + tensorshape = data[i].get_shape() + if not tensorshape: + continue + data_shape = tuple(tensorshape.as_list()) + else: + data_shape = data[i].shape shape = shapes[i] - if data[i].ndim != len(shape): + if len(data_shape) != len(shape): raise ValueError('Error when checking ' + exception_prefix + ': expected ' + names[i] + ' to have ' + str(len(shape)) + ' dimensions, but got array ' @@ -178,7 +184,7 @@ def standardize_input_data(data, data_shape = data_shape[1:] shape = shape[1:] for dim, ref_dim in zip(data_shape, shape): - if ref_dim != dim and ref_dim: + if ref_dim != dim and ref_dim is not None and dim is not None: raise ValueError( 'Error when checking ' + exception_prefix + ': expected ' + names[i] + ' to have shape ' + str(shape) + @@ -547,6 +553,10 @@ def standardize_weights(y, def has_symbolic_tensors(ls): if context.executing_eagerly(): return False + return has_tensors(ls) + + +def has_tensors(ls): if isinstance(ls, (list, tuple)): return any(tensor_util.is_tensor(v) for v in ls) return tensor_util.is_tensor(ls) @@ -632,19 +642,20 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None): provided by user. """ if y is not None: - raise ValueError('You passed a dataset iterator (%s) as input `x` to ' - 'your model. In that case, you should not specify ' - 'a target (`y`) argument, since the dataset iterator ' - 'generates both input data and target data. ' + raise ValueError('You passed a dataset or dataset iterator (%s) as ' + 'input `x` to your model. In that case, you should ' + 'not specify a target (`y`) argument, since the dataset ' + 'or dataset iterator generates both input data and ' + 'target data. ' 'Received: %s' % (x, y)) if sample_weight is not None: - raise ValueError('`sample_weight` argument is not supported when input' - ' `x` is a dataset iterator. ' + raise ValueError('`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator. ' 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) if validation_split is not None and validation_split != 0.0: raise ValueError( '`validation_split` argument is not supported when ' - 'input `x` is a dataset iterator. ' + 'input `x` is a dataset or a dataset iterator. ' 'Received: x=%s, validation_split=%f' % (x, validation_split)) @@ -685,3 +696,29 @@ def check_steps_argument(input_data, steps, steps_name): input_type=input_type_str, steps_name=steps_name)) return True return False + + +def cast_if_floating_dtype(x): + """Casts the given data tensors to the default floating point type. + + Casts only if the input is already a floating point type. + Args: + x: tensor or list/tuple of tensors. + + Returns: + Converted input. + + Raises: + RuntimeError: if data isn't tensors. + """ + if not has_tensors(x): + raise RuntimeError( + 'Please provide tensors for casting, got: {x}'.format(x=x)) + + if isinstance(x, (list, tuple)): + return [ + math_ops.cast(val, dtype=K.floatx()) + if tensor_util.is_tensor(val) and val.dtype.is_floating else val + for val in x + ] + return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b244beb5b58cf339a4687216b87418c88b953c17 --- /dev/null +++ b/tensorflow/python/keras/estimator/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras estimator API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.tf_export import tf_export + +# Keras has undeclared dependency on tensorflow/estimator:estimator_py. +# As long as you depend //third_party/py/tensorflow:tensorflow target +# everything will work as normal. + +try: + from tensorflow.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top + model_to_estimator = tf_export('keras.estimator.model_to_estimator')( + keras_lib.model_to_estimator) +except Exception: # pylint: disable=broad-except + + # pylint: disable=unused-argument + def stub_model_to_estimator(keras_model=None, + keras_model_path=None, + custom_objects=None, + model_dir=None, + config=None): + raise NotImplementedError( + 'tf.keras.estimator.model_to_estimator function not available in your ' + 'installation.') + # pylint: enable=unused-argument + + model_to_estimator = tf_export('keras.estimator.model_to_estimator')( + stub_model_to_estimator) + diff --git a/tensorflow/python/keras/_impl/keras/initializers.py b/tensorflow/python/keras/initializers.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/initializers.py rename to tensorflow/python/keras/initializers.py index ecb71d00e2c78ced6095aaa3a0180b454b04917a..b9b2e9ad598fabe8cbfbbcbd57d4d71ddf630df7 100644 --- a/tensorflow/python/keras/_impl/keras/initializers.py +++ b/tensorflow/python/keras/initializers.py @@ -20,8 +20,8 @@ from __future__ import print_function import six -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops.init_ops import Constant from tensorflow.python.ops.init_ops import Identity from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import diff --git a/tensorflow/python/keras/initializers/__init__.py b/tensorflow/python/keras/initializers/__init__.py deleted file mode 100644 index 6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/initializers/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in initializers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros - -# Functional interface. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform - -# Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py similarity index 87% rename from tensorflow/python/keras/_impl/keras/initializers_test.py rename to tensorflow/python/keras/initializers_test.py index 7b4e6b4d5b115bc788469bf1afe2a43f8dd86f04..c519e194bdc21692025f259533b8b75e2dc48c09 100644 --- a/tensorflow/python/keras/_impl/keras/initializers_test.py +++ b/tensorflow/python/keras/initializers_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.ops import init_ops from tensorflow.python.platform import test @@ -71,7 +71,7 @@ class KerasInitializersTest(test.TestCase): stddev=1, seed=126), tensor_shape, - target_mean=0., target_std=None, target_max=2) + target_mean=0., target_max=2, target_min=-2) def test_constant(self): tensor_shape = (5, 6, 4) @@ -83,49 +83,49 @@ class KerasInitializersTest(test.TestCase): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(3. / fan_in) + std = np.sqrt(1. / fan_in) self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + target_mean=0., target_std=std) def test_glorot_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, fan_out = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(6. / (fan_in + fan_out)) + std = np.sqrt(2. / (fan_in + fan_out)) self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + target_mean=0., target_std=std) def test_he_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(6. / fan_in) + std = np.sqrt(2. / fan_in) self._runner(keras.initializers.he_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + target_mean=0., target_std=std) def test_lecun_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(1. / fan_in) + std = np.sqrt(1. / fan_in) self._runner(keras.initializers.lecun_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + target_mean=0., target_std=std) def test_glorot_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, fan_out = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(2. / (fan_in + fan_out)) + std = np.sqrt(2. / (fan_in + fan_out)) self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + target_mean=0., target_std=std) def test_he_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(2. / fan_in) + std = np.sqrt(2. / fan_in) self._runner(keras.initializers.he_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + target_mean=0., target_std=std) def test_orthogonal(self): tensor_shape = (20, 20) diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/integration_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/integration_test.py rename to tensorflow/python/keras/integration_test.py index 43aff67ef93c8ec495beafdd17c5557b6398671f..2a05699407cc608c1ed0dd97d230beeb6e99e0ef 100644 --- a/tensorflow/python/keras/_impl/keras/integration_test.py +++ b/tensorflow/python/keras/integration_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.layers import core as tf_core_layers from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -29,6 +29,9 @@ from tensorflow.python.platform import test class KerasIntegrationTest(test.TestCase): + def test_version(self): + self.assertTrue(keras.__version__.endswith('-tf')) + def test_vector_classification_sequential(self): with self.test_session(): np.random.seed(1337) diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index c7be8b918c11235b7316e125cd7a9796851ad083..e3a686f45d92dde8ea90d496b3cb5099f6b84b58 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -20,141 +20,150 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import Softmax +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ReLU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import Softmax # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv1D +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import DepthwiseConv2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution1D +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import DepthwiseConv2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Subtract +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Minimum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import subtract +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import minimum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import RNN -from tensorflow.python.keras._impl.keras.layers.recurrent import StackedRNNCells -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNNCell -from tensorflow.python.keras._impl.keras.layers.recurrent import GRUCell -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTMCell -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import RNN +from tensorflow.python.keras.layers.recurrent import StackedRNNCells +from tensorflow.python.keras.layers.recurrent import SimpleRNNCell +from tensorflow.python.keras.layers.recurrent import GRUCell +from tensorflow.python.keras.layers.recurrent import LSTMCell +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # CuDNN recurrent layers. -from tensorflow.python.keras._impl.keras.layers.cudnn_recurrent import CuDNNLSTM -from tensorflow.python.keras._impl.keras.layers.cudnn_recurrent import CuDNNGRU +from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNLSTM +from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNGRU # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed + +# Serialization functions +from tensorflow.python.keras.layers.serialization import deserialize +from tensorflow.python.keras.layers.serialization import serialize del absolute_import del division diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py similarity index 84% rename from tensorflow/python/keras/_impl/keras/layers/advanced_activations.py rename to tensorflow/python/keras/layers/advanced_activations.py index 89931db3c0786b6869379e0d140e8a19e5e46d5f..eba10da6f3ce1367f4cb0180d16efdc5913fcddc 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -278,3 +278,40 @@ class Softmax(Layer): @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): return input_shape + + +@tf_export('keras.layers.ReLU') +class ReLU(Layer): + """Rectified Linear Unit activation function. + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as the input. + + Arguments: + max_value: float >= 0. Maximum activation value. + """ + + def __init__(self, max_value=None, **kwargs): + super(ReLU, self).__init__(**kwargs) + self.support_masking = True + self.max_value = K.cast_to_floatx(max_value) + if self.max_value < 0.: + raise ValueError('max_value of Relu layer ' + 'cannot be negative value: ' + str(max_value)) + + def call(self, inputs): + return activations.relu(inputs, max_value=self.max_value) + + def get_config(self): + config = {'max_value': self.max_value} + base_config = super(ReLU, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py similarity index 78% rename from tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py rename to tensorflow/python/keras/layers/advanced_activations_test.py index 343b7949accf3f0c9ddc5245910aa5faad8335c6..9e1f15b1bc508d8be0a2c0190d07eb1c2bed95c4 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test @@ -62,6 +62,20 @@ class AdvancedActivationsTest(test.TestCase): kwargs={'axis': 1}, input_shape=(2, 3, 4)) + def test_relu(self): + with self.test_session(): + testing_utils.layer_test(keras.layers.ReLU, + kwargs={'max_value': 10}, + input_shape=(2, 3, 4)) + + def test_relu_with_invalid_arg(self): + with self.assertRaisesRegexp( + ValueError, 'max_value of Relu layer cannot be negative value: -10'): + with self.test_session(): + testing_utils.layer_test(keras.layers.ReLU, + kwargs={'max_value': -10}, + input_shape=(2, 3, 4)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/layers/convolutional.py rename to tensorflow/python/keras/layers/convolutional.py index e47aaf9caccd4717dfac168105193485534ab6af..a57ac121ed7486a9beb64e6dd7ed3b132ca258df 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -21,24 +21,24 @@ from __future__ import print_function from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import backend -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer # imports for backwards namespace compatibility # pylint: disable=unused-import -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D # pylint: enable=unused-import -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops @@ -151,21 +151,23 @@ class Conv(Layer): input_dim = int(input_shape[channel_axis]) kernel_shape = self.kernel_size + (input_dim, self.filters) - self.kernel = self.add_variable(name='kernel', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - trainable=True, - dtype=self.dtype) + self.kernel = self.add_weight( + name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + dtype=self.dtype) if self.use_bias: - self.bias = self.add_variable(name='bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - trainable=True, - dtype=self.dtype) + self.bias = self.add_weight( + name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype) else: self.bias = None self.input_spec = InputSpec(ndim=self.rank + 2, @@ -380,11 +382,11 @@ class Conv2D(Conv): filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -611,11 +613,11 @@ class Conv2DTranspose(Conv2D): filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -720,21 +722,23 @@ class Conv2DTranspose(Conv2D): self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) kernel_shape = self.kernel_size + (self.filters, input_dim) - self.kernel = self.add_variable(name='kernel', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - trainable=True, - dtype=self.dtype) + self.kernel = self.add_weight( + name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + dtype=self.dtype) if self.use_bias: - self.bias = self.add_variable(name='bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - trainable=True, - dtype=self.dtype) + self.bias = self.add_weight( + name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype) else: self.bias = None self.built = True @@ -961,7 +965,7 @@ class Conv3DTranspose(Conv3D): kernel_shape = self.kernel_size + (self.filters, input_dim) self.input_spec = InputSpec(ndim=5, axes={channel_axis: input_dim}) - self.kernel = self.add_variable( + self.kernel = self.add_weight( 'kernel', shape=kernel_shape, initializer=self.kernel_initializer, @@ -970,7 +974,7 @@ class Conv3DTranspose(Conv3D): trainable=True, dtype=self.dtype) if self.use_bias: - self.bias = self.add_variable( + self.bias = self.add_weight( 'bias', shape=(self.filters,), initializer=self.bias_initializer, @@ -1191,6 +1195,7 @@ class SeparableConv(Conv): dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, + bias_initializer=initializers.get(bias_initializer), bias_regularizer=regularizers.get(bias_regularizer), activity_regularizer=regularizers.get(activity_regularizer), bias_constraint=bias_constraint, @@ -1222,7 +1227,7 @@ class SeparableConv(Conv): pointwise_kernel_shape = ( 1,) * self.rank + (self.depth_multiplier * input_dim, self.filters) - self.depthwise_kernel = self.add_variable( + self.depthwise_kernel = self.add_weight( name='depthwise_kernel', shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, @@ -1230,7 +1235,7 @@ class SeparableConv(Conv): constraint=self.depthwise_constraint, trainable=True, dtype=self.dtype) - self.pointwise_kernel = self.add_variable( + self.pointwise_kernel = self.add_weight( name='pointwise_kernel', shape=pointwise_kernel_shape, initializer=self.pointwise_initializer, @@ -1239,13 +1244,14 @@ class SeparableConv(Conv): trainable=True, dtype=self.dtype) if self.use_bias: - self.bias = self.add_variable(name='bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - trainable=True, - dtype=self.dtype) + self.bias = self.add_weight( + name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype) else: self.bias = None self.built = True @@ -1447,11 +1453,11 @@ class SeparableConv2D(SeparableConv): filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -1591,11 +1597,11 @@ class DepthwiseConv2D(Conv2D): Arguments: kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -1724,7 +1730,7 @@ class DepthwiseConv2D(Conv2D): dilation_rate=self.dilation_rate, data_format=self.data_format) - if self.bias: + if self.use_bias: outputs = backend.bias_add( outputs, self.bias, @@ -2002,7 +2008,7 @@ class ZeroPadding2D(Layer): Arguments: padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - If int: the same symmetric padding - is applied to width and height. + is applied to height and width. - If tuple of 2 ints: interpreted as two different symmetric padding values for height and width: @@ -2101,7 +2107,7 @@ class ZeroPadding3D(Layer): Arguments: padding: int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. - If int: the same symmetric padding - is applied to width and height. + is applied to height and width. - If tuple of 3 ints: interpreted as two different symmetric padding values for height and width: @@ -2261,12 +2267,12 @@ class Cropping1D(Layer): class Cropping2D(Layer): """Cropping layer for 2D input (e.g. picture). - It crops along spatial dimensions, i.e. width and height. + It crops along spatial dimensions, i.e. height and width. Arguments: cropping: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - If int: the same symmetric cropping - is applied to width and height. + is applied to height and width. - If tuple of 2 ints: interpreted as two different symmetric cropping values for height and width: diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py rename to tensorflow/python/keras/layers/convolutional_recurrent.py index 9cad08274e58d6462f03edd932c782c7b3fbbbaa..84d794cada86b15755c28592d4c8093a4d3ef87e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -21,19 +21,19 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.layers.recurrent import _generate_dropout_mask -from tensorflow.python.keras._impl.keras.layers.recurrent import _standardize_args -from tensorflow.python.keras._impl.keras.layers.recurrent import RNN -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask +from tensorflow.python.keras.layers.recurrent import _standardize_args +from tensorflow.python.keras.layers.recurrent import RNN +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/layers/convolutional_recurrent_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py rename to tensorflow/python/keras/layers/convolutional_recurrent_test.py index 827a7ffbdae676ef1263a19f490527507952418f..4b8f6f2a14e490c976d23463283bc4b81333ff92 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/layers/convolutional_test.py rename to tensorflow/python/keras/layers/convolutional_test.py index 12b42676759d499c910707cb1b78e788e3c443fd..f904744422a4b1296e8f5e8a34373fd0344dc643 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -22,10 +22,10 @@ import copy import numpy as np +from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test @@ -45,7 +45,7 @@ class Convolution1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv1d(self): kwargs = { 'filters': 2, @@ -117,7 +117,7 @@ class Conv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv2d(self): kwargs = { 'filters': 2, @@ -192,7 +192,7 @@ class Conv2DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv2dtranspose(self): kwargs = { 'filters': 2, @@ -258,7 +258,7 @@ class Conv3DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv3dtranspose(self): kwargs = { 'filters': 2, @@ -322,7 +322,7 @@ class SeparableConv1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_separable_conv1d(self): kwargs = { 'filters': 2, @@ -398,7 +398,7 @@ class SeparableConv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_separable_conv2d(self): kwargs = { 'filters': 2, @@ -477,7 +477,7 @@ class Conv3DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv3d(self): kwargs = { 'filters': 2, @@ -529,7 +529,7 @@ class Conv3DTest(test.TestCase): class ZeroPaddingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_1d(self): num_samples = 2 input_dim = 2 @@ -581,7 +581,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding1D(padding=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_2d(self): num_samples = 2 stack_size = 2 @@ -660,7 +660,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding2D(padding=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_3d(self): num_samples = 2 stack_size = 2 @@ -702,13 +702,13 @@ class ZeroPaddingTest(test.TestCase): class UpSamplingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_1d(self): with self.test_session(use_gpu=True): testing_utils.layer_test( keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_2d(self): num_samples = 2 stack_size = 2 @@ -758,7 +758,7 @@ class UpSamplingTest(test.TestCase): np.testing.assert_allclose(np_output, expected_out) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_3d(self): num_samples = 2 stack_size = 2 @@ -818,7 +818,7 @@ class UpSamplingTest(test.TestCase): class CroppingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_1d(self): num_samples = 2 time_length = 4 @@ -837,7 +837,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping1D(cropping=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_2d(self): num_samples = 2 stack_size = 2 @@ -905,7 +905,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping2D(cropping=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_3d(self): num_samples = 2 stack_size = 2 @@ -995,6 +995,7 @@ class DepthwiseConv2DTest(test.TestCase): 'bias_regularizer': 'l2', 'activity_regularizer': 'l2', 'depthwise_constraint': 'unit_norm', + 'use_bias': True, 'strides': (2, 2), } self._run_test(kwargs, 'depth_multiplier', [1]) diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/layers/core.py similarity index 93% rename from tensorflow/python/keras/_impl/keras/layers/core.py rename to tensorflow/python/keras/layers/core.py index 30327781dffc679da0814d9ef550e51c06d1cada..2bf6229ccba808360e73a333bdec3dac624d81ce 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -19,23 +19,25 @@ from __future__ import division from __future__ import print_function import copy +import sys import types as python_types +import warnings import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -714,6 +716,7 @@ class Lambda(Layer): return self.mask def get_config(self): + module = self.function.__module__ if isinstance(self.function, python_types.LambdaType): function = generic_utils.func_dump(self.function) function_type = 'lambda' @@ -721,21 +724,26 @@ class Lambda(Layer): function = self.function.__name__ function_type = 'function' + output_shape_module = None if isinstance(self._output_shape, python_types.LambdaType): output_shape = generic_utils.func_dump(self._output_shape) output_shape_type = 'lambda' + output_shape_module = self._output_shape.__module__ elif callable(self._output_shape): output_shape = self._output_shape.__name__ output_shape_type = 'function' + output_shape_module = self._output_shape.__module__ else: output_shape = self._output_shape output_shape_type = 'raw' config = { 'function': function, + 'module': module, 'function_type': function_type, 'output_shape': output_shape, 'output_shape_type': output_shape_type, + 'output_shape_module': output_shape_module, 'arguments': self.arguments } base_config = super(Lambda, self).get_config() @@ -745,8 +753,16 @@ class Lambda(Layer): def from_config(cls, config, custom_objects=None): config = config.copy() globs = globals() + module = config.pop('module', None) + if module in sys.modules: + globs.update(sys.modules[module].__dict__) + elif module is not None: + # Note: we don't know the name of the function if it's a lambda. + warnings.warn('{} is not loaded, but a Lambda layer uses it. ' + 'It may cause errors.'.format(module) + , UserWarning) if custom_objects: - globs = dict(list(globs.items()) + list(custom_objects.items())) + globs.update(custom_objects) function_type = config.pop('function_type') if function_type == 'function': # Simple lookup in custom objects @@ -760,6 +776,14 @@ class Lambda(Layer): else: raise TypeError('Unknown function type:', function_type) + output_shape_module = config.pop('output_shape_module', None) + if output_shape_module in sys.modules: + globs.update(sys.modules[output_shape_module].__dict__) + elif output_shape_module is not None: + # Note: we don't know the name of the function if it's a lambda. + warnings.warn('{} is not loaded, but a Lambda layer uses it. ' + 'It may cause errors.'.format(output_shape_module) + , UserWarning) output_shape_type = config.pop('output_shape_type') if output_shape_type == 'function': # Simple lookup in custom objects @@ -882,21 +906,23 @@ class Dense(Layer): 'should be defined. Found `None`.') self.input_spec = InputSpec(min_ndim=2, axes={-1: input_shape[-1].value}) - self.kernel = self.add_variable('kernel', - shape=[input_shape[-1].value, self.units], - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - dtype=self.dtype, - trainable=True) + self.kernel = self.add_weight( + 'kernel', + shape=[input_shape[-1].value, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True) if self.use_bias: - self.bias = self.add_variable('bias', - shape=[self.units,], - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - dtype=self.dtype, - trainable=True) + self.bias = self.add_weight( + 'bias', + shape=[self.units,], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True) else: self.bias = None self.built = True diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/layers/core_test.py rename to tensorflow/python/keras/layers/core_test.py index 9b360b65d6336d27e868b4b343b411b8ab7db917..226403c5927ed22394b708178679d1efa11dd790 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -51,7 +51,7 @@ class CoreLayersTest(test.TestCase): dropout = keras.layers.Dropout(0.5) self.assertEqual(True, dropout.supports_masking) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_spatial_dropout(self): testing_utils.layer_test( keras.layers.SpatialDropout1D, @@ -78,7 +78,7 @@ class CoreLayersTest(test.TestCase): kwargs={'rate': 0.5, 'data_format': 'channels_first'}, input_shape=(2, 3, 4, 4, 5)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_activation(self): # with string argument testing_utils.layer_test( @@ -92,7 +92,7 @@ class CoreLayersTest(test.TestCase): kwargs={'activation': keras.backend.relu}, input_shape=(3, 2)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_reshape(self): testing_utils.layer_test( keras.layers.Reshape, @@ -114,12 +114,12 @@ class CoreLayersTest(test.TestCase): kwargs={'target_shape': (-1, 1)}, input_shape=(None, None, 2)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_permute(self): testing_utils.layer_test( keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_flatten(self): testing_utils.layer_test( keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) @@ -134,7 +134,7 @@ class CoreLayersTest(test.TestCase): np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3)) self.assertAllClose(outputs, target_outputs) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_repeat_vector(self): testing_utils.layer_test( keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) @@ -173,7 +173,7 @@ class CoreLayersTest(test.TestCase): config = ld.get_config() ld = keras.layers.Lambda.from_config(config) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dense(self): testing_utils.layer_test( keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) diff --git a/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent.py rename to tensorflow/python/keras/layers/cudnn_recurrent.py index ffb90457a85bb801d766e144f45c044e0a7e3bb0..cf2b0c476c7229a288f4b4f7b31de09388ade40f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -20,12 +20,13 @@ from __future__ import print_function import collections -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.layers.recurrent import RNN +from tensorflow.python.framework import constant_op +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.layers.recurrent import RNN from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import state_ops @@ -71,10 +72,11 @@ class _CuDNNRNN(RNN): self.constants_spec = None self._states = None self._num_constants = None + self._vector_shape = constant_op.constant([-1]) def _canonical_to_params(self, weights, biases): - weights = [array_ops.reshape(x, (-1,)) for x in weights] - biases = [array_ops.reshape(x, (-1,)) for x in biases] + weights = [array_ops.reshape(x, self._vector_shape) for x in weights] + biases = [array_ops.reshape(x, self._vector_shape) for x in biases] return array_ops.concat(weights + biases, axis=0) def call(self, inputs, mask=None, training=None, initial_state=None): diff --git a/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py similarity index 76% rename from tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py rename to tensorflow/python/keras/layers/cudnn_recurrent_test.py index ad25eb226c82c5dd7129d94b348da44f3d90e0fa..8fd970239f205031954c728474abdf10ea80e99e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py @@ -18,19 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile from absl.testing import parameterized import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer class CuDNNTest(test.TestCase, parameterized.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_cudnn_rnn_basics(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): @@ -58,7 +60,7 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): 'go_backwards': go_backwards}, input_shape=(num_samples, timesteps, input_size)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainability(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): @@ -217,27 +219,14 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): out5 = model.predict(np.ones((num_samples, timesteps))) self.assertNotEqual(out4.max(), out5.max()) - # TODO(psv): Add generic cross product helper function for parametrized tests. @parameterized.named_parameters( - ('cudnnlstm_to_lstm_unidirectional_impl_1', 'LSTM', False, False, 1), - ('cudnnlstm_to_lstm_bidirectional_impl_1', 'LSTM', False, True, 1), - ('lstm_to_cudnnlstm_unidirectional_impl_1', 'LSTM', True, False, 1), - ('lstm_to_cudnnlstm_bidirectional_impl_1', 'LSTM', True, True, 1), - ('cudnngru_to_gru_unidirectional_impl_1', 'GRU', False, False, 1), - ('cudnngru_to_gru_bidirectional_impl_1', 'GRU', False, True, 1), - ('gru_to_cudnngru_unidirectional_impl_1', 'GRU', True, False, 1), - ('gru_to_cudnngru_bidirectional_impl_1', 'GRU', True, True, 1), - ('cudnnlstm_to_lstm_unidirectional_impl_2', 'LSTM', False, False, 2), - ('cudnnlstm_to_lstm_bidirectional_impl_2', 'LSTM', False, True, 2), - ('lstm_to_cudnnlstm_unidirectional_impl_2', 'LSTM', True, False, 2), - ('lstm_to_cudnnlstm_bidirectional_impl_2', 'LSTM', True, True, 2), - ('cudnngru_to_gru_unidirectional_impl_2', 'GRU', False, False, 2), - ('cudnngru_to_gru_bidirectional_impl_2', 'GRU', False, True, 2), - ('gru_to_cudnngru_unidirectional_impl_2', 'GRU', True, False, 2), - ('gru_to_cudnngru_bidirectional_impl_2', 'GRU', True, True, 2), - ) + *testing_utils.generate_combinations_with_testcase_name( + rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False], + bidirectional=[True, False], implementation=[1, 2], + model_nest_level=[1, 2], model_type=['seq', 'func'])) def test_load_weights_between_noncudnn_rnn(self, rnn_type, to_cudnn, - bidirectional, implementation): + bidirectional, implementation, + model_nest_level, model_type): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): input_size = 10 @@ -261,14 +250,6 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): cudnn_rnn_layer_class = keras.layers.CuDNNGRU rnn_layer_kwargs['reset_after'] = True - def convert_weights(source_layer, target_layer): - weights = source_layer.get_weights() - weights = keras.engine.saving.preprocess_weights_for_loading( - target_layer, weights) - target_layer.set_weights(weights) - - input_layer = keras.layers.InputLayer(input_shape) - layer = rnn_layer_class(units, **rnn_layer_kwargs) if bidirectional: layer = keras.layers.Bidirectional(layer) @@ -277,18 +258,96 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): if bidirectional: cudnn_layer = keras.layers.Bidirectional(cudnn_layer) - model = keras.models.Sequential([input_layer, layer]) - cudnn_model = keras.models.Sequential([input_layer, cudnn_layer]) + model = self._make_nested_model(input_shape, layer, model_nest_level, + model_type) + cudnn_model = self._make_nested_model(input_shape, cudnn_layer, + model_nest_level, model_type) + + if to_cudnn: + self._convert_model_weights(model, cudnn_model) + else: + self._convert_model_weights(cudnn_model, model) + + self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs), + atol=1e-4) + + def _make_nested_model(self, input_shape, layer, level=1, model_type='func'): + # example: make_nested_seq_model((1,), Dense(10), level=2).summary() + def make_nested_seq_model(input_shape, layer, level=1): + model = layer + for i in range(1, level + 1): + layers = [keras.layers.InputLayer(input_shape), + model] if (i == 1) else [model] + model = keras.models.Sequential(layers) + return model + + # example: make_nested_func_model((1,), Dense(10), level=2).summary() + def make_nested_func_model(input_shape, layer, level=1): + model_input = keras.layers.Input(input_shape) + model = layer + for _ in range(level): + model = keras.models.Model(model_input, model(model_input)) + return model + + if model_type == 'func': + return make_nested_func_model(input_shape, layer, level) + elif model_type == 'seq': + return make_nested_seq_model(input_shape, layer, level) + + def _convert_model_weights(self, source_model, target_model): + _, fname = tempfile.mkstemp('.h5') + source_model.save_weights(fname) + target_model.load_weights(fname) + os.remove(fname) + + @parameterized.named_parameters( + *testing_utils.generate_combinations_with_testcase_name( + rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False])) + def test_load_weights_between_noncudnn_rnn_time_distributed(self, rnn_type, + to_cudnn): + # Similar test as test_load_weights_between_noncudnn_rnn() but has different + # rank of input due to usage of TimeDistributed. Issue: #10356. + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + input_size = 10 + steps = 6 + timesteps = 6 + input_shape = (timesteps, steps, input_size) + units = 2 + num_samples = 32 + inputs = np.random.random((num_samples, timesteps, steps, input_size)) + + rnn_layer_kwargs = { + 'recurrent_activation': 'sigmoid', + # ensure biases are non-zero and properly converted + 'bias_initializer': 'random_uniform', + } + if rnn_type == 'LSTM': + rnn_layer_class = keras.layers.LSTM + cudnn_rnn_layer_class = keras.layers.CuDNNLSTM + else: + rnn_layer_class = keras.layers.GRU + cudnn_rnn_layer_class = keras.layers.CuDNNGRU + rnn_layer_kwargs['reset_after'] = True + + layer = rnn_layer_class(units, **rnn_layer_kwargs) + layer = keras.layers.TimeDistributed(layer) + + cudnn_layer = cudnn_rnn_layer_class(units) + cudnn_layer = keras.layers.TimeDistributed(cudnn_layer) + + model = self._make_nested_model(input_shape, layer) + cudnn_model = self._make_nested_model(input_shape, cudnn_layer) if to_cudnn: - convert_weights(layer, cudnn_layer) + self._convert_model_weights(model, cudnn_model) else: - convert_weights(cudnn_layer, layer) + self._convert_model_weights(cudnn_model, model) - self.assertAllClose( - model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4) + self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs), + atol=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_cudnnrnn_bidirectional(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/layers/embeddings.py rename to tensorflow/python/keras/layers/embeddings.py index f7398845d400b1fd4fedd532eb1520dff30d47a0..910fff720f6312041a25922cf5c63dfa8f83ec76 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -18,12 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/layers/embeddings_test.py rename to tensorflow/python/keras/layers/embeddings_test.py index 6ebf5dc94adb423abae7ec9e6910fb86439410f1..fff1c5ef9882f0c479d119ddb0bf68e919c016b4 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/layers/embeddings_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py similarity index 95% rename from tensorflow/python/keras/_impl/keras/layers/gru_test.py rename to tensorflow/python/keras/layers/gru_test.py index 48e7e14f5ab73b534ab0d1c765ad2572b2930b2b..57f660b6d5a70b950918a3f6d75c87ecccf76f82 100644 --- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py +++ b/tensorflow/python/keras/layers/gru_test.py @@ -20,16 +20,16 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer class GRULayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_GRU(self): num_samples = 2 timesteps = 3 @@ -41,7 +41,7 @@ class GRULayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_GRU(self): num_samples = 2 timesteps = 3 @@ -55,7 +55,7 @@ class GRULayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_GRU(self): num_samples = 2 timesteps = 3 @@ -68,7 +68,7 @@ class GRULayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_GRU(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/layers/local.py similarity index 86% rename from tensorflow/python/keras/_impl/keras/layers/local.py rename to tensorflow/python/keras/layers/local.py index caae820fb3a8eba76c3fbbca734908514b076982..0ebafe07cc45698200d0b1fa858a436c7a08820e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -18,15 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.util.tf_export import tf_export @@ -62,6 +62,16 @@ class LocallyConnected1D(Layer): any `dilation_rate` value != 1. padding: Currently only supports `"valid"` (case-insensitive). `"same"` may be supported in the future. + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, length)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "channels_last". activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). @@ -122,13 +132,17 @@ class LocallyConnected1D(Layer): @tf_utils.shape_type_conversion def build(self, input_shape): - input_dim = input_shape[2] + if self.data_format == 'channels_first': + input_dim, input_length = input_shape[1], input_shape[2] + else: + input_dim, input_length = input_shape[2], input_shape[1] + if input_dim is None: raise ValueError('Axis 2 of input should be fully-defined. ' 'Found shape:', input_shape) - output_length = conv_utils.conv_output_length( - input_shape[1], self.kernel_size[0], self.padding, self.strides[0]) - self.kernel_shape = (output_length, self.kernel_size[0] * input_dim, + self.output_length = conv_utils.conv_output_length( + input_length, self.kernel_size[0], self.padding, self.strides[0]) + self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, self.filters) self.kernel = self.add_weight( shape=self.kernel_shape, @@ -138,28 +152,43 @@ class LocallyConnected1D(Layer): constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - shape=(output_length, self.filters), + shape=(self.output_length, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None - self.input_spec = InputSpec(ndim=3, axes={2: input_dim}) + + if self.data_format == 'channels_first': + self.input_spec = InputSpec(ndim=3, axes={1: input_dim}) + else: + self.input_spec = InputSpec(ndim=3, axes={-1: input_dim}) self.built = True @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): - length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0], + if self.data_format == 'channels_first': + input_length = input_shape[2] + else: + input_length = input_shape[1] + + length = conv_utils.conv_output_length(input_length, self.kernel_size[0], self.padding, self.strides[0]) - return (input_shape[0], length, self.filters) + + if self.data_format == 'channels_first': + return (input_shape[0], self.filters, length) + elif self.data_format == 'channels_last': + return (input_shape[0], length, self.filters) def call(self, inputs): - output = K.local_conv1d(inputs, self.kernel, self.kernel_size, self.strides) + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_length,), self.data_format) + if self.use_bias: - output = K.bias_add(output, self.bias) - if self.activation is not None: - output = self.activation(output) + output = K.bias_add(output, self.bias, data_format=self.data_format) + + output = self.activation(output) return output def get_config(self): @@ -172,6 +201,8 @@ class LocallyConnected1D(Layer): self.strides, 'padding': self.padding, + 'data_format': + self.data_format, 'activation': activations.serialize(self.activation), 'use_bias': @@ -370,9 +401,8 @@ class LocallyConnected2D(Layer): return (input_shape[0], rows, cols, self.filters) def call(self, inputs): - output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides, - (self.output_row, self.output_col), - self.data_format) + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_row, self.output_col), self.data_format) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) diff --git a/tensorflow/python/keras/_impl/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py similarity index 61% rename from tensorflow/python/keras/_impl/keras/layers/local_test.py rename to tensorflow/python/keras/layers/local_test.py index 93741d24b9a74cf9e8a83069f7c4235b1f489818..9639e0251f5a56e4130b13c0185792fe11da2532 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local_test.py +++ b/tensorflow/python/keras/layers/local_test.py @@ -20,15 +20,15 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test class LocallyConnectedLayersTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_1d(self): num_samples = 2 num_steps = 8 @@ -40,16 +40,17 @@ class LocallyConnectedLayersTest(test.TestCase): for strides in [1]: if padding == 'same' and strides != 1: continue - - testing_utils.layer_test( - keras.layers.LocallyConnected1D, - kwargs={ - 'filters': filters, - 'kernel_size': filter_length, - 'padding': padding, - 'strides': strides - }, - input_shape=(num_samples, num_steps, input_dim)) + for data_format in ['channels_first', 'channels_last']: + testing_utils.layer_test( + keras.layers.LocallyConnected1D, + kwargs={ + 'filters': filters, + 'kernel_size': filter_length, + 'padding': padding, + 'strides': strides, + 'data_format': data_format + }, + input_shape=(num_samples, num_steps, input_dim)) def test_locallyconnected_1d_regularization(self): num_samples = 2 @@ -57,37 +58,41 @@ class LocallyConnectedLayersTest(test.TestCase): input_dim = 5 filter_length = 3 filters = 4 - kwargs = { - 'filters': filters, - 'kernel_size': filter_length, - 'kernel_regularizer': 'l2', - 'bias_regularizer': 'l2', - 'activity_regularizer': 'l2', - } - - with self.test_session(): - layer = keras.layers.LocallyConnected1D(**kwargs) - layer.build((num_samples, num_steps, input_dim)) - self.assertEqual(len(layer.losses), 2) - layer( - keras.backend.variable(np.ones((num_samples, num_steps, input_dim)))) - self.assertEqual(len(layer.losses), 3) - - k_constraint = keras.constraints.max_norm(0.01) - b_constraint = keras.constraints.max_norm(0.01) - kwargs = { - 'filters': filters, - 'kernel_size': filter_length, - 'kernel_constraint': k_constraint, - 'bias_constraint': b_constraint, - } - with self.test_session(): - layer = keras.layers.LocallyConnected1D(**kwargs) - layer.build((num_samples, num_steps, input_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) - - @tf_test_util.run_in_graph_and_eager_modes() + for data_format in ['channels_first', 'channels_last']: + kwargs = { + 'filters': filters, + 'kernel_size': filter_length, + 'kernel_regularizer': 'l2', + 'bias_regularizer': 'l2', + 'activity_regularizer': 'l2', + 'data_format': data_format + } + + with self.test_session(): + layer = keras.layers.LocallyConnected1D(**kwargs) + layer.build((num_samples, num_steps, input_dim)) + self.assertEqual(len(layer.losses), 2) + layer( + keras.backend.variable(np.ones((num_samples, + num_steps, + input_dim)))) + self.assertEqual(len(layer.losses), 3) + + k_constraint = keras.constraints.max_norm(0.01) + b_constraint = keras.constraints.max_norm(0.01) + kwargs = { + 'filters': filters, + 'kernel_size': filter_length, + 'kernel_constraint': k_constraint, + 'bias_constraint': b_constraint, + } + with self.test_session(): + layer = keras.layers.LocallyConnected1D(**kwargs) + layer.build((num_samples, num_steps, input_dim)) + self.assertEqual(layer.kernel.constraint, k_constraint) + self.assertEqual(layer.bias.constraint, b_constraint) + + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_2d(self): num_samples = 8 filters = 3 @@ -113,6 +118,7 @@ class LocallyConnectedLayersTest(test.TestCase): }, input_shape=(num_samples, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_2d_channels_first(self): num_samples = 8 filters = 3 @@ -120,15 +126,14 @@ class LocallyConnectedLayersTest(test.TestCase): num_row = 6 num_col = 10 - with self.test_session(): - testing_utils.layer_test( - keras.layers.LocallyConnected2D, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'data_format': 'channels_first' - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + testing_utils.layer_test( + keras.layers.LocallyConnected2D, + kwargs={ + 'filters': filters, + 'kernel_size': 3, + 'data_format': 'channels_first' + }, + input_shape=(num_samples, num_row, num_col, stack_size)) def test_locallyconnected_2d_regularization(self): num_samples = 8 diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/layers/lstm_test.py rename to tensorflow/python/keras/layers/lstm_test.py index 11a5e0aeaacfa7520361ae41ac3d40607e8a9050..ae381f595565cf0d060320354cb32585c1067f72 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/layers/lstm_test.py @@ -20,16 +20,16 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer class LSTMLayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_LSTM(self): num_samples = 2 timesteps = 3 @@ -56,7 +56,7 @@ class LSTMLayerTest(test.TestCase): outputs = model.layers[-1].output self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 @@ -70,7 +70,7 @@ class LSTMLayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_LSTM(self): num_samples = 2 timesteps = 3 @@ -83,7 +83,7 @@ class LSTMLayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_LSTM(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/layers/merge.py rename to tensorflow/python/keras/layers/merge.py index 2b6cf7c8a94ff40ea35e2bbfe13e6c26024857b4..f295af3fe04d87d260e4f6a98762dcfb90883531 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/layers/merge.py @@ -20,9 +20,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine.base_layer import Layer -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -250,6 +250,7 @@ class Add(_Merge): return output +@tf_export('keras.layers.Subtract') class Subtract(_Merge): """Layer that subtracts two inputs. @@ -336,6 +337,7 @@ class Maximum(_Merge): return output +@tf_export('keras.layers.Minimum') class Minimum(_Merge): """Layer that computes the minimum (element-wise) a list of inputs. @@ -446,8 +448,8 @@ class Concatenate(_Merge): class Dot(_Merge): """Layer that computes a dot product between samples in two tensors. - E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`, - the output will be a tensor of shape `(batch_size, 1)` + E.g. if applied to a list of two tensors `a` and `b` of shape + `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)` where each entry `i` will be the dot product between `a[i]` and `b[i]`. @@ -586,6 +588,7 @@ def add(inputs, **kwargs): return Add(**kwargs)(inputs) +@tf_export('keras.layers.subtract') def subtract(inputs, **kwargs): """Functional interface to the `Subtract` layer. @@ -656,6 +659,7 @@ def maximum(inputs, **kwargs): return Maximum(**kwargs)(inputs) +@tf_export('keras.layers.minimum') def minimum(inputs, **kwargs): """Functional interface to the `Minimum` layer. diff --git a/tensorflow/python/keras/_impl/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py similarity index 95% rename from tensorflow/python/keras/_impl/keras/layers/merge_test.py rename to tensorflow/python/keras/layers/merge_test.py index b2fe06f93e33ed63d6a2aa29522ecb552f582440..39bc98d039624d50788e1b7995dc5fba300a5276 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge_test.py +++ b/tensorflow/python/keras/layers/merge_test.py @@ -20,15 +20,15 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops from tensorflow.python.platform import test class MergeLayersTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_add(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -76,7 +76,7 @@ class MergeLayersTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.add([i1]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_multiply(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -92,7 +92,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_average(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -106,7 +106,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_maximum(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -120,7 +120,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_minimum(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -134,7 +134,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_concatenate(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -169,7 +169,7 @@ class MergeLayersTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'called on a list'): keras.layers.concatenate([i1], axis=-1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_dot(self): i1 = keras.layers.Input(shape=(4,)) i2 = keras.layers.Input(shape=(4,)) @@ -215,7 +215,7 @@ class MergeLayersTest(test.TestCase): dot = keras.layers.Dot(1) dot.compute_output_shape(1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_subtract(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/layers/noise.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/layers/noise.py rename to tensorflow/python/keras/layers/noise.py index addac5b137430d8f74efa126423cb39b15382502..cb7cee3ebc3ebd2413836b876f2aaf21985f1d9c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise.py +++ b/tensorflow/python/keras/layers/noise.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py similarity index 90% rename from tensorflow/python/keras/_impl/keras/layers/noise_test.py rename to tensorflow/python/keras/layers/noise_test.py index af4f031ec95bb56b72c1f1018e0e529d8ff55564..aa2be62390b0dcf0656a533cba9bdbe9ceee09dd 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise_test.py +++ b/tensorflow/python/keras/layers/noise_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test @@ -40,7 +40,7 @@ class NoiseLayersTest(test.TestCase): kwargs={'rate': 0.5}, input_shape=(3, 2, 3)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_AlphaDropout(self): testing_utils.layer_test( keras.layers.AlphaDropout, diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/layers/normalization.py rename to tensorflow/python/keras/layers/normalization.py index c16fc07fb4ecda66bd8bcc70dce5d753c73f5dd9..8b894ca6b1c256210bb9ded33ae36da2fc4c001a 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -22,18 +22,19 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util.tf_export import tf_export @@ -182,8 +183,9 @@ class BatchNormalization(Layer): def _add_tower_local_variable(self, *args, **kwargs): tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope('mean'): - return self.add_variable(*args, **kwargs) + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.MEAN): + return self.add_weight(*args, **kwargs) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -276,7 +278,7 @@ class BatchNormalization(Layer): self.axis[idx] = x + 1 # Account for added dimension if self.scale: - self.gamma = self.add_variable( + self.gamma = self.add_weight( name='gamma', shape=param_shape, dtype=param_dtype, @@ -291,7 +293,7 @@ class BatchNormalization(Layer): 1.0, dtype=param_dtype, shape=param_shape) if self.center: - self.beta = self.add_variable( + self.beta = self.add_weight( name='beta', shape=param_shape, dtype=param_dtype, @@ -364,11 +366,12 @@ class BatchNormalization(Layer): def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: - decay = ops.convert_to_tensor(1.0 - momentum, name='decay') - if decay.dtype != variable.dtype.base_dtype: - decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = (variable - value) * decay - return state_ops.assign_sub(variable, update_delta, name=scope) + with ops.colocate_with(variable): + decay = ops.convert_to_tensor(1.0 - momentum, name='decay') + if decay.dtype != variable.dtype.base_dtype: + decay = math_ops.cast(decay, variable.dtype.base_dtype) + update_delta = (variable - value) * decay + return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" @@ -574,28 +577,26 @@ class BatchNormalization(Layer): lambda: variance, lambda: moving_variance) + if self.virtual_batch_size is not None: + # This isn't strictly correct since in ghost batch norm, you are + # supposed to sequentially update the moving_mean and moving_variance + # with each sub-batch. However, since the moving statistics are only + # used during evaluation, it is more efficient to just update in one + # step and should not make a significant difference in the result. + new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) + new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) + else: + new_mean, new_variance = mean, variance + if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - mean, variance, training) + new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) - else: - new_mean, new_variance = mean, variance - - if self.virtual_batch_size is not None: - # This isn't strictly correct since in ghost batch norm, you are - # supposed to sequentially update the moving_mean and moving_variance - # with each sub-batch. However, since the moving statistics are only - # used during evaluation, it is more efficient to just update in one - # step and should not make a significant difference in the result. - new_mean = math_ops.reduce_mean(new_mean, - axis=1, keepdims=True) - new_variance = math_ops.reduce_mean(new_variance, - axis=1, keepdims=True) def _do_update(var, value): if in_eager_mode and not self.trainable: diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/layers/normalization_test.py rename to tensorflow/python/keras/layers/normalization_test.py index 84f0b2776c9980e0bdc00c173b275604ce16697a..b22f3bd1529812f6b5f63efe5cf6b6133db97f07 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/layers/pooling.py rename to tensorflow/python/keras/layers/pooling.py index 86bc8a680a529a9ea17592a42207fab58adeebce..912e8bd619db8b35a54853c0752382479567fd04 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling.py +++ b/tensorflow/python/keras/layers/pooling.py @@ -19,10 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import backend -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.keras import backend +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import conv_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py similarity index 91% rename from tensorflow/python/keras/_impl/keras/layers/pooling_test.py rename to tensorflow/python/keras/layers/pooling_test.py index 2c08b647ea0fafb7519240b0c81e8fa77f034f7f..2cd9939e66ff869dac5058d2dd00d8d495e40f55 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py +++ b/tensorflow/python/keras/layers/pooling_test.py @@ -18,23 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test class GlobalPoolingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_1d(self): testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, input_shape=(3, 4, 5)) testing_utils.layer_test( keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_2d(self): testing_utils.layer_test( keras.layers.pooling.GlobalMaxPooling2D, @@ -53,7 +53,7 @@ class GlobalPoolingTest(test.TestCase): kwargs={'data_format': 'channels_last'}, input_shape=(3, 5, 6, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_3d(self): testing_utils.layer_test( keras.layers.pooling.GlobalMaxPooling3D, @@ -75,7 +75,7 @@ class GlobalPoolingTest(test.TestCase): class Pooling2DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_2d(self): pool_size = (3, 3) for strides in [(1, 1), (2, 2)]: @@ -88,7 +88,7 @@ class Pooling2DTest(test.TestCase): }, input_shape=(3, 5, 6, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_2d(self): testing_utils.layer_test( keras.layers.AveragePooling2D, @@ -122,7 +122,7 @@ class Pooling2DTest(test.TestCase): class Pooling3DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_3d(self): pool_size = (3, 3, 3) testing_utils.layer_test( @@ -141,7 +141,7 @@ class Pooling3DTest(test.TestCase): }, input_shape=(3, 4, 11, 12, 10)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_3d(self): pool_size = (3, 3, 3) testing_utils.layer_test( @@ -163,7 +163,7 @@ class Pooling3DTest(test.TestCase): class Pooling1DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_1d(self): for padding in ['valid', 'same']: for stride in [1, 2]: @@ -173,7 +173,7 @@ class Pooling1DTest(test.TestCase): 'padding': padding}, input_shape=(3, 5, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_1d(self): for padding in ['valid', 'same']: for stride in [1, 2]: diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/layers/recurrent.py rename to tensorflow/python/keras/layers/recurrent.py index 93150b97fa87f5418f541fb211a8671f3a275883..32d25c5a650d3b66d944eee945cafa2d6f54d405 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -24,15 +24,15 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import activations -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -153,7 +153,7 @@ class StackedRNNCells(Layer): @classmethod def from_config(cls, config, custom_objects=None): - from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top cells = [] for cell_config in config.pop('cells'): cells.append( @@ -734,7 +734,7 @@ class RNN(Layer): @classmethod def from_config(cls, config, custom_objects=None): - from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) num_constants = config.pop('num_constants', None) layer = cls(cell, **config) diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/layers/recurrent_test.py rename to tensorflow/python/keras/layers/recurrent_test.py index 4c68c18825a47d87806a7a09d4054f974d569e00..802374d2d28d792c1e32bf5095b928f569144b49 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -23,7 +23,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops diff --git a/tensorflow/python/keras/_impl/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py similarity index 58% rename from tensorflow/python/keras/_impl/keras/layers/serialization.py rename to tensorflow/python/keras/layers/serialization.py index 8151ad7fdddefe08e7af0563bdf27ab335d7d1f8..7c45e08b5c48084cc57569a4d1102a0a7c5b29e1 100644 --- a/tensorflow/python/keras/_impl/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -20,22 +20,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.layers.advanced_activations import * -from tensorflow.python.keras._impl.keras.layers.convolutional import * -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import * -from tensorflow.python.keras._impl.keras.layers.core import * -from tensorflow.python.keras._impl.keras.layers.cudnn_recurrent import * -from tensorflow.python.keras._impl.keras.layers.embeddings import * -from tensorflow.python.keras._impl.keras.layers.local import * -from tensorflow.python.keras._impl.keras.layers.merge import * -from tensorflow.python.keras._impl.keras.layers.noise import * -from tensorflow.python.keras._impl.keras.layers.normalization import * -from tensorflow.python.keras._impl.keras.layers.pooling import * -from tensorflow.python.keras._impl.keras.layers.recurrent import * -from tensorflow.python.keras._impl.keras.layers.wrappers import * -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.layers.advanced_activations import * +from tensorflow.python.keras.layers.convolutional import * +from tensorflow.python.keras.layers.convolutional_recurrent import * +from tensorflow.python.keras.layers.core import * +from tensorflow.python.keras.layers.cudnn_recurrent import * +from tensorflow.python.keras.layers.embeddings import * +from tensorflow.python.keras.layers.local import * +from tensorflow.python.keras.layers.merge import * +from tensorflow.python.keras.layers.noise import * +from tensorflow.python.keras.layers.normalization import * +from tensorflow.python.keras.layers.pooling import * +from tensorflow.python.keras.layers.recurrent import * +from tensorflow.python.keras.layers.wrappers import * +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object def serialize(layer): @@ -53,7 +53,7 @@ def deserialize(config, custom_objects=None): Returns: Layer instance (may be Model, Sequential, Layer...) """ - from tensorflow.python.keras._impl.keras import models # pylint: disable=g-import-not-at-top + from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top globs = globals() # All layers. globs['Model'] = models.Model globs['Sequential'] = models.Sequential diff --git a/tensorflow/python/keras/_impl/keras/layers/serialization_test.py b/tensorflow/python/keras/layers/serialization_test.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/layers/serialization_test.py rename to tensorflow/python/keras/layers/serialization_test.py index 787160d1e71f570479144c5afd45cd41f38f0e91..5872185ef7c30aa50e8ca5aac32cc1804369017c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/serialization_test.py +++ b/tensorflow/python/keras/layers/serialization_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py similarity index 95% rename from tensorflow/python/keras/_impl/keras/layers/simplernn_test.py rename to tensorflow/python/keras/layers/simplernn_test.py index 8c7189cd4718450a85c015e08ab3a58cc5d86531..18fefbe84f6f46f2043c6586ecbc85ea76c55ea0 100644 --- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/layers/simplernn_test.py @@ -20,16 +20,16 @@ from __future__ import print_function import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer class SimpleRNNLayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -41,7 +41,7 @@ class SimpleRNNLayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -55,7 +55,7 @@ class SimpleRNNLayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -68,7 +68,7 @@ class SimpleRNNLayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_SimpleRNN(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py similarity index 92% rename from tensorflow/python/keras/_impl/keras/layers/wrappers.py rename to tensorflow/python/keras/layers/wrappers.py index 7fe57458fbeff451e706342946f93abd43a7f772..e61acf8e771eb8de1c466ffa5e1c4c7f543f77ef 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -22,12 +22,12 @@ from __future__ import print_function import copy from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.layers.recurrent import _standardize_args -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import tf_utils +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.layers.recurrent import _standardize_args +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import tf_export @@ -45,7 +45,9 @@ class Wrapper(Layer): """ def __init__(self, layer, **kwargs): + assert isinstance(layer, Layer) self.layer = layer + self._track_checkpointable(layer, name='layer') # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when # the inner layer has update ops that depend on its inputs (as opposed # to the inputs to the Wrapper layer). @@ -104,7 +106,7 @@ class Wrapper(Layer): @classmethod def from_config(cls, config, custom_objects=None): - from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top layer = deserialize_layer( config.pop('layer'), custom_objects=custom_objects) return cls(layer, **config) @@ -154,9 +156,16 @@ class TimeDistributed(Wrapper): Arguments: layer: a layer instance. + + Raises: + ValueError: If not initialized with a `Layer` instance. """ def __init__(self, layer, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `TimeDistributed` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) super(TimeDistributed, self).__init__(layer, **kwargs) self.supports_masking = True @@ -166,7 +175,10 @@ class TimeDistributed(Wrapper): self.input_spec = InputSpec(shape=input_shape) child_input_shape = [input_shape[0]] + input_shape[2:] if not self.layer.built: - self.layer.build(child_input_shape) + # The base layer class calls a conversion function on the input shape to + # convert it to a TensorShape. The conversion function requires a + # tuple which is why we cast the shape. + self.layer.build(tuple(child_input_shape)) self.layer.built = True super(TimeDistributed, self).build() self.built = True @@ -249,7 +261,8 @@ class Bidirectional(Wrapper): they will be returned as a list. Raises: - ValueError: In case of invalid `merge_mode` argument. + ValueError: If not initialized with a `Layer` instance or + In case of invalid `merge_mode` argument. Examples: @@ -265,6 +278,10 @@ class Bidirectional(Wrapper): """ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `Bidirectional` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: raise ValueError('Invalid merge mode. ' 'Merge mode should be one of ' diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py similarity index 88% rename from tensorflow/python/keras/_impl/keras/layers/wrappers_test.py rename to tensorflow/python/keras/layers/wrappers_test.py index 05b272a470df305e3ba531a149c55e9bab3298da..c8f0d216e6f7a3bb715286bd6e7975a5dc1ac1cc 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -22,9 +22,11 @@ import copy import numpy as np +from tensorflow.python import keras +from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_util from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -69,7 +71,7 @@ class _RNNCellWithConstants(keras.layers.Layer): class TimeDistributedTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( @@ -85,6 +87,10 @@ class TimeDistributedTest(test.TestCase): # test config model.get_config() + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + def test_timedistributed_static_batch_size(self): model = keras.models.Sequential() model.add( @@ -97,6 +103,13 @@ class TimeDistributedTest(test.TestCase): epochs=1, batch_size=10) + def test_timedistributed_invalid_init(self): + x = constant_op.constant(np.zeros((1, 1)).astype('float32')) + with self.assertRaisesRegexp( + ValueError, + 'Please initialize `TimeDistributed` layer with a `Layer` instance.'): + keras.layers.TimeDistributed(x) + def test_timedistributed_conv2d(self): with self.test_session(): model = keras.models.Sequential() @@ -220,6 +233,13 @@ class BidirectionalTest(test.TestCase): model = keras.models.model_from_json(model.to_json()) model.summary() + def test_bidirectional_invalid_init(self): + x = constant_op.constant(np.zeros((1, 1)).astype('float32')) + with self.assertRaisesRegexp( + ValueError, + 'Please initialize `Bidirectional` layer with a `Layer` instance.'): + keras.layers.Bidirectional(x) + def test_bidirectional_weight_loading(self): rnn = keras.layers.SimpleRNN samples = 2 @@ -424,6 +444,42 @@ class BidirectionalTest(test.TestCase): layer.trainable = True assert len(layer.trainable_weights) == 6 + def test_Bidirectional_updates(self): + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + x_reachable_update = x * x + layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) + _ = layer(x) + assert not layer.updates + assert not layer.get_updates_for(None) + assert not layer.get_updates_for(x) + layer.forward_layer.add_update(x_reachable_update, inputs=x) + layer.forward_layer.add_update(1, inputs=None) + layer.backward_layer.add_update(x_reachable_update, inputs=x) + layer.backward_layer.add_update(1, inputs=None) + assert len(layer.updates) == 4 + assert len(layer.get_updates_for(None)) == 2 + assert len(layer.get_updates_for(x)) == 2 + + def test_Bidirectional_losses(self): + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + x_reachable_loss = x * x + layer = keras.layers.Bidirectional( + keras.layers.SimpleRNN( + 3, kernel_regularizer='l1', bias_regularizer='l1')) + _ = layer(x) + assert len(layer.losses) == 4 + assert len(layer.get_losses_for(None)) == 4 + assert not layer.get_losses_for(x) + layer.forward_layer.add_loss(x_reachable_loss, inputs=x) + layer.forward_layer.add_loss(1, inputs=None) + layer.backward_layer.add_loss(x_reachable_loss, inputs=x) + layer.backward_layer.add_loss(1, inputs=None) + assert len(layer.losses) == 8 + assert len(layer.get_losses_for(None)) == 6 + assert len(layer.get_losses_for(x)) == 2 + def test_Bidirectional_with_constants(self): with self.test_session(): # Test basic case. diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/losses.py similarity index 81% rename from tensorflow/python/keras/_impl/keras/losses.py rename to tensorflow/python/keras/losses.py index 1d634d38013164659f7360fce45704c19083f475..9f548bfe0408d5c053c25b9ae14810d582b83e1e 100644 --- a/tensorflow/python/keras/_impl/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -21,28 +21,40 @@ from __future__ import print_function import six -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @tf_export('keras.metrics.mean_squared_error', - 'keras.losses.mean_squared_error') + 'keras.metrics.mse', + 'keras.metrics.MSE', + 'keras.losses.mean_squared_error', + 'keras.losses.mse', + 'keras.losses.MSE') def mean_squared_error(y_true, y_pred): return K.mean(math_ops.square(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_error', - 'keras.losses.mean_absolute_error') + 'keras.metrics.mae', + 'keras.metrics.MAE', + 'keras.losses.mean_absolute_error', + 'keras.losses.mae', + 'keras.losses.MAE') def mean_absolute_error(y_true, y_pred): return K.mean(math_ops.abs(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_percentage_error', - 'keras.losses.mean_absolute_percentage_error') + 'keras.metrics.mape', + 'keras.metrics.MAPE', + 'keras.losses.mean_absolute_percentage_error', + 'keras.losses.mape', + 'keras.losses.MAPE') def mean_absolute_percentage_error(y_true, y_pred): diff = math_ops.abs( (y_true - y_pred) / K.clip(math_ops.abs(y_true), K.epsilon(), None)) @@ -50,7 +62,11 @@ def mean_absolute_percentage_error(y_true, y_pred): @tf_export('keras.metrics.mean_squared_logarithmic_error', - 'keras.losses.mean_squared_logarithmic_error') + 'keras.metrics.msle', + 'keras.metrics.MSLE', + 'keras.losses.mean_squared_logarithmic_error', + 'keras.losses.msle', + 'keras.losses.MSLE') def mean_squared_logarithmic_error(y_true, y_pred): first_log = math_ops.log(K.clip(y_pred, K.epsilon(), None) + 1.) second_log = math_ops.log(K.clip(y_true, K.epsilon(), None) + 1.) @@ -117,7 +133,11 @@ def binary_crossentropy(y_true, y_pred): @tf_export('keras.metrics.kullback_leibler_divergence', - 'keras.losses.kullback_leibler_divergence') + 'keras.metrics.kld', + 'keras.metrics.KLD', + 'keras.losses.kullback_leibler_divergence', + 'keras.losses.kld', + 'keras.losses.KLD') def kullback_leibler_divergence(y_true, y_pred): y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) @@ -129,7 +149,10 @@ def poisson(y_true, y_pred): return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1) -@tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity') +@tf_export('keras.metrics.cosine_proximity', + 'keras.metrics.cosine', + 'keras.losses.cosine_proximity', + 'keras.losses.cosine') def cosine_proximity(y_true, y_pred): y_true = nn.l2_normalize(y_true, axis=-1) y_pred = nn.l2_normalize(y_pred, axis=-1) diff --git a/tensorflow/python/keras/losses/__init__.py b/tensorflow/python/keras/losses/__init__.py deleted file mode 100644 index 66721b694f5fd5fae7ca521ff56d4c6c6bce79b5..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/losses/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in loss functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge - -# Auxiliary utils. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/losses_test.py b/tensorflow/python/keras/losses_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/losses_test.py rename to tensorflow/python/keras/losses_test.py index 1884c0fdca79801ecd7d8cd21dae8b745ed0f6b6..3098a6d071a77ec26a132f445ab16949e90339f2 100644 --- a/tensorflow/python/keras/_impl/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -23,7 +23,7 @@ import shutil import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test try: diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/metrics.py similarity index 72% rename from tensorflow/python/keras/_impl/keras/metrics.py rename to tensorflow/python/keras/metrics.py index 747c3e65157ded6b0d227c6d6667b9092d0eed44..e03d7dfe93585efd06f4701a8d20f61fc314d564 100644 --- a/tensorflow/python/keras/_impl/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -21,22 +21,22 @@ from __future__ import print_function import six -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/metrics/__init__.py b/tensorflow/python/keras/metrics/__init__.py deleted file mode 100644 index 59faf037bce0f087d244a2faaeb52713bdc3b772..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/metrics/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in metrics functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy - -# Auxiliary utils. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/metrics_test.py rename to tensorflow/python/keras/metrics_test.py index 819bf602566fd2737eee447cec463bfb842d1a2a..15e793f5fcf0b416978095da370fbdaabd1490a6 100644 --- a/tensorflow/python/keras/_impl/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py similarity index 83% rename from tensorflow/python/keras/_impl/keras/model_subclassing_test.py rename to tensorflow/python/keras/model_subclassing_test.py index 9efeef360cf9dbca8399eedb9c21f299a1404bef..3ac4852eff6910a9861ae959f990978cea33d595 100644 --- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -23,15 +23,15 @@ import os import numpy as np import six +from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -56,8 +56,8 @@ class SimpleTestModel(keras.Model): if self.use_bn: self.bn = keras.layers.BatchNormalization(axis=-1) - def call(self, inputs): - x = self.dense1(inputs) + def call(self, x): + x = self.dense1(x) if self.use_dp: x = self.dp(x) if self.use_bn: @@ -173,7 +173,7 @@ def get_nested_model_3(input_dim, num_classes): class ModelSubclassingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_np_arrays(self): num_classes = 2 num_samples = 100 @@ -192,7 +192,7 @@ class ModelSubclassingTest(test.TestCase): model.fit(x, y, epochs=2, batch_size=32, verbose=0) _ = model.evaluate(x, y, verbose=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_multi_io_workflow_with_np_arrays(self): num_classes = (2, 3) num_samples = 1000 @@ -251,7 +251,7 @@ class ModelSubclassingTest(test.TestCase): model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0) _ = model.evaluate(steps=10, verbose=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_dataset_iterators(self): num_classes = 2 num_samples = 10 @@ -325,7 +325,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.inputs), 2) self.assertEqual(len(model.outputs), 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_updates(self): # test that updates get run during training num_samples = 100 @@ -352,7 +352,74 @@ class ModelSubclassingTest(test.TestCase): y_new = model.predict(x) self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) - @test_util.run_in_graph_and_eager_modes() + def test_updates_and_losses_for_nested_models_in_subclassed_model(self): + + # Case 1: deferred-build sequential nested in subclass. + class TestModel1(keras.Model): + + def __init__(self): + super(TestModel1, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = keras.Sequential([keras.layers.BatchNormalization(axis=1)]) + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel1() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + # Case 2: placeholder-sequential nested in subclass. + class TestModel2(keras.Model): + + def __init__(self): + super(TestModel2, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = keras.Sequential( + [keras.layers.BatchNormalization(axis=1, input_shape=(10,))]) + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel2() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + # Case 3: functional-API model nested in subclass. + inputs = keras.Input((10,)) + outputs = keras.layers.BatchNormalization(axis=1)(inputs) + bn = keras.Model(inputs, outputs) + + class TestModel3(keras.Model): + + def __init__(self): + super(TestModel3, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = bn + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel3() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + @test_util.run_in_graph_and_eager_modes def test_training_and_inference_behavior(self): # test that dropout is applied in training and not inference @@ -380,7 +447,7 @@ class ModelSubclassingTest(test.TestCase): loss = model.train_on_batch(x, y) self.assertGreater(loss, 0.1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_training_methods(self): # test fit, train_on_batch # on different input types: list, dict @@ -433,14 +500,14 @@ class ModelSubclassingTest(test.TestCase): model = MultiIOTestModel(num_classes=num_classes, use_bn=True) model.predict_on_batch([x1, x2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainable_mutation(self): # test that you can change `trainable` on a model or layer, and that # it freezes the model state during training # TODO(fchollet): add test after we unify BN behavior in eager and symbolic. pass - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_saving(self): num_classes = (2, 3) @@ -482,7 +549,7 @@ class ModelSubclassingTest(test.TestCase): self.assertAllClose(y_ref_1, y1, atol=1e-5) self.assertAllClose(y_ref_2, y2, atol=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_summary(self): class ToString(object): @@ -508,7 +575,7 @@ class ModelSubclassingTest(test.TestCase): model.summary(print_fn=print_fn) self.assertTrue('Trainable params: 587' in print_fn.contents) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_subclass_nested_in_subclass(self): num_classes = 2 num_samples = 100 @@ -531,7 +598,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.trainable_weights), 6 + len(model.test_net.trainable_weights)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_graph_nested_in_subclass(self): num_classes = 2 num_samples = 100 @@ -554,7 +621,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.trainable_weights), 6 + len(model.test_net.trainable_weights)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_subclass_nested_in_graph(self): num_classes = 2 num_samples = 100 @@ -576,7 +643,7 @@ class ModelSubclassingTest(test.TestCase): len(model.non_trainable_weights), 4) self.assertEqual(len(model.trainable_weights), 12) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_support_for_manual_training_arg(self): # In most cases, the `training` argument is left unspecified, in which # case it defaults to value corresponding to the Model method being used @@ -612,8 +679,8 @@ class ModelSubclassingTest(test.TestCase): def __init__(self): super(Foo, self).__init__() self.isdep = keras.layers.Dense(1) - self.notdep = checkpointable.NoDependency(keras.layers.Dense(2)) - self.notdep_var = checkpointable.NoDependency( + self.notdep = data_structures.NoDependency(keras.layers.Dense(2)) + self.notdep_var = data_structures.NoDependency( resource_variable_ops.ResourceVariable(1., name='notdep_var')) m = Foo() @@ -622,6 +689,51 @@ class ModelSubclassingTest(test.TestCase): self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref) self.assertEqual('notdep_var:0', m.notdep_var.name) + def test_extra_variable(self): + + class ExtraVar(keras.Model): + + def __init__(self): + super(ExtraVar, self).__init__() + self.dense = keras.layers.Dense(1) + self.var = resource_variable_ops.ResourceVariable(1.) + self.not_trainable_var = resource_variable_ops.ResourceVariable( + 2., trainable=False) + + def call(self, inputs): + return self.dense(inputs + self.var) + + m = ExtraVar() + self.assertTrue(m.trainable) + self.assertEqual([m.dense], m.layers) + self.assertEqual([m.var, m.not_trainable_var], m.variables) + self.assertEqual([m.var], m.trainable_variables) + self.assertEqual([m.not_trainable_var], m.non_trainable_variables) + m.trainable = False + self.assertEqual([m.var, m.not_trainable_var], m.variables) + self.assertEqual([], m.trainable_variables) + self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables) + m.trainable = True + + m(array_ops.ones([1, 1])) + + self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables) + self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights) + + self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var], + m.variables) + self.assertEqual([m.dense.kernel, m.dense.bias, m.var], + m.trainable_variables) + self.assertEqual([m.not_trainable_var], m.non_trainable_variables) + + m.dense.trainable = False + self.assertEqual( + [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var], + m.variables) + self.assertEqual([m.var], m.trainable_variables) + self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var], + m.non_trainable_variables) + class CustomCallModel(keras.Model): @@ -640,7 +752,7 @@ class CustomCallModel(keras.Model): class CustomCallSignatureTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_no_inputs_in_signature(self): model = CustomCallModel() first = array_ops.ones([2, 3]) @@ -654,7 +766,7 @@ class CustomCallSignatureTests(test.TestCase): output = model(first, second=second, training=False) self.assertAllClose(expected_output, self.evaluate(output)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_inputs_in_signature(self): class HasInputsAndOtherPositional(keras.Model): @@ -671,7 +783,7 @@ class CustomCallSignatureTests(test.TestCase): x1, x2 = keras.Input((1, 1)), keras.Input((1, 1)) model(x1, x2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_kwargs_in_signature(self): class HasKwargs(keras.Model): @@ -685,7 +797,7 @@ class CustomCallSignatureTests(test.TestCase): if not context.executing_eagerly(): six.assertCountEqual(self, [arg], model.inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_args_in_signature(self): class HasArgs(keras.Model): diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/models.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/models.py rename to tensorflow/python/keras/models.py index 9602e7ba39b290f33c7ca9d0d1b5b35838667531..21217fdca14eabaa425903d5370731eb94fdeec6 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -19,14 +19,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine import saving -from tensorflow.python.keras._impl.keras.engine import sequential -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.engine.input_layer import Input -from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine import saving +from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils.generic_utils import has_arg # API entries importable from `keras.models`: diff --git a/tensorflow/python/keras/models/__init__.py b/tensorflow/python/keras/models/__init__.py deleted file mode 100644 index 2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/models/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras models API.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/models_test.py similarity index 78% rename from tensorflow/python/keras/_impl/keras/models_test.py rename to tensorflow/python/keras/models_test.py index 5978ddd987c63b9d87a31be6837172f08512ef73..ad3819e6e730b48e294b340d39fddeb6d7f2d6bf 100644 --- a/tensorflow/python/keras/_impl/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras +from tensorflow.python.framework import test_util from tensorflow.python.platform import test +from tensorflow.python.training import adam class TestModelCloning(test.TestCase): @@ -123,5 +127,36 @@ class TestModelCloning(test.TestCase): keras.models._clone_sequential_model(seq_model, input_tensors=y) +class CheckpointingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_optimizer_dependency(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(4,))) + opt = adam.AdamOptimizer(0.01) + model.compile(optimizer=opt, loss='mse') + model.fit(x=np.array([[1., 2., 3., 4.]]), y=[1.], epochs=2) + save_prefix = os.path.join(self.get_temp_dir(), 'ckpt') + beta1_power, _ = opt._get_beta_accumulators() + self.evaluate(beta1_power.assign(12.)) + model.save_weights(save_prefix) + self.evaluate(beta1_power.assign(13.)) + model.load_weights(save_prefix) + self.assertEqual(12., self.evaluate(beta1_power)) + +class TestModelBackend(test.TestCase): + + def test_model_backend_float64_use_cases(self): + # Test case for GitHub issue 19318 + floatx = keras.backend.floatx() + keras.backend.set_floatx('float64') + + x = keras.Input((5,)) + y = keras.layers.Dense(1)(x) + model = keras.models.Model(x, y) + model.compile('rmsprop', 'mse') + + keras.backend.set_floatx(floatx) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/optimizers.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/optimizers.py rename to tensorflow/python/keras/optimizers.py index 9f383deb725ac69bf2f17f3627010c4e1f567ef0..0b440185ca7ccfc4fadf5419e6ceb4c64a554e1d 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -19,56 +19,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - import six from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export -def clip_norm(g, c, n): - """Clip a tensor by norm. - - Arguments: - g: gradient tensor to clip. - c: clipping threshold. - n: norm of gradient tensor. - - Returns: - Clipped gradient tensor. - """ - if c > 0: - condition = n >= c - then_expression = lambda: math_ops.scalar_mul(c / n, g) - else_expression = lambda: g - - # saving the shape to avoid converting sparse tensor to dense - if isinstance(g, ops.Tensor): - g_shape = copy.copy(g.get_shape()) - elif isinstance(g, ops.IndexedSlices): - g_shape = copy.copy(g.dense_shape) - if condition.dtype != dtypes_module.bool: - condition = math_ops.cast(condition, 'bool') - g = control_flow_ops.cond(condition, then_expression, else_expression) - if isinstance(g, ops.Tensor): - g.set_shape(g_shape) - elif isinstance(g, ops.IndexedSlices): - g._dense_shape = g_shape # pylint: disable=protected-access - return g - - @tf_export('keras.optimizers.Optimizer') class Optimizer(object): """Abstract optimizer base class. @@ -90,6 +56,9 @@ class Optimizer(object): if k not in allowed_kwargs: raise TypeError('Unexpected keyword argument ' 'passed to optimizer: ' + str(k)) + # checks that clipnorm >= 0 and clipvalue >= 0 + if kwargs[k] < 0: + raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k])) self.__dict__.update(kwargs) self.updates = [] self.weights = [] @@ -118,12 +87,13 @@ class Optimizer(object): 'gradient defined (i.e. are differentiable). ' 'Common ops without gradient: ' 'K.argmax, K.round, K.eval.') - if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt( - sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads])) - grads = [clip_norm(g, self.clipnorm, norm) for g in grads] - if hasattr(self, 'clipvalue') and self.clipvalue > 0: - grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] + if hasattr(self, 'clipnorm'): + grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] + if hasattr(self, 'clipvalue'): + grads = [ + clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) + for g in grads + ] return grads def set_weights(self, weights): @@ -718,12 +688,13 @@ class Nadam(Optimizer): return dict(list(base_config.items()) + list(config.items())) -class TFOptimizer(Optimizer): +class TFOptimizer(Optimizer, checkpointable.CheckpointableBase): """Wrapper class for native TensorFlow optimizers. """ def __init__(self, optimizer): # pylint: disable=super-init-not-called self.optimizer = optimizer + self._track_checkpointable(optimizer, name='optimizer') with K.name_scope(self.__class__.__name__): self.iterations = K.variable(0, dtype='int64', name='iterations') diff --git a/tensorflow/python/keras/optimizers/__init__.py b/tensorflow/python/keras/optimizers/__init__.py deleted file mode 100644 index 44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/optimizers/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in optimizers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD - -# Auxiliary utils. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py similarity index 94% rename from tensorflow/python/keras/_impl/keras/optimizers_test.py rename to tensorflow/python/keras/optimizers_test.py index 57636afbf089f27c00cc56c46fdb3ea50f89cc6b..55fc3fdcf47b4e5589e2253fffdc97d33f5b481b 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers_test.py +++ b/tensorflow/python/keras/optimizers_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.adam import AdamOptimizer @@ -145,6 +145,12 @@ class KerasOptimizersTest(test.TestCase): with self.assertRaises(NotImplementedError): optimizer.from_config(None) + def test_negative_clipvalue_or_clipnorm(self): + with self.assertRaises(ValueError): + _ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5) + with self.assertRaises(ValueError): + _ = keras.optimizers.Adam(clipnorm=-2.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py index 8fa3911a7a8833f4b296519c84662cf39ea2dc88..e6704eeaa1f953be68e7ccdbc7e8bd60c62a61d8 100644 --- a/tensorflow/python/keras/preprocessing/__init__.py +++ b/tensorflow/python/keras/preprocessing/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Keras data preprocessing utils.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/preprocessing/image.py rename to tensorflow/python/keras/preprocessing/image.py index 5dfbf0fca5e15c71495e4ace1418ff0de070e86f..aa425df6a8bdb29b90a6d7000d126b771247c19f 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py +++ b/tensorflow/python/keras/preprocessing/image.py @@ -29,8 +29,8 @@ import re import threading import numpy as np -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/preprocessing/image/__init__.py b/tensorflow/python/keras/preprocessing/image/__init__.py deleted file mode 100644 index 6aba5fc8252e1acf604a89a4e66c2a7db080aa73..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/preprocessing/image/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras data preprocessing utils for image data.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_brightness -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/preprocessing/image_test.py rename to tensorflow/python/keras/preprocessing/image_test.py index d2e8ac10ae5399db7b67c00a4a2b0adcdace046f..275808a6155b26159259584653cb48697af9f318 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/preprocessing/image_test.py @@ -24,7 +24,7 @@ import tempfile import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test try: diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/preprocessing/sequence.py rename to tensorflow/python/keras/preprocessing/sequence.py index 49bb0b957a9422e3c1e862b4c8e8d6d6572b2480..e0924f837a79dbdf31bee09667b43f70a1273b4b 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/preprocessing/sequence.py @@ -23,7 +23,7 @@ import random import numpy as np from six.moves import range # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/preprocessing/sequence/__init__.py b/tensorflow/python/keras/preprocessing/sequence/__init__.py deleted file mode 100644 index b7a7149cc40654c878e3c0db1fc78d8912abf498..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/preprocessing/sequence/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras data preprocessing utils for sequence data.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams -from tensorflow.python.keras._impl.keras.preprocessing.sequence import TimeseriesGenerator - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py b/tensorflow/python/keras/preprocessing/sequence_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py rename to tensorflow/python/keras/preprocessing/sequence_test.py index 0e7045f517d44e8d73b08bac7ce499f79d2bf80e..ab6a09106b5f3c8bc340a25ebe3fc82be3f71cd2 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py +++ b/tensorflow/python/keras/preprocessing/sequence_test.py @@ -22,7 +22,7 @@ from math import ceil import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py similarity index 100% rename from tensorflow/python/keras/_impl/keras/preprocessing/text.py rename to tensorflow/python/keras/preprocessing/text.py diff --git a/tensorflow/python/keras/preprocessing/text/__init__.py b/tensorflow/python/keras/preprocessing/text/__init__.py deleted file mode 100644 index 000ad68a0c01e9067f8852836ba5d502deb3fcd4..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/preprocessing/text/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras data preprocessing utils for text data.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.preprocessing.text import hashing_trick -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/preprocessing/text_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/preprocessing/text_test.py rename to tensorflow/python/keras/preprocessing/text_test.py index 6cdc0a70cca86392c0de8d00e58be0c0ecfd6519..566fd3bb1a36392fdf30da4f4a46dd076acdd1e0 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py +++ b/tensorflow/python/keras/preprocessing/text_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/regularizers.py similarity index 92% rename from tensorflow/python/keras/_impl/keras/regularizers.py rename to tensorflow/python/keras/regularizers.py index 74c37d370ea630ca3c3e5e0945828f63928572e1..28b6ad4c65a2919323b81c89de6e5a3d4b5d3ff3 100644 --- a/tensorflow/python/keras/_impl/keras/regularizers.py +++ b/tensorflow/python/keras/regularizers.py @@ -20,9 +20,9 @@ from __future__ import print_function import six -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/regularizers/__init__.py b/tensorflow/python/keras/regularizers/__init__.py deleted file mode 100644 index 3e707ccab577b5e28febd83d91f84d7b1c0d5d82..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/regularizers/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras built-in regularizers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer - -# Functional interface. -# pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 - -# Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/regularizers_test.py rename to tensorflow/python/keras/regularizers_test.py index c4f04833ba51d85c6e174cca0f546133253bddee..e2075785d8061a44da1fbf1b435a15ec6a652e11 100644 --- a/tensorflow/python/keras/_impl/keras/regularizers_test.py +++ b/tensorflow/python/keras/regularizers_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py similarity index 74% rename from tensorflow/python/keras/_impl/keras/testing_utils.py rename to tensorflow/python/keras/testing_utils.py index b8172064c37e5f3fa1cddd68b31e0c201896483c..17aba7d86c236d9bb30d3a3376b3aac40b69e77d 100644 --- a/tensorflow/python/keras/_impl/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -18,10 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict import numpy as np +from tensorflow.python import keras from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl import keras from tensorflow.python.training.rmsprop import RMSPropOptimizer from tensorflow.python.util import tf_inspect @@ -183,3 +184,76 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, # for further checks in the caller function return actual_output + + +def _combine_named_parameters(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = _combine_named_parameters(**rest) + + key = first[0] + values = first[1] + if not isinstance(values, list): + values = [values] + + combinations = [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + return combinations + + +def generate_combinations_with_testcase_name(**kwargs): + """Generate combinations based on its keyword arguments using combine(). + + This function calls combine() and appends a testcase name to the list of + dictionaries returned. The 'testcase_name' key is a required for named + parameterized tests. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + combinations = _combine_named_parameters(**kwargs) + named_combinations = [] + for combination in combinations: + assert isinstance(combination, OrderedDict) + name = ''.join([ + '_{}_{}'.format( + ''.join(filter(str.isalnum, key)), + ''.join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + named_combinations.append( + OrderedDict( + list(combination.items()) + [('testcase_name', + '_test{}'.format(name))])) + + return named_combinations + diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py index 9d924c8c905d69d7081d63983bc4a898f7c71033..69337b6a8d52abd4caf2ada518fde51c407f8103 100644 --- a/tensorflow/python/keras/utils/__init__.py +++ b/tensorflow/python/keras/utils/__init__.py @@ -18,23 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/python/keras/_impl/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/utils/conv_utils.py rename to tensorflow/python/keras/utils/conv_utils.py index 8882a3a46bcb9de7283a67f001e67ed8644a0cf7..5419e7ae0583abcf2e09d0bcc5b9526f2a9969bf 100644 --- a/tensorflow/python/keras/_impl/keras/utils/conv_utils.py +++ b/tensorflow/python/keras/utils/conv_utils.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import range # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras import backend +from tensorflow.python.keras import backend def convert_data_format(data_format, ndim): diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/utils/data_utils.py rename to tensorflow/python/keras/utils/data_utils.py index 4c49544c6a63c4e5a0b79d31b074ad352c512bfa..c1ee34ae467b7037bafa53ea1a9b4b8596917df4 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -39,7 +39,7 @@ from six.moves.urllib.error import HTTPError from six.moves.urllib.error import URLError from six.moves.urllib.request import urlopen -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.util.tf_export import tf_export @@ -324,12 +324,12 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. - Every `Sequence` must implements the `__getitem__` and the `__len__` methods. + Every `Sequence` must implement the `__getitem__` and the `__len__` methods. If you want to modify your dataset between epochs you may implement `on_epoch_end`. The method `__getitem__` should return a complete batch. - # Notes + Notes: `Sequence` are a safer way to do multiprocessing. This structure guarantees that the network will only train once diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py b/tensorflow/python/keras/utils/data_utils_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/utils/data_utils_test.py rename to tensorflow/python/keras/utils/data_utils_test.py index 677e98e871d4a148b13c1aa22696917ed8dc90f9..395df7e0e786d510e785c3ed099905a91e09a149 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py +++ b/tensorflow/python/keras/utils/data_utils_test.py @@ -29,7 +29,7 @@ import numpy as np from six.moves.urllib.parse import urljoin from six.moves.urllib.request import pathname2url -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py similarity index 100% rename from tensorflow/python/keras/_impl/keras/utils/generic_utils.py rename to tensorflow/python/keras/utils/generic_utils.py diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/utils/generic_utils_test.py rename to tensorflow/python/keras/utils/generic_utils_test.py index d57692f4f41753fc38ead2ace7e989b499bc23ff..87bc19eb37d15d35bb8ad0f5d086404f9c4f55ca 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils_test.py +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/utils/io_utils.py rename to tensorflow/python/keras/utils/io_utils.py index f82e3277de70a631c93f0ef3c240f41ddb3390a7..62674a9c77fc410a551d2ac79c22ecf959b16fc3 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py +++ b/tensorflow/python/keras/utils/io_utils.py @@ -102,13 +102,12 @@ class HDF5Matrix(object): idx = (self.start + key).tolist() else: raise IndexError - elif isinstance(key, list): + else: + # Assume list/iterable if max(key) + self.start < self.end: idx = [x + self.start for x in key] else: raise IndexError - else: - raise IndexError if self.normalizer is not None: return self.normalizer(self.data[idx]) else: diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py similarity index 79% rename from tensorflow/python/keras/_impl/keras/utils/io_utils_test.py rename to tensorflow/python/keras/utils/io_utils_test.py index cfeba188d3cadfa08efbd07fcbd46776b691e06f..81bb661edd8d815f8565285ad5dc8126f4f52e98 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py +++ b/tensorflow/python/keras/utils/io_utils_test.py @@ -22,8 +22,9 @@ import os import shutil import numpy as np +import six -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test try: @@ -95,6 +96,29 @@ class TestIOUtils(test.TestCase): self.assertEqual(out_eval.shape, ()) self.assertGreater(out_eval, 0) + # test slicing for shortened array + self.assertEqual(len(x_train[0:]), len(x_train)) + + # test __getitem__ invalid use cases + with self.assertRaises(IndexError): + _ = x_train[1000] + with self.assertRaises(IndexError): + _ = x_train[1000: 1001] + with self.assertRaises(IndexError): + _ = x_train[[1000, 1001]] + with self.assertRaises(IndexError): + _ = x_train[six.moves.range(1000, 1001)] + with self.assertRaises(IndexError): + _ = x_train[np.array([1000])] + with self.assertRaises(TypeError): + _ = x_train[None] + + # test normalizer + normalizer = lambda x: x + 1 + normalized_x_train = keras.utils.io_utils.HDF5Matrix( + h5_path, 'my_data', start=0, end=150, normalizer=normalizer) + self.assertAllClose(normalized_x_train[0][0], x_train[0][0] + 1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py similarity index 73% rename from tensorflow/python/keras/_impl/keras/utils/layer_utils.py rename to tensorflow/python/keras/utils/layer_utils.py index 902972ecbb8fd69a9252b7e19e32bee5e33e4f97..1f28c59ea41a96461a7faba2c41f5e65e6af0180 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -21,11 +21,52 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils.conv_utils import convert_kernel from tensorflow.python.util.tf_export import tf_export +def get_source_inputs(tensor, layer=None, node_index=None): + """Returns the list of input tensors necessary to compute `tensor`. + + Output will always be a list of tensors + (potentially with 1 element). + + Arguments: + tensor: The tensor to start from. + layer: Origin layer of the tensor. Will be + determined via tensor._keras_history if not provided. + node_index: Origin node index of the tensor. + + Returns: + List of input tensors. + """ + if not hasattr(tensor, '_keras_history'): + return tensor + + if layer is None or node_index: + layer, node_index, _ = tensor._keras_history + if not layer._inbound_nodes: + return [tensor] + else: + node = layer._inbound_nodes[node_index] + if not node.inbound_layers: + # Reached an Input layer, stop recursion. + return node.input_tensors + else: + source_tensors = [] + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + previous_sources = get_source_inputs(x, layer, node_index) + # Avoid input redundancy. + for x in previous_sources: + if x not in source_tensors: + source_tensors.append(x) + return source_tensors + + def count_params(weights): """Count the total number of scalars composing the weights. @@ -201,6 +242,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): print_fn('_' * line_length) +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables + + @tf_export('keras.utils.convert_all_kernels_in_model') def convert_all_kernels_in_model(model): """Converts all convolution kernels in a model from Theano to TensorFlow. diff --git a/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py similarity index 95% rename from tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py rename to tensorflow/python/keras/utils/multi_gpu_utils.py index 48c25377270ad68a23832736a5f6499999ead14f..e1c49bc85221aa94241ed746c2063aadf881f3cd 100644 --- a/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.training import Model from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import tf_export @@ -150,8 +150,8 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False): ValueError: if the `gpus` argument does not match available devices. """ # pylint: disable=g-import-not-at-top - from tensorflow.python.keras._impl.keras.layers.core import Lambda - from tensorflow.python.keras._impl.keras.layers.merge import concatenate + from tensorflow.python.keras.layers.core import Lambda + from tensorflow.python.keras.layers.merge import concatenate if isinstance(gpus, (list, tuple)): if len(gpus) <= 1: @@ -196,7 +196,7 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False): batch_size = shape[:1] input_shape = shape[1:] step = batch_size // parts - if i == num_gpus - 1: + if i == parts - 1: size = batch_size - step * i else: size = step @@ -207,7 +207,7 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False): # Relocate the model definition under CPU device scope if needed if cpu_relocation: - from tensorflow.python.keras._impl.keras.models import clone_model # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.models import clone_model # pylint: disable=g-import-not-at-top with ops.device('/cpu:0'): model = clone_model(model) diff --git a/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py similarity index 99% rename from tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils_test.py rename to tensorflow/python/keras/utils/multi_gpu_utils_test.py index 0a38d6b5228fe791ce14adc7e37e0b7a6926fadf..77792d14f53d009c0bfc17273c034c37039106bf 100644 --- a/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils_test.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.python import data -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py similarity index 100% rename from tensorflow/python/keras/_impl/keras/utils/np_utils.py rename to tensorflow/python/keras/utils/np_utils.py diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py b/tensorflow/python/keras/utils/np_utils_test.py similarity index 97% rename from tensorflow/python/keras/_impl/keras/utils/np_utils_test.py rename to tensorflow/python/keras/utils/np_utils_test.py index 1e974c2ef2aee3b6a83ad777673505f8c75b2b58..d77e76ff3ecb70bb19c0485e6e32940554e893d5 100644 --- a/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py +++ b/tensorflow/python/keras/utils/np_utils_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras +from tensorflow.python import keras from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/_impl/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py similarity index 100% rename from tensorflow/python/keras/_impl/keras/utils/tf_utils.py rename to tensorflow/python/keras/utils/tf_utils.py diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py similarity index 96% rename from tensorflow/python/keras/_impl/keras/utils/vis_utils.py rename to tensorflow/python/keras/utils/vis_utils.py index 4761cece82c727e4962d0374f8efb80dfaeac3c6..7a454ac8314acdfa3c3e61c080acdd9efdf3acdc 100644 --- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -65,8 +65,8 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): Returns: A `pydot.Dot` instance representing the Keras model. """ - from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper - from tensorflow.python.keras._impl.keras.models import Sequential + from tensorflow.python.keras.layers.wrappers import Wrapper + from tensorflow.python.keras.models import Sequential _check_pydot() dot = pydot.Dot() @@ -77,7 +77,6 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): if isinstance(model, Sequential): if not model.built: model.build() - model = model.model layers = model.layers # Create graph nodes. diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/wrappers/scikit_learn.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py rename to tensorflow/python/keras/wrappers/scikit_learn.py index 2884dc84cc5d99511947e6f0f97b0bf8a505221f..4462d94ecdb10c6f7306de1f552151e209394bac 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py +++ b/tensorflow/python/keras/wrappers/scikit_learn.py @@ -23,9 +23,9 @@ import types import numpy as np -from tensorflow.python.keras._impl.keras.models import Sequential -from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.models import Sequential +from tensorflow.python.keras.utils.generic_utils import has_arg +from tensorflow.python.keras.utils.np_utils import to_categorical from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/wrappers/scikit_learn/__init__.py b/tensorflow/python/keras/wrappers/scikit_learn/__init__.py deleted file mode 100644 index a46f859273ea0117e29a403057f9f81bc758dd52..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/wrappers/scikit_learn/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras scikit-learn API wrapper.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/wrappers/scikit_learn_test.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/wrappers/scikit_learn_test.py rename to tensorflow/python/keras/wrappers/scikit_learn_test.py index b20a84ee88b5b2b70ca2f718fbe86ffd6e949461..c322efdedf10e961f7590591ba0048e42492aaff 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn_test.py +++ b/tensorflow/python/keras/wrappers/scikit_learn_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras._impl import keras -from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python import keras +from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test INPUT_DIM = 5 diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 83b353600ad47c59759fff81c64fe5e8c2ed7926..6bfd1936e38da0b03bb6a9baba7d899957283349 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "sycl_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") # CPU only tests should use tf_py_test, GPU tests use cuda_py_test # Please avoid the py_tests and cuda_py_tests (plural) while we @@ -892,6 +893,7 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", + "//tensorflow/python:sparse_grad", "//tensorflow/python:sparse_ops", ], ) @@ -2334,6 +2336,9 @@ cuda_py_test( "//tensorflow/python:nn_ops", ], shard_count = 2, + tags = [ + "no_gpu", # Flaky: b/80127739 + ], ) cuda_py_test( @@ -2750,6 +2755,7 @@ cuda_py_test( "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", @@ -3026,3 +3032,79 @@ tf_py_test( "//tensorflow/python/eager:tape", ], ) + +# Custom op tests +tf_custom_op_library( + name = "ackermann_op.so", + srcs = ["ackermann_op.cc"], +) + +tf_py_test( + name = "ackermann_test", + size = "small", + srcs = ["ackermann_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + ], + data = [":ackermann_op.so"], + tags = ["no_pip"], +) + +tf_custom_op_library( + name = "duplicate_op.so", + srcs = ["duplicate_op.cc"], +) + +tf_py_test( + name = "duplicate_op_test", + size = "small", + srcs = ["duplicate_op_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + ], + data = [":duplicate_op.so"], + tags = ["no_pip"], +) + +tf_custom_op_library( + name = "invalid_op.so", + srcs = ["invalid_op.cc"], +) + +tf_py_test( + name = "invalid_op_test", + size = "small", + srcs = ["invalid_op_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + ], + data = [":invalid_op.so"], + tags = ["no_pip"], +) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["cond_v2_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:cond_v2", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:training", + ], + grpc_enabled = True, +) diff --git a/tensorflow/python/kernel_tests/accumulate_n_eager_test.py b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py index dc11b7deceb9040584aca1f629f4d003aef39428..5f516f2c7e6af2d5b77deeebf1d71d3d0fa6be39 100644 --- a/tensorflow/python/kernel_tests/accumulate_n_eager_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py @@ -43,10 +43,9 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) - with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).numpy()) - self.assertAllClose(x[0] * 5, - math_ops.accumulate_n([tf_x[0]] * 5).numpy()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x)) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5)) def testGrad(self): np.random.seed(42) diff --git a/tensorflow/user_ops/ackermann_op.cc b/tensorflow/python/kernel_tests/ackermann_op.cc similarity index 100% rename from tensorflow/user_ops/ackermann_op.cc rename to tensorflow/python/kernel_tests/ackermann_op.cc diff --git a/tensorflow/user_ops/ackermann_test.py b/tensorflow/python/kernel_tests/ackermann_test.py similarity index 76% rename from tensorflow/user_ops/ackermann_test.py rename to tensorflow/python/kernel_tests/ackermann_test.py index 257de498088d1f8a71898e490b8951beb7975b7a..5e0d87c783109b5ec8055e4c975157f3da07bcd4 100644 --- a/tensorflow/user_ops/ackermann_test.py +++ b/tensorflow/python/kernel_tests/ackermann_test.py @@ -17,17 +17,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path +import os -import tensorflow as tf +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test -class AckermannTest(tf.test.TestCase): +class AckermannTest(test.TestCase): def testBasic(self): - library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + library_filename = os.path.join(resource_loader.get_data_files_path(), 'ackermann_op.so') - ackermann = tf.load_op_library(library_filename) + ackermann = load_library.load_op_library(library_filename) self.assertEqual(len(ackermann.OP_LIST.op), 1) self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann') @@ -37,4 +39,4 @@ class AckermannTest(tf.test.TestCase): if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 0c29714522251ede4be6627ef48de9f370a26c7c..40567571e6d259eff3f013c67d1d1f9504fcb9e4 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -998,17 +998,15 @@ class SliceAssignTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(init_val) with self.test_session() as sess: sess.run(v.initializer) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "l-value dtype int32 does not match r-value dtype int64"): + with self.assertRaises(ValueError): sess.run(v[:].assign(too_large_val)) - with self.assertRaises(errors.InvalidArgumentError): + with self.assertRaises(ValueError): sess.run(v[:].assign(too_small_val)) class ShapeSizeRankTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDenseShape(self): t_value = [[0, 42], [24, 0]] self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value))) @@ -1020,7 +1018,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(4, self.evaluate(array_ops.size(t))) self.assertEqual(2, self.evaluate(array_ops.rank(t))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSparseShape(self): sp_value = sparse_tensor.SparseTensorValue( indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2)) @@ -1033,7 +1031,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(4, self.evaluate(array_ops.size(sp))) self.assertEqual(2, self.evaluate(array_ops.rank(sp))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSizeDtype(self): tensor = [1] self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype) @@ -1125,7 +1123,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): class ConcatSliceResourceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConcatSlice(self): r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b") r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c") diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py index 9d54add2644fb9ba6931357dbaa96368952b7486..94ed8ebd31f5874024bb6b0988073ece15d39d87 100644 --- a/tensorflow/python/kernel_tests/as_string_op_test.py +++ b/tensorflow/python/kernel_tests/as_string_op_test.py @@ -130,6 +130,16 @@ class AsStringOpTest(test.TestCase): result = output.eval(feed_dict={input_: int_inputs_}) self.assertAllEqual(s(result), ["%d" % x for x in int_inputs_]) + def testHalfInt(self): + s = lambda strs: [x.decode("ascii") for x in strs] + + with self.test_session(): + input_ = array_ops.placeholder(dtypes.int16) + int_inputs_ = [np.iinfo(np.int16).min, np.iinfo(np.int16).max] + output = string_ops.as_string(input_) + result = output.eval(feed_dict={input_: int_inputs_}) + self.assertAllEqual(s(result), ["%d" % x for x in int_inputs_]) + def testBool(self): bool_inputs_ = [False, True] s = lambda strs: [x.decode("ascii") for x in strs] diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py index 0ef08581c9f931b991ef0c1218dc503345e248c2..b98e5fd3866cde007c6c00ae0cf04b1f1c46c6f2 100644 --- a/tensorflow/python/kernel_tests/atrous_convolution_test.py +++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py @@ -124,7 +124,7 @@ class AtrousConvolutionTest(test.TestCase): x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW") self.assertEqual(y.shape.as_list(), [1, 20, None, None]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution2D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -139,7 +139,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=dilation_rate, ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution3D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -158,7 +158,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=dilation_rate, ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution1D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -173,7 +173,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=[rate], ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolutionNC(self): if test.is_gpu_available(cuda_only=True): # "NCW" and "NCHW" formats are currently supported only on CUDA. @@ -197,7 +197,7 @@ class AtrousConvolutionTest(test.TestCase): data_format="NCHW", ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousSequence(self): """Tests optimization of sequence of atrous convolutions. diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py index 08b03f851803a34dd050721e47471bafd1cd6cac..16fdedac4136d7e53eb66ba060a92b9fd7d58307 100644 --- a/tensorflow/python/kernel_tests/betainc_op_test.py +++ b/tensorflow/python/kernel_tests/betainc_op_test.py @@ -172,7 +172,7 @@ class BetaincTest(test.TestCase): tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s) err = gradient_checker.compute_gradient_error( [tf_gx_s], [gx_s.shape], tf_gout_t, gx_s.shape) - print("betainc gradient err = %g " % err) + tf_logging.info("betainc gradient err = %g " % err) self.assertLess(err, err_tolerance) # Test broadcast gradient @@ -181,7 +181,7 @@ class BetaincTest(test.TestCase): tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s) err = gradient_checker.compute_gradient_error( [tf_gx_s], [()], tf_gout_t, ga_s.shape) - print("betainc gradient err = %g " % err) + tf_logging.info("betainc gradient err = %g " % err) self.assertLess(err, err_tolerance) diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD index 30e6289420b36a75589ef25150480e48f8245ec2..4f92ab0795d231f973f988d9c5b8b39166357c4c 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/BUILD +++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD @@ -52,7 +52,7 @@ tf_py_test( tf_py_test( name = "stats_ops_test", - size = "small", + size = "medium", srcs = ["stats_ops_test.py"], additional_deps = [ "//tensorflow/python:boosted_trees_ops", diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py index 54f33f336015cc9cb50658941b8e157cc1b94df9..4e31b1ea2a796a2e83696d278cf1b4784d177150 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py @@ -792,6 +792,28 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase): class PredictionOpsTest(test_util.TensorFlowTestCase): """Tests prediction ops for inference.""" + def testPredictionOnEmptyEnsemble(self): + """Tests that prediction on a empty ensemble does not fail.""" + with self.test_session() as session: + # Create an empty ensemble. + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto='') + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_0_values = [36, 32] + feature_1_values = [11, 27] + expected_logits = [[0.0], [0.0]] + + # Prediction should work fine. + predict_op = boosted_trees_ops.predict( + tree_ensemble_handle, + bucketized_features=[feature_0_values, feature_1_values], + logits_dimension=1) + + logits = session.run(predict_op) + self.assertAllClose(expected_logits, logits) + def testPredictionMultipleTree(self): """Tests the predictions work when we have multiple trees.""" with self.test_session() as session: @@ -888,12 +910,12 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): feature_1_values = [11, 27] # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = > - # logit = 0.1*5.0+0.2*5.0+1*5 + # logit = 0.1*1.14+0.2*5.0+1*5 # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = > # logit= 0.1*1.14+0.2*7.0-1*7.0 expected_logits = [[6.114], [-5.486]] - # Do with parallelization, e.g. EVAL + # Prediction should work fine. predict_op = boosted_trees_ops.predict( tree_ensemble_handle, bucketized_features=[feature_0_values, feature_1_values], @@ -902,14 +924,147 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): logits = session.run(predict_op) self.assertAllClose(expected_logits, logits) - # Do without parallelization, e.g. INFER - the result is the same - predict_op = boosted_trees_ops.predict( + +class FeatureContribsOpsTest(test_util.TensorFlowTestCase): + """Tests feature contribs ops for model understanding.""" + + def testContribsMultipleTree(self): + """Tests that the contribs work when we have multiple trees.""" + with self.test_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 28 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + original_leaf: {scalar: 2.1} + } + } + nodes { + leaf { + scalar: 1.14 + } + } + nodes { + leaf { + scalar: 8.79 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 26 + left_id: 1 + right_id: 2 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 50 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf: {scalar: 5.5} + } + } + nodes { + leaf { + scalar: 7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + nodes { + leaf { + scalar: 6.0 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 0 + threshold: 34 + left_id: 1 + right_id: 2 + } + } + nodes { + leaf { + scalar: -7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + } + tree_weights: 0.1 + tree_weights: 0.2 + tree_weights: 1.0 + tree_metadata: { + num_layers_grown: 1} + tree_metadata: { + num_layers_grown: 2} + tree_metadata: { + num_layers_grown: 1} + """, tree_ensemble_config) + + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_0_values = [36, 32] + feature_1_values = [13, -29] # Unused. Feature is not in above ensemble. + feature_2_values = [11, 27] + + # Expected logits are computed by traversing the logit path and + # subtracting child logits from parent logits. + bias = 2.1 * 0.1 # Root node of tree_0. + expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0)) + # example_0 : (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114, + # 1.0 * 5.0 + 0.2 * 5. + .114) + # example_1 : (bias, 0.1 * 1.14, 0.2 * 7 + .114, + # 1.0 * -7. + 0.2 * 7 + .114) + expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114), + (bias, 0.114, 1.514, -5.486)) + + bucketized_features = [ + feature_0_values, feature_1_values, feature_2_values + ] + + debug_op = boosted_trees_ops.example_debug_outputs( tree_ensemble_handle, - bucketized_features=[feature_0_values, feature_1_values], + bucketized_features=bucketized_features, logits_dimension=1) - logits = session.run(predict_op) - self.assertAllClose(expected_logits, logits) + serialized_examples_debug_outputs = session.run(debug_op) + feature_ids = [] + logits_paths = [] + for example in serialized_examples_debug_outputs: + example_debug_outputs = boosted_trees_pb2.DebugOutput() + example_debug_outputs.ParseFromString(example) + feature_ids.append(example_debug_outputs.feature_ids) + logits_paths.append(example_debug_outputs.logits_path) + + self.assertAllClose(feature_ids, expected_feature_ids) + self.assertAllClose(logits_paths, expected_logits_paths) if __name__ == '__main__': diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py index 5cceb98cff26ec11137c3f8ff73e8d2e05d009e1..568e695fd590a6fd6e915acb493adb8926e92eaf 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import boosted_trees_ops from tensorflow.python.platform import googletest @@ -388,6 +391,41 @@ class StatsOpsTest(test_util.TensorFlowTestCase): ], result.eval()) + def _verify_precision(self, length): + with self.test_session(): + max_splits = 1 + num_buckets = 1 + node_ids = array_ops.fill([length], 0) + + gradients = constant_op.constant( + 2.0 / length, dtype=dtypes.float32, shape=[length, 1]) + hessians = constant_op.constant( + 0.2 / length, dtype=dtypes.float32, shape=[length, 1]) + + bucketized_features = array_ops.zeros([length], dtype=dtypes.int32) + + result = boosted_trees_ops.make_stats_summary( + node_ids, gradients, hessians, [bucketized_features], max_splits, + num_buckets) # shape=[max_splits, num_buckets, num_features, 2] + + self.assertAllClose([[[[2., 0.2]]]], result.eval()) + + def testMakeStatsSummaryNumericalPrecisionSmallBatch(self): + """Tests numeric precision.""" + self._verify_precision(length=2000) + + def testMakeStatsSummaryNumericalPrecisionMediumBatch(self): + """Tests numeric precision.""" + self._verify_precision(length=100000) + + def testMakeStatsSummaryNumericalPrecisionLargeBatch(self): + """Tests numeric precision.""" + self._verify_precision(length=1000000) + + def testMakeStatsSummaryNumericalPrecisionMegaBatch(self): + """Tests numeric precision.""" + self._verify_precision(length=50000000) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py index 13b804875e94a9f8acc9c441ba2525876a3ef58f..d55240297a8b972ea926186c2fa38da5da780612 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py @@ -139,6 +139,49 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(new_stamp, 1) self.assertProtoEquals(expected_result, tree_ensemble) + def testBiasCenteringOnEmptyEnsemble(self): + """Test growing with bias centering on an empty ensemble.""" + with self.test_session() as session: + # Create empty ensemble. + tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + gradients = np.array([[5.]], dtype=np.float32) + hessians = np.array([[24.]], dtype=np.float32) + + # Grow tree ensemble. + grow_op = boosted_trees_ops.center_bias( + tree_ensemble_handle, + mean_gradients=gradients, + mean_hessians=hessians, + l1=0.0, + l2=1.0 + ) + session.run(grow_op) + + new_stamp, serialized = session.run(tree_ensemble.serialize()) + + tree_ensemble = boosted_trees_pb2.TreeEnsemble() + tree_ensemble.ParseFromString(serialized) + + expected_result = """ + trees { + nodes { + leaf { + scalar: -0.2 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + """ + self.assertEqual(new_stamp, 1) + self.assertProtoEquals(expected_result, tree_ensemble) + def testGrowExistingEnsembleTreeNotFinalized(self): """Test growing an existing ensemble with the last tree not finalized.""" with self.test_session() as session: @@ -666,7 +709,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): num_layers_attempted: 1 last_layer_node_start: 1 last_layer_node_end: 3 - } """, tree_ensemble_config) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 5a83ec8d302b4c26aef7abfa7465eb9fd0cca019..bda6ca5ca91ab1f55c4586f604a116a9b3fed874 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -34,45 +34,45 @@ from tensorflow.python.platform import test class AssertProperIterableTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_tensor_raises(self): tensor = constant_op.constant(1) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(tensor) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_sparse_tensor_raises(self): ten = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(ten) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_ndarray_raises(self): array = np.array([1, 2, 3]) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(array) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_string_raises(self): mystr = "hello" with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(mystr) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_non_iterable_object_raises(self): non_iterable = 1234 with self.assertRaisesRegexp(TypeError, "to be iterable"): check_ops.assert_proper_iterable(non_iterable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_list_does_not_raise(self): list_of_stuff = [ constant_op.constant([11, 22]), constant_op.constant([1, 2]) ] check_ops.assert_proper_iterable(list_of_stuff) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_generator_does_not_raise(self): generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant( [1, 2])) @@ -81,20 +81,27 @@ class AssertProperIterableTest(test.TestCase): class AssertEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies([check_ops.assert_equal(small, small)]): out = array_ops.identity(small) self.evaluate(out) + @test_util.run_in_graph_and_eager_modes + def test_scalar_comparison(self): + const_true = constant_op.constant(True, name="true") + const_false = constant_op.constant(False, name="false") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(const_true, const_false, message="fail") + def test_returns_none_with_eager(self): with context.eager_mode(): small = constant_op.constant([1, 2], name="small") x = check_ops.assert_equal(small, small) assert x is None - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): # Static check static_small = constant_op.constant([1, 2], name="small") @@ -172,7 +179,7 @@ First 2 elements of y: check_ops.assert_equal(big, small, message="big does not equal small", summarize=2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): # Static check static_small = constant_op.constant([3, 1], name="small") @@ -189,7 +196,7 @@ First 2 elements of y: with self.assertRaisesOpError("small.*big"): out.eval(feed_dict={small: [3, 1], big: [4, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): small = constant_op.constant([[1, 2], [1, 2]], name="small") small_2 = constant_op.constant([1, 2], name="small_2") @@ -197,7 +204,7 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") small_2 = constant_op.constant([1, 1], name="small_2") @@ -212,13 +219,13 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_and_broadcastable_shapes(self): cond = constant_op.constant([True, False], name="small") with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): check_ops.assert_equal(cond, False, message="fail") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -229,7 +236,7 @@ First 2 elements of y: class AssertNoneEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_not_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([10, 20], name="small") @@ -238,7 +245,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([3, 1], name="small") with self.assertRaisesOpError("x != y did not hold"): @@ -247,7 +254,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3], name="big") @@ -256,7 +263,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_but_non_broadcastable_shapes(self): with self.test_session(): small = constant_op.constant([1, 1, 1], name="small") @@ -273,7 +280,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): with self.test_session(): larry = constant_op.constant([]) @@ -293,7 +300,7 @@ class AssertNoneEqualTest(test.TestCase): class AssertAllCloseTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): x = constant_op.constant(1., name="x") y = constant_op.constant(1., name="y") @@ -302,7 +309,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self): eps = np.finfo(np.float32).eps # Default rtol/atol is 10*eps @@ -313,7 +320,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self): eps = np.finfo(np.float32).eps # Default rtol/atol is 10*eps @@ -324,7 +331,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self): eps = np.finfo(np.float64).eps # Default rtol/atol is 10*eps @@ -335,7 +342,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self): eps = np.finfo(np.float64).eps # Default rtol/atol is 10*eps @@ -346,7 +353,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self): x = constant_op.constant(1., name="x") y = constant_op.constant(1.1, name="y") @@ -356,7 +363,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_due_to_custom_atol(self): x = constant_op.constant(0., name="x") y = constant_op.constant(0.1, name="y", dtype=np.float32) @@ -366,7 +373,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -374,7 +381,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(larry) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_atol_violated(self): x = constant_op.constant(10., name="x") y = constant_op.constant(10.2, name="y") @@ -385,7 +392,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_default_rtol_violated(self): x = constant_op.constant(0.1, name="x") y = constant_op.constant(0.0, name="y") @@ -405,7 +412,7 @@ class AssertAllCloseTest(test.TestCase): class AssertLessTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([1, 2], name="small") with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"): @@ -415,7 +422,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -424,7 +431,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([4, 2], name="big") @@ -432,7 +439,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 2], name="big") @@ -440,7 +447,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([3, 2], name="big") @@ -455,7 +462,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -473,7 +480,7 @@ class AssertLessTest(test.TestCase): class AssertLessEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( @@ -481,7 +488,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -492,7 +499,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 2], name="big") @@ -500,7 +507,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 1], name="big") @@ -508,7 +515,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([1, 1, 1], name="big") @@ -524,7 +531,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -536,7 +543,7 @@ class AssertLessEqualTest(test.TestCase): class AssertGreaterTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([1, 2], name="small") with self.assertRaisesOpError("fail"): @@ -546,7 +553,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -555,7 +562,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(big) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([4, 2], name="big") @@ -563,7 +570,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 2], name="big") @@ -571,7 +578,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([3, 2], name="big") @@ -586,7 +593,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -597,7 +604,7 @@ class AssertGreaterTest(test.TestCase): class AssertGreaterEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( @@ -605,7 +612,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -616,7 +623,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 2], name="big") @@ -625,7 +632,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 1], name="big") @@ -634,7 +641,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="big") big = constant_op.constant([3, 1], name="small") @@ -650,7 +657,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -662,14 +669,14 @@ class AssertGreaterEqualTest(test.TestCase): class AssertNegativeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_negative(self): frank = constant_op.constant([-1, -2], name="frank") with ops.control_dependencies([check_ops.assert_negative(frank)]): out = array_ops.identity(frank) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_positive(self): doug = constant_op.constant([1, 2], name="doug") with self.assertRaisesOpError("fail"): @@ -679,7 +686,7 @@ class AssertNegativeTest(test.TestCase): out = array_ops.identity(doug) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_zero(self): claire = constant_op.constant([0], name="claire") with self.assertRaisesOpError("x < 0 did not hold"): @@ -687,7 +694,7 @@ class AssertNegativeTest(test.TestCase): out = array_ops.identity(claire) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is negative when it satisfies: # For every element x_i in x, x_i < 0 @@ -701,7 +708,7 @@ class AssertNegativeTest(test.TestCase): class AssertPositiveTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_negative(self): freddie = constant_op.constant([-1, -2], name="freddie") with self.assertRaisesOpError("fail"): @@ -711,14 +718,14 @@ class AssertPositiveTest(test.TestCase): out = array_ops.identity(freddie) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_positive(self): remmy = constant_op.constant([1, 2], name="remmy") with ops.control_dependencies([check_ops.assert_positive(remmy)]): out = array_ops.identity(remmy) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_zero(self): meechum = constant_op.constant([0], name="meechum") with self.assertRaisesOpError("x > 0 did not hold"): @@ -726,7 +733,7 @@ class AssertPositiveTest(test.TestCase): out = array_ops.identity(meechum) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is positive when it satisfies: # For every element x_i in x, x_i > 0 @@ -740,7 +747,7 @@ class AssertPositiveTest(test.TestCase): class AssertRankTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 1 @@ -761,7 +768,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("fail.*my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 0 @@ -777,7 +784,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 0 @@ -795,7 +802,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 1 @@ -811,7 +818,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 2 @@ -829,7 +836,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_scalar_static(self): tensor = constant_op.constant([1, 2], name="my_tensor") with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"): @@ -845,7 +852,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, rank_tensor)]): array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_integer_static(self): tensor = constant_op.constant([1, 2], name="my_tensor") with self.assertRaisesRegexp(TypeError, @@ -866,7 +873,7 @@ class AssertRankTest(test.TestCase): class AssertRankInTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self): tensor_rank0 = constant_op.constant(42, name="my_tensor") with self.assertRaisesRegexp( @@ -883,7 +890,7 @@ class AssertRankInTest(test.TestCase): with self.assertRaisesOpError("fail.*my_tensor.*rank"): array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self): tensor_rank0 = constant_op.constant(42, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): @@ -899,7 +906,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self): tensor_rank1 = constant_op.constant([42, 43], name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): @@ -917,7 +924,7 @@ class AssertRankInTest(test.TestCase): tensor_rank1: (42.0, 43.0) }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self): tensor_rank1 = constant_op.constant((42, 43), name="my_tensor") with self.assertRaisesRegexp(ValueError, "rank"): @@ -935,7 +942,7 @@ class AssertRankInTest(test.TestCase): tensor_rank1: (42.0, 43.0) }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_scalar_static(self): tensor = constant_op.constant((42, 43), name="my_tensor") desired_ranks = ( @@ -959,7 +966,7 @@ class AssertRankInTest(test.TestCase): desired_ranks[1]: [2, 1], }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_integer_static(self): tensor = constant_op.constant((42, 43), name="my_tensor") with self.assertRaisesRegexp(TypeError, @@ -980,7 +987,7 @@ class AssertRankInTest(test.TestCase): class AssertRankAtLeastTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 1 @@ -998,7 +1005,7 @@ class AssertRankAtLeastTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 0 @@ -1014,7 +1021,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 0 @@ -1030,7 +1037,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 1 @@ -1046,7 +1053,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 2 @@ -1067,7 +1074,7 @@ class AssertRankAtLeastTest(test.TestCase): class AssertNonNegativeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_negative(self): zoe = constant_op.constant([-1, -2], name="zoe") with self.assertRaisesOpError("x >= 0 did not hold"): @@ -1075,14 +1082,14 @@ class AssertNonNegativeTest(test.TestCase): out = array_ops.identity(zoe) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_zero_and_positive(self): lucas = constant_op.constant([0, 2], name="lucas") with ops.control_dependencies([check_ops.assert_non_negative(lucas)]): out = array_ops.identity(lucas) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is non-negative when it satisfies: # For every element x_i in x, x_i >= 0 @@ -1096,14 +1103,14 @@ class AssertNonNegativeTest(test.TestCase): class AssertNonPositiveTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_zero_and_negative(self): tom = constant_op.constant([0, -2], name="tom") with ops.control_dependencies([check_ops.assert_non_positive(tom)]): out = array_ops.identity(tom) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_positive(self): rachel = constant_op.constant([0, 2], name="rachel") with self.assertRaisesOpError("x <= 0 did not hold"): @@ -1111,7 +1118,7 @@ class AssertNonPositiveTest(test.TestCase): out = array_ops.identity(rachel) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is non-positive when it satisfies: # For every element x_i in x, x_i <= 0 @@ -1125,14 +1132,14 @@ class AssertNonPositiveTest(test.TestCase): class AssertIntegerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_integer(self): integers = constant_op.constant([1, 2], name="integers") with ops.control_dependencies([check_ops.assert_integer(integers)]): out = array_ops.identity(integers) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_float(self): floats = constant_op.constant([1.0, 2.0], name="floats") with self.assertRaisesRegexp(TypeError, "Expected.*integer"): @@ -1141,7 +1148,7 @@ class AssertIntegerTest(test.TestCase): class AssertTypeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_correct_type(self): integers = constant_op.constant([1, 2], dtype=dtypes.int64) with ops.control_dependencies([ @@ -1149,7 +1156,7 @@ class AssertTypeTest(test.TestCase): out = array_ops.identity(integers) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_wrong_type(self): floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16) with self.assertRaisesRegexp(TypeError, "must be of type.*float32"): @@ -1158,74 +1165,74 @@ class AssertTypeTest(test.TestCase): class IsStrictlyIncreasingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_constant_tensor_is_not_strictly_increasing(self): self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decreasing_tensor_is_not_strictly_increasing(self): self.assertFalse(self.evaluate( check_ops.is_strictly_increasing([1, 0, -1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_2d_decreasing_tensor_is_not_strictly_increasing(self): self.assertFalse( self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_tensor_is_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_two_tensor(self): self.assertTrue( self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensor_with_one_element_is_strictly_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_is_strictly_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([]))) class IsNonDecreasingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_constant_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decreasing_tensor_is_not_non_decreasing(self): self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_2d_decreasing_tensor_is_not_non_decreasing(self): self.assertFalse(self.evaluate( check_ops.is_non_decreasing([[1, 3], [2, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_one_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_two_tensor(self): self.assertTrue(self.evaluate( check_ops.is_non_decreasing([[-1, 2], [3, 3]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensor_with_one_element_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([]))) class FloatDTypeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_assert_same_float_dtype(self): self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype(None, None)) @@ -1279,7 +1286,7 @@ class FloatDTypeTest(test.TestCase): class AssertScalarTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_assert_scalar(self): check_ops.assert_scalar(constant_op.constant(3)) check_ops.assert_scalar(constant_op.constant("foo")) diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index e08123b0417912c479476d8147d832d1715b8882..fb52d10475fa47f37b1ee7de97b49878b5d13341 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.platform import test @@ -414,6 +417,16 @@ class ClipTest(test.TestCase): self.assertAllClose(np_ans, tf_ans) + def testClipByValueEmptyTensor(self): + # Test case for GitHub issue 19337 + zero = array_ops.placeholder(dtype=dtypes.float32, shape=None) + x = clip_ops.clip_by_value(zero, zero, zero) + y = clip_ops.clip_by_value(zero, 1.0, 1.0) + z = clip_ops.clip_by_value(zero, zero, 1.0) + w = clip_ops.clip_by_value(zero, 1.0, zero) + with self.test_session(use_gpu=True) as sess: + sess.run([x, y, z, w], feed_dict={zero: np.zeros((7, 0))}) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..759db5d5f43a144150918446e6ce206b3095904f --- /dev/null +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -0,0 +1,536 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2 +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver +from tensorflow.python.util import compat + + +class NewCondTest(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals): + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: True}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: False}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNoInputs(self): + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + + def true_fn(): + return constant_op.constant(1.0) + + def false_fn(): + return constant_op.constant(2.0) + + out = cond_v2.cond_v2(pred, true_fn, false_fn) + + self.assertEqual(sess.run(out, {pred: True}), [1.0]) + self.assertEqual(sess.run(out, {pred: False}), [2.0]) + + def _createCond(self, name): + pred = constant_op.constant(True, name="pred") + x = constant_op.constant(1.0, name="x") + + def true_fn(): + return x + + def false_fn(): + return x + 1 + + return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op + + def testDefaultName(self): + with ops.Graph().as_default(): + cond = self._createCond(None) + self.assertEqual(cond.name, "cond") + self.assertIn("cond_true", ops.get_default_graph()._functions) + self.assertIn("cond_false", ops.get_default_graph()._functions) + + with ops.Graph().as_default(): + with ops.name_scope("foo"): + cond = self._createCond("") + self.assertEqual(cond.name, "foo/cond") + self.assertIn("foo_cond_true", ops.get_default_graph()._functions) + self.assertIn("foo_cond_false", ops.get_default_graph()._functions) + + cond2 = self._createCond(None) + self.assertEqual(cond2.name, "foo/cond_1") + self.assertIn("foo_cond_1_true", ops.get_default_graph()._functions) + self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions) + + def testSecondDerivative(self): + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testGradientOfDeserializedCond(self): + with ops.Graph().as_default(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + ops.add_to_collection("x", x) + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + ops.add_to_collection("pred", pred) + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + for c in cond: + ops.add_to_collection("cond", c) + meta_graph = saver.export_meta_graph() + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + saver.import_meta_graph(meta_graph) + x = ops.get_collection("x")[0] + pred = ops.get_collection("pred")[0] + cond = ops.get_collection("cond") + cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") + cond_grad_grad = gradients_impl.gradients( + cond_grad, [x], name="cond_grad_grad") + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testLowering(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + out_cond = self._createCond("cond") + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # If lowering was enabled, there should be a `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue(switch_found, + "A `Switch` op should exist if the graph was lowered.") + + # If lowering was enabled, there should be no `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse(if_found, + "An `If` op was found, but it should be lowered.") + + def testLoweringDisabledInXLA(self): + with self.test_session(graph=ops.Graph()) as sess: + # Build the cond_v2 in an XLA context + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + out_cond = self._createCond("cond") + xla_context.Exit() + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # Lowering disabled in XLA, there should be no `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse( + switch_found, + "A `Switch` op exists, but the graph should not be lowered.") + + # Lowering disabled in XLA, there should still be an `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue( + if_found, + "An `If` op was not found, but the graph should not be lowered.") + + +class CondV2CollectionTest(test.TestCase): + + def testCollectionIntValueAccessInCond(self): + """Read values from graph collections inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = 2 + y = 5 + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + def fn(): + x_const = constant_op.constant(ops.get_collection("x")[0]) + y_const = constant_op.constant(ops.get_collection("y")[0]) + return math_ops.add(x_const, y_const) + + cnd = cond_v2.cond_v2(True, fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionTensorValueAccessInCond(self): + """Read tensors from collections inside of cond_v2 & use them.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + + def fn(): + x_read = ops.get_collection("x")[0] + y_read = ops.get_collection("y")[0] + return math_ops.add(x_read, y_read) + + cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionIntValueWriteInCond(self): + """Make sure Int writes to collections work inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + def true_fn(): + z = math_ops.add(x, y) + ops.add_to_collection("z", 7) + return math_ops.mul(x, z) + + def false_fn(): + z = math_ops.add(x, y) + return math_ops.mul(x, z) + + cnd = cond_v2.cond_v2( + True, true_fn, + false_fn) + self.assertEquals(cnd[0].eval(), 14) + + read_z_collection = ops.get_collection("z") + self.assertEquals(read_z_collection, [7]) + + +class CondV2ContainerTest(test.TestCase): + + def testContainer(self): + """Set containers outside & inside of cond_v2. + + Make sure the containers are set correctly for both variable creation + (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) + """ + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + v0 = variables.Variable([0]) + q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + def container(node): + return node.op.get_attr("container") + + self.assertEqual(compat.as_bytes(""), container(v0)) + self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) + + def true_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2t"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2t"), container(v2)) + self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(2.0) + + def false_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2f"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2f"), container(v2)) + self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(6.0) + + with ops.container("l1"): + cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) + self.assertEquals(cnd_true[0].eval(), 2) + + cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) + self.assertEquals(cnd_false[0].eval(), 6) + + v4 = variables.Variable([3]) + q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) + v5 = variables.Variable([4]) + q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v4)) + self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) + self.assertEqual(compat.as_bytes(""), container(v5)) + self.assertEqual(compat.as_bytes(""), container(q5.queue_ref)) + + +class CondV2ColocationGroupAndDeviceTest(test.TestCase): + + def testColocateWithBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + with ops.colocate_with(b.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testColocateWithInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn2(): + with ops.colocate_with(b.op): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant([2.0], name="d") + self.assertEqual([b"loc:@a"], d.op.colocation_groups()) + + def testColocateWithInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + with ops.device("/device:CPU:1"): + b = constant_op.constant([2.0], name="b") + + def fn(): + with ops.colocate_with(b.op): + c = math_ops.add(a, a, name="c") + return c + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + # We expect there to be two partitions because of the + # colocate_with. We are only running the cond, which has a data + # dependency on `a` but not on `b`. So, without the colocate_with + # we would expect execution on just one device. + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + + def testDeviceBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn(): + c = constant_op.constant(3.0) + self.assertEqual("/device:CPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:GPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testDeviceInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn2(): + with ops.device("/device:GPU:0"): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant(4.0) + self.assertEqual("/device:CPU:0", d.op.device) + + def testDeviceInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + def fn(): + with ops.device("/device:CPU:1"): + c = math_ops.add(a, a, name="c") + return c + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index 79e419867d70071280b7c88b6bfa820b935b24cd..ae6875340e776fc6808be3f4afeb59644245c886 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import test class ConfusionMatrixTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExample(self): """This is a test of the example provided in pydoc.""" with self.test_session(): diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py index 8e9d75667d49bf9e377ccb9290a3a91786b5a1cb..a0d5557b925162b254e34e9fc0971393ec119059 100644 --- a/tensorflow/python/kernel_tests/constant_op_eager_test.py +++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py @@ -32,6 +32,9 @@ from tensorflow.python.util import compat # TODO(josh11b): add tests with lists/tuples, Shape. +# TODO(ashankar): Collapse with tests in constant_op_test.py and use something +# like the test_util.run_in_graph_and_eager_modes decorator to confirm +# equivalence between graph and eager execution. class ConstantTest(test.TestCase): def _testCpu(self, x): @@ -280,6 +283,34 @@ class ConstantTest(test.TestCase): with self.assertRaisesRegexp(ValueError, None): constant_op.constant([[1, 2], [3], [4, 5]]) + # TODO(ashankar): This test fails with graph construction since + # tensor_util.make_tensor_proto (invoked from constant_op.constant) + # does not handle iterables (it relies on numpy conversion). + # For consistency, should graph construction handle Python objects + # that implement the sequence protocol (but not numpy conversion), + # or should eager execution fail on such sequences? + def testCustomSequence(self): + + # This is inspired by how many objects in pandas are implemented: + # - They implement the Python sequence protocol + # - But may raise a KeyError on __getitem__(self, 0) + # See https://github.com/tensorflow/tensorflow/issues/20347 + class MySeq(object): + + def __getitem__(self, key): + if key != 1 and key != 3: + raise KeyError(key) + return key + + def __len__(self): + return 2 + + def __iter__(self): + l = list([1, 3]) + return l.__iter__() + + self.assertAllEqual([1, 3], self.evaluate(constant_op.constant(MySeq()))) + class AsTensorTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index a291bef0ad6f16184ff29f665457a53b77447d54..474d06b8f3a4276c65711d74ba0d1db6fb06cbf9 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -312,8 +312,8 @@ class Conv2DTest(test.TestCase): expected_values = self.evaluate(expected_results) computed_values = self.evaluate(computed_results) for e_value, c_value in zip(expected_values, computed_values): - print("expected = ", e_value) - print("actual = ", c_value) + tf_logging.info("expected = ", e_value) + tf_logging.info("actual = ", c_value) self.assertAllClose( e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) @@ -337,15 +337,15 @@ class Conv2DTest(test.TestCase): for i in range(len(tensors)): conv = tensors[i] value = values[i] - print("expected = ", expected) - print("actual = ", value) + tf_logging.info("expected = ", expected) + tf_logging.info("actual = ", value) tol = 1e-5 if value.dtype == np.float16: tol = 1e-3 self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol) self.assertShapeEqual(value, conv) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x1Filter(self): expected_output = [ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, @@ -358,7 +358,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter2x1Dilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 4, 4, 1], @@ -367,7 +367,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmpty(self): expected_output = [] self._VerifyValues( @@ -377,7 +377,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[0, 2, 3, 3], @@ -386,7 +386,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] @@ -397,7 +397,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 2, 3, 3], @@ -406,7 +406,7 @@ class Conv2DTest(test.TestCase): dilations=[1, 2], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [ @@ -420,7 +420,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x2FilterDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 2, 3, 3], @@ -429,7 +429,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride2(self): expected_output = [2271.0, 2367.0, 2463.0] self._VerifyValues( @@ -439,7 +439,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride2Same(self): expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] self._VerifyValues( @@ -449,7 +449,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride1x2(self): expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0] self._VerifyValues( @@ -459,7 +459,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSmallerThanStrideValid(self): expected_output = [65, 95, 275, 305] self._VerifyValues( @@ -469,7 +469,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSmallerThanStrideSame(self): self._VerifyValues( tensor_in_sizes=[1, 3, 3, 1], @@ -492,7 +492,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=[44, 28, 41, 16]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSize(self): self._VerifyValues( tensor_in_sizes=[1, 2, 2, 1], @@ -501,7 +501,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=[50, 60]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 3, 3, 1], @@ -547,8 +547,8 @@ class Conv2DTest(test.TestCase): # "values" consists of two tensors for two backprops value = self.evaluate(conv) self.assertShapeEqual(value, conv) - print("expected = ", expected) - print("actual = ", value) + tf_logging.info("expected = ", expected) + tf_logging.info("actual = ", value) self.assertArrayNear(expected, value.flatten(), err) def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes, @@ -587,9 +587,9 @@ class Conv2DTest(test.TestCase): values.append(_GetVal(data_format, use_gpu)) for i in range(1, len(values)): - self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) + self.assertAllClose(values[0], values[i], rtol=1e-2, atol=1e-2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth1ValidBackpropInput(self): expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -604,7 +604,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyBackpropInput(self): expected_output = [] for (data_format, use_gpu) in GetTestConfigs(): @@ -619,7 +619,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropInput(self): expected_output = [ 14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0, @@ -639,7 +639,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropInputStride1x2(self): expected_output = [ 1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0, @@ -657,7 +657,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DStrideTwoFilterOneSameBackpropInput(self): expected_output = [ 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0, @@ -675,7 +675,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeBackpropInput(self): expected_output = [5.0, 11.0, 17.0, 23.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -723,8 +723,8 @@ class Conv2DTest(test.TestCase): data_format=data_format) value = self.evaluate(conv) self.assertShapeEqual(value, conv) - print("expected = ", expected) - print("actual = ", value) + tf_logging.info("expected = ", expected) + tf_logging.info("actual = ", value) self.assertArrayNear(expected, value.flatten(), 1e-5) def _CompareBackFilter(self, input_sizes, filter_sizes, output_sizes, @@ -759,7 +759,7 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth1ValidBackpropFilter(self): expected = [5.0, 8.0, 14.0, 17.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -773,7 +773,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyBackpropFilter(self): expected = [] for (data_format, use_gpu) in GetTestConfigs(): @@ -787,7 +787,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DBackpropFilterWithEmptyInput(self): expected = [0, 0, 0, 0] for (data_format, use_gpu) in GetTestConfigs(): @@ -801,7 +801,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropFilter(self): expected = [ 17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0, @@ -820,7 +820,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self): expected = [161.0, 182.0, 287.0, 308.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -834,7 +834,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DStrideTwoFilterOneSameBackpropFilter(self): expected_output = [78.] for (data_format, use_gpu) in GetTestConfigs(): @@ -848,7 +848,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self): expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -912,8 +912,8 @@ class Conv2DTest(test.TestCase): value_2 = sess.run(conv_2) self.assertShapeEqual(value, conv) self.assertShapeEqual(value_2, conv_2) - print("expected = ", value_2) - print("actual = ", value) + tf_logging.info("expected = ", value_2) + tf_logging.info("actual = ", value) self.assertArrayNear(value_2.flatten(), value.flatten(), err) # Testing for backprops @@ -965,8 +965,8 @@ class Conv2DTest(test.TestCase): value_2 = sess.run(conv_2) self.assertShapeEqual(value, conv) self.assertShapeEqual(value_2, conv_2) - print("expected = ", value_2) - print("actual = ", value) + tf_logging.info("expected = ", value_2) + tf_logging.info("actual = ", value) self.assertArrayNear(value_2.flatten(), value.flatten(), err) def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): @@ -1178,7 +1178,7 @@ class Conv2DTest(test.TestCase): # since fp16 numerical gradients are too imprecise. err = np.fabs(jacob_t - reference_jacob_t).max() - print("conv_2d gradient error = ", err) + tf_logging.info("conv_2d gradient error = ", err) self.assertLess(err, 0.002) def testInputGradientValidPaddingStrideOne(self): @@ -1546,7 +1546,7 @@ class DepthwiseConv2DTest(test.TestCase): conv = nn_impl.depthwise_conv2d( t1, t2, strides=[1, stride, stride, 1], padding=padding) value = sess.run(conv) - print("value = ", value) + tf_logging.info("value = ", value) self.assertArrayNear(expected, np.ravel(value), 1e-5) self.assertShapeEqual(value, conv) @@ -1668,7 +1668,7 @@ class SeparableConv2DTest(test.TestCase): conv = array_ops.transpose(conv, [0, 2, 3, 1]) value = sess.run(conv) - print("value = ", value) + tf_logging.info("value = ", value) self.assertArrayNear(expected, np.ravel(value), 1e-5) self.assertShapeEqual(value, conv) @@ -1826,7 +1826,7 @@ class Conv2DBenchmark(test.Benchmark): wall_time = time.time() - start self.report_benchmark( name="conv_stack_iter_%d" % iter_index, wall_time=wall_time) - print("conv_stack_iter_%d: %.4f" % (iter_index, wall_time)) + tf_logging.info("conv_stack_iter_%d: %.4f" % (iter_index, wall_time)) def GetInceptionFwdTest(input_size, filter_size, stride, padding, @@ -1897,19 +1897,19 @@ if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): setattr(Conv2DTest, "testInceptionFwd_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))) setattr( Conv2DTest, "testInceptionFwdDilatedConv_" + str(index), - test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest( + test_util.run_in_graph_and_eager_modes(GetInceptionFwdDilatedConvTest( input_size_, filter_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackInput_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackInputTest(input_size_, filter_size_, output_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackFilter_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackFilterTest(input_size_, filter_size_, output_size_, [stride_, stride_], padding_))) @@ -1924,17 +1924,17 @@ if __name__ == "__main__": fshape = [1, 1, 1, 256] oshape = [1, 400, 400, 256] setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME"))) setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME", gpu_only=True))) test.main() diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 87da89831c8ded9b8382c7bb251948b6d202300e..b61232cdedecacf0cc0f9b1661486a52afc86c2e 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gradient_checker @@ -95,7 +96,8 @@ class UnaryOpTest(test.TestCase): np_ans = np_func(x) with self.test_session(use_gpu=False): inx = ops.convert_to_tensor(x) - if x.dtype in (np.float32, np.float64): + if x.dtype in (np.float32, np.float64, + dtypes_lib.bfloat16.as_numpy_dtype): y = 1.1 * tf_func(inx) np_ans *= 1.1 else: @@ -104,6 +106,8 @@ class UnaryOpTest(test.TestCase): self.assertShapeEqual(np_ans, y) if x.dtype == np.float16: self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3) + elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype: + self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2) else: self.assertAllClose(np_ans, tf_cpu) @@ -152,7 +156,7 @@ class UnaryOpTest(test.TestCase): def _compareGpu(self, x, np_func, tf_func): np_ans = np_func(x) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): result = tf_func(ops.convert_to_tensor(x)) tf_gpu = result.eval() if x.dtype == np.float16: @@ -164,7 +168,7 @@ class UnaryOpTest(test.TestCase): def _compareSparseGpu(self, x, np_func, tf_func, tol): x_sp, x_sp_vals = _sparsify(x) res_np = np_func(x_sp_vals) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): self._check(tf_func(x_sp), res_np, x_sp, tol) def _compareBoth(self, x, np_func, tf_func): @@ -240,6 +244,12 @@ class UnaryOpTest(test.TestCase): math_ops.lgamma) self._compareBoth(x, np.vectorize(math.erf), math_ops.erf) self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc) + try: + from scipy import special # pylint: disable=g-import-not-at-top + self._compareBoth(x, special.i0e, math_ops.bessel_i0e) + self._compareBoth(x, special.i1e, math_ops.bessel_i1e) + except ImportError as e: + tf_logging.warn("Cannot test special functions: %s" % str(e)) self._compareBothSparse(x, np.abs, math_ops.abs) self._compareBothSparse(x, np.negative, math_ops.negative) @@ -285,6 +295,12 @@ class UnaryOpTest(test.TestCase): self._compareBoth(x, np.arcsin, math_ops.asin) self._compareBoth(x, np.arccos, math_ops.acos) self._compareBoth(x, np.arctan, math_ops.atan) + try: + from scipy import special # pylint: disable=g-import-not-at-top + self._compareBoth(x, special.i0e, math_ops.bessel_i0e) + self._compareBoth(x, special.i1e, math_ops.bessel_i1e) + except ImportError as e: + tf_logging.warn("Cannot test special functions: %s" % str(e)) self._compareBothSparse(x, np.abs, math_ops.abs) self._compareBothSparse(x, np.negative, math_ops.negative) @@ -333,6 +349,12 @@ class UnaryOpTest(test.TestCase): self._compareBoth(k, np.arcsin, math_ops.asin) self._compareBoth(k, np.arccos, math_ops.acos) self._compareBoth(k, np.tan, math_ops.tan) + try: + from scipy import special # pylint: disable=g-import-not-at-top + self._compareBoth(x, special.i0e, math_ops.bessel_i0e) + self._compareBoth(x, special.i1e, math_ops.bessel_i1e) + except ImportError as e: + tf_logging.warn("Cannot test special functions: %s" % str(e)) self._compareBothSparse(x, np.abs, math_ops.abs) self._compareBothSparse(x, np.negative, math_ops.negative) @@ -369,6 +391,12 @@ class UnaryOpTest(test.TestCase): math_ops.lgamma) self._compareBoth(x, np.vectorize(math.erf), math_ops.erf) self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc) + try: + from scipy import special # pylint: disable=g-import-not-at-top + self._compareBoth(x, special.i0e, math_ops.bessel_i0e) + self._compareBoth(x, special.i1e, math_ops.bessel_i1e) + except ImportError as e: + tf_logging.warn("Cannot test special functions: %s" % str(e)) self._compareBothSparse(x, np.abs, math_ops.abs) self._compareBothSparse(x, np.negative, math_ops.negative) @@ -630,7 +658,7 @@ class BinaryOpTest(test.TestCase): def _compareGpu(self, x, y, np_func, tf_func): np_ans = np_func(x, y) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) out = tf_func(inx, iny) @@ -643,12 +671,11 @@ class BinaryOpTest(test.TestCase): self._compareCpu(x, y, np_func, tf_func, also_compare_variables) if x.dtype in (np.float16, np.float32, np.float64, np.complex64, np.complex128): - if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.igamma, - math_ops.igammac, math_ops.zeta, math_ops.polygamma): + if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta, + math_ops.polygamma): self._compareGradientX(x, y, np_func, tf_func) self._compareGradientY(x, y, np_func, tf_func) - if tf_func in (math_ops.igamma, math_ops.igammac, math_ops.zeta, - math_ops.polygamma): + if tf_func in (math_ops.zeta, math_ops.polygamma): # These methods only support gradients in the second parameter self._compareGradientY(x, y, np_func, tf_func) self._compareGpu(x, y, np_func, tf_func) @@ -1203,7 +1230,7 @@ class BinaryOpTest(test.TestCase): class ComparisonOpTest(test.TestCase): def _compareScalar(self, func, x, y, dtype): - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): out = func( ops.convert_to_tensor(np.array([x]).astype(dtype)), ops.convert_to_tensor(np.array([y]).astype(dtype))) @@ -1236,7 +1263,7 @@ class ComparisonOpTest(test.TestCase): def _compare(self, x, y, np_func, tf_func): np_ans = np_func(x, y) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y)) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) @@ -1337,7 +1364,8 @@ class LogicalOpTest(test.TestCase): def _compareBinary(self, x, y, np_func, tf_func, use_gpu=False): np_ans = np_func(x, y) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) out = tf_func(inx, iny) @@ -1348,7 +1376,8 @@ class LogicalOpTest(test.TestCase): def _not(self, x, use_gpu=False): np_ans = np.logical_not(x) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): out = math_ops.logical_not(ops.convert_to_tensor(x)) tf_val = out.eval() self.assertEqual(out.dtype, dtypes_lib.bool) @@ -1433,7 +1462,8 @@ class SelectOpTest(test.TestCase): def _compare(self, c, x, y, use_gpu): np_ans = np.where(c, x, y) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): out = array_ops.where(c, x, y) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) @@ -1576,7 +1606,8 @@ class BatchSelectOpTest(test.TestCase): np_ans = np.dstack( [x_i if c_i else y_i for c_i, x_i, y_i in zip(c, x, y)]).transpose( [2, 0, 1]) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): out = array_ops.where(c, x, y) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) @@ -1681,7 +1712,9 @@ class MinMaxOpTest(test.TestCase): def _compare(self, x, y, use_gpu): np_min, np_max = np.minimum(x, y), np.maximum(x, y) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session( + use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()) as sess: inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) omin, omax = math_ops.minimum(inx, iny), math_ops.maximum(inx, iny) @@ -1843,7 +1876,9 @@ class IsFiniteInfNanTest(test.TestCase): def _compare(self, x, use_gpu): np_finite, np_inf, np_nan = np.isfinite(x), np.isinf(x), np.isnan(x) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session( + use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()) as sess: inx = ops.convert_to_tensor(x) ofinite, oinf, onan = math_ops.is_finite(inx), math_ops.is_inf( inx), math_ops.is_nan(inx) @@ -1884,7 +1919,7 @@ class IsFiniteInfNanTest(test.TestCase): x = np.full((size,), value, dtype=dtype) np_y = np.sqrt(x) np_nan = np.isnan(np_y) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): tf_y = math_ops.sqrt(x) tf_nan = math_ops.is_nan(tf_y) if value < 0: @@ -1939,7 +1974,8 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareMake(self, real, imag, use_gpu): np_ans = real + (1j) * imag - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): real = ops.convert_to_tensor(real) imag = ops.convert_to_tensor(imag) tf_ans = math_ops.complex(real, imag) @@ -1958,7 +1994,8 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareRealImag(self, cplx, use_gpu): np_real, np_imag = np.real(cplx), np.imag(cplx) np_zeros = np_real * 0 - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): inx = ops.convert_to_tensor(cplx) tf_real = math_ops.real(inx) tf_imag = math_ops.imag(inx) @@ -1985,7 +2022,9 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareAngle(self, cplx, use_gpu): np_angle = np.angle(cplx) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session( + use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()) as sess: inx = ops.convert_to_tensor(cplx) tf_angle = math_ops.angle(inx) tf_angle_val = sess.run(tf_angle) @@ -2019,7 +2058,8 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareConj(self, cplx, use_gpu): np_ans = np.conj(cplx) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): inx = ops.convert_to_tensor(cplx) tf_conj = math_ops.conj(inx) tf_ans = tf_conj.eval() diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py index 93b2ff4561bcc8fd13855cde444c4b6237d7949b..97d7e2d8f90a620b693e2c81adc616d399e13bd6 100644 --- a/tensorflow/python/kernel_tests/dct_ops_test.py +++ b/tensorflow/python/kernel_tests/dct_ops_test.py @@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name fftpack = try_import("scipy.fftpack") +def _np_dct2(signals, norm=None): + """Computes the DCT-II manually with NumPy.""" + # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 + dct_size = signals.shape[-1] + dct = np.zeros_like(signals) + for k in range(dct_size): + phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) + dct[..., k] = np.sum(signals * phi, axis=-1) + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + if norm == "ortho": + # The orthonormal scaling includes a factor of 0.5 which we combine with + # the overall scaling of 2.0 to cancel. + dct[..., 0] *= np.sqrt(1.0 / dct_size) + dct[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + dct *= 2.0 + return dct + + +def _np_dct3(signals, norm=None): + """Computes the DCT-III manually with NumPy.""" + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + dct_size = signals.shape[-1] + signals = np.array(signals) # make a copy so we can modify + if norm == "ortho": + signals[..., 0] *= np.sqrt(4.0 / dct_size) + signals[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + signals *= 2.0 + dct = np.zeros_like(signals) + # X_k = 0.5 * x_0 + + # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1 + half_x0 = 0.5 * signals[..., 0] + for k in range(dct_size): + phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size) + dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1) + return dct + + +NP_DCT = {2: _np_dct2, 3: _np_dct3} +NP_IDCT = {2: _np_dct3, 3: _np_dct2} + + class DCTOpsTest(test.TestCase): - def _np_dct2(self, signals, norm=None): - """Computes the DCT-II manually with NumPy.""" - # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 - dct_size = signals.shape[-1] - dct = np.zeros_like(signals) - for k in range(dct_size): - phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) - dct[..., k] = np.sum(signals * phi, axis=-1) - # SciPy's `dct` has a scaling factor of 2.0 which we follow. - # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src - if norm == "ortho": - # The orthonormal scaling includes a factor of 0.5 which we combine with - # the overall scaling of 2.0 to cancel. - dct[..., 0] *= np.sqrt(1.0 / dct_size) - dct[..., 1:] *= np.sqrt(2.0 / dct_size) - else: - dct *= 2.0 - return dct - - def _compare(self, signals, norm, atol=5e-4, rtol=5e-4): - """Compares the DCT to SciPy (if available) and a NumPy implementation.""" - np_dct = self._np_dct2(signals, norm) - tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval() + def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4): + """Compares (I)DCT to SciPy (if available) and a NumPy implementation.""" + np_dct = NP_DCT[dct_type](signals, norm) + tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval() self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) + np_idct = NP_IDCT[dct_type](signals, norm) + tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval() + self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) if fftpack: - scipy_dct = fftpack.dct(signals, type=2, norm=norm) + scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm) self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) + scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) + self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol) + # Verify inverse(forward(s)) == s, up to a normalization factor. + tf_idct_dct = spectral_ops.idct( + tf_dct, type=dct_type, norm=norm).eval() + tf_dct_idct = spectral_ops.dct( + tf_idct, type=dct_type, norm=norm).eval() + if norm is None: + tf_idct_dct *= 0.5 / signals.shape[-1] + tf_dct_idct *= 0.5 / signals.shape[-1] + self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol) + self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) def test_random(self): """Test randomly generated batches of data.""" with spectral_ops_test_util.fft_kernel_label_map(): with self.test_session(use_gpu=True): - for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]): + for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]): signals = np.random.rand(*shape).astype(np.float32) for norm in (None, "ortho"): - self._compare(signals, norm) + self._compare(signals, norm, 2) + self._compare(signals, norm, 3) def test_error(self): signals = np.random.rand(10) # Unsupported type. with self.assertRaises(ValueError): - spectral_ops.dct(signals, type=3) + spectral_ops.dct(signals, type=1) # Unknown normalization. with self.assertRaises(ValueError): spectral_ops.dct(signals, norm="bad") diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 5e223b18281ed9c06a3f72a16b6d22290851f37b..7134e02c348b47048cff5b0c205d1dd613c31a81 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -356,7 +356,7 @@ class DepthwiseConv2DTest(test.TestCase): with self.test_session(graph=graph, use_gpu=use_gpu) as sess: tolerance = { dtypes.float16: 4e-0, - dtypes.float32: 5e-4, + dtypes.float32: 8e-4, dtypes.float64: 1e-12, }[data_type] diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index cf2e8832fd5225e4d4be617a97b355bb410084c2..14532965d8c2c62139b3cd922acb9f90c0691d53 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -93,6 +93,7 @@ cuda_py_test( size = "small", srcs = ["categorical_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -134,6 +135,10 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "noguitar", # b/110489471 + "notap", # b/110489471 + ], ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index 09812db8166567403dc966ac9cb4304be0740e50..9ad77a54cbc730296508e4fe74248d2413029151 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -22,8 +22,10 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bernoulli from tensorflow.python.ops.distributions import kullback_leibler @@ -56,59 +58,65 @@ def entropy(p): class BernoulliTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def testP(self): p = [0.2, 0.4] dist = bernoulli.Bernoulli(probs=p) with self.test_session(): - self.assertAllClose(p, dist.probs.eval()) + self.assertAllClose(p, self.evaluate(dist.probs)) + @test_util.run_in_graph_and_eager_modes def testLogits(self): logits = [-42., 42.] dist = bernoulli.Bernoulli(logits=logits) with self.test_session(): - self.assertAllClose(logits, dist.logits.eval()) + self.assertAllClose(logits, self.evaluate(dist.logits)) if not special: return with self.test_session(): - self.assertAllClose(special.expit(logits), dist.probs.eval()) + self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) p = [0.01, 0.99, 0.42] dist = bernoulli.Bernoulli(probs=p) with self.test_session(): - self.assertAllClose(special.logit(p), dist.logits.eval()) + self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) + @test_util.run_in_graph_and_eager_modes def testInvalidP(self): invalid_ps = [1.01, 2.] for p in invalid_ps: with self.test_session(): with self.assertRaisesOpError("probs has components greater than 1"): dist = bernoulli.Bernoulli(probs=p, validate_args=True) - dist.probs.eval() + self.evaluate(dist.probs) invalid_ps = [-0.01, -3.] for p in invalid_ps: with self.test_session(): with self.assertRaisesOpError("Condition x >= 0"): dist = bernoulli.Bernoulli(probs=p, validate_args=True) - dist.probs.eval() + self.evaluate(dist.probs) valid_ps = [0.0, 0.5, 1.0] for p in valid_ps: with self.test_session(): dist = bernoulli.Bernoulli(probs=p) - self.assertEqual(p, dist.probs.eval()) # Should not fail + self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail + @test_util.run_in_graph_and_eager_modes def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): dist = make_bernoulli(batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) - self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) + self.assertAllEqual(batch_shape, + self.evaluate(dist.batch_shape_tensor())) self.assertAllEqual([], dist.event_shape.as_list()) - self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + @test_util.run_in_graph_and_eager_modes def testDtype(self): dist = make_bernoulli([]) self.assertEqual(dist.dtype, dtypes.int32) @@ -126,6 +134,7 @@ class BernoulliTest(test.TestCase): self.assertEqual(dist64.dtype, dist64.sample(5).dtype) self.assertEqual(dist64.dtype, dist64.mode().dtype) + @test_util.run_in_graph_and_eager_modes def _testPmf(self, **kwargs): dist = bernoulli.Bernoulli(**kwargs) with self.test_session(): @@ -147,8 +156,9 @@ class BernoulliTest(test.TestCase): # pylint: enable=bad-continuation for x, expected_pmf in zip(xs, expected_pmfs): - self.assertAllClose(dist.prob(x).eval(), expected_pmf) - self.assertAllClose(dist.log_prob(x).eval(), np.log(expected_pmf)) + self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) + self.assertAllClose( + self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) def testPmfCorrectBroadcastDynamicShape(self): with self.test_session(): @@ -165,15 +175,17 @@ class BernoulliTest(test.TestCase): p: [0.2, 0.3, 0.4] }), [[0.2, 0.7, 0.4]]) + @test_util.run_in_graph_and_eager_modes def testPmfInvalid(self): p = [0.1, 0.2, 0.7] with self.test_session(): dist = bernoulli.Bernoulli(probs=p, validate_args=True) with self.assertRaisesOpError("must be non-negative."): - dist.prob([1, 1, -1]).eval() + self.evaluate(dist.prob([1, 1, -1])) with self.assertRaisesOpError("Elements cannot exceed 1."): - dist.prob([2, 0, 1]).eval() + self.evaluate(dist.prob([2, 0, 1])) + @test_util.run_in_graph_and_eager_modes def testPmfWithP(self): p = [[0.2, 0.4], [0.3, 0.6]] self._testPmf(probs=p) @@ -203,7 +215,7 @@ class BernoulliTest(test.TestCase): with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) - self.assertEqual(2, len(dist.log_prob([[1], [1]]).eval().shape)) + self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape)) with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) @@ -215,25 +227,31 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) self.assertEqual((2, 1), dist.log_prob(1).get_shape()) + @test_util.run_in_graph_and_eager_modes def testBoundaryConditions(self): with self.test_session(): dist = bernoulli.Bernoulli(probs=1.0) - self.assertAllClose(np.nan, dist.log_prob(0).eval()) - self.assertAllClose([np.nan], [dist.log_prob(1).eval()]) + self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) + self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) + @test_util.run_in_graph_and_eager_modes def testEntropyNoBatch(self): p = 0.2 dist = bernoulli.Bernoulli(probs=p) with self.test_session(): - self.assertAllClose(dist.entropy().eval(), entropy(p)) + self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) + @test_util.run_in_graph_and_eager_modes def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] dist = bernoulli.Bernoulli(probs=p, validate_args=False) with self.test_session(): - self.assertAllClose(dist.entropy().eval(), [[entropy(0.1), entropy(0.7)], - [entropy(0.2), entropy(0.6)]]) + self.assertAllClose( + self.evaluate(dist.entropy()), + [[entropy(0.1), entropy(0.7)], [entropy(0.2), + entropy(0.6)]]) + @test_util.run_in_graph_and_eager_modes def testSampleN(self): with self.test_session(): p = [0.2, 0.6] @@ -242,7 +260,7 @@ class BernoulliTest(test.TestCase): samples = dist.sample(n) samples.set_shape([n, 2]) self.assertEqual(samples.dtype, dtypes.int32) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertTrue(np.all(sample_values >= 0)) self.assertTrue(np.all(sample_values <= 1)) # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / @@ -255,6 +273,16 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(np.log([.2, .4])) self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) + @test_util.run_in_graph_and_eager_modes + def testNotReparameterized(self): + p = constant_op.constant([0.2, 0.6]) + with backprop.GradientTape() as tape: + tape.watch(p) + dist = bernoulli.Bernoulli(probs=p) + samples = dist.sample(100) + grad_p = tape.gradient(samples, p) + self.assertIsNone(grad_p) + def testSampleActsLikeSampleN(self): with self.test_session() as sess: p = [0.2, 0.6] @@ -262,51 +290,54 @@ class BernoulliTest(test.TestCase): n = 1000 seed = 42 self.assertAllEqual( - dist.sample(n, seed).eval(), dist.sample(n, seed).eval()) + self.evaluate(dist.sample(n, seed)), + self.evaluate(dist.sample(n, seed))) n = array_ops.placeholder(dtypes.int32) - sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)], - feed_dict={n: 1000}) - self.assertAllEqual(sample, sample) + sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)], + feed_dict={n: 1000}) + self.assertAllEqual(sample1, sample2) + @test_util.run_in_graph_and_eager_modes def testMean(self): with self.test_session(): p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) dist = bernoulli.Bernoulli(probs=p) - self.assertAllEqual(dist.mean().eval(), p) + self.assertAllEqual(self.evaluate(dist.mean()), p) + @test_util.run_in_graph_and_eager_modes def testVarianceAndStd(self): var = lambda p: p * (1. - p) with self.test_session(): p = [[0.2, 0.7], [0.5, 0.4]] dist = bernoulli.Bernoulli(probs=p) self.assertAllClose( - dist.variance().eval(), + self.evaluate(dist.variance()), np.array( [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32)) self.assertAllClose( - dist.stddev().eval(), + self.evaluate(dist.stddev()), np.array( [[np.sqrt(var(0.2)), np.sqrt(var(0.7))], [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], dtype=np.float32)) + @test_util.run_in_graph_and_eager_modes def testBernoulliBernoulliKL(self): - with self.test_session() as sess: - batch_size = 6 - a_p = np.array([0.5] * batch_size, dtype=np.float32) - b_p = np.array([0.4] * batch_size, dtype=np.float32) + batch_size = 6 + a_p = np.array([0.5] * batch_size, dtype=np.float32) + b_p = np.array([0.4] * batch_size, dtype=np.float32) - a = bernoulli.Bernoulli(probs=a_p) - b = bernoulli.Bernoulli(probs=b_p) + a = bernoulli.Bernoulli(probs=a_p) + b = bernoulli.Bernoulli(probs=b_p) - kl = kullback_leibler.kl_divergence(a, b) - kl_val = sess.run(kl) + kl = kullback_leibler.kl_divergence(a, b) + kl_val = self.evaluate(kl) - kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( - (1. - a_p) / (1. - b_p))) + kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( + (1. - a_p) / (1. - b_p))) - self.assertEqual(kl.get_shape(), (batch_size,)) - self.assertAllClose(kl_val, kl_expected) + self.assertEqual(kl.get_shape(), (batch_size,)) + self.assertAllClose(kl_val, kl_expected) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index ab5041a6eb477ce231acbd1e6041c354ee17409b..36f3ffc333f74e3f6e672b6ba1591bf8de08a010 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -21,9 +21,11 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import beta as beta_lib @@ -45,6 +47,7 @@ special = try_import("scipy.special") stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class BetaTest(test.TestCase): def testSimpleShapes(self): @@ -52,8 +55,8 @@ class BetaTest(test.TestCase): a = np.random.rand(3) b = np.random.rand(3) dist = beta_lib.Beta(a, b) - self.assertAllEqual([], dist.event_shape_tensor().eval()) - self.assertAllEqual([3], dist.batch_shape_tensor().eval()) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) @@ -62,8 +65,8 @@ class BetaTest(test.TestCase): a = np.random.rand(3, 2, 2) b = np.random.rand(3, 2, 2) dist = beta_lib.Beta(a, b) - self.assertAllEqual([], dist.event_shape_tensor().eval()) - self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) self.assertEqual( tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) @@ -73,8 +76,8 @@ class BetaTest(test.TestCase): a = np.random.rand(3, 2, 2) b = np.random.rand(2, 2) dist = beta_lib.Beta(a, b) - self.assertAllEqual([], dist.event_shape_tensor().eval()) - self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) self.assertEqual( tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) @@ -85,7 +88,7 @@ class BetaTest(test.TestCase): with self.test_session(): dist = beta_lib.Beta(a, b) self.assertEqual([1, 3], dist.concentration1.get_shape()) - self.assertAllClose(a, dist.concentration1.eval()) + self.assertAllClose(a, self.evaluate(dist.concentration1)) def testBetaProperty(self): a = [[1., 2, 3]] @@ -93,24 +96,24 @@ class BetaTest(test.TestCase): with self.test_session(): dist = beta_lib.Beta(a, b) self.assertEqual([1, 3], dist.concentration0.get_shape()) - self.assertAllClose(b, dist.concentration0.eval()) + self.assertAllClose(b, self.evaluate(dist.concentration0)) def testPdfXProper(self): a = [[1., 2, 3]] b = [[2., 4, 3]] with self.test_session(): dist = beta_lib.Beta(a, b, validate_args=True) - dist.prob([.1, .3, .6]).eval() - dist.prob([.2, .3, .5]).eval() + self.evaluate(dist.prob([.1, .3, .6])) + self.evaluate(dist.prob([.2, .3, .5])) # Either condition can trigger. with self.assertRaisesOpError("sample must be positive"): - dist.prob([-1., 0.1, 0.5]).eval() + self.evaluate(dist.prob([-1., 0.1, 0.5])) with self.assertRaisesOpError("sample must be positive"): - dist.prob([0., 0.1, 0.5]).eval() + self.evaluate(dist.prob([0., 0.1, 0.5])) with self.assertRaisesOpError("sample must be less than `1`"): - dist.prob([.1, .2, 1.2]).eval() + self.evaluate(dist.prob([.1, .2, 1.2])) with self.assertRaisesOpError("sample must be less than `1`"): - dist.prob([.1, .2, 1.0]).eval() + self.evaluate(dist.prob([.1, .2, 1.0])) def testPdfTwoBatches(self): with self.test_session(): @@ -119,7 +122,7 @@ class BetaTest(test.TestCase): x = [.5, .5] dist = beta_lib.Beta(a, b) pdf = dist.prob(x) - self.assertAllClose([1., 3. / 2], pdf.eval()) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) self.assertEqual((2,), pdf.get_shape()) def testPdfTwoBatchesNontrivialX(self): @@ -129,7 +132,7 @@ class BetaTest(test.TestCase): x = [.3, .7] dist = beta_lib.Beta(a, b) pdf = dist.prob(x) - self.assertAllClose([1, 63. / 50], pdf.eval()) + self.assertAllClose([1, 63. / 50], self.evaluate(pdf)) self.assertEqual((2,), pdf.get_shape()) def testPdfUniformZeroBatch(self): @@ -140,7 +143,7 @@ class BetaTest(test.TestCase): x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) dist = beta_lib.Beta(a, b) pdf = dist.prob(x) - self.assertAllClose([1.] * 5, pdf.eval()) + self.assertAllClose([1.] * 5, self.evaluate(pdf)) self.assertEqual((5,), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenSameRank(self): @@ -150,7 +153,7 @@ class BetaTest(test.TestCase): x = [[.5, .5], [.3, .7]] dist = beta_lib.Beta(a, b) pdf = dist.prob(x) - self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], pdf.eval()) + self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf)) self.assertEqual((2, 2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): @@ -159,7 +162,7 @@ class BetaTest(test.TestCase): b = [1., 2] x = [[.5, .5], [.2, .8]] pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], pdf.eval()) + self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf)) self.assertEqual((2, 2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenSameRank(self): @@ -168,7 +171,7 @@ class BetaTest(test.TestCase): b = [[1., 2], [2., 3]] x = [[.5, .5]] pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval()) + self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) self.assertEqual((2, 2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenLowerRank(self): @@ -177,7 +180,7 @@ class BetaTest(test.TestCase): b = [[1., 2], [2., 3]] x = [.5, .5] pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval()) + self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) self.assertEqual((2, 2), pdf.get_shape()) def testBetaMean(self): @@ -189,7 +192,7 @@ class BetaTest(test.TestCase): if not stats: return expected_mean = stats.beta.mean(a, b) - self.assertAllClose(expected_mean, dist.mean().eval()) + self.assertAllClose(expected_mean, self.evaluate(dist.mean())) def testBetaVariance(self): with session.Session(): @@ -200,7 +203,7 @@ class BetaTest(test.TestCase): if not stats: return expected_variance = stats.beta.var(a, b) - self.assertAllClose(expected_variance, dist.variance().eval()) + self.assertAllClose(expected_variance, self.evaluate(dist.variance())) def testBetaMode(self): with session.Session(): @@ -209,7 +212,7 @@ class BetaTest(test.TestCase): expected_mode = (a - 1) / (a + b - 2) dist = beta_lib.Beta(a, b) self.assertEqual(dist.mode().get_shape(), (3,)) - self.assertAllClose(expected_mode, dist.mode().eval()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) def testBetaModeInvalid(self): with session.Session(): @@ -217,13 +220,13 @@ class BetaTest(test.TestCase): b = np.array([2., 4, 1.2]) dist = beta_lib.Beta(a, b, allow_nan_stats=False) with self.assertRaisesOpError("Condition x < y.*"): - dist.mode().eval() + self.evaluate(dist.mode()) a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) dist = beta_lib.Beta(a, b, allow_nan_stats=False) with self.assertRaisesOpError("Condition x < y.*"): - dist.mode().eval() + self.evaluate(dist.mode()) def testBetaModeEnableAllowNanStats(self): with session.Session(): @@ -234,7 +237,7 @@ class BetaTest(test.TestCase): expected_mode = (a - 1) / (a + b - 2) expected_mode[0] = np.nan self.assertEqual((3,), dist.mode().get_shape()) - self.assertAllClose(expected_mode, dist.mode().eval()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) @@ -243,7 +246,7 @@ class BetaTest(test.TestCase): expected_mode = (a - 1) / (a + b - 2) expected_mode[0] = np.nan self.assertEqual((3,), dist.mode().get_shape()) - self.assertAllClose(expected_mode, dist.mode().eval()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) def testBetaEntropy(self): with session.Session(): @@ -254,7 +257,7 @@ class BetaTest(test.TestCase): if not stats: return expected_entropy = stats.beta.entropy(a, b) - self.assertAllClose(expected_entropy, dist.entropy().eval()) + self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) def testBetaSample(self): with self.test_session(): @@ -263,7 +266,7 @@ class BetaTest(test.TestCase): beta = beta_lib.Beta(a, b) n = constant_op.constant(100000) samples = beta.sample(n) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000,)) self.assertFalse(np.any(sample_values < 0.0)) if not stats: @@ -280,6 +283,18 @@ class BetaTest(test.TestCase): self.assertAllClose( np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) + def testBetaFullyReparameterized(self): + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + with backprop.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + beta = beta_lib.Beta(a, b) + samples = beta.sample(100) + grad_a, grad_b = tape.gradient(samples, [a, b]) + self.assertIsNotNone(grad_a) + self.assertIsNotNone(grad_b) + # Test that sampling with the same seed twice gives the same results. def testBetaSampleMultipleTimes(self): with self.test_session(): @@ -291,13 +306,13 @@ class BetaTest(test.TestCase): beta1 = beta_lib.Beta(concentration1=a_val, concentration0=b_val, name="beta1") - samples1 = beta1.sample(n_val, seed=123456).eval() + samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) random_seed.set_random_seed(654321) beta2 = beta_lib.Beta(concentration1=a_val, concentration0=b_val, name="beta2") - samples2 = beta2.sample(n_val, seed=123456).eval() + samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) self.assertAllClose(samples1, samples2) @@ -308,7 +323,7 @@ class BetaTest(test.TestCase): beta = beta_lib.Beta(a, b) n = constant_op.constant(100000) samples = beta.sample(n) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) self.assertFalse(np.any(sample_values < 0.0)) if not stats: @@ -325,7 +340,7 @@ class BetaTest(test.TestCase): a = 10. * np.random.random(shape).astype(dt) b = 10. * np.random.random(shape).astype(dt) x = np.random.random(shape).astype(dt) - actual = beta_lib.Beta(a, b).cdf(x).eval() + actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) if not stats: @@ -339,7 +354,7 @@ class BetaTest(test.TestCase): a = 10. * np.random.random(shape).astype(dt) b = 10. * np.random.random(shape).astype(dt) x = np.random.random(shape).astype(dt) - actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval() + actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) if not stats: @@ -350,46 +365,47 @@ class BetaTest(test.TestCase): with self.test_session(): a, b = -4.2, -9.1 dist = beta_lib.BetaWithSoftplusConcentration(a, b) - self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval()) - self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval()) + self.assertAllClose( + self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) + self.assertAllClose( + self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) def testBetaBetaKL(self): - with self.test_session() as sess: - for shape in [(10,), (4, 5)]: - a1 = 6.0 * np.random.random(size=shape) + 1e-4 - b1 = 6.0 * np.random.random(size=shape) + 1e-4 - a2 = 6.0 * np.random.random(size=shape) + 1e-4 - b2 = 6.0 * np.random.random(size=shape) + 1e-4 - # Take inverse softplus of values to test BetaWithSoftplusConcentration - a1_sp = np.log(np.exp(a1) - 1.0) - b1_sp = np.log(np.exp(b1) - 1.0) - a2_sp = np.log(np.exp(a2) - 1.0) - b2_sp = np.log(np.exp(b2) - 1.0) - - d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) - d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) - d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp, - concentration0=b1_sp) - d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, - concentration0=b2_sp) - - if not special: - return - kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + - (a1 - a2) * special.digamma(a1) + - (b1 - b2) * special.digamma(b1) + - (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) - - for dist1 in [d1, d1_sp]: - for dist2 in [d2, d2_sp]: - kl = kullback_leibler.kl_divergence(dist1, dist2) - kl_val = sess.run(kl) - self.assertEqual(kl.get_shape(), shape) - self.assertAllClose(kl_val, kl_expected) - - # Make sure KL(d1||d1) is 0 - kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1)) - self.assertAllClose(kl_same, np.zeros_like(kl_expected)) + for shape in [(10,), (4, 5)]: + a1 = 6.0 * np.random.random(size=shape) + 1e-4 + b1 = 6.0 * np.random.random(size=shape) + 1e-4 + a2 = 6.0 * np.random.random(size=shape) + 1e-4 + b2 = 6.0 * np.random.random(size=shape) + 1e-4 + # Take inverse softplus of values to test BetaWithSoftplusConcentration + a1_sp = np.log(np.exp(a1) - 1.0) + b1_sp = np.log(np.exp(b1) - 1.0) + a2_sp = np.log(np.exp(a2) - 1.0) + b2_sp = np.log(np.exp(b2) - 1.0) + + d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) + d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) + d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp, + concentration0=b1_sp) + d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, + concentration0=b2_sp) + + if not special: + return + kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + + (a1 - a2) * special.digamma(a1) + + (b1 - b2) * special.digamma(b1) + + (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) + + for dist1 in [d1, d1_sp]: + for dist2 in [d2, d2_sp]: + kl = kullback_leibler.kl_divergence(dist1, dist2) + kl_val = self.evaluate(kl) + self.assertEqual(kl.get_shape(), shape) + self.assertAllClose(kl_val, kl_expected) + + # Make sure KL(d1||d1) is 0 + kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) + self.assertAllClose(kl_same, np.zeros_like(kl_expected)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py index 33db014279de2625380ec367b3fc5a96b5f9c4d6..8b11556330acc7dab68715ddc69563107a313ee6 100644 --- a/tensorflow/python/kernel_tests/distributions/bijector_test.py +++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py @@ -24,12 +24,14 @@ import numpy as np import six from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class BaseBijectorTest(test.TestCase): """Tests properties of the Bijector base-class.""" @@ -47,42 +49,38 @@ class BaseBijectorTest(test.TestCase): def __init__(self): super(_BareBonesBijector, self).__init__(forward_min_event_ndims=0) - with self.test_session() as sess: - bij = _BareBonesBijector() - self.assertEqual([], bij.graph_parents) - self.assertEqual(False, bij.is_constant_jacobian) - self.assertEqual(False, bij.validate_args) - self.assertEqual(None, bij.dtype) - self.assertEqual("bare_bones_bijector", bij.name) - - for shape in [[], [1, 2], [1, 2, 3]]: - [ - forward_event_shape_, - inverse_event_shape_, - ] = sess.run([ - bij.inverse_event_shape_tensor(shape), - bij.forward_event_shape_tensor(shape), - ]) - self.assertAllEqual(shape, forward_event_shape_) - self.assertAllEqual(shape, bij.forward_event_shape(shape)) - self.assertAllEqual(shape, inverse_event_shape_) - self.assertAllEqual(shape, bij.inverse_event_shape(shape)) - - with self.assertRaisesRegexp( - NotImplementedError, "inverse not implemented"): - bij.inverse(0) - - with self.assertRaisesRegexp( - NotImplementedError, "forward not implemented"): - bij.forward(0) - - with self.assertRaisesRegexp( - NotImplementedError, "inverse_log_det_jacobian not implemented"): - bij.inverse_log_det_jacobian(0, event_ndims=0) - - with self.assertRaisesRegexp( - NotImplementedError, "forward_log_det_jacobian not implemented"): - bij.forward_log_det_jacobian(0, event_ndims=0) + bij = _BareBonesBijector() + self.assertEqual([], bij.graph_parents) + self.assertEqual(False, bij.is_constant_jacobian) + self.assertEqual(False, bij.validate_args) + self.assertEqual(None, bij.dtype) + self.assertEqual("bare_bones_bijector", bij.name) + + for shape in [[], [1, 2], [1, 2, 3]]: + forward_event_shape_ = self.evaluate( + bij.inverse_event_shape_tensor(shape)) + inverse_event_shape_ = self.evaluate( + bij.forward_event_shape_tensor(shape)) + self.assertAllEqual(shape, forward_event_shape_) + self.assertAllEqual(shape, bij.forward_event_shape(shape)) + self.assertAllEqual(shape, inverse_event_shape_) + self.assertAllEqual(shape, bij.inverse_event_shape(shape)) + + with self.assertRaisesRegexp( + NotImplementedError, "inverse not implemented"): + bij.inverse(0) + + with self.assertRaisesRegexp( + NotImplementedError, "forward not implemented"): + bij.forward(0) + + with self.assertRaisesRegexp( + NotImplementedError, "inverse_log_det_jacobian not implemented"): + bij.inverse_log_det_jacobian(0, event_ndims=0) + + with self.assertRaisesRegexp( + NotImplementedError, "forward_log_det_jacobian not implemented"): + bij.forward_log_det_jacobian(0, event_ndims=0) class IntentionallyMissingError(Exception): @@ -92,9 +90,10 @@ class IntentionallyMissingError(Exception): class BrokenBijector(bijector.Bijector): """Forward and inverse are not inverses of each other.""" - def __init__(self, forward_missing=False, inverse_missing=False): + def __init__( + self, forward_missing=False, inverse_missing=False, validate_args=False): super(BrokenBijector, self).__init__( - validate_args=False, forward_min_event_ndims=0, name="broken") + validate_args=validate_args, forward_min_event_ndims=0, name="broken") self._forward_missing = forward_missing self._inverse_missing = inverse_missing @@ -118,6 +117,33 @@ class BrokenBijector(bijector.Bijector): raise IntentionallyMissingError return math_ops.log(2.) +class BijectorTestEventNdims(test.TestCase): + + def testBijectorNonIntegerEventNdims(self): + bij = BrokenBijector() + with self.assertRaisesRegexp(ValueError, "Expected integer"): + bij.forward_log_det_jacobian(1., event_ndims=1.5) + with self.assertRaisesRegexp(ValueError, "Expected integer"): + bij.inverse_log_det_jacobian(1., event_ndims=1.5) + + def testBijectorArrayEventNdims(self): + bij = BrokenBijector() + with self.assertRaisesRegexp(ValueError, "Expected scalar"): + bij.forward_log_det_jacobian(1., event_ndims=(1, 2)) + with self.assertRaisesRegexp(ValueError, "Expected scalar"): + bij.inverse_log_det_jacobian(1., event_ndims=(1, 2)) + + def testBijectorDynamicEventNdims(self): + bij = BrokenBijector(validate_args=True) + event_ndims = array_ops.placeholder(dtype=np.int32, shape=None) + with self.test_session(): + with self.assertRaisesOpError("Expected scalar"): + bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({ + event_ndims: (1, 2)}) + with self.assertRaisesOpError("Expected scalar"): + bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({ + event_ndims: (1, 2)}) + @six.add_metaclass(abc.ABCMeta) class BijectorCachingTestBase(object): diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index ca2358fe99934e110ba743c6085d1f25ff0f5e5e..d8939433ce68ffa561e8e2200826f88dbe283ac2 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util @@ -40,7 +42,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32): return categorical.Categorical(logits, dtype=dtype) -class CategoricalTest(test.TestCase): +class CategoricalTest(test.TestCase, parameterized.TestCase): def testP(self): p = [0.2, 0.8] @@ -131,7 +133,7 @@ class CategoricalTest(test.TestCase): with self.test_session(): self.assertAllClose(dist.prob(0).eval(), 0.2) - def testCDFWithDynamicEventShape(self): + def testCDFWithDynamicEventShapeKnownNdims(self): """Test that dynamically-sized events with unknown shape work.""" batch_size = 2 histograms = array_ops.placeholder(dtype=dtypes.float32, @@ -167,6 +169,21 @@ class CategoricalTest(test.TestCase): self.assertAllClose(actual_cdf_one, expected_cdf_one) self.assertAllClose(actual_cdf_two, expected_cdf_two) + @parameterized.named_parameters( + ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]), + ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88])) + def testCDFWithDynamicEventShapeUnknownNdims( + self, events, histograms, expected_cdf): + """Test that dynamically-sized events with unknown shape work.""" + event_ph = array_ops.placeholder_with_default(events, shape=None) + histograms_ph = array_ops.placeholder_with_default(histograms, shape=None) + dist = categorical.Categorical(probs=histograms_ph) + cdf_op = dist.cdf(event_ph) + + actual_cdf = self.evaluate(cdf_op) + self.assertAllClose(actual_cdf, expected_cdf) + def testCDFWithBatch(self): histograms = [[0.1, 0.2, 0.3, 0.25, 0.15], [0.0, 0.75, 0.2, 0.05, 0.0]] @@ -360,6 +377,15 @@ class CategoricalTest(test.TestCase): self.assertAllClose( [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2) + def testNotReparameterized(self): + p = constant_op.constant([0.3, 0.3, 0.4]) + with backprop.GradientTape() as tape: + tape.watch(p) + dist = categorical.Categorical(p) + samples = dist.sample(100) + grad_p = tape.gradient(samples, p) + self.assertIsNone(grad_p) + def testLogPMFBroadcasting(self): with self.test_session(): # 1 x 2 x 2 diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index 7922fb0606c6f4b475b25da716d5f9a169e213b5..1b9edcc85a7581de1cb1bd93fdbb9d47b8d1b84a 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -17,6 +17,9 @@ from __future__ import division from __future__ import print_function import numpy as np + +from tensorflow.python.eager import backprop +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -250,10 +253,10 @@ class DirichletMultinomialTest(test.TestCase): dist.variance(), dist.stddev(), ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.05) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) + self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testCovariance(self): # Shape [2] @@ -442,7 +445,7 @@ class DirichletMultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) self.assertAllClose( actual_covariance_, sample_covariance_, atol=0., rtol=0.20) @@ -470,10 +473,25 @@ class DirichletMultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) self.assertAllEqual([4, 4], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.15) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) + + def testNotReparameterized(self): + total_count = constant_op.constant(5.0) + concentration = constant_op.constant([0.1, 0.1, 0.1]) + with backprop.GradientTape() as tape: + tape.watch(total_count) + tape.watch(concentration) + dist = ds.DirichletMultinomial( + total_count=total_count, + concentration=concentration) + samples = dist.sample(100) + grad_total_count, grad_concentration = tape.gradient( + samples, [total_count, concentration]) + self.assertIsNone(grad_total_count) + self.assertIsNone(grad_concentration) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index a2f1de5aaf3a75c1cfac820cc4494af34d082250..67ed0447ede39d7f0738c8caf3cc665bcfe5fd0b 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -20,11 +20,14 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -38,17 +41,19 @@ def try_import(name): # pylint: disable=invalid-name return module +special = try_import("scipy.special") stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class DirichletTest(test.TestCase): def testSimpleShapes(self): with self.test_session(): alpha = np.random.rand(3) dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual(3, dist.event_shape_tensor().eval()) - self.assertAllEqual([], dist.batch_shape_tensor().eval()) + self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) @@ -56,8 +61,8 @@ class DirichletTest(test.TestCase): with self.test_session(): alpha = np.random.rand(3, 2, 2) dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual(2, dist.event_shape_tensor().eval()) - self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) @@ -66,22 +71,22 @@ class DirichletTest(test.TestCase): with self.test_session(): dist = dirichlet_lib.Dirichlet(alpha) self.assertEqual([1, 3], dist.concentration.get_shape()) - self.assertAllClose(alpha, dist.concentration.eval()) + self.assertAllClose(alpha, self.evaluate(dist.concentration)) def testPdfXProper(self): alpha = [[1., 2, 3]] with self.test_session(): dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) - dist.prob([.1, .3, .6]).eval() - dist.prob([.2, .3, .5]).eval() + self.evaluate(dist.prob([.1, .3, .6])) + self.evaluate(dist.prob([.2, .3, .5])) # Either condition can trigger. with self.assertRaisesOpError("samples must be positive"): - dist.prob([-1., 1.5, 0.5]).eval() + self.evaluate(dist.prob([-1., 1.5, 0.5])) with self.assertRaisesOpError("samples must be positive"): - dist.prob([0., .1, .9]).eval() + self.evaluate(dist.prob([0., .1, .9])) with self.assertRaisesOpError( "sample last-dimension must sum to `1`"): - dist.prob([.1, .2, .8]).eval() + self.evaluate(dist.prob([.1, .2, .8])) def testPdfZeroBatches(self): with self.test_session(): @@ -89,7 +94,7 @@ class DirichletTest(test.TestCase): x = [.5, .5] dist = dirichlet_lib.Dirichlet(alpha) pdf = dist.prob(x) - self.assertAllClose(1., pdf.eval()) + self.assertAllClose(1., self.evaluate(pdf)) self.assertEqual((), pdf.get_shape()) def testPdfZeroBatchesNontrivialX(self): @@ -98,7 +103,7 @@ class DirichletTest(test.TestCase): x = [.3, .7] dist = dirichlet_lib.Dirichlet(alpha) pdf = dist.prob(x) - self.assertAllClose(7. / 5, pdf.eval()) + self.assertAllClose(7. / 5, self.evaluate(pdf)) self.assertEqual((), pdf.get_shape()) def testPdfUniformZeroBatches(self): @@ -108,7 +113,7 @@ class DirichletTest(test.TestCase): x = [[.2, .5, .3], [.3, .4, .3]] dist = dirichlet_lib.Dirichlet(alpha) pdf = dist.prob(x) - self.assertAllClose([2., 2.], pdf.eval()) + self.assertAllClose([2., 2.], self.evaluate(pdf)) self.assertEqual((2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenSameRank(self): @@ -117,7 +122,7 @@ class DirichletTest(test.TestCase): x = [[.5, .5], [.3, .7]] dist = dirichlet_lib.Dirichlet(alpha) pdf = dist.prob(x) - self.assertAllClose([1., 7. / 5], pdf.eval()) + self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) self.assertEqual((2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): @@ -125,7 +130,7 @@ class DirichletTest(test.TestCase): alpha = [1., 2] x = [[.5, .5], [.2, .8]] pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 8. / 5], pdf.eval()) + self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) self.assertEqual((2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenSameRank(self): @@ -133,7 +138,7 @@ class DirichletTest(test.TestCase): alpha = [[1., 2], [2., 3]] x = [[.5, .5]] pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 3. / 2], pdf.eval()) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) self.assertEqual((2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenLowerRank(self): @@ -141,7 +146,7 @@ class DirichletTest(test.TestCase): alpha = [[1., 2], [2., 3]] x = [.5, .5] pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 3. / 2], pdf.eval()) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) self.assertEqual((2), pdf.get_shape()) def testMean(self): @@ -152,43 +157,44 @@ class DirichletTest(test.TestCase): if not stats: return expected_mean = stats.dirichlet.mean(alpha) - self.assertAllClose(dirichlet.mean().eval(), expected_mean) + self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) def testCovarianceFromSampling(self): alpha = np.array([[1., 2, 3], [2.5, 4, 0.01]], dtype=np.float32) - with self.test_session() as sess: - dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3] - x = dist.sample(int(250e3), seed=1) - sample_mean = math_ops.reduce_mean(x, 0) - x_centered = x - sample_mean[None, ...] - sample_cov = math_ops.reduce_mean(math_ops.matmul( - x_centered[..., None], x_centered[..., None, :]), 0) - sample_var = array_ops.matrix_diag_part(sample_cov) - sample_stddev = math_ops.sqrt(sample_var) - [ - sample_mean_, - sample_cov_, - sample_var_, - sample_stddev_, - analytic_mean, - analytic_cov, - analytic_var, - analytic_stddev, - ] = sess.run([ - sample_mean, - sample_cov, - sample_var, - sample_stddev, - dist.mean(), - dist.covariance(), - dist.variance(), - dist.stddev(), - ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) + dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3] + x = dist.sample(int(250e3), seed=1) + sample_mean = math_ops.reduce_mean(x, 0) + x_centered = x - sample_mean[None, ...] + sample_cov = math_ops.reduce_mean(math_ops.matmul( + x_centered[..., None], x_centered[..., None, :]), 0) + sample_var = array_ops.matrix_diag_part(sample_cov) + sample_stddev = math_ops.sqrt(sample_var) + + [ + sample_mean_, + sample_cov_, + sample_var_, + sample_stddev_, + analytic_mean, + analytic_cov, + analytic_var, + analytic_stddev, + ] = self.evaluate([ + sample_mean, + sample_cov, + sample_var, + sample_stddev, + dist.mean(), + dist.covariance(), + dist.variance(), + dist.stddev(), + ]) + + self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testVariance(self): with self.test_session(): @@ -201,7 +207,8 @@ class DirichletTest(test.TestCase): expected_covariance = np.diag(stats.dirichlet.var(alpha)) expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]] / denominator - self.assertAllClose(dirichlet.covariance().eval(), expected_covariance) + self.assertAllClose( + self.evaluate(dirichlet.covariance()), expected_covariance) def testMode(self): with self.test_session(): @@ -209,7 +216,7 @@ class DirichletTest(test.TestCase): expected_mode = (alpha - 1) / (np.sum(alpha) - 3) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) self.assertEqual(dirichlet.mode().get_shape(), [3]) - self.assertAllClose(dirichlet.mode().eval(), expected_mode) + self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) def testModeInvalid(self): with self.test_session(): @@ -217,7 +224,7 @@ class DirichletTest(test.TestCase): dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, allow_nan_stats=False) with self.assertRaisesOpError("Condition x < y.*"): - dirichlet.mode().eval() + self.evaluate(dirichlet.mode()) def testModeEnableAllowNanStats(self): with self.test_session(): @@ -227,7 +234,7 @@ class DirichletTest(test.TestCase): expected_mode = np.zeros_like(alpha) + np.nan self.assertEqual(dirichlet.mode().get_shape(), [3]) - self.assertAllClose(dirichlet.mode().eval(), expected_mode) + self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) def testEntropy(self): with self.test_session(): @@ -237,7 +244,7 @@ class DirichletTest(test.TestCase): if not stats: return expected_entropy = stats.dirichlet.entropy(alpha) - self.assertAllClose(dirichlet.entropy().eval(), expected_entropy) + self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) def testSample(self): with self.test_session(): @@ -245,7 +252,7 @@ class DirichletTest(test.TestCase): dirichlet = dirichlet_lib.Dirichlet(alpha) n = constant_op.constant(100000) samples = dirichlet.sample(n) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000, 2)) self.assertTrue(np.all(sample_values > 0.0)) if not stats: @@ -258,6 +265,48 @@ class DirichletTest(test.TestCase): a=1., b=2.).cdf)[0], 0.01) + def testDirichletFullyReparameterized(self): + alpha = constant_op.constant([1.0, 2.0, 3.0]) + with backprop.GradientTape() as tape: + tape.watch(alpha) + dirichlet = dirichlet_lib.Dirichlet(alpha) + samples = dirichlet.sample(100) + grad_alpha = tape.gradient(samples, alpha) + self.assertIsNotNone(grad_alpha) + + def testDirichletDirichletKL(self): + conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], + [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) + conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) + + d1 = dirichlet_lib.Dirichlet(conc1) + d2 = dirichlet_lib.Dirichlet(conc2) + x = d1.sample(int(1e4), seed=0) + kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) + kl_actual = kullback_leibler.kl_divergence(d1, d2) + + kl_sample_val = self.evaluate(kl_sample) + kl_actual_val = self.evaluate(kl_actual) + + self.assertEqual(conc1.shape[:-1], kl_actual.get_shape()) + + if not special: + return + + kl_expected = ( + special.gammaln(np.sum(conc1, -1)) + - special.gammaln(np.sum(conc2, -1)) + - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) + + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma( + np.sum(conc1, -1, keepdims=True))), -1)) + + self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) + self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) + + # Make sure KL(d1||d1) is 0 + kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) + self.assertAllClose(kl_same, np.zeros_like(kl_expected)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py index 7afdf0f947605c6b982e8bf7defdd6224180e089..850da3e9697ab5f087761e9988094a3015636c36 100644 --- a/tensorflow/python/kernel_tests/distributions/exponential_test.py +++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py @@ -23,7 +23,9 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import exponential as exponential_lib from tensorflow.python.platform import test @@ -42,6 +44,7 @@ def try_import(name): # pylint: disable=invalid-name stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class ExponentialTest(test.TestCase): def testExponentialLogPDF(self): @@ -61,8 +64,8 @@ class ExponentialTest(test.TestCase): if not stats: return expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) - self.assertAllClose(log_pdf.eval(), expected_log_pdf) - self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testExponentialCDF(self): with session.Session(): @@ -79,7 +82,7 @@ class ExponentialTest(test.TestCase): if not stats: return expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) - self.assertAllClose(cdf.eval(), expected_cdf) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testExponentialMean(self): with session.Session(): @@ -89,7 +92,7 @@ class ExponentialTest(test.TestCase): if not stats: return expected_mean = stats.expon.mean(scale=1 / lam_v) - self.assertAllClose(exponential.mean().eval(), expected_mean) + self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) def testExponentialVariance(self): with session.Session(): @@ -99,7 +102,8 @@ class ExponentialTest(test.TestCase): if not stats: return expected_variance = stats.expon.var(scale=1 / lam_v) - self.assertAllClose(exponential.variance().eval(), expected_variance) + self.assertAllClose( + self.evaluate(exponential.variance()), expected_variance) def testExponentialEntropy(self): with session.Session(): @@ -109,7 +113,8 @@ class ExponentialTest(test.TestCase): if not stats: return expected_entropy = stats.expon.entropy(scale=1 / lam_v) - self.assertAllClose(exponential.entropy().eval(), expected_entropy) + self.assertAllClose( + self.evaluate(exponential.entropy()), expected_entropy) def testExponentialSample(self): with self.test_session(): @@ -119,7 +124,7 @@ class ExponentialTest(test.TestCase): exponential = exponential_lib.Exponential(rate=lam) samples = exponential.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000, 2)) self.assertFalse(np.any(sample_values < 0.0)) if not stats: @@ -142,7 +147,7 @@ class ExponentialTest(test.TestCase): samples = exponential.sample(n, seed=138) self.assertEqual(samples.get_shape(), (n, batch_size, 2)) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertFalse(np.any(sample_values < 0.0)) if not stats: @@ -159,12 +164,21 @@ class ExponentialTest(test.TestCase): stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) + def testFullyReparameterized(self): + lam = constant_op.constant([0.1, 1.0]) + with backprop.GradientTape() as tape: + tape.watch(lam) + exponential = exponential_lib.Exponential(rate=lam) + samples = exponential.sample(100) + grad_lam = tape.gradient(samples, lam) + self.assertIsNotNone(grad_lam) + def testExponentialWithSoftplusRate(self): with self.test_session(): lam = [-2.2, -3.4] exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) - self.assertAllClose(nn_ops.softplus(lam).eval(), - exponential.rate.eval()) + self.assertAllClose( + self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py index 5e4813ac0762d2855d7fbe6754fe1466c29c06c9..297e20264c6d36f5b9098005393302337e3d1315 100644 --- a/tensorflow/python/kernel_tests/distributions/gamma_test.py +++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py @@ -21,9 +21,10 @@ import importlib import numpy as np -from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import gamma as gamma_lib @@ -45,6 +46,7 @@ special = try_import("scipy.special") stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class GammaTest(test.TestCase): def testGammaShape(self): @@ -53,9 +55,9 @@ class GammaTest(test.TestCase): beta = constant_op.constant(11.0) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - self.assertEqual(gamma.batch_shape_tensor().eval(), (5,)) + self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(gamma.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) def testGammaLogPDF(self): @@ -74,8 +76,8 @@ class GammaTest(test.TestCase): if not stats: return expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf.eval(), expected_log_pdf) - self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensional(self): with self.test_session(): @@ -87,10 +89,10 @@ class GammaTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) log_pdf = gamma.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = gamma.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: return @@ -108,10 +110,10 @@ class GammaTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) log_pdf = gamma.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = gamma.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: @@ -135,7 +137,7 @@ class GammaTest(test.TestCase): if not stats: return expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(cdf.eval(), expected_cdf) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testGammaMean(self): with self.test_session(): @@ -146,7 +148,7 @@ class GammaTest(test.TestCase): if not stats: return expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.mean().eval(), expected_means) + self.assertAllClose(self.evaluate(gamma.mean()), expected_means) def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): with self.test_session(): @@ -155,7 +157,7 @@ class GammaTest(test.TestCase): gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_modes = (alpha_v - 1) / beta_v self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(gamma.mode().eval(), expected_modes) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): with self.test_session(): @@ -166,7 +168,7 @@ class GammaTest(test.TestCase): rate=beta_v, allow_nan_stats=False) with self.assertRaisesOpError("x < y"): - gamma.mode().eval() + self.evaluate(gamma.mode()) def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): with self.test_session(): @@ -179,7 +181,7 @@ class GammaTest(test.TestCase): expected_modes = (alpha_v - 1) / beta_v expected_modes[0] = np.nan self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(gamma.mode().eval(), expected_modes) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaVariance(self): with self.test_session(): @@ -190,7 +192,7 @@ class GammaTest(test.TestCase): if not stats: return expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.variance().eval(), expected_variances) + self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) def testGammaStd(self): with self.test_session(): @@ -201,7 +203,7 @@ class GammaTest(test.TestCase): if not stats: return expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) - self.assertAllClose(gamma.stddev().eval(), expected_stddev) + self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) def testGammaEntropy(self): with self.test_session(): @@ -212,10 +214,10 @@ class GammaTest(test.TestCase): if not stats: return expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.entropy().eval(), expected_entropy) + self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) def testGammaSampleSmallAlpha(self): - with session.Session(): + with self.test_session(): alpha_v = 0.05 beta_v = 1.0 alpha = constant_op.constant(alpha_v) @@ -223,7 +225,7 @@ class GammaTest(test.TestCase): n = 100000 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) @@ -240,7 +242,7 @@ class GammaTest(test.TestCase): atol=.15) def testGammaSample(self): - with session.Session(): + with self.test_session(): alpha_v = 4.0 beta_v = 3.0 alpha = constant_op.constant(alpha_v) @@ -248,7 +250,7 @@ class GammaTest(test.TestCase): n = 100000 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) @@ -264,14 +266,26 @@ class GammaTest(test.TestCase): stats.gamma.var(alpha_v, scale=1 / beta_v), atol=.15) + def testGammaFullyReparameterized(self): + alpha = constant_op.constant(4.0) + beta = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(alpha) + tape.watch(beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(100) + grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta]) + self.assertIsNotNone(grad_alpha) + self.assertIsNotNone(grad_beta) + def testGammaSampleMultiDimensional(self): - with session.Session(): + with self.test_session(): alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) n = 10000 samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n, 10, 100)) self.assertEqual(sample_values.shape, (n, 10, 100)) zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 @@ -283,11 +297,11 @@ class GammaTest(test.TestCase): sample_values.mean(axis=0), stats.gamma.mean( alpha_bc, scale=1 / beta_bc), - rtol=.035) + atol=0., rtol=.05) self.assertAllClose( sample_values.var(axis=0), stats.gamma.var(alpha_bc, scale=1 / beta_bc), - atol=4.5) + atol=10.0, rtol=0.) fails = 0 trials = 0 for ai, a in enumerate(np.reshape(alpha_v, [-1])): @@ -306,12 +320,12 @@ class GammaTest(test.TestCase): return ks < 0.02 def testGammaPdfOfSampleMultiDims(self): - with session.Session() as sess: + with self.test_session(): gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) num = 50000 samples = gamma.sample(num, seed=137) pdfs = gamma.prob(samples) - sample_vals, pdf_vals = sess.run([samples, pdfs]) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2)) self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) @@ -345,18 +359,18 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = constant_op.constant(0.0, name="alpha") beta_v = constant_op.constant(1.0, name="beta") - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - with self.assertRaisesOpError("alpha"): - gamma.mean().eval() + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) + self.evaluate(gamma.mean()) alpha_v = constant_op.constant(1.0, name="alpha") beta_v = constant_op.constant(0.0, name="beta") - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - with self.assertRaisesOpError("beta"): - gamma.mean().eval() + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) + self.evaluate(gamma.mean()) def testGammaWithSoftplusConcentrationRate(self): with self.test_session(): @@ -364,10 +378,10 @@ class GammaTest(test.TestCase): beta_v = constant_op.constant([1.0, -3.6], name="beta") gamma = gamma_lib.GammaWithSoftplusConcentrationRate( concentration=alpha_v, rate=beta_v) - self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), - gamma.concentration.eval()) - self.assertAllEqual(nn_ops.softplus(beta_v).eval(), - gamma.rate.eval()) + self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)), + self.evaluate(gamma.concentration)) + self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)), + self.evaluate(gamma.rate)) def testGammaGammaKL(self): alpha0 = np.array([3.]) @@ -377,15 +391,15 @@ class GammaTest(test.TestCase): beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. - with self.test_session() as sess: + with self.test_session(): g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(g0, g1) - # Execute graph. - [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) + # Execute graph. + [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.get_shape()) @@ -399,7 +413,7 @@ class GammaTest(test.TestCase): + alpha0 * (beta1 / beta0 - 1.)) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) - self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2) + self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py index 55577386c450c7ac63f62c8a6dfd277af50e2387..24b243f647e495c47d57f914951263e3ee4ca7a5 100644 --- a/tensorflow/python/kernel_tests/distributions/laplace_test.py +++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py @@ -22,8 +22,10 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import laplace as laplace_lib from tensorflow.python.platform import test @@ -43,6 +45,7 @@ def try_import(name): # pylint: disable=invalid-name stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class LaplaceTest(test.TestCase): def testLaplaceShape(self): @@ -51,9 +54,9 @@ class LaplaceTest(test.TestCase): scale = constant_op.constant(11.0) laplace = laplace_lib.Laplace(loc=loc, scale=scale) - self.assertEqual(laplace.batch_shape_tensor().eval(), (5,)) + self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,)) self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(laplace.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), []) self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) def testLaplaceLogPDF(self): @@ -70,11 +73,11 @@ class LaplaceTest(test.TestCase): if not stats: return expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(log_pdf.eval(), expected_log_pdf) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) pdf = laplace.prob(x) self.assertEqual(pdf.get_shape(), (6,)) - self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testLaplaceLogPDFMultidimensional(self): with self.test_session(): @@ -86,11 +89,11 @@ class LaplaceTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T laplace = laplace_lib.Laplace(loc=loc, scale=scale) log_pdf = laplace.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = laplace.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: return @@ -108,11 +111,11 @@ class LaplaceTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T laplace = laplace_lib.Laplace(loc=loc, scale=scale) log_pdf = laplace.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = laplace.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: return @@ -136,7 +139,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v) - self.assertAllClose(cdf.eval(), expected_cdf) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testLaplaceLogCDF(self): with self.test_session(): @@ -154,7 +157,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v) - self.assertAllClose(cdf.eval(), expected_cdf) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testLaplaceLogSurvivalFunction(self): with self.test_session(): @@ -172,7 +175,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v) - self.assertAllClose(sf.eval(), expected_sf) + self.assertAllClose(self.evaluate(sf), expected_sf) def testLaplaceMean(self): with self.test_session(): @@ -183,7 +186,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_means = stats.laplace.mean(loc_v, scale=scale_v) - self.assertAllClose(laplace.mean().eval(), expected_means) + self.assertAllClose(self.evaluate(laplace.mean()), expected_means) def testLaplaceMode(self): with self.test_session(): @@ -191,7 +194,7 @@ class LaplaceTest(test.TestCase): scale_v = np.array([1.0, 4.0, 5.0]) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) self.assertEqual(laplace.mode().get_shape(), (3,)) - self.assertAllClose(laplace.mode().eval(), loc_v) + self.assertAllClose(self.evaluate(laplace.mode()), loc_v) def testLaplaceVariance(self): with self.test_session(): @@ -202,7 +205,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_variances = stats.laplace.var(loc_v, scale=scale_v) - self.assertAllClose(laplace.variance().eval(), expected_variances) + self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) def testLaplaceStd(self): with self.test_session(): @@ -213,7 +216,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_stddev = stats.laplace.std(loc_v, scale=scale_v) - self.assertAllClose(laplace.stddev().eval(), expected_stddev) + self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) def testLaplaceEntropy(self): with self.test_session(): @@ -224,7 +227,7 @@ class LaplaceTest(test.TestCase): if not stats: return expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v) - self.assertAllClose(laplace.entropy().eval(), expected_entropy) + self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) def testLaplaceSample(self): with session.Session(): @@ -235,7 +238,7 @@ class LaplaceTest(test.TestCase): n = 100000 laplace = laplace_lib.Laplace(loc=loc, scale=scale) samples = laplace.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) if not stats: @@ -253,6 +256,18 @@ class LaplaceTest(test.TestCase): atol=0.) self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) + def testLaplaceFullyReparameterized(self): + loc = constant_op.constant(4.0) + scale = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(loc) + tape.watch(scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + samples = laplace.sample(100) + grad_loc, grad_scale = tape.gradient(samples, [loc, scale]) + self.assertIsNotNone(grad_loc) + self.assertIsNotNone(grad_scale) + def testLaplaceSampleMultiDimensional(self): with session.Session(): loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 @@ -260,7 +275,7 @@ class LaplaceTest(test.TestCase): laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) n = 10000 samples = laplace.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n, 10, 100)) self.assertEqual(sample_values.shape, (n, 10, 100)) zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 @@ -297,32 +312,31 @@ class LaplaceTest(test.TestCase): return ks < 0.02 def testLaplacePdfOfSampleMultiDims(self): - with session.Session() as sess: - laplace = laplace_lib.Laplace(loc=[7., 11.], scale=[[5.], [6.]]) - num = 50000 - samples = laplace.sample(num, seed=137) - pdfs = laplace.prob(samples) - sample_vals, pdf_vals = sess.run([samples, pdfs]) - self.assertEqual(samples.get_shape(), (num, 2, 2)) - self.assertEqual(pdfs.get_shape(), (num, 2, 2)) - self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) - self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) - self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) - self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) - if not stats: - return - self.assertAllClose( - stats.laplace.mean( - [[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])), - sample_vals.mean(axis=0), - rtol=0.05, - atol=0.) - self.assertAllClose( - stats.laplace.var([[7., 11.], [7., 11.]], - scale=np.array([[5., 5.], [6., 6.]])), - sample_vals.var(axis=0), - rtol=0.05, - atol=0.) + laplace = laplace_lib.Laplace(loc=[7., 11.], scale=[[5.], [6.]]) + num = 50000 + samples = laplace.sample(num, seed=137) + pdfs = laplace.prob(samples) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) + self.assertEqual(samples.get_shape(), (num, 2, 2)) + self.assertEqual(pdfs.get_shape(), (num, 2, 2)) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) + if not stats: + return + self.assertAllClose( + stats.laplace.mean( + [[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])), + sample_vals.mean(axis=0), + rtol=0.05, + atol=0.) + self.assertAllClose( + stats.laplace.var([[7., 11.], [7., 11.]], + scale=np.array([[5., 5.], [6., 6.]])), + sample_vals.var(axis=0), + rtol=0.05, + atol=0.) def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): s_p = zip(sample_vals, pdf_vals) @@ -338,24 +352,27 @@ class LaplaceTest(test.TestCase): with self.test_session(): loc_v = constant_op.constant(0.0, name="loc") scale_v = constant_op.constant(-1.0, name="scale") - laplace = laplace_lib.Laplace( - loc=loc_v, scale=scale_v, validate_args=True) - with self.assertRaisesOpError("scale"): - laplace.mean().eval() + with self.assertRaisesOpError( + "Condition x > 0 did not hold element-wise"): + laplace = laplace_lib.Laplace( + loc=loc_v, scale=scale_v, validate_args=True) + self.evaluate(laplace.mean()) loc_v = constant_op.constant(1.0, name="loc") scale_v = constant_op.constant(0.0, name="scale") - laplace = laplace_lib.Laplace( - loc=loc_v, scale=scale_v, validate_args=True) - with self.assertRaisesOpError("scale"): - laplace.mean().eval() + with self.assertRaisesOpError( + "Condition x > 0 did not hold element-wise"): + laplace = laplace_lib.Laplace( + loc=loc_v, scale=scale_v, validate_args=True) + self.evaluate(laplace.mean()) def testLaplaceWithSoftplusScale(self): with self.test_session(): loc_v = constant_op.constant([0.0, 1.0], name="loc") scale_v = constant_op.constant([-1.0, 2.0], name="scale") laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) - self.assertAllClose(nn_ops.softplus(scale_v).eval(), laplace.scale.eval()) - self.assertAllClose(loc_v.eval(), laplace.loc.eval()) + self.assertAllClose( + self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale)) + self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py index e24e8ade73a7ad762c877214f5ec3ee0848863fe..bfd40ba2b7a5d32e957507b36d44e1198bd3867f 100644 --- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py @@ -18,6 +18,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import backprop +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -310,10 +312,10 @@ class MultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.10) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) def testSampleUnbiasedScalarBatch(self): with self.test_session() as sess: @@ -338,10 +340,24 @@ class MultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) self.assertAllEqual([4, 4], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.10) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) + + def testNotReparameterized(self): + total_count = constant_op.constant(5.0) + p = constant_op.constant([0.2, 0.6]) + with backprop.GradientTape() as tape: + tape.watch(total_count) + tape.watch(p) + dist = multinomial.Multinomial( + total_count=total_count, + probs=p) + samples = dist.sample(100) + grad_total_count, grad_p = tape.gradient(samples, [total_count, p]) + self.assertIsNone(grad_total_count) + self.assertIsNone(grad_p) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index 07c7d6d11d0f3bcecfd1029295d3249c3ea8584b..7ff48c0c10f4d2cd18072a22cdcef0fefc530eae 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -23,10 +23,12 @@ import math import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops @@ -54,7 +56,7 @@ class NormalTest(test.TestCase): self._rng = np.random.RandomState(123) def assertAllFinite(self, tensor): - is_finite = np.isfinite(tensor.eval()) + is_finite = np.isfinite(self.evaluate(tensor)) all_true = np.ones_like(is_finite, dtype=np.bool) self.assertAllEqual(all_true, is_finite) @@ -62,13 +64,13 @@ class NormalTest(test.TestCase): with self.test_session(): param_shapes = normal_lib.Normal.param_shapes(sample_shape) mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] - self.assertAllEqual(expected, mu_shape.eval()) - self.assertAllEqual(expected, sigma_shape.eval()) + self.assertAllEqual(expected, self.evaluate(mu_shape)) + self.assertAllEqual(expected, self.evaluate(sigma_shape)) mu = array_ops.zeros(mu_shape) sigma = array_ops.ones(sigma_shape) self.assertAllEqual( expected, - array_ops.shape(normal_lib.Normal(mu, sigma).sample()).eval()) + self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) def _testParamStaticShapes(self, sample_shape, expected): param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) @@ -76,25 +78,30 @@ class NormalTest(test.TestCase): self.assertEqual(expected, mu_shape) self.assertEqual(expected, sigma_shape) + @test_util.run_in_graph_and_eager_modes def testParamShapes(self): sample_shape = [10, 3, 4] self._testParamShapes(sample_shape, sample_shape) self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + @test_util.run_in_graph_and_eager_modes def testParamStaticShapes(self): sample_shape = [10, 3, 4] self._testParamStaticShapes(sample_shape, sample_shape) self._testParamStaticShapes( tensor_shape.TensorShape(sample_shape), sample_shape) + @test_util.run_in_graph_and_eager_modes def testNormalWithSoftplusScale(self): with self.test_session(): mu = array_ops.zeros((10, 3)) rho = array_ops.ones((10, 3)) * -2. normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) - self.assertAllEqual(mu.eval(), normal.loc.eval()) - self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.scale.eval()) + self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) + @test_util.run_in_graph_and_eager_modes def testNormalLogPDF(self): with self.test_session(): batch_size = 6 @@ -104,25 +111,31 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) log_pdf = normal.log_prob(x) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - log_pdf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(log_pdf).shape) self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) pdf = normal.prob(x) - self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(pdf).shape) self.assertAllEqual(normal.batch_shape, pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, pdf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) if not stats: return - expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + expected_log_pdf = stats.norm(self.evaluate(mu), + self.evaluate(sigma)).logpdf(x) + self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) + self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) + @test_util.run_in_graph_and_eager_modes def testNormalLogPDFMultidimensional(self): with self.test_session(): batch_size = 6 @@ -133,29 +146,34 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) log_pdf = normal.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - log_pdf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(log_pdf).shape) self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) pdf = normal.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) - self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf_values.shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) self.assertAllEqual(normal.batch_shape, pdf.get_shape()) self.assertAllEqual(normal.batch_shape, pdf_values.shape) if not stats: return - expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) + expected_log_pdf = stats.norm(self.evaluate(mu), + self.evaluate(sigma)).logpdf(x) self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + @test_util.run_in_graph_and_eager_modes def testNormalCDF(self): with self.test_session(): batch_size = 50 @@ -165,15 +183,19 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) cdf = normal.cdf(x) - self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(cdf).shape) self.assertAllEqual(normal.batch_shape, cdf.get_shape()) - self.assertAllEqual(normal.batch_shape, cdf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) if not stats: return expected_cdf = stats.norm(mu, sigma).cdf(x) - self.assertAllClose(expected_cdf, cdf.eval(), atol=0) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) + @test_util.run_in_graph_and_eager_modes def testNormalSurvivalFunction(self): with self.test_session(): batch_size = 50 @@ -184,15 +206,19 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) sf = normal.survival_function(x) - self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(sf).shape) self.assertAllEqual(normal.batch_shape, sf.get_shape()) - self.assertAllEqual(normal.batch_shape, sf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) if not stats: return expected_sf = stats.norm(mu, sigma).sf(x) - self.assertAllClose(expected_sf, sf.eval(), atol=0) + self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) + @test_util.run_in_graph_and_eager_modes def testNormalLogCDF(self): with self.test_session(): batch_size = 50 @@ -203,15 +229,18 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) cdf = normal.log_cdf(x) - self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(cdf).shape) self.assertAllEqual(normal.batch_shape, cdf.get_shape()) - self.assertAllEqual(normal.batch_shape, cdf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) if not stats: return expected_cdf = stats.norm(mu, sigma).logcdf(x) - self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: @@ -233,6 +262,7 @@ class NormalTest(test.TestCase): self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1]) + @test_util.run_in_graph_and_eager_modes def testNormalLogSurvivalFunction(self): with self.test_session(): batch_size = 50 @@ -243,16 +273,20 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) sf = normal.log_survival_function(x) - self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(sf).shape) self.assertAllEqual(normal.batch_shape, sf.get_shape()) - self.assertAllEqual(normal.batch_shape, sf.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) if not stats: return expected_sf = stats.norm(mu, sigma).logsf(x) - self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5) + self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) + @test_util.run_in_graph_and_eager_modes def testNormalEntropyWithScalarInputs(self): # Scipy.stats.norm cannot deal with the shapes in the other test. with self.test_session(): @@ -261,18 +295,20 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) entropy = normal.entropy() - self.assertAllEqual(normal.batch_shape_tensor().eval(), - entropy.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - entropy.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(entropy).shape) self.assertAllEqual(normal.batch_shape, entropy.get_shape()) - self.assertAllEqual(normal.batch_shape, entropy.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) # scipy.stats.norm cannot deal with these shapes. if not stats: return expected_entropy = stats.norm(mu_v, sigma_v).entropy() - self.assertAllClose(expected_entropy, entropy.eval()) + self.assertAllClose(expected_entropy, self.evaluate(entropy)) + @test_util.run_in_graph_and_eager_modes def testNormalEntropy(self): with self.test_session(): mu_v = np.array([1.0, 1.0, 1.0]) @@ -284,14 +320,16 @@ class NormalTest(test.TestCase): expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast** 2) entropy = normal.entropy() - np.testing.assert_allclose(expected_entropy, entropy.eval()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - entropy.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), - entropy.eval().shape) + np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(entropy).shape) self.assertAllEqual(normal.batch_shape, entropy.get_shape()) - self.assertAllEqual(normal.batch_shape, entropy.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) + @test_util.run_in_graph_and_eager_modes def testNormalMeanAndMode(self): with self.test_session(): # Mu will be broadcast to [7, 7, 7]. @@ -301,11 +339,12 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3,), normal.mean().get_shape()) - self.assertAllEqual([7., 7, 7], normal.mean().eval()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) self.assertAllEqual((3,), normal.mode().get_shape()) - self.assertAllEqual([7., 7, 7], normal.mode().eval()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) + @test_util.run_in_graph_and_eager_modes def testNormalQuantile(self): with self.test_session(): batch_size = 52 @@ -319,15 +358,18 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) x = normal.quantile(p) - self.assertAllEqual(normal.batch_shape_tensor().eval(), x.get_shape()) - self.assertAllEqual(normal.batch_shape_tensor().eval(), x.eval().shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), x.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(x).shape) self.assertAllEqual(normal.batch_shape, x.get_shape()) - self.assertAllEqual(normal.batch_shape, x.eval().shape) + self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) if not stats: return expected_x = stats.norm(mu, sigma).ppf(p) - self.assertAllClose(expected_x, x.eval(), atol=0.) + self.assertAllClose(expected_x, self.evaluate(x), atol=0.) def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): g = ops.Graph() @@ -354,6 +396,7 @@ class NormalTest(test.TestCase): def testQuantileFiniteGradientAtDifficultPointsFloat64(self): self._baseQuantileFiniteGradientAtDifficultPoints(np.float64) + @test_util.run_in_graph_and_eager_modes def testNormalVariance(self): with self.test_session(): # sigma will be broadcast to [7, 7, 7] @@ -363,8 +406,9 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3,), normal.variance().get_shape()) - self.assertAllEqual([49., 49, 49], normal.variance().eval()) + self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) + @test_util.run_in_graph_and_eager_modes def testNormalStandardDeviation(self): with self.test_session(): # sigma will be broadcast to [7, 7, 7] @@ -374,8 +418,9 @@ class NormalTest(test.TestCase): normal = normal_lib.Normal(loc=mu, scale=sigma) self.assertAllEqual((3,), normal.stddev().get_shape()) - self.assertAllEqual([7., 7, 7], normal.stddev().eval()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) + @test_util.run_in_graph_and_eager_modes def testNormalSample(self): with self.test_session(): mu = constant_op.constant(3.0) @@ -385,7 +430,7 @@ class NormalTest(test.TestCase): n = constant_op.constant(100000) normal = normal_lib.Normal(loc=mu, scale=sigma) samples = normal.sample(n) - sample_values = samples.eval() + sample_values = self.evaluate(samples) # Note that the standard error for the sample mean is ~ sigma / sqrt(n). # The sample variance similarly is dependent on sigma and n. # Thus, the tolerances below are very sensitive to number of samples @@ -394,18 +439,34 @@ class NormalTest(test.TestCase): self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) - expected_samples_shape = tensor_shape.TensorShape([n.eval()]).concatenate( - tensor_shape.TensorShape(normal.batch_shape_tensor().eval())) + expected_samples_shape = tensor_shape.TensorShape( + [self.evaluate(n)]).concatenate( + tensor_shape.TensorShape( + self.evaluate(normal.batch_shape_tensor()))) self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) - expected_samples_shape = (tensor_shape.TensorShape( - [n.eval()]).concatenate(normal.batch_shape)) + expected_samples_shape = ( + tensor_shape.TensorShape([self.evaluate(n)]).concatenate( + normal.batch_shape)) self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) + def testNormalFullyReparameterized(self): + mu = constant_op.constant(4.0) + sigma = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(mu) + tape.watch(sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(100) + grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma]) + self.assertIsNotNone(grad_mu) + self.assertIsNotNone(grad_sigma) + + @test_util.run_in_graph_and_eager_modes def testNormalSampleMultiDimensional(self): with self.test_session(): batch_size = 2 @@ -417,7 +478,7 @@ class NormalTest(test.TestCase): n = constant_op.constant(100000) normal = normal_lib.Normal(loc=mu, scale=sigma) samples = normal.sample(n) - sample_values = samples.eval() + sample_values = self.evaluate(samples) # Note that the standard error for the sample mean is ~ sigma / sqrt(n). # The sample variance similarly is dependent on sigma and n. # Thus, the tolerances below are very sensitive to number of samples @@ -428,32 +489,37 @@ class NormalTest(test.TestCase): self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) - expected_samples_shape = tensor_shape.TensorShape([n.eval()]).concatenate( - tensor_shape.TensorShape(normal.batch_shape_tensor().eval())) + expected_samples_shape = tensor_shape.TensorShape( + [self.evaluate(n)]).concatenate( + tensor_shape.TensorShape( + self.evaluate(normal.batch_shape_tensor()))) self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) - expected_samples_shape = (tensor_shape.TensorShape( - [n.eval()]).concatenate(normal.batch_shape)) + expected_samples_shape = ( + tensor_shape.TensorShape([self.evaluate(n)]).concatenate( + normal.batch_shape)) self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) + @test_util.run_in_graph_and_eager_modes def testNegativeSigmaFails(self): with self.test_session(): - normal = normal_lib.Normal( - loc=[1.], scale=[-5.], validate_args=True, name="G") with self.assertRaisesOpError("Condition x > 0 did not hold"): - normal.mean().eval() + normal = normal_lib.Normal( + loc=[1.], scale=[-5.], validate_args=True, name="G") + self.evaluate(normal.mean()) + @test_util.run_in_graph_and_eager_modes def testNormalShape(self): with self.test_session(): mu = constant_op.constant([-3.0] * 5) sigma = constant_op.constant(11.0) normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertEqual(normal.batch_shape_tensor().eval(), [5]) + self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(normal.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) def testNormalShapeWithPlaceholders(self): @@ -465,31 +531,31 @@ class NormalTest(test.TestCase): # get_batch_shape should return an "" tensor. self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) self.assertEqual(normal.event_shape, ()) - self.assertAllEqual(normal.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) self.assertAllEqual( sess.run(normal.batch_shape_tensor(), feed_dict={mu: 5.0, sigma: [1.0, 2.0]}), [2]) + @test_util.run_in_graph_and_eager_modes def testNormalNormalKL(self): - with self.test_session() as sess: - batch_size = 6 - mu_a = np.array([3.0] * batch_size) - sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) - mu_b = np.array([-3.0] * batch_size) - sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) + batch_size = 6 + mu_a = np.array([3.0] * batch_size) + sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) + mu_b = np.array([-3.0] * batch_size) + sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) - n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) - n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) + n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) + n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) - kl = kullback_leibler.kl_divergence(n_a, n_b) - kl_val = sess.run(kl) + kl = kullback_leibler.kl_divergence(n_a, n_b) + kl_val = self.evaluate(kl) - kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( - (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) + kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( + (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) - self.assertEqual(kl.get_shape(), (batch_size,)) - self.assertAllClose(kl_val, kl_expected) + self.assertEqual(kl.get_shape(), (batch_size,)) + self.assertAllClose(kl_val, kl_expected) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index d5d50a180a1df6d7c56635b6b18509fbabc06d4d..a634194ce5293f4d7e7a68aa661080ed06493297 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -89,7 +89,7 @@ class NdtriTest(test.TestCase): all_true = np.ones_like(is_finite, dtype=np.bool) self.assertAllEqual(all_true, is_finite) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtri(self): """Verifies that ndtri computation is correct.""" with self.test_session(): @@ -138,15 +138,16 @@ class NdtriTest(test.TestCase): lambda x: special_math.ndtri(x), p) # pylint: disable=unnecessary-lambda self.assertAllFinite(self.evaluate(grads[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtriFiniteGradientFloat32(self): self._baseNdtriFiniteGradientTest(np.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtriFiniteGradientFloat64(self): self._baseNdtriFiniteGradientTest(np.float64) +@test_util.run_all_in_graph_and_eager_modes class NdtrTest(test.TestCase): _use_log = False # Grid min/max chosen to ensure 0 < cdf(x) < 1. diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py index f1150de58e0dae5da25f74f95fb391c340a01262..05590542efe2623e608f783233db68240331ba20 100644 --- a/tensorflow/python/kernel_tests/distributions/student_t_test.py +++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py @@ -23,8 +23,10 @@ import math import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import student_t @@ -44,6 +46,7 @@ def try_import(name): # pylint: disable=invalid-name stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class StudentTTest(test.TestCase): def testStudentPDFAndLogPDF(self): @@ -60,10 +63,10 @@ class StudentTTest(test.TestCase): log_pdf = student.log_prob(t) self.assertEquals(log_pdf.get_shape(), (6,)) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) pdf = student.prob(t) self.assertEquals(pdf.get_shape(), (6,)) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) if not stats: return @@ -88,10 +91,10 @@ class StudentTTest(test.TestCase): t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T student = student_t.StudentT(df, loc=mu, scale=sigma) log_pdf = student.log_prob(t) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = student.prob(t) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: @@ -117,10 +120,10 @@ class StudentTTest(test.TestCase): log_cdf = student.log_cdf(t) self.assertEquals(log_cdf.get_shape(), (6,)) - log_cdf_values = log_cdf.eval() + log_cdf_values = self.evaluate(log_cdf) cdf = student.cdf(t) self.assertEquals(cdf.get_shape(), (6,)) - cdf_values = cdf.eval() + cdf_values = self.evaluate(cdf) if not stats: return @@ -140,7 +143,7 @@ class StudentTTest(test.TestCase): with self.test_session(): student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) ent = student.entropy() - ent_values = ent.eval() + ent_values = self.evaluate(ent) # Help scipy broadcast to 3x3 ones = np.array([[1, 1, 1]]) @@ -167,14 +170,14 @@ class StudentTTest(test.TestCase): n = constant_op.constant(200000) student = student_t.StudentT(df=df, loc=mu, scale=sigma) samples = student.sample(n, seed=123456) - sample_values = samples.eval() + sample_values = self.evaluate(samples) n_val = 200000 self.assertEqual(sample_values.shape, (n_val,)) - self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0) + self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) self.assertAllClose( sample_values.var(), sigma_v**2 * df_v / (df_v - 2), - rtol=1e-2, + rtol=0.1, atol=0) self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) @@ -189,12 +192,12 @@ class StudentTTest(test.TestCase): random_seed.set_random_seed(654321) student = student_t.StudentT( df=df, loc=mu, scale=sigma, name="student_t1") - samples1 = student.sample(n, seed=123456).eval() + samples1 = self.evaluate(student.sample(n, seed=123456)) random_seed.set_random_seed(654321) student2 = student_t.StudentT( df=df, loc=mu, scale=sigma, name="student_t2") - samples2 = student2.sample(n, seed=123456).eval() + samples2 = self.evaluate(student2.sample(n, seed=123456)) self.assertAllClose(samples1, samples2) @@ -205,7 +208,7 @@ class StudentTTest(test.TestCase): n = constant_op.constant(200000) student = student_t.StudentT(df=df, loc=1., scale=1.) samples = student.sample(n, seed=123456) - sample_values = samples.eval() + sample_values = self.evaluate(samples) n_val = 200000 self.assertEqual(sample_values.shape, (n_val, 4)) self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) @@ -213,34 +216,34 @@ class StudentTTest(test.TestCase): def testStudentSampleMultiDimensional(self): with self.test_session(): batch_size = 7 - df = constant_op.constant([[3., 7.]] * batch_size) + df = constant_op.constant([[5., 7.]] * batch_size) mu = constant_op.constant([[3., -3.]] * batch_size) sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * batch_size) - df_v = [3., 7.] + df_v = [5., 7.] mu_v = [3., -3.] sigma_v = [np.sqrt(10.), np.sqrt(15.)] n = constant_op.constant(200000) student = student_t.StudentT(df=df, loc=mu, scale=sigma) samples = student.sample(n, seed=123456) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) self.assertAllClose( - sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0) + sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) self.assertAllClose( sample_values[:, 0, 0].var(), sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), - rtol=1e-1, + rtol=0.2, atol=0) self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) self.assertAllClose( - sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0) + sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) self.assertAllClose( sample_values[:, 0, 1].var(), sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), - rtol=1e-1, + rtol=0.2, atol=0) - self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1]) + self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) def _checkKLApprox(self, df, mu, sigma, samples): n = samples.size @@ -270,7 +273,7 @@ class StudentTTest(test.TestCase): self.assertEqual(student.entropy().get_shape(), (3,)) self.assertEqual(student.log_prob(2.).get_shape(), (3,)) self.assertEqual(student.prob(2.).get_shape(), (3,)) - self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) + self.assertEqual(student.sample(37).get_shape(), (37, 3,)) _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) @@ -325,7 +328,7 @@ class StudentTTest(test.TestCase): with self.test_session(): mu = [1., 3.3, 4.4] student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) - mean = student.mean().eval() + mean = self.evaluate(student.mean()) self.assertAllClose([1., 3.3, 4.4], mean) def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): @@ -335,7 +338,7 @@ class StudentTTest(test.TestCase): df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False) with self.assertRaisesOpError("x < y"): - student.mean().eval() + self.evaluate(student.mean()) def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): with self.test_session(): @@ -344,7 +347,7 @@ class StudentTTest(test.TestCase): student = student_t.StudentT( df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True) - mean = student.mean().eval() + mean = self.evaluate(student.mean()) self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): @@ -356,7 +359,7 @@ class StudentTTest(test.TestCase): sigma = [5., 4., 3., 2., 1.] student = student_t.StudentT( df=df, loc=mu, scale=sigma, allow_nan_stats=True) - var = student.variance().eval() + var = self.evaluate(student.variance()) ## scipy uses inf for variance when the mean is undefined. When mean is # undefined we say variance is undefined as well. So test the first # member of var, making sure it is NaN, then replace with inf and compare @@ -379,7 +382,7 @@ class StudentTTest(test.TestCase): mu = [0., 1., 3.3, 4.4] sigma = [4., 3., 2., 1.] student = student_t.StudentT(df=df, loc=mu, scale=sigma) - var = student.variance().eval() + var = self.evaluate(student.variance()) if not stats: return @@ -394,14 +397,14 @@ class StudentTTest(test.TestCase): student = student_t.StudentT( df=1., loc=0., scale=1., allow_nan_stats=False) with self.assertRaisesOpError("x < y"): - student.variance().eval() + self.evaluate(student.variance()) with self.test_session(): # df <= 1 ==> variance not defined student = student_t.StudentT( df=0.5, loc=0., scale=1., allow_nan_stats=False) with self.assertRaisesOpError("x < y"): - student.variance().eval() + self.evaluate(student.variance()) def testStd(self): with self.test_session(): @@ -411,7 +414,7 @@ class StudentTTest(test.TestCase): sigma = [5., 4., 3., 2., 1.] student = student_t.StudentT(df=df, loc=mu, scale=sigma) # Test broadcast of mu across shape of df/sigma - stddev = student.stddev().eval() + stddev = self.evaluate(student.stddev()) mu *= len(df) if not stats: @@ -428,59 +431,73 @@ class StudentTTest(test.TestCase): sigma = [5., 4., 3.] student = student_t.StudentT(df=df, loc=mu, scale=sigma) # Test broadcast of mu across shape of df/sigma - mode = student.mode().eval() + mode = self.evaluate(student.mode()) self.assertAllClose([-1., 0, 1], mode) def testPdfOfSample(self): - with self.test_session() as sess: - student = student_t.StudentT(df=3., loc=np.pi, scale=1.) - num = 20000 - samples = student.sample(num, seed=123456) - pdfs = student.prob(samples) - mean = student.mean() - mean_pdf = student.prob(student.mean()) - sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run( - [samples, pdfs, student.mean(), mean_pdf]) - self.assertEqual(samples.get_shape(), (num,)) - self.assertEqual(pdfs.get_shape(), (num,)) - self.assertEqual(mean.get_shape(), ()) - self.assertNear(np.pi, np.mean(sample_vals), err=0.02) - self.assertNear(np.pi, mean_val, err=1e-6) - # Verify integral over sample*pdf ~= 1. - self._assertIntegral(sample_vals, pdf_vals, err=2e-3) - if not stats: - return - self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) + student = student_t.StudentT(df=3., loc=np.pi, scale=1.) + num = 20000 + samples = student.sample(num, seed=123456) + pdfs = student.prob(samples) + mean = student.mean() + mean_pdf = student.prob(student.mean()) + sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate( + [samples, pdfs, student.mean(), mean_pdf]) + self.assertEqual(samples.get_shape(), (num,)) + self.assertEqual(pdfs.get_shape(), (num,)) + self.assertEqual(mean.get_shape(), ()) + self.assertNear(np.pi, np.mean(sample_vals), err=0.1) + self.assertNear(np.pi, mean_val, err=1e-6) + # Verify integral over sample*pdf ~= 1. + # Tolerance increased since eager was getting a value of 1.002041. + self._assertIntegral(sample_vals, pdf_vals, err=5e-2) + if not stats: + return + self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) + + def testFullyReparameterized(self): + df = constant_op.constant(2.0) + mu = constant_op.constant(1.0) + sigma = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(df) + tape.watch(mu) + tape.watch(sigma) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(100) + grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma]) + self.assertIsNotNone(grad_df) + self.assertIsNotNone(grad_mu) + self.assertIsNotNone(grad_sigma) def testPdfOfSampleMultiDims(self): - with self.test_session() as sess: - student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) - self.assertAllEqual([], student.event_shape) - self.assertAllEqual([], student.event_shape_tensor().eval()) - self.assertAllEqual([2, 2], student.batch_shape) - self.assertAllEqual([2, 2], student.batch_shape_tensor().eval()) - num = 50000 - samples = student.sample(num, seed=123456) - pdfs = student.prob(samples) - sample_vals, pdf_vals = sess.run([samples, pdfs]) - self.assertEqual(samples.get_shape(), (num, 2, 2)) - self.assertEqual(pdfs.get_shape(), (num, 2, 2)) - self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03) - self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03) - self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) - self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) - self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) - self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) - if not stats: - return - self.assertNear( - stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var - np.var(sample_vals[:, :, 0]), - err=.4) - self.assertNear( - stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var - np.var(sample_vals[:, :, 1]), - err=.4) + student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) + self.assertAllEqual([], student.event_shape) + self.assertAllEqual([], self.evaluate(student.event_shape_tensor())) + self.assertAllEqual([2, 2], student.batch_shape) + self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor())) + num = 50000 + samples = student.sample(num, seed=123456) + pdfs = student.prob(samples) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) + self.assertEqual(samples.get_shape(), (num, 2, 2)) + self.assertEqual(pdfs.get_shape(), (num, 2, 2)) + self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1) + self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05) + if not stats: + return + self.assertNear( + stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var + np.var(sample_vals[:, :, 0]), + err=1.0) + self.assertNear( + stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var + np.var(sample_vals[:, :, 1]), + err=1.0) def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): s_p = zip(sample_vals, pdf_vals) @@ -494,10 +511,10 @@ class StudentTTest(test.TestCase): def testNegativeDofFails(self): with self.test_session(): - student = student_t.StudentT(df=[2, -5.], loc=0., scale=1., - validate_args=True, name="S") with self.assertRaisesOpError(r"Condition x > 0 did not hold"): - student.mean().eval() + student = student_t.StudentT( + df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") + self.evaluate(student.mean()) def testStudentTWithAbsDfSoftplusScale(self): with self.test_session(): @@ -507,9 +524,11 @@ class StudentTTest(test.TestCase): student = student_t.StudentTWithAbsDfSoftplusScale( df=df, loc=mu, scale=sigma) self.assertAllClose( - math_ops.floor(math_ops.abs(df)).eval(), student.df.eval()) - self.assertAllClose(mu.eval(), student.loc.eval()) - self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval()) + math_ops.floor(self.evaluate(math_ops.abs(df))), + self.evaluate(student.df)) + self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) + self.assertAllClose( + self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index a8def95b147b6dd4825675769187733b8493b374..bc9c267b9a5eac6fd8c9c4290dcc4b56865ddb50 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -22,9 +22,11 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import uniform as uniform_lib @@ -46,15 +48,17 @@ stats = try_import("scipy.stats") class UniformTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def testUniformRange(self): with self.test_session(): a = 3.0 b = 10.0 uniform = uniform_lib.Uniform(low=a, high=b) - self.assertAllClose(a, uniform.low.eval()) - self.assertAllClose(b, uniform.high.eval()) - self.assertAllClose(b - a, uniform.range().eval()) + self.assertAllClose(a, self.evaluate(uniform.low)) + self.assertAllClose(b, self.evaluate(uniform.high)) + self.assertAllClose(b - a, self.evaluate(uniform.range())) + @test_util.run_in_graph_and_eager_modes def testUniformPDF(self): with self.test_session(): a = constant_op.constant([-3.0] * 5 + [15.0]) @@ -75,22 +79,24 @@ class UniformTest(test.TestCase): expected_pdf = _expected_pdf() pdf = uniform.prob(x) - self.assertAllClose(expected_pdf, pdf.eval()) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) log_pdf = uniform.log_prob(x) - self.assertAllClose(np.log(expected_pdf), log_pdf.eval()) + self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) + @test_util.run_in_graph_and_eager_modes def testUniformShape(self): with self.test_session(): a = constant_op.constant([-3.0] * 5) b = constant_op.constant(11.0) uniform = uniform_lib.Uniform(low=a, high=b) - self.assertEqual(uniform.batch_shape_tensor().eval(), (5,)) + self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,)) self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(uniform.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) + @test_util.run_in_graph_and_eager_modes def testUniformPDFWithScalarEndpoint(self): with self.test_session(): a = constant_op.constant([0.0, 5.0]) @@ -101,8 +107,9 @@ class UniformTest(test.TestCase): expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)]) pdf = uniform.prob(x) - self.assertAllClose(expected_pdf, pdf.eval()) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) + @test_util.run_in_graph_and_eager_modes def testUniformCDF(self): with self.test_session(): batch_size = 6 @@ -121,11 +128,12 @@ class UniformTest(test.TestCase): return cdf cdf = uniform.cdf(x) - self.assertAllClose(_expected_cdf(), cdf.eval()) + self.assertAllClose(_expected_cdf(), self.evaluate(cdf)) log_cdf = uniform.log_cdf(x) - self.assertAllClose(np.log(_expected_cdf()), log_cdf.eval()) + self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) + @test_util.run_in_graph_and_eager_modes def testUniformEntropy(self): with self.test_session(): a_v = np.array([1.0, 1.0, 1.0]) @@ -133,18 +141,20 @@ class UniformTest(test.TestCase): uniform = uniform_lib.Uniform(low=a_v, high=b_v) expected_entropy = np.log(b_v - a_v) - self.assertAllClose(expected_entropy, uniform.entropy().eval()) + self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) + @test_util.run_in_graph_and_eager_modes def testUniformAssertMaxGtMin(self): with self.test_session(): a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) - uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, "x < y"): - uniform.low.eval() + uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) + self.evaluate(uniform.low) + @test_util.run_in_graph_and_eager_modes def testUniformSample(self): with self.test_session(): a = constant_op.constant([3.0, 4.0]) @@ -156,17 +166,18 @@ class UniformTest(test.TestCase): uniform = uniform_lib.Uniform(low=a, high=b) samples = uniform.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000, 2)) self.assertAllClose( - sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-2) + sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) self.assertAllClose( - sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-2) + sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) self.assertFalse( np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) self.assertFalse( np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) + @test_util.run_in_graph_and_eager_modes def _testUniformSampleMultiDimensional(self): # DISABLED: Please enable this test once b/issues/30149644 is resolved. with self.test_session(): @@ -183,7 +194,7 @@ class UniformTest(test.TestCase): samples = uniform.sample(n) self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertFalse( np.any(sample_values[:, 0, 0] < a_v[0]) or @@ -197,6 +208,7 @@ class UniformTest(test.TestCase): self.assertAllClose( sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) + @test_util.run_in_graph_and_eager_modes def testUniformMean(self): with self.test_session(): a = 10.0 @@ -205,8 +217,9 @@ class UniformTest(test.TestCase): if not stats: return s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(uniform.mean().eval(), s_uniform.mean()) + self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) + @test_util.run_in_graph_and_eager_modes def testUniformVariance(self): with self.test_session(): a = 10.0 @@ -215,8 +228,9 @@ class UniformTest(test.TestCase): if not stats: return s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(uniform.variance().eval(), s_uniform.var()) + self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) + @test_util.run_in_graph_and_eager_modes def testUniformStd(self): with self.test_session(): a = 10.0 @@ -225,8 +239,9 @@ class UniformTest(test.TestCase): if not stats: return s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(uniform.stddev().eval(), s_uniform.std()) + self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) + @test_util.run_in_graph_and_eager_modes def testUniformNans(self): with self.test_session(): a = 10.0 @@ -235,23 +250,26 @@ class UniformTest(test.TestCase): no_nans = constant_op.constant(1.0) nans = constant_op.constant(0.0) / constant_op.constant(0.0) - self.assertTrue(math_ops.is_nan(nans).eval()) + self.assertTrue(self.evaluate(math_ops.is_nan(nans))) with_nans = array_ops.stack([no_nans, nans]) pdf = uniform.prob(with_nans) - is_nan = math_ops.is_nan(pdf).eval() + is_nan = self.evaluate(math_ops.is_nan(pdf)) self.assertFalse(is_nan[0]) self.assertTrue(is_nan[1]) + @test_util.run_in_graph_and_eager_modes def testUniformSamplePdf(self): with self.test_session(): a = 10.0 b = [11.0, 100.0] uniform = uniform_lib.Uniform(a, b) self.assertTrue( - math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0).eval()) + self.evaluate( + math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) + @test_util.run_in_graph_and_eager_modes def testUniformBroadcasting(self): with self.test_session(): a = 10.0 @@ -260,8 +278,9 @@ class UniformTest(test.TestCase): pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]]) expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) - self.assertAllClose(expected_pdf, pdf.eval()) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) + @test_util.run_in_graph_and_eager_modes def testUniformSampleWithShape(self): with self.test_session(): a = 10.0 @@ -275,12 +294,25 @@ class UniformTest(test.TestCase): [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], ] # pylint: enable=bad-continuation - self.assertAllClose(expected_pdf, pdf.eval()) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) pdf = uniform.prob(uniform.sample()) expected_pdf = [1.0, 0.1] - self.assertAllClose(expected_pdf, pdf.eval()) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) + + def testFullyReparameterized(self): + a = constant_op.constant(0.1) + b = constant_op.constant(0.8) + with backprop.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + uniform = uniform_lib.Uniform(a, b) + samples = uniform.sample(100) + grad_a, grad_b = tape.gradient(samples, [a, b]) + self.assertIsNotNone(grad_a) + self.assertIsNotNone(grad_b) + # Eager doesn't pass due to a type mismatch in one of the ops. def testUniformFloat64(self): uniform = uniform_lib.Uniform( low=np.float64(0.), high=np.float64(1.)) diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 94c5b9b1d58d82feefed5eafa7611bf0ebc07a24..9d38ffcb4a963efb71153f59d6269ba84a5d1379 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -22,9 +22,11 @@ import importlib import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -57,64 +59,6 @@ def _logit(x): class AssertCloseTest(test.TestCase): - def testAssertCloseIntegerDtype(self): - x = array_ops.placeholder(dtypes.int32) - y = x - z = array_ops.placeholder(dtypes.int32) - feed_dict = {x: [1, 5, 10, 15, 20], z: [2, 5, 10, 15, 20]} - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with ops.control_dependencies([du.assert_close(y, x)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval(feed_dict=feed_dict) - - def testAssertCloseNonIntegerDtype(self): - x = array_ops.placeholder(dtypes.float32) - y = x + 1e-8 - z = array_ops.placeholder(dtypes.float32) - feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]} - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with ops.control_dependencies([du.assert_close(y, x)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval(feed_dict=feed_dict) - - def testAssertCloseEpsilon(self): - x = [0., 5, 10, 15, 20] - # x != y - y = [0.1, 5, 10, 15, 20] - # x = z - z = [1e-8, 5, 10, 15, 20] - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval() - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval() - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval() - def testAssertIntegerForm(self): # This should only be detected as an integer. x = array_ops.placeholder(dtypes.float32) @@ -147,18 +91,21 @@ class AssertCloseTest(test.TestCase): class MaybeGetStaticTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def testGetStaticInt(self): x = 2 self.assertEqual(x, du.maybe_get_static_value(x)) self.assertAllClose( np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + @test_util.run_in_graph_and_eager_modes def testGetStaticNumpyArray(self): x = np.array(2, dtype=np.int32) self.assertEqual(x, du.maybe_get_static_value(x)) self.assertAllClose( np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + @test_util.run_in_graph_and_eager_modes def testGetStaticConstant(self): x = constant_op.constant(2, dtype=dtypes.int32) self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x)) @@ -173,6 +120,7 @@ class MaybeGetStaticTest(test.TestCase): class GetLogitsAndProbsTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def testImproperArguments(self): with self.test_session(): with self.assertRaises(ValueError): @@ -181,6 +129,7 @@ class GetLogitsAndProbsTest(test.TestCase): with self.assertRaises(ValueError): du.get_logits_and_probs(logits=[0.1], probs=[0.1]) + @test_util.run_in_graph_and_eager_modes def testLogits(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) logits = _logit(p) @@ -189,9 +138,10 @@ class GetLogitsAndProbsTest(test.TestCase): new_logits, new_p = du.get_logits_and_probs( logits=logits, validate_args=True) - self.assertAllClose(p, new_p.eval(), rtol=1e-5, atol=0.) - self.assertAllClose(logits, new_logits.eval(), rtol=1e-5, atol=0.) + self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) + self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) + @test_util.run_in_graph_and_eager_modes def testLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) @@ -200,9 +150,10 @@ class GetLogitsAndProbsTest(test.TestCase): new_logits, new_p = du.get_logits_and_probs( logits=logits, multidimensional=True, validate_args=True) - self.assertAllClose(new_p.eval(), p) - self.assertAllClose(new_logits.eval(), logits) + self.assertAllClose(self.evaluate(new_p), p) + self.assertAllClose(self.evaluate(new_logits), logits) + @test_util.run_in_graph_and_eager_modes def testProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) @@ -210,9 +161,10 @@ class GetLogitsAndProbsTest(test.TestCase): new_logits, new_p = du.get_logits_and_probs( probs=p, validate_args=True) - self.assertAllClose(_logit(p), new_logits.eval()) - self.assertAllClose(p, new_p.eval()) + self.assertAllClose(_logit(p), self.evaluate(new_logits)) + self.assertAllClose(p, self.evaluate(new_p)) + @test_util.run_in_graph_and_eager_modes def testProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) @@ -220,9 +172,10 @@ class GetLogitsAndProbsTest(test.TestCase): new_logits, new_p = du.get_logits_and_probs( probs=p, multidimensional=True, validate_args=True) - self.assertAllClose(np.log(p), new_logits.eval()) - self.assertAllClose(p, new_p.eval()) + self.assertAllClose(np.log(p), self.evaluate(new_logits)) + self.assertAllClose(p, self.evaluate(new_p)) + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgs(self): p = [0.01, 0.2, 0.5, 0.7, .99] # Component less than 0. @@ -233,26 +186,27 @@ class GetLogitsAndProbsTest(test.TestCase): with self.test_session(): _, prob = du.get_logits_and_probs( probs=p, validate_args=True) - prob.eval() + self.evaluate(prob) with self.assertRaisesOpError("Condition x >= 0"): _, prob = du.get_logits_and_probs( probs=p2, validate_args=True) - prob.eval() + self.evaluate(prob) _, prob = du.get_logits_and_probs( probs=p2, validate_args=False) - prob.eval() + self.evaluate(prob) with self.assertRaisesOpError("probs has components greater than 1"): _, prob = du.get_logits_and_probs( probs=p3, validate_args=True) - prob.eval() + self.evaluate(prob) _, prob = du.get_logits_and_probs( probs=p3, validate_args=False) - prob.eval() + self.evaluate(prob) + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) # Component less than 0. Still sums to 1. @@ -265,35 +219,35 @@ class GetLogitsAndProbsTest(test.TestCase): with self.test_session(): _, prob = du.get_logits_and_probs( probs=p, multidimensional=True) - prob.eval() + self.evaluate(prob) with self.assertRaisesOpError("Condition x >= 0"): _, prob = du.get_logits_and_probs( probs=p2, multidimensional=True, validate_args=True) - prob.eval() + self.evaluate(prob) _, prob = du.get_logits_and_probs( probs=p2, multidimensional=True, validate_args=False) - prob.eval() + self.evaluate(prob) with self.assertRaisesOpError( "(probs has components greater than 1|probs does not sum to 1)"): _, prob = du.get_logits_and_probs( probs=p3, multidimensional=True, validate_args=True) - prob.eval() + self.evaluate(prob) _, prob = du.get_logits_and_probs( probs=p3, multidimensional=True, validate_args=False) - prob.eval() + self.evaluate(prob) with self.assertRaisesOpError("probs does not sum to 1"): _, prob = du.get_logits_and_probs( probs=p4, multidimensional=True, validate_args=True) - prob.eval() + self.evaluate(prob) _, prob = du.get_logits_and_probs( probs=p4, multidimensional=True, validate_args=False) - prob.eval() + self.evaluate(prob) def testProbsMultidimShape(self): with self.test_session(): @@ -354,6 +308,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): param) checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])}) + @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): with self.test_session(): with self.assertRaises(TypeError): @@ -396,6 +351,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)}) +@test_util.run_all_in_graph_and_eager_modes class LogCombinationsTest(test.TestCase): def testLogCombinationsBinomial(self): @@ -412,7 +368,7 @@ class LogCombinationsTest(test.TestCase): counts = [[1., 1], [2., 3], [4., 8], [11, 4]] log_binom = du.log_combinations(n, counts) self.assertEqual([4], log_binom.get_shape()) - self.assertAllClose(log_combs, log_binom.eval()) + self.assertAllClose(log_combs, self.evaluate(log_binom)) def testLogCombinationsShape(self): # Shape [2, 2] @@ -537,14 +493,20 @@ class RotateTransposeTest(test.TestCase): x = np.array(x) return np.transpose(x, np.roll(np.arange(len(x.shape)), shift)) + @test_util.run_in_graph_and_eager_modes def testRollStatic(self): with self.test_session(): - with self.assertRaisesRegexp(ValueError, "None values not supported."): + if context.executing_eagerly(): + error_message = r"Attempt to convert a value \(None\)" + else: + error_message = "None values not supported." + with self.assertRaisesRegexp(ValueError, error_message): du.rotate_transpose(None, 1) for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): for shift in np.arange(-5, 5): y = du.rotate_transpose(x, shift) - self.assertAllEqual(self._np_rotate_transpose(x, shift), y.eval()) + self.assertAllEqual( + self._np_rotate_transpose(x, shift), self.evaluate(y)) self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list()) def testRollDynamic(self): @@ -569,12 +531,10 @@ class PickVectorTest(test.TestCase): with self.test_session(): x = np.arange(10, 12) y = np.arange(15, 18) - self.assertAllEqual(x, - du.pick_vector( - math_ops.less(0, 5), x, y).eval()) - self.assertAllEqual(y, - du.pick_vector( - math_ops.less(5, 0), x, y).eval()) + self.assertAllEqual( + x, self.evaluate(du.pick_vector(math_ops.less(0, 5), x, y))) + self.assertAllEqual( + y, self.evaluate(du.pick_vector(math_ops.less(5, 0), x, y))) self.assertAllEqual(x, du.pick_vector( constant_op.constant(True), x, y)) # No eval. @@ -795,6 +755,30 @@ class FillTriangularTest(test.TestCase): self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True) +class FillTriangularInverseTest(FillTriangularTest): + + def _run_test(self, x_, use_deferred_shape=False, **kwargs): + x_ = np.asarray(x_) + with self.test_session() as sess: + static_shape = None if use_deferred_shape else x_.shape + x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) + zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.) + - array_ops.stop_gradient(x_pl * (x_pl - 1.))) + x = x_pl + zeros_like_x_pl + actual = du.fill_triangular(x, **kwargs) + inverse_actual = du.fill_triangular_inverse(actual, **kwargs) + + inverse_actual_ = sess.run( + inverse_actual, + feed_dict={x_pl: x_}) + + if use_deferred_shape: + self.assertEqual(None, inverse_actual.shape) + else: + self.assertAllEqual(x_.shape, inverse_actual.shape) + self.assertAllEqual(x_, inverse_actual_) + + class ReduceWeightedLogSumExp(test.TestCase): def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False): @@ -870,25 +854,25 @@ class ReduceWeightedLogSumExp(test.TestCase): [1, 1, 1]]) self.assertAllClose( - np.log(4), - du.reduce_weighted_logsumexp(x, w).eval()) + np.log(4), self.evaluate(du.reduce_weighted_logsumexp(x, w))) with np.errstate(divide="ignore"): self.assertAllClose( np.log([0, 2, 2]), - du.reduce_weighted_logsumexp(x, w, axis=0).eval()) + self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=0))) self.assertAllClose( np.log([1, 3]), - du.reduce_weighted_logsumexp(x, w, axis=1).eval()) + self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=1))) self.assertAllClose( np.log([[1], [3]]), - du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True).eval()) + self.evaluate( + du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True))) self.assertAllClose( np.log(4), - du.reduce_weighted_logsumexp(x, w, axis=[0, 1]).eval()) + self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=[0, 1]))) class GenNewSeedTest(test.TestCase): @@ -986,7 +970,7 @@ class SoftplusTest(test.TestCase): # Note that this range contains both zero and inf. x = constant_op.constant(np.logspace(-8, 6).astype(np.float16)) y = du.softplus_inverse(x) - grads = gradients_impl.gradients(y, x)[0].eval() + grads = self.evaluate(gradients_impl.gradients(y, x)[0]) # Equivalent to `assertAllFalse` (if it existed). self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads)) @@ -996,11 +980,13 @@ class SoftplusTest(test.TestCase): # gradient and its approximations should be finite as well. x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16)) y = du.softplus_inverse(x) - grads = gradients_impl.gradients(y, x)[0].eval() + grads = self.evaluate(gradients_impl.gradients(y, x)[0]) # Equivalent to `assertAllTrue` (if it existed). self.assertAllEqual( np.ones_like(grads).astype(np.bool), np.isfinite(grads)) + +@test_util.run_all_in_graph_and_eager_modes class ArgumentsTest(test.TestCase): def testNoArguments(self): diff --git a/tensorflow/user_ops/duplicate_op.cc b/tensorflow/python/kernel_tests/duplicate_op.cc similarity index 100% rename from tensorflow/user_ops/duplicate_op.cc rename to tensorflow/python/kernel_tests/duplicate_op.cc diff --git a/tensorflow/user_ops/duplicate_op_test.py b/tensorflow/python/kernel_tests/duplicate_op_test.py similarity index 69% rename from tensorflow/user_ops/duplicate_op_test.py rename to tensorflow/python/kernel_tests/duplicate_op_test.py index b61e68d75e3ef253788da82cce56d113bc5e44f9..529d3dd0b3aa1f1013119ef4a90363dbd8d53cd0 100644 --- a/tensorflow/user_ops/duplicate_op_test.py +++ b/tensorflow/python/kernel_tests/duplicate_op_test.py @@ -17,23 +17,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path +import os -import tensorflow as tf +from tensorflow.python.framework import load_library +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test -class DuplicateOpTest(tf.test.TestCase): +class DuplicateOpTest(test.TestCase): def testBasic(self): - library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + library_filename = os.path.join(resource_loader.get_data_files_path(), 'duplicate_op.so') - duplicate = tf.load_op_library(library_filename) + duplicate = load_library.load_op_library(library_filename) self.assertEqual(len(duplicate.OP_LIST.op), 0) with self.test_session(): - self.assertEqual(tf.add(1, 41).eval(), 42) + self.assertEqual(math_ops.add(1, 41).eval(), 42) if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index 159cba5fa3d69be5e3e3b22a85138c29d03981cc..c4d4ce780be2fa5a2617874ddb608e41edf70c36 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import from tensorflow.python.platform import test -from tensorflow.python.framework import dtypes class DynamicStitchTestBase(object): diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index e53ca1dcaa520b6937aefa45e2740f1c94188b09..55d75cb4749d6f1a33d6cf7a993a336d1afcf992 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import itertools +import math import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -31,6 +32,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables @@ -736,6 +738,222 @@ class EmbeddingLookupSparseTest(test.TestCase): x, sp_ids, sp_weights, combiner="mean") +class SafeEmbeddingLookupSparseTest(test.TestCase): + + def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1): + assert vocab_size > 0 + assert embed_dim > 0 + assert num_shards > 0 + assert num_shards <= vocab_size + + embedding_weights = partitioned_variables.create_partitioned_variables( + shape=[vocab_size, embed_dim], + slicing=[num_shards, 1], + initializer=init_ops.truncated_normal_initializer( + mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)) + for w in embedding_weights: + w.initializer.run() + embedding_weights = [w.eval() for w in embedding_weights] + return embedding_weights + + def _ids_and_weights_2d(self): + # Each row demonstrates a test case: + # Row 0: multiple valid ids, 1 invalid id, weighted mean + # Row 1: all ids are invalid (leaving no valid ids after pruning) + # Row 2: no ids to begin with + # Row 3: single id + # Row 4: all ids have <=0 weight + indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [5, 4] + + sparse_ids = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int64), + constant_op.constant(shape, dtypes.int64)) + + sparse_weights = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sparse_ids, sparse_weights + + def _ids_and_weights_3d(self): + # Each (2-D) index demonstrates a test case: + # Index 0, 0: multiple valid ids, 1 invalid id, weighted mean + # Index 0, 1: all ids are invalid (leaving no valid ids after pruning) + # Index 0, 2: no ids to begin with + # Index 1, 0: single id + # Index 1, 1: all ids have <=0 weight + # Index 1, 2: no ids to begin with + indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0], + [1, 1, 1]] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [2, 3, 4] + + sparse_ids = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int64), + constant_op.constant(shape, dtypes.int64)) + + sparse_weights = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sparse_ids, sparse_weights + + def test_safe_embedding_lookup_sparse_return_zero_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights).eval()) + + self.assertAllClose( + embedding_lookup_result, + [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / + 3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4]) + + def test_safe_embedding_lookup_sparse_return_special_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, default_id=3).eval()) + + self.assertAllClose( + embedding_lookup_result, + [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / + 3.0, embedding_weights[0][3], embedding_weights[0][3], + embedding_weights[0][2], embedding_weights[0][3]]) + + def test_safe_embedding_lookup_sparse_no_weights(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, _ = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + self.assertAllClose( + embedding_lookup_result, + [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, + [0] * 4, embedding_weights[0][2], ( + embedding_weights[0][0] + embedding_weights[0][1]) / 2.0]) + + def test_safe_embedding_lookup_sparse_partitioned(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, _ = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + embedding_weights = list(itertools.chain(*embedding_weights)) + self.assertAllClose(embedding_lookup_result, + [(embedding_weights[0] + embedding_weights[1]) / 2.0, + [0] * 4, [0] * 4, embedding_weights[2], + (embedding_weights[0] + embedding_weights[1]) / 2.0]) + + def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, sparse_weights = self._ids_and_weights_2d() + + embedding_weights[1] = embedding_weights[1].astype(np.float64) + self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids) + embedding_weights = [ + constant_op.constant(w, dtype=dtypes.float64) + for w in embedding_weights + ] + self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids, sparse_weights) + + def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights).eval()) + + self.assertAllClose(embedding_lookup_result, [[ + (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, + [0] * 4, [0] * 4 + ], [embedding_weights[0][2], [0] * 4, [0] * 4]]) + + def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, default_id=3).eval()) + + self.assertAllClose( + embedding_lookup_result, + [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / + 3.0, embedding_weights[0][3], embedding_weights[0][3]], [ + embedding_weights[0][2], embedding_weights[0][3], + embedding_weights[0][3] + ]]) + + def test_safe_embedding_lookup_sparse_3d_no_weights(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, _ = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + self.assertAllClose(embedding_lookup_result, [[( + embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [ + 0 + ] * 4], [ + embedding_weights[0][2], + (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4 + ]]) + + def test_safe_embedding_lookup_sparse_3d_partitioned(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, _ = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + embedding_weights = list(itertools.chain(*embedding_weights)) + self.assertAllClose(embedding_lookup_result, [[ + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4 + ], [ + embedding_weights[2], + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4 + ]]) + + def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights( + self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, sparse_weights = self._ids_and_weights_3d() + + embedding_weights[1] = embedding_weights[1].astype(np.float64) + self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids) + embedding_weights = [ + constant_op.constant(w, dtype=dtypes.float64) + for w in embedding_weights + ] + self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids, sparse_weights) + + class DynamicStitchOpTest(test.TestCase): def testCint32Cpu(self): diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index ce73e7ad3e5f822363c697609dfa163b6f13751a..9e7b5283381dd7bc0725e1ab6fb9d7d13153f02d 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase): q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run() self.assertEqual(4, q.size().eval()) + @test_util.run_in_graph_and_eager_modes def testMultipleDequeues(self): - with self.test_session() as session: - q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) - q.enqueue_many([[1, 2, 3]]).run() - a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()]) - self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue_many([[1, 2, 3]])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + @test_util.run_in_graph_and_eager_modes + def testQueuesDontShare(self): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index facadc971ff516e4f9edea0c4f52ab0953ec5fce..5272a3631fa6a49ea913694b382f4331b46c8a29 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import @@ -56,7 +57,7 @@ def simple_scoped_fn(a, x): class FunctionalOpsTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_Simple(self): with self.test_session(): elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -72,7 +73,7 @@ class FunctionalOpsTest(test.TestCase): initializer=10) self.assertAllEqual(880, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -83,7 +84,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(22, r_value[0]) self.assertAllEqual(20, r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -111,7 +112,7 @@ class FunctionalOpsTest(test.TestCase): self.assertEqual(len(variables.trainable_variables()), 1) self.assertAllEqual(880, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_Simple(self): with self.test_session(): elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -127,7 +128,7 @@ class FunctionalOpsTest(test.TestCase): initializer=10) self.assertAllEqual(1282, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -138,7 +139,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(22, r_value[0]) self.assertAllEqual(20, r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -182,7 +183,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(720.0, self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_Simple(self): with self.test_session(): nums = [1, 2, 3, 4, 5, 6] @@ -202,7 +203,7 @@ class FunctionalOpsTest(test.TestCase): values=constant_op.constant([0, 1, 2]), dense_shape=[2, 2])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapOverScalarErrors(self): with self.assertRaisesRegexp(ValueError, "not scalars"): functional_ops.map_fn(lambda x: x, [1, 2]) @@ -251,7 +252,7 @@ class FunctionalOpsTest(test.TestCase): r = gradients_impl.gradients(y, elems)[0] self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_SimpleNotTensor(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -260,7 +261,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual( np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_SingleInputMultiOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -275,7 +276,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual((nums + 3) * 2, received[0]) self.assertAllEqual(-(nums + 3) * 2, received[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiOutputMismatchedDtype(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -287,7 +288,7 @@ class FunctionalOpsTest(test.TestCase): nums, dtype=[dtypes.int64, dtypes.int64]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSingleOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -298,7 +299,7 @@ class FunctionalOpsTest(test.TestCase): received = self.evaluate(r) self.assertAllEqual(nums * nums + (-nums), received) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSameStructureOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -313,7 +314,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(-nums, received[1]) self.assertAllEqual(nums, received[2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_Simple(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") @@ -328,7 +329,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_Reverse(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") @@ -345,7 +346,7 @@ class FunctionalOpsTest(test.TestCase): self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -357,7 +358,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -367,7 +368,7 @@ class FunctionalOpsTest(test.TestCase): (elems + 1, -elems), initializer) self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSameTypeOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -377,7 +378,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(np.cumsum(elems), r_value[0]) self.assertAllEqual(np.cumsum(-elems), r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiOutputMismatchedInitializer(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -408,7 +409,7 @@ class FunctionalOpsTest(test.TestCase): results = np.array([6, 16, 38, 84, 178, 368]) self.assertAllEqual(results, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScanFoldl_Nested(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") @@ -467,7 +468,7 @@ class FunctionalOpsTest(test.TestCase): variables.global_variables_initializer().run() sess.run(grad) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -479,7 +480,7 @@ class FunctionalOpsTest(test.TestCase): y = functional_ops.foldl(fn, x, initializer=initializer) self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -491,7 +492,7 @@ class FunctionalOpsTest(test.TestCase): y = functional_ops.map_fn(lambda e: e, x) self.assertIs(None, y.get_shape().dims) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapEmptyScalar(self): with self.test_session(): map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) @@ -507,7 +508,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScanShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -604,6 +605,25 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, [6]) + def testRemoteFunctionSameDeviceDirectSession(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/cpu:0"): + a = variables.Variable(2, dtype=dtypes.int32) + b = variables.Variable(3, dtype=dtypes.int32) + + with ops.device("/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, [6]) + def testRemoteFunctionCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -652,6 +672,24 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, 9.0) + def testRemoteFunctionGPUCPUStrings(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.string) + def _remote_fn(inp): + return array_ops.identity(inp) + + a = array_ops.constant("a") + + with ops.device("/gpu:0"): + remote_op = functional_ops.remote_call( + args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + ret = sess.run(remote_op) + self.assertAllEqual(ret, [b"a"]) + def testRemoteFunctionCrossProcess(self): workers, _ = test_util.create_local_cluster(2, 1) @@ -1043,6 +1081,58 @@ class PartitionedCallTest(test.TestCase): self.assertTrue(compat.as_bytes("CPU:1") in outputs[1].eval()) self.assertTrue(compat.as_bytes("CPU:2") in outputs[2].eval()) + def testAssignAddResourceVariable(self): + + v = resource_variable_ops.ResourceVariable(1.0) + + @function.Defun() + def AssignAdd(): + v.assign_add(1.0) + + op = functional_ops.partitioned_call( + args=AssignAdd.captured_inputs, f=AssignAdd) + _ = self.evaluate(variables.global_variables_initializer()) + _ = self.evaluate(op) + value = self.evaluate(v.read_value()) + self.assertEqual(value, 2.0) + + def testFunctionWithResourcesOnDifferentDevices(self): + # TODO(akshayka): Remove the `skipTest` once we can whitelist ops as + # safe to be invoked with resources on different devices. + self.skipTest("The Placer disallows ops with resource inputs " + "on different devices.") + + with ops.device("/cpu:0"): + v_cpu_zero = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name="v_cpu_zero") + + with ops.device("/cpu:1"): + v_cpu_one = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name="v_cpu_one") + + with ops.device("/gpu:0"): + v_gpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name="v_gpu") + + def sum_gather(): + cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2])) + also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2])) + gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) + return cpu_result, also_cpu_result, gpu_result + + defined = function.Defun()(sum_gather) + with self.test_session( + config=config_pb2.ConfigProto( + allow_soft_placement=False, + log_device_placement=True, + device_count={"CPU": 2})) as sess: + sess.run(variables.global_variables_initializer()) + expected = sess.run(sum_gather()) + result = sess.run( + functional_ops.partitioned_call( + args=defined.captured_inputs, f=defined)) + self.assertAllEqual(expected, result) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py index 91ebe8de9921268b2a3c5ad645585e1fe83c7419..58e2a8ac2a3b827647b1b1176f4b69e6a88b76c6 100644 --- a/tensorflow/python/kernel_tests/gather_nd_op_test.py +++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py @@ -197,7 +197,21 @@ class GatherNdTest(test.TestCase): self.assertEqual(None, shape.ndims) self.assertEqual(None, shape[0].value) - def testBadIndices(self): + def testBadIndicesCPU(self): + with self.test_session(use_gpu=False): + params = [0, 1, 2] + indices = [[[0], [7]]] # Make this one higher rank + gather_nd = array_ops.gather_nd(params, indices) + with self.assertRaisesOpError( + r"flat indices\[1, :\] = \[7\] does not index into param " + r"\(shape: \[3\]\)"): + gather_nd.eval() + + def _disabledTestBadIndicesGPU(self): + # TODO disabled due to different behavior on GPU and CPU + # On GPU the bad indices do not raise error but fetch 0 values + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): params = [0, 1, 2] indices = [[[0], [7]]] # Make this one higher rank @@ -207,7 +221,21 @@ class GatherNdTest(test.TestCase): r"\(shape: \[3\]\)"): gather_nd.eval() - def testBadIndicesWithSlices(self): + def testBadIndicesWithSlicesCPU(self): + with self.test_session(use_gpu=False): + params = [[0, 1, 2]] + indices = [[[0], [0], [1]]] # Make this one higher rank + gather_nd = array_ops.gather_nd(params, indices) + with self.assertRaisesOpError( + r"flat indices\[2, :\] = \[1\] does not index into param " + r"\(shape: \[1,3\]\)"): + gather_nd.eval() + + def _disabledTestBadIndicesWithSlicesGPU(self): + # TODO disabled due to different behavior on GPU and CPU + # On GPU the bad indices do not raise error but fetch 0 values + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): params = [[0, 1, 2]] indices = [[[0], [0], [1]]] # Make this one higher rank diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index a2fcd751dfa94605d271587640815fae6ac1c360..033fa959359e20894c376341d2f9ad79d30a5878 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -27,7 +27,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.platform import test -_TEST_TYPES = (dtypes.float32, dtypes.complex64, dtypes.complex128) +_TEST_TYPES = (dtypes.int64, dtypes.float32, + dtypes.complex64, dtypes.complex128) class GatherTest(test.TestCase): @@ -122,6 +123,9 @@ class GatherTest(test.TestCase): gather, [tf_params, tf_indices, tf_axis], gather_grad) self.assertEqual(indices_grad, None) self.assertEqual(axis_grad, None) + if dtype.is_integer: + self.assertEqual(params_grad, None) + continue # For axis 0, we are able to create an efficient IndexedSlices for # the gradient. if axis == 0: @@ -177,7 +181,19 @@ class GatherTest(test.TestCase): gather_t = array_ops.gather(params, indices, axis=axis) self.assertEqual(None, gather_t.shape) - def testBadIndices(self): + def testBadIndicesCPU(self): + with self.test_session(use_gpu=False): + params = [[0, 1, 2], [3, 4, 5]] + with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): + array_ops.gather(params, [[7]], axis=0).eval() + with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): + array_ops.gather(params, [[7]], axis=1).eval() + + def _disabledTestBadIndicesGPU(self): + # TODO disabled due to different behavior on GPU and CPU + # On GPU the bad indices do not raise error but fetch 0 values + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index a9b55854f1b4a3dfc49f05397ca32bc7b2ccb88e..927ca012ae6fc876364734c6f9bafd62ccc87467 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -362,6 +362,71 @@ class UniformUnitScalingInitializationTest(test.TestCase): dtype=dtypes.string) +class VarianceScalingInitializationTest(test.TestCase): + + def testTruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='truncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer(distribution='normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testUntruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='untruncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'random_normal', wraps=random_ops.random_normal) \ + as mock_random_normal: + x = init(shape).eval() + self.assertTrue(mock_random_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testUniformDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer(distribution='uniform') + + with self.test_session(use_gpu=True): + x = init(shape).eval() + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + # TODO(vrv): move to sequence_ops_test? class RangeTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py index 0f95e13187fcd5cc199d871ea5efdca363b37cd0..6e894365af68877bd4f2ff4ae0f18db7c0829275 100644 --- a/tensorflow/python/kernel_tests/inplace_ops_test.py +++ b/tensorflow/python/kernel_tests/inplace_ops_test.py @@ -166,7 +166,8 @@ class InplaceOpsTest(test_util.TensorFlowTestCase): def testEmpty(self): for dtype in [ - dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool + dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool, + dtypes.uint8 ]: with self.test_session(use_gpu=True): test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)] @@ -187,11 +188,12 @@ class InplaceOpsTest(test_util.TensorFlowTestCase): self.assertEqual(val.dtype, dtype.as_numpy_dtype) self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype)) - val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval() - self.assertEqual(val.tolist(), [[b"", b""]]) + with self.test_session(use_gpu=True): + val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval() + self.assertEqual(val.tolist(), [[b"", b""]]) - val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval() - self.assertEqual(val.tolist(), [[b"", b""]]) + val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval() + self.assertEqual(val.tolist(), [[b"", b""]]) if __name__ == "__main__": diff --git a/tensorflow/user_ops/invalid_op.cc b/tensorflow/python/kernel_tests/invalid_op.cc similarity index 100% rename from tensorflow/user_ops/invalid_op.cc rename to tensorflow/python/kernel_tests/invalid_op.cc diff --git a/tensorflow/user_ops/invalid_op_test.py b/tensorflow/python/kernel_tests/invalid_op_test.py similarity index 67% rename from tensorflow/user_ops/invalid_op_test.py rename to tensorflow/python/kernel_tests/invalid_op_test.py index c90a00ce58bb4f6e1bd74c9f323e6cdc86397365..238299a895487b1cab7db053fd7f354d4a167ea9 100644 --- a/tensorflow/user_ops/invalid_op_test.py +++ b/tensorflow/python/kernel_tests/invalid_op_test.py @@ -17,19 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path +import os -import tensorflow as tf +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test -class InvalidOpTest(tf.test.TestCase): +class InvalidOpTest(test.TestCase): def testBasic(self): - library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + library_filename = os.path.join(resource_loader.get_data_files_path(), 'invalid_op.so') - with self.assertRaises(tf.errors.InvalidArgumentError): - tf.load_op_library(library_filename) + with self.assertRaises(errors.InvalidArgumentError): + load_library.load_op_library(library_filename) if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 91be80322c37792be02d1b625df6757c9d80b060..69d3aa401751f56ea338a5ac4b24d65e68dbddeb 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -107,6 +107,10 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -124,6 +128,10 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -140,6 +148,10 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -177,6 +189,10 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -213,4 +229,8 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py index 2b80f01b73441185281a3e2ef4db003b150c1e12..3ede2aceaa51c2795029ba13b763fed3e2ddc441 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py @@ -80,7 +80,7 @@ class SquareLinearOperatorBlockDiagTest( build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]), ] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = ( build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ @@ -91,26 +91,19 @@ class SquareLinearOperatorBlockDiagTest( for block_shape in expected_blocks ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in expected_blocks - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = self.evaluate(matrices) - operator = block_diag.LinearOperatorBlockDiag( - [linalg.LinearOperatorFullMatrix( - m_ph, is_square=True) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = block_diag.LinearOperatorBlockDiag( - [linalg.LinearOperatorFullMatrix( - m, is_square=True) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + l, is_square=True) for l in lin_op_matrices]) + + # Should be auto-set. + self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) @@ -123,7 +116,7 @@ class SquareLinearOperatorBlockDiagTest( block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) - return operator, block_diag_dense, feed_dict + return operator, block_diag_dense def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index 5713d169696c78e996332b7a515a3ee2eedca839..7261d4bb3bc4aa24f51be21f9ac261549dca58d5 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -95,7 +95,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( # real, the matrix will not be real. return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating real spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -107,22 +107,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator( # zero, so the operator will still be self-adjoint. spectrum = math_ops.cast(spectrum, dtype) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, is_self_adjoint=True, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, is_self_adjoint=True, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -149,7 +145,7 @@ class LinearOperatorCirculantTestHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.float32, dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating Hermitian spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -172,22 +168,18 @@ class LinearOperatorCirculantTestHermitianSpectrum( spectrum = math_ops.fft(h_c) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -213,7 +205,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # Will be well conditioned enough to get accurate solves. spectrum = linear_operator_test_util.random_sign_uniform( @@ -222,22 +214,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( minval=1., maxval=2.) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -432,7 +420,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.float32, dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating Hermitian spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -455,22 +443,18 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( spectrum = math_ops.fft2d(h_c) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant2D( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant2D( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant2D( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat class LinearOperatorCirculant2DTestNonHermitianSpectrum( @@ -486,7 +470,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # Will be well conditioned enough to get accurate solves. spectrum = linear_operator_test_util.random_sign_uniform( @@ -495,22 +479,18 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( minval=1., maxval=2.) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant2D( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant2D( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant2D( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_real_hermitian_spectrum_gives_real_symmetric_operator(self): with self.test_session() as sess: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py index f96b9ccdaacae7d8e0552ed3d74ce53808fed963..612a50bcec771f8511d20d19b312a797d531f109 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py @@ -44,7 +44,7 @@ class SquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() shape = list(build_info.shape) @@ -56,33 +56,23 @@ class SquareLinearOperatorCompositionTest( for _ in range(num_operators) ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in range(num_operators) - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = sess.run(matrices) - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated each matrix to a numpy array. + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = linalg.LinearOperatorComposition( + [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices], + is_square=True) + matmul_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(matmul_order_list[0]) + mat = matmul_order_list[0] for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) - return operator, mat, feed_dict + return operator, mat def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. @@ -148,7 +138,7 @@ class NonSquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() shape = list(build_info.shape) @@ -170,30 +160,22 @@ class NonSquareLinearOperatorCompositionTest( shape_2, dtype=dtype) ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in range(num_operators) - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = sess.run(matrices) - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph]) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m) for m in matrices]) - feed_dict = None - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated each matrix to a numpy array. + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = linalg.LinearOperatorComposition( + [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices]) + matmul_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(matmul_order_list[0]) + mat = matmul_order_list[0] for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) - return operator, mat, feed_dict + return operator, mat def test_static_shapes(self): operators = [ diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py index 0a0e31c716ecfa10ed93cff92fa908a240f8495e..83cc8c483f9aec6dd0ddf3f961a8180af7515e40 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py @@ -34,25 +34,21 @@ class LinearOperatorDiagTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) diag = linear_operator_test_util.random_sign_uniform( shape[:-1], minval=1., maxval=2., dtype=dtype) + + lin_op_diag = diag + if use_placeholder: - diag_ph = array_ops.placeholder(dtype=dtype) - # Evaluate the diag here because (i) you cannot feed a tensor, and (ii) - # diag is random and we want the same value used for both mat and - # feed_dict. - diag = diag.eval() - operator = linalg.LinearOperatorDiag(diag_ph) - feed_dict = {diag_ph: diag} - else: - operator = linalg.LinearOperatorDiag(diag) - feed_dict = None + lin_op_diag = array_ops.placeholder_with_default(diag, shape=None) + + operator = linalg.LinearOperatorDiag(lin_op_diag) - mat = array_ops.matrix_diag(diag) + matrix = array_ops.matrix_diag(diag) - return operator, mat, feed_dict + return operator, matrix def test_assert_positive_definite_raises_for_zero_eigenvalue(self): # Matrix with one positive eigenvalue and one zero eigenvalue. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py index b3da623b5e8d8c99c6777e75e2d49f24dab1c96b..1a40a29ec6a040ca3d98e0b27492b1379d30cb4b 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -36,30 +35,20 @@ class SquareLinearOperatorFullMatrixTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype) + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True) - feed_dict = {matrix_ph: matrix} - else: - # is_square should be auto-detected here. - operator = linalg.LinearOperatorFullMatrix(matrix) - feed_dict = None + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) - return operator, mat, feed_dict + return operator, matrix def test_is_x_flags(self): # Matrix with two positive eigenvalues. @@ -136,32 +125,20 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( def _dtypes_to_test(self): return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype, force_well_conditioned=True) + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - # is_square is auto-set because of self_adjoint/pd. - operator = linalg.LinearOperatorFullMatrix( - matrix_ph, is_self_adjoint=True, is_positive_definite=True) - feed_dict = {matrix_ph: matrix} - else: - operator = linalg.LinearOperatorFullMatrix( - matrix, is_self_adjoint=True, is_positive_definite=True) - feed_dict = None - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) - - return operator, mat, feed_dict + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) + + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) + + return operator, matrix def test_is_x_flags(self): # Matrix with two positive eigenvalues. @@ -210,26 +187,18 @@ class NonSquareLinearOperatorFullMatrixTest( linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_normal(shape, dtype=dtype) + + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix_ph) - feed_dict = {matrix_ph: matrix} - else: - operator = linalg.LinearOperatorFullMatrix(matrix) - feed_dict = None + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) - return operator, mat, feed_dict + return operator, matrix def test_is_x_flags(self): matrix = [[3., 2., 1.], [1., 1., 1.]] diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py index 59f63f949e96991193412d3574603e58a75cb6e5..35dcf4417c313f5cbc00c8b66b4c5d1f2e157212 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py @@ -43,7 +43,7 @@ class LinearOperatorIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) assert shape[-1] == shape[-2] @@ -54,13 +54,7 @@ class LinearOperatorIdentityTest( num_rows, batch_shape=batch_shape, dtype=dtype) mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype) - # Nothing to feed since LinearOperatorIdentity takes no Tensor args. - if use_placeholder: - feed_dict = {} - else: - feed_dict = None - - return operator, mat, feed_dict + return operator, mat def test_assert_positive_definite(self): with self.test_session(): @@ -261,7 +255,7 @@ class LinearOperatorScaledIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) assert shape[-1] == shape[-2] @@ -274,24 +268,23 @@ class LinearOperatorScaledIdentityTest( multiplier = linear_operator_test_util.random_sign_uniform( shape=batch_shape, minval=1., maxval=2., dtype=dtype) - operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier) # Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args. + lin_op_multiplier = multiplier + if use_placeholder: - multiplier_ph = array_ops.placeholder(dtype=dtype) - multiplier = multiplier.eval() - operator = linalg_lib.LinearOperatorScaledIdentity( - num_rows, multiplier_ph) - feed_dict = {multiplier_ph: multiplier} - else: - feed_dict = None + lin_op_multiplier = array_ops.placeholder_with_default( + multiplier, shape=None) + + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows, lin_op_multiplier) multiplier_matrix = array_ops.expand_dims( array_ops.expand_dims(multiplier, -1), -1) - mat = multiplier_matrix * linalg_ops.eye( + matrix = multiplier_matrix * linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=dtype) - return operator, mat, feed_dict + return operator, matrix def test_assert_positive_definite_does_not_raise_when_positive(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py index 784c730bbc8179dd1302294b2d558e8a0c532c0c..e26b946151dd8ddb923e34352feb6b483f9752fc 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py @@ -101,7 +101,7 @@ class SquareLinearOperatorKroneckerTest( def _tests_to_skip(self): return ["det", "solve", "solve_with_broadcast"] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_factors = build_info.__dict__["factors"] matrices = [ @@ -110,26 +110,15 @@ class SquareLinearOperatorKroneckerTest( for block_shape in expected_factors ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in expected_factors - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = self.evaluate(matrices) - operator = kronecker.LinearOperatorKronecker( - [linalg.LinearOperatorFullMatrix( - m_ph, is_square=True) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = kronecker.LinearOperatorKronecker( - [linalg.LinearOperatorFullMatrix( - m, is_square=True) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) + lin_op_matrices = [ + array_ops.placeholder_with_default(m, shape=None) for m in matrices] + + operator = kronecker.LinearOperatorKronecker( + [linalg.LinearOperatorFullMatrix( + l, is_square=True) for l in lin_op_matrices]) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) @@ -138,7 +127,7 @@ class SquareLinearOperatorKroneckerTest( if not use_placeholder: kronecker_dense.set_shape(shape) - return operator, kronecker_dense, feed_dict + return operator, kronecker_dense def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index 8095f6419ef0d9543339cf1f4ee9cd4783f852b9..34b35a4ffb878c63f851f2b31491e7bfa4057417 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -68,7 +68,7 @@ class BaseLinearOperatorLowRankUpdatetest(object): build_info((3, 4, 4)), build_info((2, 1, 4, 4))] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): # Recall A = L + UDV^H shape = list(build_info.shape) diag_shape = shape[:-1] @@ -80,17 +80,17 @@ class BaseLinearOperatorLowRankUpdatetest(object): # operator, with condition number as high as 1e4. base_diag = linear_operator_test_util.random_uniform( diag_shape, minval=1e-4, maxval=1., dtype=dtype) - base_diag_ph = array_ops.placeholder(dtype=dtype) + lin_op_base_diag = base_diag # U u = linear_operator_test_util.random_normal_correlated_columns( u_perturbation_shape, dtype=dtype) - u_ph = array_ops.placeholder(dtype=dtype) + lin_op_u = u # V v = linear_operator_test_util.random_normal_correlated_columns( u_perturbation_shape, dtype=dtype) - v_ph = array_ops.placeholder(dtype=dtype) + lin_op_v = v # D if self._is_diag_update_positive: @@ -99,42 +99,25 @@ class BaseLinearOperatorLowRankUpdatetest(object): else: diag_update = linear_operator_test_util.random_normal( diag_update_shape, stddev=1e-4, dtype=dtype) - diag_update_ph = array_ops.placeholder(dtype=dtype) + lin_op_diag_update = diag_update if use_placeholder: - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - base_diag = base_diag.eval() - u = u.eval() - v = v.eval() - diag_update = diag_update.eval() - - # In all cases, set base_operator to be positive definite. - base_operator = linalg.LinearOperatorDiag( - base_diag_ph, is_positive_definite=True) - - operator = linalg.LinearOperatorLowRankUpdate( - base_operator, - u=u_ph, - v=v_ph if self._use_v else None, - diag_update=diag_update_ph if self._use_diag_update else None, - is_diag_update_positive=self._is_diag_update_positive) - feed_dict = { - base_diag_ph: base_diag, - u_ph: u, - v_ph: v, - diag_update_ph: diag_update} - else: - base_operator = linalg.LinearOperatorDiag( - base_diag, is_positive_definite=True) - operator = linalg.LinearOperatorLowRankUpdate( - base_operator, - u, - v=v if self._use_v else None, - diag_update=diag_update if self._use_diag_update else None, - is_diag_update_positive=self._is_diag_update_positive) - feed_dict = None + lin_op_base_diag = array_ops.placeholder_with_default( + base_diag, shape=None) + lin_op_u = array_ops.placeholder_with_default(u, shape=None) + lin_op_v = array_ops.placeholder_with_default(v, shape=None) + lin_op_diag_update = array_ops.placeholder_with_default( + diag_update, shape=None) + + base_operator = linalg.LinearOperatorDiag( + lin_op_base_diag, is_positive_definite=True) + + operator = linalg.LinearOperatorLowRankUpdate( + base_operator, + lin_op_u, + v=lin_op_v if self._use_v else None, + diag_update=lin_op_diag_update if self._use_diag_update else None, + is_diag_update_positive=self._is_diag_update_positive) # The matrix representing L base_diag_mat = array_ops.matrix_diag(base_diag) @@ -146,28 +129,28 @@ class BaseLinearOperatorLowRankUpdatetest(object): if self._use_v and self._use_diag_update: # In this case, we have L + UDV^H and it isn't symmetric. expect_use_cholesky = False - mat = base_diag_mat + math_ops.matmul( + matrix = base_diag_mat + math_ops.matmul( u, math_ops.matmul(diag_update_mat, v, adjoint_b=True)) elif self._use_v: # In this case, we have L + UDV^H and it isn't symmetric. expect_use_cholesky = False - mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) + matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) elif self._use_diag_update: # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0. expect_use_cholesky = self._is_diag_update_positive - mat = base_diag_mat + math_ops.matmul( + matrix = base_diag_mat + math_ops.matmul( u, math_ops.matmul(diag_update_mat, u, adjoint_b=True)) else: # In this case, we have L + UU^H, which is PD since L > 0. expect_use_cholesky = True - mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) + matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) if expect_use_cholesky: self.assertTrue(operator._use_cholesky) else: self.assertFalse(operator._use_cholesky) - return operator, mat, feed_dict + return operator, matrix class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py index a57d2f085e089fb913f09fdd9b07cf13aa7f3c35..167c6cacd1a5bbbaa70a7fdd236ddd70ea8cd4e8 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py @@ -38,28 +38,23 @@ class LinearOperatorLowerTriangularTest( # matrix_triangular_solve. return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) # Upper triangle will be nonzero, but ignored. # Use a diagonal that ensures this matrix is well conditioned. tril = linear_operator_test_util.random_tril_matrix( shape, dtype=dtype, force_well_conditioned=True, remove_upper=False) + lin_op_tril = tril + if use_placeholder: - tril_ph = array_ops.placeholder(dtype=dtype) - # Evaluate the tril here because (i) you cannot feed a tensor, and (ii) - # tril is random and we want the same value used for both mat and - # feed_dict. - tril = tril.eval() - operator = linalg.LinearOperatorLowerTriangular(tril_ph) - feed_dict = {tril_ph: tril} - else: - operator = linalg.LinearOperatorLowerTriangular(tril) - feed_dict = None + lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None) + + operator = linalg.LinearOperatorLowerTriangular(lin_op_tril) - mat = array_ops.matrix_band_part(tril, -1, 0) + matrix = array_ops.matrix_band_part(tril, -1, 0) - return operator, mat, feed_dict + return operator, matrix def test_assert_non_singular(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py index 7d367a92750ae3562c93d2381eb895c94a866eaa..6f401358a2519a699488f0372323b5a41621c4cd 100644 --- a/tensorflow/python/kernel_tests/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg_grad_test.py @@ -177,6 +177,12 @@ if __name__ == '__main__': MatrixUnaryFunctorGradientTest, 'MatrixDeterminantGradient', name, _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_determinant, dtype, shape)) + _AddTest( + MatrixUnaryFunctorGradientTest, 'LogMatrixDeterminantGradient', + name, + _GetMatrixUnaryFunctorGradientTest( + lambda x: linalg_ops.log_matrix_determinant(x)[1], + dtype, shape)) # Tests for gradients of matrix_solve_ls for dtype in np.float32, np.float64: diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 49855200c2427a88a4bd582c2ef786c38a6fa76a..bf82e08551e6a276b95bf77f7932c31d7a844a78 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -46,7 +46,7 @@ def scalar_shape(): @test_util.with_c_shapes class ListOpsTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPop(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) @@ -54,14 +54,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPopGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testPushPop() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStack(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) @@ -70,14 +70,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStackGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testStack() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorListFromTensor(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) @@ -87,14 +87,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual(self.evaluate(e), 1.0) self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFromTensorGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testTensorListFromTensor() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetItem(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) @@ -104,14 +104,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [3.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testGetSetItem() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUnknownShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=-1) @@ -122,7 +122,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCPUGPUCopy(self): if not context.num_gpus(): return @@ -140,7 +140,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStack(self): with context.graph_mode(), self.test_session(): tl = list_ops.empty_tensor_list( @@ -152,7 +152,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackInLoop(self): with context.graph_mode(), self.test_session(): t1 = list_ops.empty_tensor_list( @@ -170,7 +170,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32) self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackSwitchDtype(self): with context.graph_mode(), self.test_session(): list_ = list_ops.empty_tensor_list( @@ -192,7 +192,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackInLoopSwitchDtype(self): with context.graph_mode(), self.test_session(): t1 = list_ops.empty_tensor_list( @@ -216,7 +216,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)]) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSerialize(self): # pylint: disable=g-import-not-at-top try: @@ -248,7 +248,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): worker_e = array_ops.identity(e) self.assertAllEqual(self.evaluate(worker_e), [2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPopGradients(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, @@ -260,7 +260,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): e = 2 * e self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStackFromTensorGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) @@ -272,7 +272,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): grad = tape.gradient(result, [c])[0] self.assertAllEqual(self.evaluate(grad), [2.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) @@ -288,14 +288,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0]) self.assertAllEqual(self.evaluate(grad_c2), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSetOutOfBounds(self): c = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceVariableScatterGather(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) @@ -319,7 +319,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): [[1.0, 2.0]] * 4) self.assertAllEqual(self.evaluate(updated_v_stacked), expected) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConcat(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) @@ -379,7 +379,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls, element_dtype=dtypes.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushBackBatch(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py index 28c85fa13ad100c38382d2b787ff965f9e3ca44e..e635a71c78484278b54bfc4de70e232834c37a0a 100644 --- a/tensorflow/python/kernel_tests/logging_ops_test.py +++ b/tensorflow/python/kernel_tests/logging_ops_test.py @@ -59,7 +59,7 @@ class LoggingOpsTest(test.TestCase): class PrintGradientTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPrintShape(self): inp = constant_op.constant(2.0, shape=[100, 32]) inp_printed = logging_ops.Print(inp, [inp]) diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 1123c20a165ba93bd380fa471a8be91f7005d7bb..87fc715783b972a20465827d697cf06637588154 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -118,6 +119,14 @@ class AbsoluteDifferenceLossTest(test.TestCase): with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerNoMemoryLeaked(self): + # This is a somewhat convoluted way of testing that nothing gets added to + # a global collection. + predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) + labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + losses.absolute_difference(labels, predictions) + class SoftmaxCrossEntropyLossTest(test.TestCase): @@ -246,6 +255,13 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 0.0, 3) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerNoMemoryLeaked(self): + logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) + losses.sparse_softmax_cross_entropy(labels, logits) + def testAllCorrectInt64Labels(self): with self.test_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py index 361853448ce2c8477af6920257c58c1eba0fa952..944de217a175764aa2e43b4fc488d912041e279a 100644 --- a/tensorflow/python/kernel_tests/pad_op_test.py +++ b/tensorflow/python/kernel_tests/pad_op_test.py @@ -317,6 +317,11 @@ class PadOpTest(test.TestCase): [constant_op.constant(1, shape=[2]), [0, unknown]]) self.assertEqual([6, None], padded.get_shape().as_list()) + # Zero padding on a known dimension. + inp = array_ops.placeholder(dtypes.int32, [None, None, 20]) + padded = array_ops.pad(inp, [[0, 0], [0, unknown], [0, 0]]) + self.assertEqual([None, None, 20], padded.get_shape().as_list()) + def testScalars(self): paddings = np.zeros((0, 2), dtype=np.int32) inp = np.asarray(7) diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index a0c372db7d0a4e76c37c01e1ce24cd8fc9123f7a..e95c72971521452a239b78ff4ab9c25c3089f1da 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -947,7 +947,7 @@ class PoolingTest(test.TestCase): output_sizes, x_init_value=x_init_value, delta=1e-2) - print("%s gradient error = " % func_name, err) + tf_logging.info("%s gradient error = " % func_name, err) self.assertLess(err, err_tolerance) def _ConstructAndTestSecondGradient(self, @@ -1024,7 +1024,7 @@ class PoolingTest(test.TestCase): input_sizes, x_init_value=x_init_value, delta=1e-2) - print("%s second-order gradient error = " % func_name, err) + tf_logging.info("%s second-order gradient error = " % func_name, err) self.assertLess(err, err_tolerance) def _testMaxPoolGradValidPadding1_1(self, data_format, use_gpu): diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index b9f44d728a1d9843df1e836594f9caa7010d8a94..50154a45a8b58f270509e404737c8650cbd2c5ff 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc import re import numpy as np @@ -26,6 +27,7 @@ from six.moves import queue from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import constant_op @@ -34,6 +36,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import script_ops @@ -432,25 +435,40 @@ class PyFuncTest(test.TestCase): # ----- Tests shared by py_func and eager_py_func ----- def testCleanup(self): - for _ in xrange(1000): - g = ops.Graph() - with g.as_default(): - c = constant_op.constant([1.], dtypes.float32) - _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) - _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) - self.assertTrue(script_ops._py_funcs.size() < 100) + # Delete everything created by previous tests to avoid side effects. + ops.reset_default_graph() + gc.collect() + initial_size = script_ops._py_funcs.size() + # Encapsulate the graph generation, so locals can be deleted. + def make_graphs(): + for _ in xrange(1000): + g = ops.Graph() + with g.as_default(): + c = constant_op.constant([1.], dtypes.float32) + _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) + _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) + # These ops have a reference to 'c' which has a reference to the graph. + # Checks if the functions are being deleted though the graph is referenced from them. + # (see #18292) + _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) + _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) + + # Call garbage collector to enforce deletion. + make_graphs() + ops.reset_default_graph() + gc.collect() + self.assertEqual(initial_size, script_ops._py_funcs.size()) # ----- Tests for eager_py_func ----- - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerSingleOutputInt32(self): a = array_ops.ones((3, 3), dtype=dtypes.int32) x = array_ops.ones((3, 1), dtype=dtypes.int32) output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) - with self.test_session(): - ret = self.evaluate(output) - self.assertAllEqual(ret, [[3], [3], [3]]) + ret = self.evaluate(output) + self.assertAllEqual(ret, [[3], [3], [3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerSingleOutputFloat32(self): with test_util.device(use_gpu=True): a = array_ops.ones((3, 3), dtype=dtypes.float32) @@ -459,7 +477,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerArrayOutput(self): with test_util.device(use_gpu=True): a = array_ops.ones((3, 3), dtype=dtypes.float32) @@ -469,7 +487,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerReturnNone(self): with test_util.device(use_gpu=True): def no_return_value(): @@ -482,7 +500,7 @@ class PyFuncTest(test.TestCase): else: self.assertIsNone(ret) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerPyFuncInDefun(self): with test_util.device(use_gpu=True): def wrapper(): @@ -494,7 +512,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(wrapped()) self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerExceptionHandling(self): with test_util.device(use_gpu=True): self._testExceptionHandling( @@ -513,11 +531,10 @@ class PyFuncTest(test.TestCase): self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerReturningVariableRaisesError(self): def return_variable(): - variable = resource_variable_ops.ResourceVariable(0.0) - return variable + return resource_variable_ops.ResourceVariable(0.0) with self.assertRaisesRegexp(errors.UnknownError, "Attempting to return a variable"): @@ -525,6 +542,99 @@ class PyFuncTest(test.TestCase): return_variable, inp=[], Tout=dtypes.float32) self.evaluate(output) + @test_util.run_in_graph_and_eager_modes + def testEagerGradientTape(self): + + def f(x): + return x**2 + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(x) + y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) + dy_dx = tape.gradient(y, x) + self.assertEqual(self.evaluate(dy_dx), 6.0) + + def testEagerGradientGraph(self): + + def f(x): + return x**2 + + x = constant_op.constant(3.0) + y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) + dy_dx = gradients_impl.gradients(y, x)[0] + self.assertEqual(self.evaluate(dy_dx), 6.0) + + @test_util.run_in_graph_and_eager_modes + def testEagerGradientTapeMultipleArgs(self): + + def f(x, y): + return x**2 + y**2 + + x = constant_op.constant(3.0) + y = constant_op.constant(4.0) + with backprop.GradientTape() as tape: + tape.watch(x) + tape.watch(y) + z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) + + dz_dx, dz_dy = tape.gradient(z, [x, y]) + self.assertEqual(self.evaluate(dz_dx), 6.0) + self.assertEqual(self.evaluate(dz_dy), 8.0) + + def testEagerGradientGraphMultipleArgs(self): + + def f(x, y): + return x**2 + y**2 + + x = constant_op.constant(3.0) + y = constant_op.constant(4.0) + z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) + + dz_dx, dz_dy = gradients_impl.gradients(z, [x, y]) + self.assertEqual(self.evaluate(dz_dx), 6.0) + self.assertEqual(self.evaluate(dz_dy), 8.0) + + def testEagerGradientGraphLogHuber(self): + + def log_huber(x, m): + if math_ops.abs(x) <= m: + return x**2 + else: + return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2)) + + x = array_ops.placeholder(dtypes.float32) + m = array_ops.placeholder(dtypes.float32) + + y = script_ops.eager_py_func( + func=log_huber, inp=[x, m], Tout=dtypes.float32) + dy_dx = gradients_impl.gradients(y, x)[0] + + with self.test_session() as sess: + # Takes the first branch of log_huber. + y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) + self.assertEqual(y, 1.0) + self.assertEqual(dy_dx, 2.0) + + def testEagerRespectsDevicePlacmentOfOp(self): + + def f(x): + return math_ops.square(x) + + def g(x): + return math_ops.add(x, x) + + with ops.device("/CPU:0"): + # Explicitly ask for the py_funcs to execute on CPU, even if + # a GPU is available. + x = array_ops.placeholder(dtypes.float32) + y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32) + z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(z, feed_dict={x: 3.0}) + self.assertEqual(output, 18.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index acd7566eec8e3fffd74db33234b03a0c87427a3e..3b3a28fc9a24104cc9032ab23dfc51e690d3ec94 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -107,6 +107,23 @@ cuda_py_test( tags = ["nozapfhahn"], ) +cuda_py_test( + name = "random_grad_test", + size = "small", + srcs = ["random_grad_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_grad", + "//tensorflow/python:random_ops", + ], +) + cuda_py_test( name = "random_poisson_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py index 051c7d86bf2342f15b587fc350bfbede7fae2285..bd64d61af8e793e71a319b6ac1af95bd7dd16a3d 100644 --- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py +++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py @@ -54,7 +54,7 @@ native_sampler = random_ops.multinomial class MultinomialTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSmallEntropy(self): random_seed.set_random_seed(1618) for output_dtype in [np.int32, np.int64]: diff --git a/tensorflow/python/kernel_tests/random/random_grad_test.py b/tensorflow/python/kernel_tests/random/random_grad_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d455b785bbf562fb41f30cab7e0bb723a7b894 --- /dev/null +++ b/tensorflow/python/kernel_tests/random/random_grad_test.py @@ -0,0 +1,240 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.random_grad.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_grad +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +class AddLeadingUnitDimensionsTest(test.TestCase): + + def testBasic(self): + ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 3) + self.assertAllEqual(ret.shape, [1, 1, 1, 3, 2, 1]) + + def testZeroExtraDimensions(self): + ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 0) + self.assertAllEqual(ret.shape, [3, 2, 1]) + + def testScalarInput(self): + ret = random_grad.add_leading_unit_dimensions(1.0, 2) + self.assertAllEqual(ret.shape, [1, 1]) + + def testUnknownShape(self): + x = array_ops.placeholder(dtypes.float32) + num_dimensions = array_ops.placeholder(dtypes.int32) + ret = random_grad.add_leading_unit_dimensions(x, num_dimensions) + with self.test_session() as sess: + ret_val = sess.run(ret, {x: np.ones([2, 2]), num_dimensions: 2}) + self.assertAllEqual(ret_val.shape, [1, 1, 2, 2]) + + +class RandomGammaGradTest(test.TestCase): + """Tests for derivative of a sample ~ Gamma(alpha, beta) wrt alpha and beta. + + The sample is an "implicit" function of alpha, beta and the independent random + noise u. The derivatives we are looking for are + d sample(alpha, beta, u) / dalpha (and dbeta). + + The derivative w.r.t. beta is computed by the standard automatic + differentiation, so we trust that it is computed correctly. + + The derivative w.r.t. alpha is computed by Eigen function, so we test it in + several ways. Unfortunately, the standard derivative checking by perturbing + the parameter is impossible here, because we cannot fix the value of u + in the random sampler. Instead, we compare the derivative for the given pair + of (sample, alpha) to the values computed in various ways, and also check + some statistical properties of the derivative. + """ + + def testGradientsShape(self): + shape = [2, 3] + alpha = array_ops.ones([2, 2]) + beta = array_ops.ones([1, 2]) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + self.assertAllEqual(grads_alpha.shape, alpha.shape) + self.assertAllEqual(grads_beta.shape, beta.shape) + + def testGradientsShapeWithOneSamplePerParameter(self): + shape = [] + alpha = array_ops.ones([2, 2]) + beta = array_ops.ones([1, 2]) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + self.assertAllEqual(grads_alpha.shape, alpha.shape) + self.assertAllEqual(grads_beta.shape, beta.shape) + + def testGradientsUnknownShape(self): + shape = array_ops.placeholder(dtypes.int32) + alpha = array_ops.placeholder(dtypes.float32) + beta = array_ops.placeholder(dtypes.float32) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + + alpha_val = np.ones([1, 2]) + beta_val = np.ones([2, 1]) + with self.test_session() as sess: + grads_alpha_val, grads_beta_val = sess.run( + [grads_alpha, grads_beta], + {alpha: alpha_val, beta: beta_val, shape: [2, 1]}) + self.assertAllEqual(grads_alpha_val.shape, alpha_val.shape) + self.assertAllEqual(grads_beta_val.shape, beta_val.shape) + + def _testCompareToExplicitDerivative(self, dtype): + """Compare to the explicit reparameterization derivative. + + Verifies that the computed derivative satisfies + dsample / dalpha = d igammainv(alpha, u) / dalpha, + where u = igamma(alpha, sample). + + Args: + dtype: TensorFlow dtype to perform the computations in. + """ + delta = 1e-3 + np_dtype = dtype.as_numpy_dtype + try: + from scipy import misc # pylint: disable=g-import-not-at-top + from scipy import special # pylint: disable=g-import-not-at-top + + alpha_val = np.logspace(-2, 3, dtype=np_dtype) + alpha = constant_op.constant(alpha_val) + sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype) + actual = gradients_impl.gradients(sample, alpha)[0] + + (sample_val, actual_val) = self.evaluate((sample, actual)) + + u = special.gammainc(alpha_val, sample_val) + expected_val = misc.derivative( + lambda alpha_prime: special.gammaincinv(alpha_prime, u), + alpha_val, dx=delta * alpha_val) + + self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3) + except ImportError as e: + tf_logging.warn("Cannot use special functions in a test: %s" % str(e)) + + def testCompareToExplicitDerivativeFloat(self): + self._testCompareToExplicitDerivative(dtypes.float32) + + def testCompareToExplicitDerivativeDouble(self): + self._testCompareToExplicitDerivative(dtypes.float64) + + def _testCompareToImplicitDerivative(self, dtype): + """Compare to the implicit reparameterization derivative. + + Let's derive the formula we compare to. + + Start from the fact that CDF maps a random variable to the Uniform + random variable: + igamma(alpha, sample) = u, where u ~ Uniform(0, 1). + + Apply d / dalpha to both sides: + d igamma(alpha, sample) / dalpha + + d igamma(alpha, sample) / dsample * dsample/dalpha = 0 + d igamma(alpha, sample) / dalpha + + d igamma(alpha, sample) / dsample * dsample / dalpha = 0 + dsample/dalpha = - (d igamma(alpha, sample) / dalpha) + / d igamma(alpha, sample) / dsample + + This is the equation (8) of https://arxiv.org/abs/1805.08498 + + Args: + dtype: TensorFlow dtype to perform the computations in. + """ + np_dtype = dtype.as_numpy_dtype + alpha = constant_op.constant(np.logspace(-2, 3, dtype=np_dtype)) + sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype) + actual = gradients_impl.gradients(sample, alpha)[0] + + sample_sg = array_ops.stop_gradient(sample) + cdf = math_ops.igamma(alpha, sample_sg) + dcdf_dalpha, dcdf_dsample = gradients_impl.gradients( + cdf, [alpha, sample_sg]) + # Numerically unstable due to division, do not try at home. + expected = -dcdf_dalpha / dcdf_dsample + + (actual_val, expected_val) = self.evaluate((actual, expected)) + + self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3) + + def testCompareToImplicitDerivativeFloat(self): + self._testCompareToImplicitDerivative(dtypes.float32) + + def testCompareToImplicitDerivativeDouble(self): + self._testCompareToImplicitDerivative(dtypes.float64) + + def testAverageAlphaGradient(self): + """Statistical test for the gradient. + + Using the equation (5) of https://arxiv.org/abs/1805.08498, we have + 1 = d/dalpha E_{sample ~ Gamma(alpha, 1)} sample + = E_{sample ~ Gamma(alpha, 1)} dsample/dalpha. + Here we verify that the rhs is fairly close to one. + The convergence speed is not great, so we use many samples and loose bounds. + """ + num_samples = 1000 + alpha = constant_op.constant([0.8, 1e1, 1e3], dtype=dtypes.float32) + sample = random_ops.random_gamma([num_samples], alpha) + # We need to average the gradients, which is equivalent to averaging the + # samples and then doing backprop. + mean_sample = math_ops.reduce_mean(sample, axis=0) + dsample_dalpha = gradients_impl.gradients(mean_sample, alpha)[0] + dsample_dalpha_val = self.evaluate(dsample_dalpha) + self.assertAllClose(dsample_dalpha_val, [1.0] * 3, atol=1e-1, rtol=1e-1) + + def testQuadraticLoss(self): + """Statistical test for the gradient. + + The equation (5) of https://arxiv.org/abs/1805.08498 says + d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample) + = E_{sample ~ Gamma(alpha, 1)} df(sample)/dalpha. + + Choose a quadratic loss function f(sample) = (sample - t)^2. + Then, the lhs can be computed analytically: + d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample) + = d/dalpha [ (alpha + alpha^2) - 2 * t * alpha + t^2 ] + = 1 + 2 * alpha - 2 * t. + + We compare the Monte-Carlo estimate of the expectation with the + true gradient. + """ + num_samples = 1000 + t = 0.3 + alpha = 0.5 + expected = 1 + 2 * alpha - 2 * t + + alpha = constant_op.constant(alpha) + sample = random_ops.random_gamma([num_samples], alpha, 1.0) + loss = math_ops.reduce_mean(math_ops.square(sample - t)) + dloss_dalpha = gradients_impl.gradients(loss, alpha)[0] + dloss_dalpha_val = self.evaluate(dloss_dalpha) + self.assertAllClose(expected, dloss_dalpha_val, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 82a27eebeef16c9dacaf1b900f0398a56533cd2d..8e06e1abfb52244e8c1a9b4ed15a270f6048e028 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -25,8 +25,6 @@ import shutil import threading import zlib -import six - from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -77,6 +75,69 @@ _TEXT = b"""Gaily bedight, """ +class TFCompressionTestCase(test.TestCase): + + def setUp(self): + super(TFCompressionTestCase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self, options=None, prefix=""): + filenames = [] + for i in range(self._num_files): + name = prefix + "tfrecord.%d.txt" % i + records = [self._Record(i, j) for j in range(self._num_records)] + fn = self._WriteRecordsToFile(records, name, options) + filenames.append(fn) + return filenames + + def _WriteRecordsToFile(self, records, name="tfrecord", options=None): + fn = os.path.join(self.get_temp_dir(), name) + with tf_record.TFRecordWriter(fn, options=options) as writer: + for r in records: + writer.write(r) + return fn + + def _ZlibCompressFile(self, infile, name="tfrecord.z"): + # zlib compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), name) + with open(zfn, "wb") as f: + f.write(cdata) + return zfn + + def _GzipCompressFile(self, infile, name="tfrecord.gz"): + # gzip compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = f.read() + + gzfn = os.path.join(self.get_temp_dir(), name) + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + return gzfn + + def _ZlibDecompressFile(self, infile, name="tfrecord"): + with open(infile, "rb") as f: + cdata = zlib.decompress(f.read()) + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + def _GzipDecompressFile(self, infile, name="tfrecord"): + with gzip.GzipFile(infile, "rb") as f: + cdata = f.read() + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + class IdentityReaderTest(test.TestCase): def _ExpectRead(self, sess, key, value, expected): @@ -348,7 +409,7 @@ class TextLineReaderTest(test.TestCase): k, v = sess.run([key, value]) -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTest(TFCompressionTestCase): def setUp(self): super(FixedLengthRecordReaderTest, self).setUp() @@ -407,40 +468,18 @@ class FixedLengthRecordReaderTest(test.TestCase): # gap_bytes=hop_bytes-record_bytes def _CreateGzipFiles(self, num_records, gap_bytes): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with gzip.GzipFile(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - if num_records > 0: - f.write(self._Record(i, 0)) - for j in range(1, num_records): - if gap_bytes > 0: - f.write(b"G" * gap_bytes) - f.write(self._Record(i, j)) - f.write(b"F" * self._footer_bytes) + filenames = self._CreateFiles(num_records, gap_bytes) + for fn in filenames: + # compress inplace. + self._GzipCompressFile(fn, fn) return filenames # gap_bytes=hop_bytes-record_bytes def _CreateZlibFiles(self, num_records, gap_bytes): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn + ".tmp", "wb") as f: - f.write(b"H" * self._header_bytes) - if num_records > 0: - f.write(self._Record(i, 0)) - for j in range(1, num_records): - if gap_bytes > 0: - f.write(b"G" * gap_bytes) - f.write(self._Record(i, j)) - f.write(b"F" * self._footer_bytes) - with open(fn + ".tmp", "rb") as f: - cdata = zlib.compress(f.read()) - with open(fn, "wb") as zf: - zf.write(cdata) + filenames = self._CreateFiles(num_records, gap_bytes) + for fn in filenames: + # compress inplace. + self._ZlibCompressFile(fn, fn) return filenames def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records): @@ -477,10 +516,7 @@ class FixedLengthRecordReaderTest(test.TestCase): ]) f.write(compat.as_bytes(all_records_str)) f.write(b"F" * self._footer_bytes) - with open(fn + ".tmp", "rb") as f: - cdata = zlib.compress(f.read()) - with open(fn, "wb") as zf: - zf.write(cdata) + self._ZlibCompressFile(fn + ".tmp", fn) return filenames # gap_bytes=hop_bytes-record_bytes @@ -529,7 +565,6 @@ class FixedLengthRecordReaderTest(test.TestCase): for i in range(self._num_files): for j in range(num_overlapped_records): k, v = sess.run([key, value]) - print(v) self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._OverlappedRecord(i, j), v) @@ -579,25 +614,10 @@ class FixedLengthRecordReaderTest(test.TestCase): files, num_overlapped_records, encoding="ZLIB") -class TFRecordReaderTest(test.TestCase): +class TFRecordReaderTest(TFCompressionTestCase): def setUp(self): super(TFRecordReaderTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - - def _Record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _CreateFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = tf_record.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._Record(i, j)) - return filenames def testOneEpoch(self): files = self._CreateFiles() @@ -647,107 +667,27 @@ class TFRecordReaderTest(test.TestCase): self.assertEqual(self._num_files * self._num_records, num_v) def testReadZlibFiles(self): - files = self._CreateFiles() - zlib_files = [] - for i, fn in enumerate(files): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) - queue.enqueue_many([zlib_files]).run() + queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i])) + self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) def testReadGzipFiles(self): - files = self._CreateFiles() - gzip_files = [] - for i, fn in enumerate(files): - with open(fn, "rb") as f: - cdata = f.read() - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(zfn, "wb") as f: - f.write(cdata) - gzip_files.append(zfn) - - with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) - reader = io_ops.TFRecordReader(name="test_reader", options=options) - queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) - key, value = reader.read(queue) - - queue.enqueue_many([gzip_files]).run() - queue.close().run() - for i in range(self._num_files): - for j in range(self._num_records): - k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i])) - self.assertAllEqual(self._Record(i, j), v) - - -class TFRecordWriterZlibTest(test.TestCase): - - def setUp(self): - super(TFRecordWriterZlibTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - - def _Record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _CreateFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) - writer = tf_record.TFRecordWriter(fn, options=options) - for j in range(self._num_records): - writer.write(self._Record(i, j)) - writer.close() - del writer - - return filenames - - def _WriteRecordsToFile(self, records, name="tf_record"): - fn = os.path.join(self.get_temp_dir(), name) - writer = tf_record.TFRecordWriter(fn, options=None) - for r in records: - writer.write(r) - writer.close() - del writer - return fn - - def _ZlibCompressFile(self, infile, name="tfrecord.z"): - # zlib compress the file and write compressed contents to file. - with open(infile, "rb") as f: - cdata = zlib.compress(f.read()) - - zfn = os.path.join(self.get_temp_dir(), name) - with open(zfn, "wb") as f: - f.write(cdata) - return zfn + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + files = self._CreateFiles(options) - def testOneEpoch(self): - files = self._CreateFiles() with self.test_session() as sess: - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) @@ -760,196 +700,6 @@ class TFRecordWriterZlibTest(test.TestCase): self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) - with self.assertRaisesOpError("is closed and has insufficient elements " - "\\(requested 1, current size 0\\)"): - k, v = sess.run([key, value]) - - def testZLibFlushRecord(self): - fn = self._WriteRecordsToFile([b"small record"], "small_record") - with open(fn, "rb") as h: - buff = h.read() - - # creating more blocks and trailing blocks shouldn't break reads - compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) - - output = b"" - for c in buff: - if isinstance(c, int): - c = six.int2byte(c) - output += compressor.compress(c) - output += compressor.flush(zlib.Z_FULL_FLUSH) - - output += compressor.flush(zlib.Z_FULL_FLUSH) - output += compressor.flush(zlib.Z_FULL_FLUSH) - output += compressor.flush(zlib.Z_FINISH) - - # overwrite the original file with the compressed data - with open(fn, "wb") as h: - h.write(output) - - with self.test_session() as sess: - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) - reader = io_ops.TFRecordReader(name="test_reader", options=options) - queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) - key, value = reader.read(queue) - queue.enqueue(fn).run() - queue.close().run() - k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % fn)) - self.assertAllEqual(b"small record", v) - - def testZlibReadWrite(self): - """Verify that files produced are zlib compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") - zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") - - # read the compressed contents and verify. - actual = [] - for r in tf_record.tf_record_iterator( - zfn, - options=tf_record.TFRecordOptions( - tf_record.TFRecordCompressionType.ZLIB)): - actual.append(r) - self.assertEqual(actual, original) - - def testZlibReadWriteLarge(self): - """Verify that writing large contents also works.""" - - # Make it large (about 5MB) - original = [_TEXT * 10240] - fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") - zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") - - # read the compressed contents and verify. - actual = [] - for r in tf_record.tf_record_iterator( - zfn, - options=tf_record.TFRecordOptions( - tf_record.TFRecordCompressionType.ZLIB)): - actual.append(r) - self.assertEqual(actual, original) - - def testGzipReadWrite(self): - """Verify that files produced are gzip compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") - - # gzip compress the file and write compressed contents to file. - with open(fn, "rb") as f: - cdata = f.read() - gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz") - with gzip.GzipFile(gzfn, "wb") as f: - f.write(cdata) - - actual = [] - for r in tf_record.tf_record_iterator( - gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)): - actual.append(r) - self.assertEqual(actual, original) - - -class TFRecordIteratorTest(test.TestCase): - - def setUp(self): - super(TFRecordIteratorTest, self).setUp() - self._num_records = 7 - - def _Record(self, r): - return compat.as_bytes("Record %d" % r) - - def _WriteCompressedRecordsToFile( - self, - records, - name="tfrecord.z", - compression_type=tf_record.TFRecordCompressionType.ZLIB): - fn = os.path.join(self.get_temp_dir(), name) - options = tf_record.TFRecordOptions(compression_type=compression_type) - writer = tf_record.TFRecordWriter(fn, options=options) - for r in records: - writer.write(r) - writer.close() - del writer - return fn - - def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS): - with open(infile, "rb") as f: - cdata = zlib.decompress(f.read(), wbits) - zfn = os.path.join(self.get_temp_dir(), name) - with open(zfn, "wb") as f: - f.write(cdata) - return zfn - - def testIterator(self): - fn = self._WriteCompressedRecordsToFile( - [self._Record(i) for i in range(self._num_records)], - "compressed_records") - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) - reader = tf_record.tf_record_iterator(fn, options) - for i in range(self._num_records): - record = next(reader) - self.assertAllEqual(self._Record(i), record) - with self.assertRaises(StopIteration): - record = next(reader) - - def testWriteZlibRead(self): - """Verify compression with TFRecordWriter is zlib library compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteCompressedRecordsToFile(original, - "write_zlib_read.tfrecord.z") - zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) - self.assertEqual(actual, original) - - def testWriteZlibReadLarge(self): - """Verify compression for large records is zlib library compatible.""" - # Make it large (about 5MB) - original = [_TEXT * 10240] - fn = self._WriteCompressedRecordsToFile(original, - "write_zlib_read_large.tfrecord.z") - zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record") - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) - self.assertEqual(actual, original) - - def testWriteGzipRead(self): - original = [b"foo", b"bar"] - fn = self._WriteCompressedRecordsToFile( - original, - "write_gzip_read.tfrecord.gz", - compression_type=TFRecordCompressionType.GZIP) - - with gzip.GzipFile(fn, "rb") as f: - cdata = f.read() - zfn = os.path.join(self.get_temp_dir(), "tf_record") - with open(zfn, "wb") as f: - f.write(cdata) - - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) - self.assertEqual(actual, original) - - def testBadFile(self): - """Verify that tf_record_iterator throws an exception on bad TFRecords.""" - fn = os.path.join(self.get_temp_dir(), "bad_file") - with tf_record.TFRecordWriter(fn) as writer: - writer.write(b"123") - fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") - with open(fn, "rb") as f: - with open(fn_truncated, "wb") as f2: - # DataLossError requires that we've written the header, so this must - # be at least 12 bytes. - f2.write(f.read(14)) - with self.assertRaises(errors_impl.DataLossError): - for _ in tf_record.tf_record_iterator(fn_truncated): - pass - class AsyncReaderTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 073799cc465f765de411655fe50b470d680a9e39..0fb0b8895cbc847639999ad1bd23e7fb04c86034 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -106,12 +106,26 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(False, name="bool_test") self.assertAllEqual(bool(v), False) + def testDifferentAssignGraph(self): + with ops.Graph().as_default(): + v = resource_variable_ops.ResourceVariable(1.0) + ops.reset_default_graph() + v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the + # variable graph. + def testFetchHandle(self): with self.test_session(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") self.assertGreater(len(handle.eval()), 0) + def testCachedValueReadBeforeWrite(self): + with self.test_session() as sess: + v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0") + sess.run(v.initializer) + value, _ = sess.run([v, v.assign_add(1.0)]) + self.assertAllEqual(value, 0.0) + def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -131,14 +145,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertIn("", str(handle)) self.assertIn("", repr(handle)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDtypeSurvivesIdentity(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) self.evaluate(resource_variable_ops.assign_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32))) - @test_util.run_in_graph_and_eager_modes() + def testUnreadOpName(self): + v = resource_variable_ops.ResourceVariable(1.0) + self.assertNotEqual(v.name, v.assign_add(1.0).name) + + @test_util.run_in_graph_and_eager_modes def testCreateRead(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( @@ -147,7 +165,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertAllEqual(1, value) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testManyAssigns(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.assign_variable_op( @@ -165,7 +183,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(f, 1) self.assertEqual(s, 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignAdd(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( @@ -176,7 +194,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertEqual(read, 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterAdd(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -189,7 +207,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterSub(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -202,7 +220,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMul(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -215,7 +233,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterDiv(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -228,7 +246,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMin(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op( @@ -265,7 +283,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): meta_graph_two = saver.export_meta_graph(graph=graph) self.assertEqual(meta_graph_def, meta_graph_two) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMax(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -278,7 +296,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[6]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterAddScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -291,7 +309,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterSubScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -304,7 +322,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMulScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -317,7 +335,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterDivScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -330,7 +348,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMinScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -343,7 +361,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMaxScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -408,7 +426,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(ref, indices, updates) self.assertAllEqual(ref.read_value(), [True, True, True]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstraintArg(self): constraint = lambda x: x v = resource_variable_ops.ResourceVariable( @@ -448,32 +466,32 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.OutOfRangeError): state_ops.count_up_to(v, 1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFnDtype(self): v = resource_variable_ops.ResourceVariable( initial_value=lambda: 1, dtype=dtypes.float32, name="var0") self.assertEqual(dtypes.float32, v.value().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFnNoDtype(self): v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, name="var2") self.assertEqual(dtypes.int32, v.value().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeAllVariables(self): v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(1.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOperatorOverload(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(2.0, self.evaluate(v + v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -491,7 +509,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(4.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoad(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -524,7 +542,26 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): sess.run(v.initialized_value()) - @test_util.run_in_graph_and_eager_modes() + def testTrainableInProto(self): + with ops.Graph().as_default(): + non_trainable_variable = resource_variable_ops.ResourceVariable( + trainable=False, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + False, + resource_variable_ops.ResourceVariable( + variable_def=non_trainable_variable.to_proto()) + .trainable) + trainable_variable = resource_variable_ops.ResourceVariable( + trainable=True, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + True, + resource_variable_ops.ResourceVariable( + variable_def=trainable_variable.to_proto()) + .trainable) + + @test_util.run_in_graph_and_eager_modes def testSparseRead(self): with self.test_session(): init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) @@ -546,7 +583,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEquals(v._handle, w._handle) self.assertEquals(v._graph_element, w._graph_element) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignAddMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -564,7 +601,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(4.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignSubMethod(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -582,7 +619,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(0.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDestroyResource(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -671,7 +708,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): v = resource_variable_ops.ResourceVariable( name="var4", initial_value=array_ops.ones(shape=[10, 20, 35])) @@ -789,13 +826,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_add(v, [1], [3]) self.assertAllEqual([1.0, 5.0], v.numpy()) + def testScatterNdAddStateOps(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable( + [1, 1, 1, 1, 1, 1, 1, 1], dtype=dtypes.float32, name="add") + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) + state_ops.scatter_nd_add(v, indices, updates) + self.assertAllClose(expected, v.numpy()) + def testScatterUpdateCast(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update") state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterUpdateInvalidArgs(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") # The exact error and message differ between graph construction (where the diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index fe5ad84c104502f0e09d3a963b406f49d6b97b71..957baf8c6089a6a033f54762fef290399d80cd09 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -81,6 +81,25 @@ class ScalarStateRNNCell(rnn_cell_impl.RNNCell): return (input_, state + 1) +class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell): + """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" + + @property + def output_size(self): + return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2)) + + @property + def state_size(self): + return tensor_shape.TensorShape([]) + + def zero_state(self, batch_size, dtype): + return array_ops.zeros([], dtype=dtypes.int32) + + def call(self, input_, state, scope=None): + concatenated = array_ops.concat((input_, input_), axis=-1) + return (input_, concatenated), state + 1 + + class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell): """RNN Cell its state as a TensorArray.""" @@ -108,7 +127,7 @@ class RNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() if context.executing_eagerly(): @@ -122,7 +141,7 @@ class RNNTest(test.TestCase): dtype=dtypes.float32, sequence_length=[[4]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBatchSizeFromInput(self): cell = Plus1RNNCell() in_eager_mode = context.executing_eagerly() @@ -162,7 +181,7 @@ class RNNTest(test.TestCase): self.assertEqual(None, outputs.shape[0].value) self.assertEqual(None, state.shape[0].value) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScalarStateIsAccepted(self): cell = ScalarStateRNNCell() in_eager_mode = context.executing_eagerly() @@ -182,7 +201,29 @@ class RNNTest(test.TestCase): self.assertAllEqual([[[1], [2], [3], [4]]], outputs) self.assertAllEqual(4, state) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def testUnbalancedOutputIsAccepted(self): + cell = UnbalancedOutputRNNCell() + in_eager_mode = context.executing_eagerly() + + if in_eager_mode: + inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + + with self.test_session() as sess: + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, sequence_length=[4]) + if not in_eager_mode: + outputs, state = sess.run( + [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) + + self.assertIsInstance(outputs, tuple) + self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0]) + self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) + self.assertAllEqual(4, state) + + @test_util.run_in_graph_and_eager_modes def testTensorArrayStateIsAccepted(self): cell = TensorArrayStateRNNCell() in_eager_mode = context.executing_eagerly() @@ -215,7 +256,7 @@ class RNNTest(test.TestCase): cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCellsBuild(self): f32 = dtypes.float32 f64 = dtypes.float64 diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index 79fe927b8aff17165af40580a491560b885b89f3..f9b9c77bbf7e2a8afdbfbd0929a68856b8aae51c 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -144,7 +144,9 @@ class StatefulScatterNdTest(test.TestCase): self.assertAllClose(new, ref_var.eval()) def _VariableRankTests(self, np_scatter, tf_scatter): - for vtype in (np.float32, np.float64, np.complex64, np.complex128): + for vtype in (np.int32, + np.float32, np.float64, + np.complex64, np.complex128): for itype in (np.int32, np.int64): self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) @@ -221,7 +223,7 @@ class StatefulScatterNdTest(test.TestCase): # self._VariableRankTests(_NumpyDiv, state_ops.scatter_nd_div) def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter): - for vtype in (np.float32, np.float64): + for vtype in (np.int32, np.float32, np.float64): for itype in (np.int32, np.int64): self._VariableRankTest( np_scatter, tf_scatter, vtype, itype, repeat_indices=True) @@ -367,7 +369,7 @@ class ScatterNdTest(test.TestCase): del input_ # input_ is not used in scatter_nd return array_ops.scatter_nd(indices, updates, shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidShape(self): # TODO(apassos) figure out how to unify these errors with self.assertRaises(errors.InvalidArgumentError diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index c70a4ffce7be71effe3ea10faa9754ab2b3842ce..1a0fa744aeec7b6df281835c266ebbd901f22fea 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -159,7 +159,13 @@ class ScatterTest(test.TestCase): # Clips small values to avoid division by zero. def clip_small_values(x): - return 1e-4 * np.sign(x) if np.abs(x) < 1e-4 else x + threshold = 1e-4 + sign = np.sign(x) + + if isinstance(x, np.int32): + threshold = 1 + sign = np.random.choice([-1, 1]) + return threshold * sign if np.abs(x) < threshold else x updates = np.vectorize(clip_small_values)(updates) old = _AsType(np.random.randn(*((first_dim,) + extra_shape)), vtype) @@ -181,7 +187,11 @@ class ScatterTest(test.TestCase): tf_scatter, repeat_indices=False, updates_are_scalar=False): - for vtype in (np.float32, np.float64): + vtypes = [np.float32, np.float64] + if tf_scatter != state_ops.scatter_div: + vtypes.append(np.int32) + + for vtype in vtypes: for itype in (np.int32, np.int64): self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices, updates_are_scalar) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 794be096b7309a18f9fe225642bcaafb5058df78..a82855dfeb5b8fcf215f545e53ae0f26638011da 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -264,7 +264,9 @@ class UnsortedSegmentTest(SegmentReductionHelper): # A subset of ops has been enabled for complex numbers self.complex_ops_list = [(np.add, None, - math_ops.unsorted_segment_sum, lambda t: 0)] + math_ops.unsorted_segment_sum, lambda t: 0), + (np.ndarray.__mul__, None, + math_ops.unsorted_segment_prod, lambda t: 1)] self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64] self.all_dtypes = (self.differentiable_dtypes + diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 7368251ab69574cc6cba703e605f108c6ab45649..34e34d9d1b2034d8679844f051358f020a44587a 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -642,6 +642,29 @@ class TileTest(test.TestCase): err = gradient_checker.compute_gradient_error(a, [4, 2], tiled, [4, 4]) self.assertLess(err, 1e-3) + def testGradientWithSparseGradWithRank1(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + outputs = array_ops.gather(array_ops.tile(inputs, [3]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testGradientWithSparseGradWithRank3(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + inputs = array_ops.reshape(inputs, [-1, 1, 1]) + outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + def testShapeFunctionEdgeCases(self): # Unknown multiples shape. inp = constant_op.constant(0.0, shape=[4, 4, 4, 4]) diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index 5fc9bef21816e3a12f0d274bab1fc82a83546422..402f67619b41a5f13c6603eb6665974a09a8f4fb 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -225,7 +225,7 @@ class SliceTest(test.TestCase): self.assertAllEqual(m1.get_shape().as_list(), [1, 2, 3]) m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1]) - self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None]) + self.assertAllEqual(m2.get_shape().as_list(), [1, 2, 3]) def _testGradientSlice(self, input_shape, slice_begin, slice_size): diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py index 27b39a626fcc6b2705bf9e797b5293ed3f1c7820..3847cebc7dcabd66c26a4e4551e5856c6a927a33 100644 --- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py @@ -300,6 +300,51 @@ class SerializeSparseTest(test.TestCase): sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse, dtypes.variant) + def testVariantSerializeDeserializeScalar(self): + with self.test_session(use_gpu=False) as sess: + indices_value = np.array([[]], dtype=np.int64) + values_value = np.array([37], dtype=np.int32) + shape_value = np.array([], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder() + serialized = sparse_ops.serialize_sparse( + sparse_tensor, out_type=dtypes.variant) + deserialized = sparse_ops.deserialize_sparse( + serialized, dtype=dtypes.int32) + deserialized_value = sess.run( + deserialized, + feed_dict={ + sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.dense_shape: shape_value + }) + self.assertAllEqual(deserialized_value.indices, indices_value) + self.assertAllEqual(deserialized_value.values, values_value) + self.assertAllEqual(deserialized_value.dense_shape, shape_value) + + def testVariantSerializeDeserializeScalarBatch(self): + with self.test_session(use_gpu=False) as sess: + indices_value = np.array([[]], dtype=np.int64) + values_value = np.array([37], dtype=np.int32) + shape_value = np.array([], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder() + serialized = sparse_ops.serialize_sparse( + sparse_tensor, out_type=dtypes.variant) + stacked = array_ops.stack([serialized, serialized]) + deserialized = sparse_ops.deserialize_sparse(stacked, dtype=dtypes.int32) + deserialized_value = sess.run( + deserialized, + feed_dict={ + sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.dense_shape: shape_value + }) + self.assertAllEqual(deserialized_value.indices, + np.array([[0], [1]], dtype=np.int64)) + self.assertAllEqual(deserialized_value.values, + np.array([37, 37], dtype=np.int32)) + self.assertAllEqual(deserialized_value.dense_shape, + np.array([2], dtype=np.int64)) + def _testDeserializeFailsWrongTypeHelper(self, serialize_fn, deserialize_fn, diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py index da116601f833cc6b471e383e030c5fbe93b52ac5..97f30daf4a9c9615e1b42a1ba94e693e166bbc1c 100644 --- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py @@ -21,13 +21,15 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import sparse_ops +import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import from tensorflow.python.platform import test class SparseSliceOpTest(test.TestCase): - def _SparseTensor_4x6(self): + def _SparseTensor_4x6(self, val_dtype=np.int64): # [0 | |2 | |4 |5 ] # [ |11| |13|14| ] # [20| | |23| |25] @@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase): [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype( np.int64) val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype( - np.int64) + val_dtype) shape = np.array([4, 6]).astype(np.int64) return sparse_tensor.SparseTensor(ind, val, shape) @@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35]) self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1]) + def testGradients(self): + sp_input = self._SparseTensor_4x6(val_dtype=np.float32) + start_and_size = [([0, 0], [4, 2]), + ([0, 2], [5, 2]), + ([0, 4], [5, 3])] + + with self.test_session(use_gpu=False): + for start, size in start_and_size: + sp_output = sparse_ops.sparse_slice(sp_input, start, size) + nnz_in = len(sp_input.values.eval()) + nnz_out = len(sp_output.values.eval()) + + err = gradient_checker.compute_gradient_error( + [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,)) + self.assertLess(err, 1e-3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index 8cfee3eb933afcea7a58d5632948b87b0c4c10df..419cd5ecdafab92910cd06fb18148796f70afb44 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -95,7 +95,7 @@ class SplitOpTest(test.TestCase): sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]}) self.assertTrue("Cannot infer num from shape" in str(context.exception)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExplicitNum(self): size_splits = array_ops.constant([2, 2, 6], dtype=dtypes.int32) value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @@ -109,7 +109,7 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(r[1], value[2:4]) self.assertAllEqual(r[2], value[4:]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testListOfScalarTensors(self): a = math_ops.to_int32(5) b = math_ops.to_int32(6) @@ -168,7 +168,7 @@ class SplitOpTest(test.TestCase): offset += size_splits[i] self.assertAllEqual(result[i], inp[slices]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSpecialCasesVariable(self): self._testSpecialCasesVariable() for dtype in _TEST_DTYPES: @@ -210,13 +210,13 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(np_ans[i], out[i]) self.assertShapeEqual(np_ans[i], tf_ans[i]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitRows(self): for dtype in _TEST_DTYPES: inp = self._makeData((4, 4), dtype) self._compare(inp, 0, 4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitCols(self): for dtype in _TEST_DTYPES: inp = self._makeData((4, 4), dtype) @@ -232,7 +232,7 @@ class SplitOpTest(test.TestCase): self.assertEqual(out[i].shape, expected_shape) self.assertEqual(expected_shape, tf_ans[i].get_shape()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEmpty(self): # Note: np.split returns a rank-0 empty ndarray # if the input ndarray is empty. @@ -244,7 +244,7 @@ class SplitOpTest(test.TestCase): self._testEmpty(inp, 2, 3, (8, 0, 7)) self._testEmpty(inp, 2, 7, (8, 0, 3)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testIdentity(self): for dtype in _TEST_DTYPES: inp = self._makeData((2, 2, 2), dtype) @@ -252,7 +252,7 @@ class SplitOpTest(test.TestCase): self._compare(inp, 1, 1) self._compare(inp, 2, 1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitDim0(self): for dtype in _TEST_DTYPES: self._compare(self._makeData((6, 10, 18), dtype), 0, 3) @@ -281,7 +281,7 @@ class SplitOpTest(test.TestCase): offset += length self.assertAllEqual(result[i], inp[slices]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandom(self): for dtype in _TEST_DTYPES: for _ in range(5): diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py index a5bd1b6ee072e4e025bf76351a971782b4c23fad..e20daccb28a72f939a448f4cf3bfb283ea616e96 100644 --- a/tensorflow/python/kernel_tests/string_split_op_test.py +++ b/tensorflow/python/kernel_tests/string_split_op_test.py @@ -146,5 +146,101 @@ class StringSplitOpTest(test.TestCase): self.assertAllEqual(shape, [3, 1]) +class StringSplitV2OpTest(test.TestCase): + + def testSplitV2(self): + strings = ["pigs on the wing", "animals"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]]) + self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"]) + self.assertAllEqual(shape, [2, 4]) + + def testSplitV2MultiCharSeparator(self): + # Match Python behavior: + # >>> '1<>2<>3'.split('<>') + # ['1', '2', '3'] + # >>> "<><>4<>5<><>6<>".split("<>") + # ['', '', '4', '5', '', '6', ''] + strings = ["1<>2<>3", "<><>4<>5<><>6<>"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, sep="<>") + indices, values, shape = sess.run(tokens) + self.assertAllEqual( + indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]) + self.assertAllEqual(values, [b"1", b"2", b"3", + b"", b"", b"4", b"5", b"", b"6", b""]) + self.assertAllEqual(shape, [2, 7]) + + def testSplitV2SimpleSeparator(self): + # Match Python behavior: + # >>> '1,2,3'.split(',') + # ['1', '2', '3'] + # >>> '1,2,,3,'.split(',') + # ['1', '2', '', '3', ''] + strings = ["1,2,3", "4,5,,6,"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, sep=',') + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4]]) + self.assertAllEqual(values, [b"1", b"2", b"3", + b"4", b"5", b"", b"6", b""]) + self.assertAllEqual(shape, [2, 5]) + + def testSplitV2EmptySeparator(self): + # Match Python behavior: + # >>> '1 2 3'.split() + # ['1', '2', '3'] + #>>> ' 1 2 3 '.split() + #['1', '2', '3'] + strings = ["1 2 3", " 4 5 6 "] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2]]) + self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"]) + self.assertAllEqual(shape, [2, 3]) + + def testSplitV2SimpleSeparatorMaxSplit(self): + # Match Python behavior: + # >>> '1,2,3'.split(',', maxsplit=1) + # ['1', '2,3'] + # >>> '4,5,,6,'.split(',', maxsplit=1) + # ['4', '5,,6,'] + strings = ["1,2,3", "4,5,,6,"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], + [1, 0], [1, 1]]) + self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"]) + self.assertAllEqual(shape, [2, 2]) + + def testSplitV2EmptySeparatorMaxSplit(self): + # Match Python behavior: + # '1 2 3'.split(maxsplit=1) + # ['1', '2 3'] + # >>> " 4 5 6 ".split(maxsplit=1) + # ['4', '5 6 '] + strings = ["1 2 3", " 4 5 6 "] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, maxsplit=1) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], + [1, 0], [1, 1]]) + self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "]) + self.assertAllEqual(shape, [2, 2]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 1b935d5286729e9e802c56e90e2ae7ab72a6e080..0b3a396d6bf46fb46416662a9443ed7b5811e15c 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -150,7 +150,7 @@ class TemplateTest(test.TestCase): # Parameters are tied, so the loss should have gone down after training. self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_skip_stack_frames(self): first = traceback.format_stack() second = traceback.format_stack() @@ -158,7 +158,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(result)) self.assertNotEqual(len(first), len(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_with_name(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -204,7 +204,7 @@ class TemplateTest(test.TestCase): self.assertEqual(v1, v3) self.assertEqual("s1/dummy:0", v1.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_in_scope(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -221,7 +221,7 @@ class TemplateTest(test.TestCase): self.assertEqual("scope/s1/dummy:0", v1.name) self.assertEqual("scope/s1_1/dummy:0", v3.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_with_internal_reuse(self): tmpl1 = template.make_template("s1", internally_variable_scoped_function) tmpl2 = template.make_template("s1", internally_variable_scoped_function) @@ -237,13 +237,13 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl1("not_test") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_without_name(self): with self.assertRaisesRegexp( ValueError, "name cannot be None."): template.make_template(None, variable_scoped_function) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_make_template(self): # Test both that we can call it with positional and keywords. tmpl1 = template.make_template( @@ -266,7 +266,7 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_enforces_no_extra_trainable_variables_eager(self): tmpl = template.make_template("s", function_with_side_create, @@ -287,7 +287,7 @@ class TemplateTest(test.TestCase): trainable=False) self.assertEqual(tmpl(name="1"), tmpl(name="2")) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_internal_variable_reuse(self): def nested(): @@ -310,7 +310,7 @@ class TemplateTest(test.TestCase): self.assertEqual("s1/nested/x:0", v1.name) self.assertEqual("s1_1/nested/x:0", v3.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_templates(self): def nested_template(): @@ -360,7 +360,7 @@ class TemplateTest(test.TestCase): self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name) self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_templates_with_defun(self): def variable_scoped_function_no_return_value(trainable=True): @@ -429,7 +429,7 @@ class TemplateTest(test.TestCase): "a", partial, create_graph_function_=True) self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_immediate_scope_creation(self): # Create templates in scope a then call in scope b. make_template should # capture the scope the first time it is called, and make_immediate_template @@ -454,7 +454,7 @@ class TemplateTest(test.TestCase): self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_scope_access(self): # Ensure that we can access the scope inside the template, because the name # of that scope may be different from the name we pass to make_template, due @@ -479,7 +479,7 @@ class TemplateTest(test.TestCase): # Template is called at the top level, so there is no preceding "foo_2". self.assertEqual(tc.variable_scope.name, "blah") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_custom_getter(self): # Custom getter that maintains call count and forwards to true getter custom_getter_count = [0] @@ -512,7 +512,7 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_fails_gracefully(self): for create_scope_now in [True, False]: def module_function_with_one_arg(inputs): @@ -535,7 +535,7 @@ class TemplateTest(test.TestCase): templatized_function(data) self.assertTrue(templatized_function._variables_created) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). @@ -586,7 +586,7 @@ class TemplateTest(test.TestCase): "Second application of template should also get " "a freshly uniquified name scope.") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_global_variables(self): # Make sure global_variables are created. with variable_scope.variable_scope("foo"): @@ -608,7 +608,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.global_variables)) self.assertEqual(2, len(tb.global_variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainable_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -632,7 +632,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.variables)) self.assertEqual(1, len(tb.variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_non_trainable_variables(self): # Make sure non_trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -675,7 +675,7 @@ class TemplateTest(test.TestCase): self.assertEqual(0, len(ta.local_variables)) self.assertEqual(1, len(tb.local_variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_make_template_with_defun(self): def variable_scoped_function_no_return_value(scope_name): diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index c0b36f143d109eb28e2784b49e8fd4099b5799a6..6de6fbe7679fa8e95d3032b04fb81b43ac3a60d9 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -26,11 +26,13 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -73,7 +75,7 @@ class TensorArrayTest(test.TestCase): super(TensorArrayTest, cls).tearDownClass() session_lib.Session.reset(cls._workers[0].target) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteRead(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -121,11 +123,11 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayWritePack(dtypes.complex128) self._testTensorArrayWritePack(dtypes.string) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWritePack(self): self._testTensorArrayWritePackMaybeLegacy() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEmptyTensorArrayPack(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -159,7 +161,7 @@ class TensorArrayTest(test.TestCase): convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0], [6.0, 7.0], [106.0, 107.0], [8.0, 9.0]]), c0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteConcat(self): self._testTensorArrayWriteConcat(dtypes.float32) self._testTensorArrayWriteConcat(dtypes.float64) @@ -182,7 +184,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros() @@ -198,7 +200,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros() @@ -249,7 +251,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayUnpackRead(dtypes.complex128) self._testTensorArrayUnpackRead(dtypes.string) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayUnpackRead(self): self._testTensorArrayUnpackReadMaybeLegacy() @@ -295,7 +297,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(convert([]).reshape(0, 2), d1) self.assertAllEqual(convert([[3.0, 301.0]]), d2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArraySplitRead(self): self._testTensorArraySplitRead(dtypes.float32) self._testTensorArraySplitRead(dtypes.float64) @@ -395,7 +397,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(t_g_ta_0, t_g_ta_1) self.assertAllEqual([[4.0, 5.0]], d_r1_0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) @@ -414,7 +416,7 @@ class TensorArrayTest(test.TestCase): "resizeable and size is: 3"): self.evaluate(ta.write(3, 3.0).flow) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) @@ -448,7 +450,7 @@ class TensorArrayTest(test.TestCase): "it has already been written to."): self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayConcatIncompatibleShapesFails(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -480,7 +482,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError("shape"): self.evaluate(w3.concat()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArraySplitIncompatibleShapesFails(self): with self.test_session(use_gpu=True): in_eager_mode = context.executing_eagerly() @@ -549,7 +551,59 @@ class TensorArrayTest(test.TestCase): dtypes.complex64, dtypes.complex128): self._testTensorArrayWriteGradientAddMultipleAdds(dtype) - @test_util.run_in_graph_and_eager_modes() + def testTensorArrayGradWithShapeKnownElementShape(self): + with self.test_session(use_gpu=True) as sess: + ta = tensor_array_ops.TensorArray( + size=3, + dtype=dtypes.float32, + element_shape=tensor_shape.TensorShape([2, 3])) + handle, flow = data_flow_ops.tensor_array_grad_with_shape( + handle=ta.handle, + flow_in=ta.flow, + shape_to_prepend=tensor_shape.TensorShape([4, 5]), + source="source") + ta_grad = tensor_array_ops.TensorArray( + dtypes.float32, handle=handle, flow=flow) + value = array_ops.placeholder(dtypes.float32) + ta_grad = ta_grad.write(0, value) + read_value = ta_grad.read(0) + + # Make sure shape inference worked. + self.assertAllEqual([None, None, 2, 3], read_value.shape.as_list()) + # Writing with wrong shape should not work. + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Could not write to TensorArray"): + fed_value = np.random.random([2, 3]) + sess.run(read_value, feed_dict={value: fed_value}) + # Writing with correct shape should work. + fed_value = np.random.random([4, 5, 2, 3]) + self.assertAllClose(fed_value, + sess.run(read_value, feed_dict={value: fed_value})) + + def testTensorArrayGradWithShapeUnknownElementShape(self): + with self.test_session(use_gpu=True) as sess: + ta = tensor_array_ops.TensorArray( + size=3, dtype=dtypes.float32, + element_shape=None) # Note that element_shape is unknown + handle, flow = data_flow_ops.tensor_array_grad_with_shape( + handle=ta.handle, + flow_in=ta.flow, + shape_to_prepend=tensor_shape.TensorShape([4, 5]), + source="source") + ta_grad = tensor_array_ops.TensorArray( + dtypes.float32, handle=handle, flow=flow) + value = array_ops.placeholder(dtypes.float32) + ta_grad = ta_grad.write(0, value) + read_value = ta_grad.read(0) + + # Make sure shape inference worked. + self.assertIsNone(read_value.shape.ndims) + # Write with some shape and check read value. + fed_value = np.random.random([4, 5, 7]) + self.assertAllClose(fed_value, + sess.run(read_value, feed_dict={value: fed_value})) + + @test_util.run_in_graph_and_eager_modes def testMultiTensorArray(self): with self.test_session(use_gpu=True): h1 = tensor_array_ops.TensorArray( @@ -652,7 +706,7 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientWritePackConcatAndRead(self): self._testTensorArrayGradientWritePackConcatAndRead() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadTwice(self): with self.test_session(use_gpu=True): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) @@ -757,14 +811,14 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientDynamicUnpackRead(self): self._testTensorArrayGradientDynamicUnpackRead() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCloseTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) self.evaluate(ta.close()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSizeTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -772,7 +826,7 @@ class TensorArrayTest(test.TestCase): s = ta.size() self.assertAllEqual(3, self.evaluate(s)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWriteCloseTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -870,7 +924,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(grad_val.sum(axis=0), var_grad_t) self.assertAllClose(grad_val.sum(axis=0), state0_grad_t) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWhileLoopWritePackGradients(self): self._testWhileLoopWritePackGradients( dynamic_size=False, dtype=dtypes.float32) @@ -882,7 +936,7 @@ class TensorArrayTest(test.TestCase): self._testWhileLoopWritePackGradients( dynamic_size=True, dtype=dtypes.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradSerialTwoLoops(self): with self.test_session(use_gpu=True): def loop(x): @@ -1059,7 +1113,7 @@ class TensorArrayTest(test.TestCase): r5 = w5.read(0) self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testUnpackShape(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -1093,7 +1147,7 @@ class TensorArrayTest(test.TestCase): def testUnpackShape(self): self._testUnpackShape() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitShape(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -1235,7 +1289,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([10.0, -10.0], read_vals[1]) self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteGatherAndGradients(self): with self.test_session(use_gpu=True) as session: ta = tensor_array_ops.TensorArray( @@ -1379,7 +1433,7 @@ class TensorArrayTest(test.TestCase): self.assertFalse( [s for s in dev_stats[d] if "/TensorArray" in s.node_name]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayIdentity(self): with self.test_session(use_gpu=True): ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 9dc4ec0f9625ccf399807316c9c46309432bb2e7..054c6f9dd79156bc4b4f3179528fe56235fdf369 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -57,7 +57,7 @@ class VariableScopeTest(test.TestCase): v1 = vs.get_variable("v", [1]) self.assertEqual(v, v1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResource(self): vs = variable_scope._get_default_variable_store() v1 = vs.get_variable("v", [1], use_resource=True) @@ -87,7 +87,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( set(expected_names), set([v.name for v in vs._vars.values()])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeInitializer(self): init = init_ops.constant_initializer(0.3) with variable_scope.variable_scope("tower0") as tower: @@ -100,7 +100,7 @@ class VariableScopeTest(test.TestCase): self.evaluate(variables_lib.variables_initializer([w])) self.assertAllClose(self.evaluate(w.value()), 0.3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeConstraint(self): constraint = lambda x: 0. * x with variable_scope.variable_scope("tower1") as tower: @@ -117,7 +117,7 @@ class VariableScopeTest(test.TestCase): variables_lib.global_variables_initializer().run() self.assertAllEqual(compat.as_bytes(v.eval()), b"") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeDType(self): with variable_scope.variable_scope("tower2") as tower: with variable_scope.variable_scope("foo", dtype=dtypes.float16): @@ -197,7 +197,33 @@ class VariableScopeTest(test.TestCase): self.assertAllEqual([v1, v2], [v3, v4]) f() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def testEagerVariablesStoreAddsToCollections(self): + store = variable_scope.EagerVariableStore() + with store.as_default(): + trainable = variable_scope.get_variable("v1", [], trainable=True) + not_trainable = variable_scope.get_variable("v2", [], trainable=False) + concat = variable_scope.get_variable( + "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES]) + self.assertEqual( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), + [trainable, not_trainable]) + self.assertEqual( + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), + [trainable, concat]) + self.assertEqual( + ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat]) + + @test_util.run_in_graph_and_eager_modes + def testEagerVariablesOutsideStoreNotAddedToCollections(self): + if not context.executing_eagerly(): + return + variable_scope.get_variable("v1", [], trainable=True) + variable_scope.get_variable("v2", [], trainable=False) + self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + + @test_util.run_in_graph_and_eager_modes def testInitFromNonTensorValue(self): v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) self.evaluate(variables_lib.variables_initializer([v])) @@ -213,7 +239,7 @@ class VariableScopeTest(test.TestCase): with self.assertRaises(error): variable_scope.get_variable("x4", initializer={}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFromNonInitializer(self): # Test various dtypes with zeros initializer as following: types = [ @@ -268,7 +294,7 @@ class VariableScopeTest(test.TestCase): v_tower = variable_scope.get_variable("v", []) self.assertFalse(v_tower.value().device.startswith(caching_device)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeRegularizer(self): init = init_ops.constant_initializer(0.3) @@ -313,7 +339,7 @@ class VariableScopeTest(test.TestCase): losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(3, len(losses)) # No new loss added. - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeFromValue(self): init = constant_op.constant(0.1) w = variable_scope.get_variable("v", initializer=init) @@ -402,7 +428,7 @@ class VariableScopeTest(test.TestCase): sess.run(v0.initializer) sess.run(add) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetVariableScope(self): # Test the get_variable_scope() function and setting properties of result. init = init_ops.constant_initializer(0.3) @@ -423,7 +449,7 @@ class VariableScopeTest(test.TestCase): new_init = variable_scope.get_variable_scope().initializer self.assertEqual(new_init, None) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScope(self): with variable_scope.variable_scope("tower4") as tower: self.assertEqual(tower.name, "tower4") @@ -442,7 +468,7 @@ class VariableScopeTest(test.TestCase): with ops.name_scope("scope") as sc: self.assertEqual(sc, "tower6/tower4/scope/") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeNameScope(self): with ops.name_scope("testVarScopeNameScope1"): with variable_scope.variable_scope("tower") as tower: @@ -935,7 +961,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( constant_op.constant([], name="c").name, "another/inner/c:0") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetLocalVar(self): # Check that local variable respects naming. with variable_scope.variable_scope("outer") as outer: @@ -1227,6 +1253,31 @@ class VariableScopeWithCustomGetterTest(test.TestCase): self.assertEqual(v3, v4) self.assertEqual(3, called[0]) # skipped one in the first new_scope + def testSynchronizationAndAggregationWithCustomGetter(self): + called = [0] + synchronization = variable_scope.VariableSynchronization.AUTO + aggregation = variable_scope.VariableAggregation.NONE + + def custom_getter(getter, *args, **kwargs): + called[0] += 1 + + # Verify synchronization and aggregation kwargs are as expected. + self.assertEqual(kwargs["synchronization"], synchronization) + self.assertEqual(kwargs["aggregation"], aggregation) + return getter(*args, **kwargs) + + with variable_scope.variable_scope("scope", custom_getter=custom_getter): + variable_scope.get_variable("v", [1]) + self.assertEqual(1, called[0]) + + with variable_scope.variable_scope("scope", custom_getter=custom_getter): + synchronization = variable_scope.VariableSynchronization.ON_READ + aggregation = variable_scope.VariableAggregation.MEAN + variable_scope.get_variable( + "v1", [1], synchronization=synchronization, aggregation=aggregation) + + self.assertEqual(2, called[0]) + def testCustomGetterWithReuse(self): # Custom getter can choose to behave differently on reused variables. def custom_getter(getter, *args, **kwargs): @@ -1329,6 +1380,23 @@ class VariableScopeWithCustomGetterTest(test.TestCase): self.assertAllEqual(variable_names, ["forced_name"]) + called = [False] + + def creater_c(next_creator, **kwargs): + called[0] = True + self.assertEqual(kwargs["synchronization"], + variable_scope.VariableSynchronization.ON_WRITE) + self.assertEqual(kwargs["aggregation"], + variable_scope.VariableAggregation.MEAN) + return next_creator(**kwargs) + + with variable_scope.variable_creator_scope(creater_c): + variable_scope.get_variable( + "v", [], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + self.assertTrue(called[0]) + class PartitionInfoTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 27599868b74be323189b872c2147c6a33f84d170..62d596da91682c396c04efbc64cf063c8e29e7cc 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase): with self.assertRaises(ValueError): sess.run(v.initialized_value()) + def testTrainableInProto(self): + with ops.Graph().as_default(): + non_trainable_variable = variables.Variable( + trainable=False, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + False, + variables.Variable(variable_def=non_trainable_variable.to_proto()) + .trainable) + trainable_variable = variables.Variable( + trainable=True, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + True, + variables.Variable(variable_def=trainable_variable.to_proto()) + .trainable) + def testLoad(self): with self.test_session(): var = variables.Variable(np.zeros((5, 5), np.float32)) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 1cf7d2abd16dcbfeefaebc1b4f5363eac74caac5..b8969a41aba1f8ee84233ce7ac398193183d292f 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -22,7 +22,7 @@ import copy from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util import function_utils @@ -191,7 +191,7 @@ class Layer(base_layer.Layer): RuntimeError: If called with partioned variable regularization and eager execution is enabled. """ - + def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): for var in variable: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index ab49e37b90e183034ae7ab720fa92b06f39b2aed..298e96e711cbf8a0f625f95d737d1e7a83f4431d 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test class BaseLayerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerProperties(self): layer = base_layers.Layer(name='my_layer') self.assertEqual(layer.variables, []) @@ -53,13 +53,13 @@ class BaseLayerTest(test.TestCase): layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInt64Layer(self): layer = base_layers.Layer(name='my_layer', dtype='int64') layer.add_variable('my_var', [2, 2]) self.assertEqual(layer.name, 'my_layer') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAddWeight(self): layer = base_layers.Layer(name='my_layer') @@ -116,7 +116,7 @@ class BaseLayerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCall(self): class MyLayer(base_layers.Layer): @@ -132,7 +132,7 @@ class BaseLayerTest(test.TestCase): # op is only supported in GRAPH mode self.assertEqual(outputs.op.name, 'my_layer/Square') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeepCopy(self): class MyLayer(base_layers.Layer): @@ -155,7 +155,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer_copy._graph, layer._graph) self.assertEqual(layer_copy._private_tensor, layer._private_tensor) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScopeNaming(self): class PrivateLayer(base_layers.Layer): @@ -203,7 +203,7 @@ class BaseLayerTest(test.TestCase): my_layer_scoped1.apply(inputs) self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -230,7 +230,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1], [2]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecMinNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -258,7 +258,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[[1], [2]]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecMaxNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -286,7 +286,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1], [2]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecDtypeCheck(self): class CustomerLayer(base_layers.Layer): @@ -306,7 +306,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant(1.0, dtype=dtypes.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecAxesCheck(self): class CustomerLayer(base_layers.Layer): @@ -328,7 +328,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecShapeCheck(self): class CustomerLayer(base_layers.Layer): @@ -348,7 +348,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoInputSpec(self): class CustomerLayer(base_layers.Layer): @@ -369,7 +369,7 @@ class BaseLayerTest(test.TestCase): layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32', shape=(2, 3))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_count_params(self): dense = core_layers.Dense(16) dense.build((None, 4)) @@ -379,7 +379,7 @@ class BaseLayerTest(test.TestCase): with self.assertRaises(ValueError): dense.count_params() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDictInputOutput(self): class DictLayer(base_layers.Layer): @@ -589,6 +589,5 @@ class BaseLayerTest(test.TestCase): ValueError, 'Input graph and Layer graph are not the same'): layer.apply(constant_op.constant([[1.]])) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 34a1487e748e41eebae8b87b17c34d0deda8597f..36cef3855e5233bf878a7dab178cb2a5f4a779c2 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -23,7 +23,7 @@ from __future__ import print_function from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import layers as keras_layers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -217,7 +217,6 @@ def conv1d(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -421,7 +420,6 @@ def conv2d(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -627,7 +625,6 @@ def conv3d(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -1266,7 +1263,6 @@ def conv2d_transpose(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -1438,7 +1434,6 @@ def conv3d_transpose(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 6d8e9eac878bb2eb65bfa29e872a0576a39af662..aadff231dabb06a7c05446fb92f758de57a744da 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -27,7 +27,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np -from tensorflow.python.keras._impl.keras import layers as keras_layers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base from tensorflow.python.ops import init_ops from tensorflow.python.util.tf_export import tf_export @@ -184,7 +184,6 @@ def dense( bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _scope=name, _reuse=reuse) return layer.apply(inputs) diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index cf45b07637108422f1c612390bb01efdad6d5bcf..040c1cddc0f2540eec5fcf3442bed3f4800bec7c 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -41,7 +41,7 @@ from tensorflow.python.platform import test class DenseTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDenseProperties(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') self.assertEqual(dense.units, 2) @@ -91,14 +91,14 @@ class DenseTest(test.TestCase): core_layers.Dense(5)(inputs) core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallTensorDot(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 4, 3), seed=1) outputs = dense(inputs) self.assertListEqual([5, 4, 2], outputs.get_shape().as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoBias(self): dense = core_layers.Dense(2, use_bias=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) @@ -112,7 +112,7 @@ class DenseTest(test.TestCase): self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.bias, None) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNonTrainable(self): dense = core_layers.Dense(2, trainable=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) @@ -125,7 +125,7 @@ class DenseTest(test.TestCase): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOutputShape(self): dense = core_layers.Dense(7, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 3), seed=1) @@ -165,7 +165,7 @@ class DenseTest(test.TestCase): dense = core_layers.Dense(4, name='my_dense') dense(inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testActivation(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') inputs = random_ops.random_uniform((5, 3), seed=1) @@ -325,7 +325,7 @@ class DenseTest(test.TestCase): var_key = 'test2/dense/kernel' self.assertEqual(var_dict[var_key].name, '%s:0' % var_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeOutputShape(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') ts = tensor_shape.TensorShape @@ -347,7 +347,7 @@ class DenseTest(test.TestCase): dense.compute_output_shape(ts([None, 4, 3])).as_list()) # pylint: enable=protected-access - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstraints(self): k_constraint = lambda x: x / math_ops.reduce_sum(x) b_constraint = lambda x: x / math_ops.reduce_max(x) @@ -369,7 +369,7 @@ def _get_variable_dict_from_varstore(): class DropoutTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDropoutProperties(self): dp = core_layers.Dropout(0.5, name='dropout') self.assertEqual(dp.rate, 0.5) @@ -377,7 +377,7 @@ class DropoutTest(test.TestCase): dp.apply(array_ops.ones(())) self.assertEqual(dp.name, 'dropout') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBooleanLearningPhase(self): dp = core_layers.Dropout(0.5) inputs = array_ops.ones((5, 3)) @@ -402,7 +402,7 @@ class DropoutTest(test.TestCase): np_output = sess.run(dropped, feed_dict={training: False}) self.assertAllClose(np.ones((5, 5)), np_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicNoiseShape(self): inputs = array_ops.ones((5, 3, 2)) noise_shape = [None, 1, None] diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 33284b0d695272db5a4e0d757d6f24b1930068de..f7bc10a6a634d4f821894f1f07106ba340d421af 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -24,7 +24,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np -from tensorflow.python.keras._impl.keras import layers as keras_layers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base from tensorflow.python.ops import init_ops from tensorflow.python.util.tf_export import tf_export @@ -44,7 +44,7 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer): normalized, typically the features axis/axes. For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1`. If a list of axes is provided, each axis in `axis` will be normalized - simultaneously. Default is `-1` which takes uses last axis. Note: when + simultaneously. Default is `-1` which uses the last axis. Note: when using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and `moving_variance` variables are the same rank as the input Tensor, with dimension size 1 in all reduced (non-axis) dimensions). @@ -308,7 +308,6 @@ def batch_normalization(inputs, virtual_batch_size=virtual_batch_size, adjustment=adjustment, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs, training=training) diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index 75abe56f51f2a206ea3e5a5dad032446c150293a..c53cca3d312470c6fc22b4cca0bb9c76ed0865af 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -19,7 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras import layers as keras_layers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 77fa2c1f66d2214dbb08e4d0ad3437fa4fe02822..fde3a83770280038b777a141693d117dace4b41f 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) { return x != static_cast(0); } +int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) { + bfloat16* const buffer = reinterpret_cast(buffer_raw); + const float start(buffer[0]); + const float delta = static_cast(buffer[1]) - start; + for (npy_intp i = 2; i < length; ++i) { + buffer[i] = static_cast(start + i * delta); + } + return 0; +} + // NumPy casts // Performs a NumPy array cast from type 'From' to 'To'. @@ -548,6 +558,7 @@ bool Initialize() { NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN; NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap; NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero; + NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill; Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr); diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 09d4b01fa43babdc09f8f255e79bbed539ddc04c..bc928cd9e5ef4d5a0ec0ce73e853e3e022a1f6fa 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase): np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)), atol=2e-2) + def testArange(self): + self.assertAllEqual( + np.arange(100, dtype=np.float32).astype(bfloat16), + np.arange(100, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16), + np.arange(-10.5, 7.8, 0.5, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16), + np.arange(-0., -7., -0.25, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16), + np.arange(-16384., 16384., 64., dtype=bfloat16)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 9df38d464ca6ad40f22b720902e1c6f127cf846d..ec1ba7b8f7d611ad659ac483505a7d86bf4b31e5 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -312,6 +312,40 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor, return Status::OK(); } + +inline void FastMemcpy(void* dst, const void* src, size_t size) { + // clang-format off + switch (size) { + // Most compilers will generate inline code for fixed sizes, + // which is significantly faster for small copies. + case 1: memcpy(dst, src, 1); break; + case 2: memcpy(dst, src, 2); break; + case 3: memcpy(dst, src, 3); break; + case 4: memcpy(dst, src, 4); break; + case 5: memcpy(dst, src, 5); break; + case 6: memcpy(dst, src, 6); break; + case 7: memcpy(dst, src, 7); break; + case 8: memcpy(dst, src, 8); break; + case 9: memcpy(dst, src, 9); break; + case 10: memcpy(dst, src, 10); break; + case 11: memcpy(dst, src, 11); break; + case 12: memcpy(dst, src, 12); break; + case 13: memcpy(dst, src, 13); break; + case 14: memcpy(dst, src, 14); break; + case 15: memcpy(dst, src, 15); break; + case 16: memcpy(dst, src, 16); break; +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX) && \ + !defined(IS_MOBILE_PLATFORM) + // On Linux, memmove appears to be faster than memcpy for + // large sizes, strangely enough. + default: memmove(dst, src, size); break; +#else + default: memcpy(dst, src, size); break; +#endif + } + // clang-format on +} + } // namespace // Converts the given TF_Tensor to a numpy ndarray. @@ -362,8 +396,8 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) { " bytes but TF_Tensor was ", TF_TensorByteSize(tensor.get()), " bytes"); } else { - memcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()), - PyArray_NBYTES(py_array)); + FastMemcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()), + PyArray_NBYTES(py_array)); } // PyArray_Return turns rank 0 arrays into numpy scalars @@ -377,7 +411,7 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) { // Make sure we dereference this array object in case of error, etc. Safe_PyObjectPtr array_safe(make_safe( - PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr))); + PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr))); if (!array_safe) return errors::InvalidArgument("Not a ndarray."); PyArrayObject* array = reinterpret_cast(array_safe.get()); diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h index 25322b458b8475882830599dd4ae02f10d97094b..d4621d61ee98b9eb4b19213145059d242c88f40c 100644 --- a/tensorflow/python/lib/core/numpy.h +++ b/tensorflow/python/lib/core/numpy.h @@ -29,7 +29,9 @@ limitations under the License. #define NO_IMPORT_ARRAY #endif +// Place `` before to avoid build failure in macOS. #include +#include #include "numpy/arrayobject.h" #include "numpy/ufuncobject.h" diff --git a/tensorflow/python/lib/core/py_exception_registry.cc b/tensorflow/python/lib/core/py_exception_registry.cc index 6637de632b48e4dfc8219543161464b10dcdbe12..d03cf8930b9e2b12d92c72678b501f6a9e659768 100644 --- a/tensorflow/python/lib/core/py_exception_registry.cc +++ b/tensorflow/python/lib/core/py_exception_registry.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/python/lib/core/py_exception_registry.h" - #include +#include "tensorflow/python/lib/core/py_exception_registry.h" + namespace tensorflow { PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr; diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 8c6bb7955a4e29daddd92860e41d7105192eb24b..57139986af7d2adc3670529d1bb22233f167ced0 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include + #include "numpy/arrayobject.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -33,8 +35,6 @@ limitations under the License. #include "tensorflow/python/lib/core/py_util.h" #include "tensorflow/python/lib/core/safe_ptr.h" -#include - namespace tensorflow { namespace { @@ -55,37 +55,35 @@ struct PyCall { string token; // The device on which Tensors are stored; only used for EagerPyFunc. - Device* device; - - // True if and only if the op has been placed on a GPU. - bool gpu; + Device* device = nullptr; // True if the call is associated with an EagerPyFunc. - bool eager; + bool eager = false; // Inputs and outputs of this function invocation. std::vector ins; std::vector out; }; +bool IsCPUDevice(const Device* d) { + return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; +} + // Givens the 'call', prepares the token and inputs as a python tuple // that is appropriate for calling the trampoline. Status MakeArgTuple(const PyCall* call, PyObject** tuple) { int64 n = call->ins.size(); PyObject* lst = PyList_New(n); CHECK(lst); + // TFE_TensorHandle assumes that CPU is identified by nullptr. + Device* device = IsCPUDevice(call->device) ? nullptr : call->device; for (int64 i = 0; i < n; ++i) { PyObject* arg = nullptr; const Tensor& t = call->ins[i]; if (call->eager) { - if (call->gpu) { - arg = EagerTensorFromHandle( - new TFE_TensorHandle(t, call->device, call->device)); - } else { - // TFE_TensorHandle assumes that CPU is identified by `nullptr`. - arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr, nullptr)); - } + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, device, device)); if (arg == nullptr) { + Py_DECREF(lst); return errors::Internal("Unable to procure EagerTensor from Tensor."); } } else { @@ -97,8 +95,9 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { } PyList_SetItem(lst, i, arg); } - *tuple = Py_BuildValue("(sON)", call->token.c_str(), - call->gpu ? Py_True : Py_False, lst); + const char* device_name = + device == nullptr ? nullptr : device->attributes().name().c_str(); + *tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst); CHECK(*tuple); return Status::OK(); } @@ -167,9 +166,40 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. +// Validates that `output_tensor` is backed by memory in `expected_device` +// (which is assumed to be a local device, one on which the kernel was +// executed.) +// +// It may be nice to copy the tensor to the right device instead of failing if +// it isn't already there. This is left as a future exercise. The required +// device-copying logic is implemented in Python at the moment. tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + const Device* expected_device, const Tensor** output_tensor) { - return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor); + auto handle = EagerTensor_Handle(eager_tensor)->handle; + Device* actual_device = nullptr; + TF_RETURN_IF_ERROR(handle->Device(&actual_device)); + TF_RETURN_IF_ERROR(handle->Tensor(output_tensor)); + // actual_device may be nullptr, which implies local CPU. + if (expected_device == actual_device) return Status::OK(); + const string& expected_device_name = expected_device->attributes().name(); + if (actual_device == nullptr) { + if (!IsCPUDevice(expected_device)) { + return errors::Internal( + "expected the py_func to return a Tensor backed by memory in ", + expected_device_name, + ", but is actually backed by local host memory. This is a bug."); + } + return Status::OK(); + } + const string& actual_device_name = actual_device->attributes().name(); + if (actual_device_name != expected_device_name) { + return errors::Internal( + "expected the py_func to return a Tensor backed by memory in ", + expected_device_name, ", but is actually in ", actual_device_name, + ". This is a bug."); + } + return Status::OK(); } // Calls the registered py function through the trampoline. @@ -224,7 +254,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { const PyObject* item = PyList_GetItem(result, i); if (EagerTensor_CheckExact(item)) { const Tensor* tensor = nullptr; - s = ExtractTensorFromEagerTensor(item, &tensor); + s = ExtractTensorFromEagerTensor(item, call->device, &tensor); if (s.ok()) t = *tensor; } else { s = errors::FailedPrecondition( @@ -245,7 +275,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { DCHECK(call->eager); if (result != Py_None) { const Tensor* t = nullptr; - s = ExtractTensorFromEagerTensor(result, &t); + s = ExtractTensorFromEagerTensor(result, call->device, &t); if (s.ok()) call->out.push_back(*t); } } else if (PyArray_Check(result)) { @@ -449,13 +479,11 @@ class PyFuncOp : public OpKernel { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); eager_ = type_string() == "EagerPyFunc"; - gpu_ = ctx->device_type().type_string() == DEVICE_GPU; } void Compute(OpKernelContext* ctx) override { PyCall call; call.token = token_; - call.gpu = gpu_; call.eager = eager_; if (call.eager) { // Eager's C API uses `Device`, whereas `OpKernelContext` stores a @@ -464,6 +492,7 @@ class PyFuncOp : public OpKernel { if (call.device == nullptr) { ctx->CtxFailureWithWarning( errors::Internal("Unrecognized device class")); + return; } } @@ -508,9 +537,6 @@ class PyFuncOp : public OpKernel { private: string token_; - // True if and only if this op has been placed on a GPU. - bool gpu_; - // True if and only if this op should execute the python function eagerly, // i.e., if and only if the eager attribute is set. bool eager_; diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 32ea737a99067877e7f527e44d261a0b7c2eb07e..3b4f12ae31b9e905ed15e86533e648b4c95736e1 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) { #endif } +bool IsPyDouble(PyObject* obj) { + return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type. +} + bool IsPyFloat(PyObject* obj) { return PyFloat_Check(obj) || PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types @@ -84,6 +88,41 @@ bool IsPyDimension(PyObject* obj) { return ret; } +// Sets *elem to a NEW reference to an element in seq on success. +// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0. +Status SampleElementFromSequence(PyObject* seq, PyObject** elem) { + *elem = PySequence_GetItem(seq, 0); + if (*elem != nullptr) return Status::OK(); + // seq may implement the sequence protocol (i.e., implement __getitem__) + // but may legitimately not have a 0-th element (__getitem__(self, 0) + // raises a KeyError). For example: + // seq = pandas.Series([0, 1, 2], index=[2, 4, 6]) + // + // We don't actually care for the element at key 0, any element will do + // for inferring the element types. All elements are expected to + // have the same type, and this will be validated when converting + // to an EagerTensor. + PyErr_Clear(); + Safe_PyObjectPtr iter(PyObject_GetIter(seq)); + if (PyErr_Occurred()) { + return errors::InvalidArgument("Cannot infer dtype of a ", + Py_TYPE(seq)->tp_name, + " object: ", PyExceptionFetch()); + } + *elem = PyIter_Next(iter.get()); + if (PyErr_Occurred()) { + return errors::InvalidArgument( + "Cannot infer dtype of a ", Py_TYPE(seq)->tp_name, + " object, as iter().next() failed: ", PyExceptionFetch()); + } + if (*elem == nullptr) { + return errors::InvalidArgument("Cannot infer dtype of a ", + Py_TYPE(seq)->tp_name, + " object since it is an empty sequence"); + } + return Status::OK(); +} + Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { std::vector refs_to_clean; while (true) { @@ -94,7 +133,9 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { auto length = PySequence_Length(obj); if (length > 0) { shape->AddDim(length); - obj = PySequence_GetItem(obj, 0); + PyObject* elem = nullptr; + TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem)); + obj = elem; refs_to_clean.push_back(make_safe(obj)); continue; } else if (length == 0) { @@ -113,8 +154,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { "Attempted to convert an invalid sequence to a Tensor."); } } - } else if (IsPyFloat(obj)) { + } else if (IsPyDouble(obj)) { *dtype = DT_DOUBLE; + } else if (IsPyFloat(obj)) { + *dtype = DT_FLOAT; } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { // Have to test for bool before int, since IsInt(True/False) == true. *dtype = DT_BOOL; @@ -433,7 +476,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { break; } switch (infer_dtype) { - case DT_DOUBLE: + case DT_FLOAT: // TODO(josh11b): Handle mixed floats and complex numbers? if (requested_dtype == DT_INVALID) { // TensorFlow uses float32s to represent floating point numbers @@ -446,7 +489,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { // final type. RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); } - + case DT_DOUBLE: + RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); case DT_INT64: if (requested_dtype == DT_INVALID) { const char* error = ConvertInt32(obj, shape, ret); diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 00cbf0c532cf80d3bb27afe168ecde963ba3591d..6b6c82015fd2b73e410d64306ecbd613ccf1967c 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/python/lib/core/py_util.h" +// Place `` before to avoid build failure in macOS. +#include +#include + #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" -#include namespace tensorflow { namespace { diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h index 32d286888666bde8742403bb8e231b3d6d4bf695..35d71f7629e54027eb3a8bdd7f4275e325ac7f11 100644 --- a/tensorflow/python/lib/core/safe_ptr.h +++ b/tensorflow/python/lib/core/safe_ptr.h @@ -19,6 +19,7 @@ limitations under the License. #include #include + #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py index 59f5075f177ef5335115cb4f24182d28a9b547c8..f22fb253e4d59813226f0e9741cabcfbf0cdcd1a 100644 --- a/tensorflow/python/lib/io/file_io.py +++ b/tensorflow/python/lib/io/file_io.py @@ -21,6 +21,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import binascii import os import uuid @@ -33,6 +34,10 @@ from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +# A good default block size depends on the system in question. +# A somewhat conservative default chosen here. +_DEFAULT_BLOCK_SIZE = 16 * 1024 * 1024 + class FileIO(object): """FileIO class that exposes methods to read / write to / from files. @@ -551,3 +556,56 @@ def stat(filename): with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.Stat(compat.as_bytes(filename), file_statistics, status) return file_statistics + + +def filecmp(filename_a, filename_b): + """Compare two files, returning True if they are the same, False otherwise. + + We check size first and return False quickly if the files are different sizes. + If they are the same size, we continue to generating a crc for the whole file. + + You might wonder: why not use Python's filecmp.cmp() instead? The answer is + that the builtin library is not robust to the many different filesystems + TensorFlow runs on, and so we here perform a similar comparison with + the more robust FileIO. + + Args: + filename_a: string path to the first file. + filename_b: string path to the second file. + + Returns: + True if the files are the same, False otherwise. + """ + size_a = FileIO(filename_a, "rb").size() + size_b = FileIO(filename_b, "rb").size() + if size_a != size_b: + return False + + # Size is the same. Do a full check. + crc_a = file_crc32(filename_a) + crc_b = file_crc32(filename_b) + return crc_a == crc_b + + +def file_crc32(filename, block_size=_DEFAULT_BLOCK_SIZE): + """Get the crc32 of the passed file. + + The crc32 of a file can be used for error checking; two files with the same + crc32 are considered equivalent. Note that the entire file must be read + to produce the crc32. + + Args: + filename: string, path to a file + block_size: Integer, process the files by reading blocks of `block_size` + bytes. Use -1 to read the file as once. + + Returns: + hexadecimal as string, the crc32 of the passed file. + """ + crc = 0 + with FileIO(filename, mode="rb") as f: + chunk = f.read(n=block_size) + while chunk: + crc = binascii.crc32(chunk, crc) + chunk = f.read(n=block_size) + return hex(crc & 0xFFFFFFFF) diff --git a/tensorflow/python/lib/io/file_io_test.py b/tensorflow/python/lib/io/file_io_test.py index 223858edfa84eaa1c7879a9774dcc836de4f4672..c21eb931037f1728149456d62b1534f59527cfdb 100644 --- a/tensorflow/python/lib/io/file_io_test.py +++ b/tensorflow/python/lib/io/file_io_test.py @@ -491,5 +491,96 @@ class FileIoTest(test.TestCase): v = file_io.file_exists(file_path) self.assertEqual(v, True) + def testFilecmp(self): + file1 = os.path.join(self._base_dir, "file1") + file_io.write_string_to_file(file1, "This is a sentence\n" * 100) + + file2 = os.path.join(self._base_dir, "file2") + file_io.write_string_to_file(file2, "This is another sentence\n" * 100) + + file3 = os.path.join(self._base_dir, "file3") + file_io.write_string_to_file(file3, u"This is another sentence\n" * 100) + + self.assertFalse(file_io.filecmp(file1, file2)) + self.assertTrue(file_io.filecmp(file2, file3)) + + def testFilecmpSameSize(self): + file1 = os.path.join(self._base_dir, "file1") + file_io.write_string_to_file(file1, "This is a sentence\n" * 100) + + file2 = os.path.join(self._base_dir, "file2") + file_io.write_string_to_file(file2, "This is b sentence\n" * 100) + + file3 = os.path.join(self._base_dir, "file3") + file_io.write_string_to_file(file3, u"This is b sentence\n" * 100) + + self.assertFalse(file_io.filecmp(file1, file2)) + self.assertTrue(file_io.filecmp(file2, file3)) + + def testFilecmpBinary(self): + file1 = os.path.join(self._base_dir, "file1") + file_io.FileIO(file1, "wb").write("testing\n\na") + + file2 = os.path.join(self._base_dir, "file2") + file_io.FileIO(file2, "wb").write("testing\n\nb") + + file3 = os.path.join(self._base_dir, "file3") + file_io.FileIO(file3, "wb").write("testing\n\nb") + + file4 = os.path.join(self._base_dir, "file4") + file_io.FileIO(file4, "wb").write("testing\n\ntesting") + + self.assertFalse(file_io.filecmp(file1, file2)) + self.assertFalse(file_io.filecmp(file1, file4)) + self.assertTrue(file_io.filecmp(file2, file3)) + + def testFileCrc32(self): + file1 = os.path.join(self._base_dir, "file1") + file_io.write_string_to_file(file1, "This is a sentence\n" * 100) + crc1 = file_io.file_crc32(file1) + + file2 = os.path.join(self._base_dir, "file2") + file_io.write_string_to_file(file2, "This is another sentence\n" * 100) + crc2 = file_io.file_crc32(file2) + + file3 = os.path.join(self._base_dir, "file3") + file_io.write_string_to_file(file3, "This is another sentence\n" * 100) + crc3 = file_io.file_crc32(file3) + + self.assertTrue(crc1 != crc2) + self.assertEqual(crc2, crc3) + + def testFileCrc32WithBytes(self): + file1 = os.path.join(self._base_dir, "file1") + file_io.write_string_to_file(file1, "This is a sentence\n" * 100) + crc1 = file_io.file_crc32(file1, block_size=24) + + file2 = os.path.join(self._base_dir, "file2") + file_io.write_string_to_file(file2, "This is another sentence\n" * 100) + crc2 = file_io.file_crc32(file2, block_size=24) + + file3 = os.path.join(self._base_dir, "file3") + file_io.write_string_to_file(file3, "This is another sentence\n" * 100) + crc3 = file_io.file_crc32(file3, block_size=-1) + + self.assertTrue(crc1 != crc2) + self.assertEqual(crc2, crc3) + + def testFileCrc32Binary(self): + file1 = os.path.join(self._base_dir, "file1") + file_io.FileIO(file1, "wb").write("testing\n\n") + crc1 = file_io.file_crc32(file1) + + file2 = os.path.join(self._base_dir, "file2") + file_io.FileIO(file2, "wb").write("testing\n\n\n") + crc2 = file_io.file_crc32(file2) + + file3 = os.path.join(self._base_dir, "file3") + file_io.FileIO(file3, "wb").write("testing\n\n\n") + crc3 = file_io.file_crc32(file3) + + self.assertTrue(crc1 != crc2) + self.assertEqual(crc2, crc3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc1a25f420b434e6aa7d37cdf65f693e4d8c01a --- /dev/null +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -0,0 +1,322 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf_record.TFRecordWriter and tf_record.tf_record_iterator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +import six + +from tensorflow.python.framework import errors_impl +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +prefix_path = "third_party/tensorflow/core/lib" + +# pylint: disable=invalid-name +TFRecordCompressionType = tf_record.TFRecordCompressionType +# pylint: enable=invalid-name + +# Edgar Allan Poe's 'Eldorado' +_TEXT = b"""Gaily bedight, + A gallant knight, + In sunshine and in shadow, + Had journeyed long, + Singing a song, + In search of Eldorado. + + But he grew old + This knight so bold + And o'er his heart a shadow + Fell as he found + No spot of ground + That looked like Eldorado. + + And, as his strength + Failed him at length, + He met a pilgrim shadow + 'Shadow,' said he, + 'Where can it be + This land of Eldorado?' + + 'Over the Mountains + Of the Moon' + Down the Valley of the Shadow, + Ride, boldly ride,' + The shade replied, + 'If you seek for Eldorado!' + """ + + +class TFCompressionTestCase(test.TestCase): + + def setUp(self): + super(TFCompressionTestCase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self, options=None, prefix=""): + filenames = [] + for i in range(self._num_files): + name = prefix + "tfrecord.%d.txt" % i + records = [self._Record(i, j) for j in range(self._num_records)] + fn = self._WriteRecordsToFile(records, name, options) + filenames.append(fn) + return filenames + + def _WriteRecordsToFile(self, records, name="tfrecord", options=None): + fn = os.path.join(self.get_temp_dir(), name) + with tf_record.TFRecordWriter(fn, options=options) as writer: + for r in records: + writer.write(r) + return fn + + def _ZlibCompressFile(self, infile, name="tfrecord.z"): + # zlib compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), name) + with open(zfn, "wb") as f: + f.write(cdata) + return zfn + + def _GzipCompressFile(self, infile, name="tfrecord.gz"): + # gzip compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = f.read() + + gzfn = os.path.join(self.get_temp_dir(), name) + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + return gzfn + + def _ZlibDecompressFile(self, infile, name="tfrecord"): + with open(infile, "rb") as f: + cdata = zlib.decompress(f.read()) + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + def _GzipDecompressFile(self, infile, name="tfrecord"): + with gzip.GzipFile(infile, "rb") as f: + cdata = f.read() + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + +class TFRecordWriterTest(TFCompressionTestCase): + + def setUp(self): + super(TFRecordWriterTest, self).setUp() + + def _AssertFilesEqual(self, a, b, equal): + for an, bn in zip(a, b): + with open(an, "rb") as af, open(bn, "rb") as bf: + if equal: + self.assertEqual(af.read(), bf.read()) + else: + self.assertNotEqual(af.read(), bf.read()) + + def testWriteReadZLibFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + zlib_files = [ + self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, zlib_files, False) + + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + compressed_files = self._CreateFiles(options, prefix="compressed") + self._AssertFilesEqual(compressed_files, zlib_files, True) + + # Decompress compress and verify same. + uncompressed_files = [ + self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + def testWriteReadGzipFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + gzip_files = [ + self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, gzip_files, False) + + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + compressed_files = self._CreateFiles(options, prefix="compressed") + + # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so + # compressed_files can't be compared with gzip_files + + # Decompress compress and verify same. + uncompressed_files = [ + self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + +class TFRecordWriterZlibTest(TFCompressionTestCase): + + def testZLibFlushRecord(self): + original = [b"small record"] + fn = self._WriteRecordsToFile(original, "small_record") + with open(fn, "rb") as h: + buff = h.read() + + # creating more blocks and trailing blocks shouldn't break reads + compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) + + output = b"" + for c in buff: + if isinstance(c, int): + c = six.int2byte(c) + output += compressor.compress(c) + output += compressor.flush(zlib.Z_FULL_FLUSH) + + output += compressor.flush(zlib.Z_FULL_FLUSH) + output += compressor.flush(zlib.Z_FULL_FLUSH) + output += compressor.flush(zlib.Z_FINISH) + + # overwrite the original file with the compressed data + with open(fn, "wb") as h: + h.write(output) + + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(fn, options=options)) + self.assertEqual(actual, original) + + def testZlibReadWrite(self): + """Verify that files produced are zlib compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") + zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") + + # read the compressed contents and verify. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(zfn, options=options)) + self.assertEqual(actual, original) + + def testZlibReadWriteLarge(self): + """Verify that writing large contents also works.""" + + # Make it large (about 5MB) + original = [_TEXT * 10240] + fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") + zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") + + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(zfn, options=options)) + self.assertEqual(actual, original) + + def testGzipReadWrite(self): + """Verify that files produced are gzip compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") + gzfn = self._GzipCompressFile(fn, "tfrecord.gz") + + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + actual = list(tf_record.tf_record_iterator(gzfn, options=options)) + self.assertEqual(actual, original) + + +class TFRecordIteratorTest(TFCompressionTestCase): + + def setUp(self): + super(TFRecordIteratorTest, self).setUp() + self._num_records = 7 + + def testIterator(self): + records = [self._Record(0, i) for i in range(self._num_records)] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(records, "compressed_records", options) + + reader = tf_record.tf_record_iterator(fn, options) + for expected in records: + record = next(reader) + self.assertAllEqual(expected, record) + with self.assertRaises(StopIteration): + record = next(reader) + + def testWriteZlibRead(self): + """Verify compression with TFRecordWriter is zlib library compatible.""" + original = [b"foo", b"bar"] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", + options) + + zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) + self.assertEqual(actual, original) + + def testWriteZlibReadLarge(self): + """Verify compression for large records is zlib library compatible.""" + # Make it large (about 5MB) + original = [_TEXT * 10240] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) + self.assertEqual(actual, original) + + def testWriteGzipRead(self): + original = [b"foo", b"bar"] + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", + options) + + gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") + actual = list(tf_record.tf_record_iterator(gzfn)) + self.assertEqual(actual, original) + + def testBadFile(self): + """Verify that tf_record_iterator throws an exception on bad TFRecords.""" + fn = os.path.join(self.get_temp_dir(), "bad_file") + with tf_record.TFRecordWriter(fn) as writer: + writer.write(b"123") + fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") + with open(fn, "rb") as f: + with open(fn_truncated, "wb") as f2: + # DataLossError requires that we've written the header, so this must + # be at least 12 bytes. + f2.write(f.read(14)) + with self.assertRaises(errors_impl.DataLossError): + for _ in tf_record.tf_record_iterator(fn_truncated): + pass + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3678bd4c1f6a4500622b6d9e8334cb1ebae46578..fe459a96b98733f8a706b0c3b84000c5a74894ad 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -568,7 +568,6 @@ ops.NotDifferentiable("Size") @ops.RegisterGradient("Tile") def _TileGrad(op, grad): """Sum reduces grad along the tiled dimensions.""" - assert isinstance(grad, ops.Tensor) input_shape = array_ops.shape(op.inputs[0]) # We interleave multiples and input_shape to get split_shape, # reshape grad to split_shape, and reduce along all even @@ -581,6 +580,13 @@ def _TileGrad(op, grad): split_shape = array_ops.reshape( array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) axes = math_ops.range(0, array_ops.size(split_shape), 2) + # Sum reduces grad along the first dimension for IndexedSlices + if isinstance(grad, ops.IndexedSlices): + grad = math_ops.unsorted_segment_sum( + grad.values, + math_ops.mod(grad.indices, input_shape[0]), + input_shape[0]) + split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference if not context.executing_eagerly(): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c639c6b932ccfae4bca8f8ceeead0a0da39fd327..361667ec49aba9705787c3c7ac096add36afb40b 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_array_ops import * +from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -1623,7 +1624,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, + `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, `complex64`, `complex128` or `bool`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' @@ -1897,7 +1898,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl and paddings_constant is not None): new_shape = [] for padding, dim in zip(paddings_constant, input_shape.as_list()): - if padding is None or dim is None or not all(padding): + if padding is None or dim is None or any((x is None for x in padding)): new_shape.append(None) else: new_shape.append(sum(padding) + dim) @@ -2609,14 +2610,6 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") -@tf_export("reverse") -def reverse(tensor, axis, name=None): - return gen_array_ops.reverse_v2(tensor, axis, name) - - -reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ - - # pylint: disable=redefined-builtin @tf_export("reverse_sequence") @deprecation.deprecated_args( diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py index 2a2bcdd9d69b7a0aed1e7f3d3197cf6d7dd98451..868a4f6b84df2c0d1b8b55a254f16f1be5ee1f1d 100644 --- a/tensorflow/python/ops/boosted_trees_ops.py +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -25,6 +25,8 @@ from tensorflow.python.ops import resources # Re-exporting ops used by other modules. # pylint: disable=unused-import from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index cabc1e724cdb667f4d0c5059ff1d78854a45b30c..375a5ec2c30069c955152e590b3edea0319de73a 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -341,8 +341,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): y_sum, y_np[:y_sum])) index_and_values_str = '' - if x.shape == y.shape: - # If the shapes of x and y are the same, + if x.shape == y.shape and x.shape.as_list(): + # If the shapes of x and y are the same (and not scalars), # Get the values that actually differed and their indices. # If shapes are different this information is more confusing # than useful. diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..98668facd5bc56892fa00f258dfebcbe93c063da --- /dev/null +++ b/tensorflow/python/ops/collective_ops.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow collective Ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import device +from tensorflow.python.ops import gen_collective_ops + + +def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op, + subdiv_offsets=(0,)): + """Reduces tensors collectively, across devices. + + Args: + t: the tensor to be reduced. + group_size: the total number of tensors to be collectively reduced. + Each must reside on a different device. + group_key: an integer identifying the group of devices. + instance_key: an integer identifying the participating group of Ops. + merge_op: string naming the binary Op to be applied to compute each + partial reduction. + final_op: string naming the unary Op to be applied to each fully + reduced value. Can be 'Id' for no operation. + subdiv_offsets: a list of integer offsets into the tensor at which each + independent subdivision should begin. Use [0] if no subdivision should + be done. + + Returns: + An Op implementing the distributed reduction. + + Raises: + ValueError: if any of the input parameter constraints are not met. + """ + if not device.canonical_name(t.device): + raise ValueError('Device assignment required for collective ops') + if group_size <= 1: + raise ValueError('Parameter group_size to add_reduce must be at least 2.') + return gen_collective_ops.collective_reduce(t, + group_size=group_size, + group_key=group_key, + instance_key=instance_key, + merge_op=merge_op, + final_op=final_op, + subdiv_offsets=subdiv_offsets) + + +def broadcast_send(t, shape, dtype, group_size, group_key, instance_key): + """Broadcasts one tensor to a group of others, across devices. + + Args: + t: the tensor to be sent. + shape: the shape of the tensor being sent, which must agree with t. + dtype: the type of the tensor being sent, which must agree with t. + group_size: one plus the number of receiving tensors, i.e. the total + number of devices participating. Each tensor must reside on a + different device. + group_key: an integer identifying the group of devices. + instance_key: an integer identifying the participating group of Ops. + + Returns: + An Op implementing the distributed broadcast send. + + Raises: + ValueError: if any of the input parameter constraints are not met. + + Note that the shape and dtype arguments appear redundant since they + should be obtainable from t. The are two reasons for including + them. First, the shape and type of tensors passed via broadcast must + be known ahead of time in their most specific form so that the receive + side can allocate memory for the operation and shape/type inference can + carry forward from there. Including the same declarations on the + send side clarifies a commitment already made. Secondly, having nearly + identical use syntax for send and receive sides may simplify tool-driven + generation of broadcast. + """ + if not device.canonical_name(t.device): + raise ValueError('Device assignment required for collective ops') + if group_size <= 1: + raise ValueError( + 'Parameter group_size to broadcast_send must be at least 2.') + if t.shape != shape: + raise ValueError( + 'Shape of broadcast_send tensor not equal to delcared shape') + if t.dtype != dtype: + raise ValueError( + 'Type of broadcast_send tensor not equal to declared type') + return gen_collective_ops.collective_bcast_send(t, + shape=shape, + group_size=group_size, + group_key=group_key, + instance_key=instance_key) + + +def broadcast_recv(shape, dtype, group_size, group_key, instance_key): + """Receives a broadcasts tensor, across devices. + + Args: + shape: Shape of the tensor to be received. + dtype: Type of the tensor to be received. + group_size: one plus the number of receiving tensors, i.e. the total + number of devices participating. Each tensor must reside on a + different device. + group_key: an integer identifying the group of devices. + instance_key: an integer identifying the participating group of Ops. + + Returns: + An Op implementing the broadcast receive. + + Raises: + ValueError: if any of the input parameter constraints are not met. + """ + if group_size <= 1: + raise ValueError( + 'Parameter group_size to broadcast_send must be at least 2.') + return gen_collective_ops.collective_bcast_recv(shape=shape, + T=dtype, + group_size=group_size, + group_key=group_key, + instance_key=instance_key) diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc64ef9f631faf2f76c3dbb3e70e1f37bbe4b1a --- /dev/null +++ b/tensorflow/python/ops/collective_ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Collective Operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.platform import test + +# TODO(tucker): Make these ops work in eager mode. b/79776476 + + +class CollectiveOpTest(test.TestCase): + + def _testCollectiveReduce(self, t0, t1, expected): + group_key = 1 + instance_key = 1 + with self.test_session( + config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: + with ops.device('/CPU:0'): + in0 = constant_op.constant(t0) + colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key, + 'Add', 'Div') + with ops.device('/CPU:1'): + in1 = constant_op.constant(t1) + colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key, + 'Add', 'Div') + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 1 + results = sess.run([colred0, colred1], options=run_options) + self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) + self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5) + + def testCollectiveReduce(self): + self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], + [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3], + [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2]) + + def _testCollectiveBroadcast(self, t0): + group_key = 1 + instance_key = 1 + with self.test_session( + config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: + with ops.device('/CPU:0'): + in0 = constant_op.constant(t0) + out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype, + 2, group_key, instance_key) + with ops.device('/CPU:1'): + c1 = constant_op.constant(t0) + out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype, + 2, group_key, instance_key) + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 1 + results = sess.run([out0, out1], options=run_options) + self.assertAllClose(results[0], t0, rtol=1e-5, atol=1e-5) + self.assertAllClose(results[1], t0, rtol=1e-5, atol=1e-5) + + def testCollectiveBroadcast(self): + self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/applications/nasnet/__init__.py b/tensorflow/python/ops/cond_v2.py similarity index 61% rename from tensorflow/python/keras/applications/nasnet/__init__.py rename to tensorflow/python/ops/cond_v2.py index 94eb145b85b85b2e52ca37e7aebc681c1f054e16..76173e0f309b80402a15acdab5d2af49f35de741 100644 --- a/tensorflow/python/keras/applications/nasnet/__init__.py +++ b/tensorflow/python/ops/cond_v2.py @@ -11,18 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -"""NASNet Keras applications.""" +# ============================================================================= +"""cond_v2 wrapper module. + +This imports the cond_v2 method and all necessary dependencies (this is to avoid +circular dependencies in the cond_v2 implementation). See cond_v2_impl for more +information. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.nasnet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge -from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile -from tensorflow.python.keras._impl.keras.applications.nasnet import preprocess_input +# pylint: disable=unused-import +from tensorflow.python.framework import function +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.ops import gradients_impl -del absolute_import -del division -del print_function +from tensorflow.python.ops.cond_v2_impl import cond_v2 +# pylint: enable=unused-import diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..d310f83dca97889157eb078b11a3ca51caae2fc2 --- /dev/null +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -0,0 +1,479 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 and gradient. + +This is a version of cond that emits a single If op, as well as the gradient +function for If ops produced by cond_v2. This will eventually replace the +current tf.cond implementation once it reaches feature and performance parity. + +NOTE: most users of cond_v2 should import cond_v2, not this module! This module +does not contain all the necessary imports to prevent circular dependencies, +while cond_v2 does. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.util import compat + + +# The following modules cannot be imported directly because they cause circular +# dependencies. These are set in each corresponding module. +_function = None +_function_def_to_graph = None +_gradients_impl = None + +# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify +# that they aren't part of the official public API. These protected members +# often need to be used by implementation code however. Rather than litter the +# code with pylint comments, we ignore protected access violations for +# readability. +# pylint: disable=protected-access + + +def cond_v2(pred, true_fn, false_fn, name="cond"): + """Like tf.cond, except emits a single If op.""" + if not name: + name = "cond" + + with ops.name_scope(name) as scope: + # Identify if there is a caller device, & get the innermost if possible. + device_stack = ops.get_default_graph()._device_function_stack + caller_device = device_stack[-1] if device_stack else None + + caller_colocation_stack = ops.get_default_graph()._colocation_stack + caller_container = ops.get_default_graph()._container + caller_collection_ref = ops.get_default_graph()._collections + + func_name_prefix = scope.replace("/", "_") + + true_graph = _function.func_graph_from_py_func( + true_fn, [], [], + name="%strue" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) + false_graph = _function.func_graph_from_py_func( + false_fn, [], [], + name="%sfalse" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) + _check_same_outputs(true_graph, false_graph) + + # Add inputs to true_graph and false_graph to make them match. Note that + # this modifies true_graph and false_graph. + cond_inputs = _make_inputs_match(true_graph, false_graph, + true_graph.extra_inputs, + false_graph.extra_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # the gradient computation. + + true_intermediates = _get_intermediates(true_graph) + false_intermediates = _get_intermediates(false_graph) + + # Save the original number of outputs to return to the caller. + num_cond_outputs = len(true_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_outputs, extra_false_outputs = _pad_params( + true_graph, false_graph, true_intermediates, false_intermediates) + + true_graph.outputs.extend(extra_true_outputs) + false_graph.outputs.extend(extra_false_outputs) + + # Create the If op. + tensors = gen_functional_ops._if( + pred, cond_inputs, [t.dtype for t in true_graph.outputs], + _create_new_tf_function(true_graph), + _create_new_tf_function(false_graph), + name=scope) + + # Set the flag to enable lowering on the `if` op if necessary + # Lowering allows cond_v2 to avoid some of the limitations of Functions, + # allowing users to specify devices & colocation inside of cond_v2 branches, + # and enabling non-strict evaluation & partial pruning of cond_v2 branches. + # This brings cond_v2 closer to feature parity with tf.cond. + # + # However, we do not lower `If` in the XLA context because it is easier for + # XLA to apply its own optimizations when dealing with un-lowered `If` + # operators than with lowered switch/merge control flow. + # + # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output + if_op = tensors[0].op + if not control_flow_util.IsInXLAContext(if_op): + if_op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + + return tensors[:num_cond_outputs] + + +@ops.RegisterGradient("If") +def _IfGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of an If op produced by cond_v2.""" + true_graph, false_graph = _get_func_graphs(op) + + # Create grad functions that compute the gradient of the true/false forward + # graphs. These functions will capture tensors from the forward pass + # functions. + true_grad_graph = _create_grad_func( + true_graph, grads, _get_grad_fn_name(true_graph)) + false_grad_graph = _create_grad_func( + false_graph, grads, _get_grad_fn_name(false_graph)) + + assert ([t.dtype for t in true_grad_graph.outputs] == + [t.dtype for t in false_grad_graph.outputs]) + + # Match up the captured grad function inputs with outputs of 'op' and other + # external tensors. + true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) + false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + + # Make the inputs to true_grad_graph and false_grad_graph match. Note that + # this modifies true_grad_graph and false_grad_graph. + grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, + true_grad_inputs, false_grad_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # higher-order gradient computations. + + true_grad_intermediates = _get_intermediates(true_grad_graph) + false_grad_intermediates = _get_intermediates(false_grad_graph) + + # Save the original number of gradient outputs to return. + num_grad_outputs = len(true_grad_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( + true_grad_graph, false_grad_graph, + true_grad_intermediates, false_grad_intermediates) + + true_grad_graph.outputs.extend(extra_true_grad_outputs) + false_grad_graph.outputs.extend(extra_false_grad_outputs) + + # Create the gradient If op. + tensors = gen_functional_ops._if( + op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + _create_new_tf_function(true_grad_graph), + _create_new_tf_function(false_grad_graph)) + + # The predicate has no gradient. + return [None] + tensors[:num_grad_outputs] + + +def _get_func_graphs(if_op): + """Returns `_FuncGraph`s for the input op branches. + + Args: + if_op: The _If Operation. + + Returns: + A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. + """ + def _get_func_graph_for_branch(branch_name): + """Generates and returns a _FuncGraph for the given branch.""" + extra_inputs = if_op.inputs[1:] # First input is pred. + input_shapes = [t.shape for t in extra_inputs] + func_name = if_op.get_attr(branch_name).name + fdef = if_op.graph._get_function(func_name).definition + func_graph = _function_def_to_graph.function_def_to_graph( + fdef, input_shapes) + func_graph.extra_inputs = extra_inputs + func_graph.extra_args = func_graph.inputs + func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + return func_graph + + return (_get_func_graph_for_branch("then_branch"), + _get_func_graph_for_branch("else_branch")) + + +def _grad_fn(func_graph, grads): + """The gradient function for each conditional branch. + + This function builds the gradient graph of the corresponding forward-pass + conditional branch in `func_graph`. This is done by differentiating + func_graph's outputs w.r.t. its inputs. + + Args: + func_graph: function._FuncGraph. The corresponding forward-pass function. + grads: The list of input gradient Tensors. + + Returns: + The output gradient Tensors. + """ + # Filter out untrainable function outputs. + # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes + # cause _GradientsHelper to raise an exception (e.g. the implementation + # doesn't expect 'ys' to contain boolean tensors). + assert len(func_graph.outputs) == len(grads) + ys = [] + grad_ys = [] + for y, grad_y in zip(func_graph.outputs, grads): + if not _gradients_impl._IsTrainable(y): + continue + ys.append(y) + grad_ys.append(grad_y) + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _get_grad_inputs. + result = _gradients_impl._GradientsHelper( + ys, func_graph.inputs, grad_ys=grad_ys, + src_graph=func_graph) + + # Functions can't return None; replace Nones with zero tensors. + # TODO(b/80444525): don't return anything here and make _IfGrad return None if + # both branches have zero gradient. + for i in range(len(result)): + if result[i] is None: + result[i] = array_ops.zeros_like(func_graph.inputs[i]) + + return result + + +def _create_grad_func(func_graph, grads, name): + """Returns the _FuncGraph representation of _grad_fn.""" + return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) + + +def _get_grad_inputs(if_op, cond_graph, grad_graph): + """Returns the tensors we should pass to grad_graph. + + This method handles tensors captured from cond_graph in grad_graph. It + converts these to suitable input tensors from the outer graph. + + Args: + if_op: Operation. The forward-pass If op that uses cond_graph. + cond_graph: function._FuncGraph. The forward-pass function. + grad_graph: function._FuncGraph. The gradients function. + + Returns: + A list of inputs tensors to be passed to grad_graph. + """ + inputs = [] + + # Maps placeholders in cond_graph -> input tensor in outer graph. + forward_input_map = {v: k for k, v in cond_graph._captured.items()} + + for t in grad_graph.extra_inputs: + if t.graph == ops.get_default_graph(): + # t is in the outer graph (e.g. one of the input gradients). + inputs.append(t) + elif t in forward_input_map: + # t is an input placeholder in cond_graph. Get the corresponding input + # tensor in the outer graph. + assert t.graph == cond_graph + assert forward_input_map[t].graph == ops.get_default_graph() + inputs.append(forward_input_map[t]) + else: + # t is an intermediate value in cond_graph. Get the corresponding output + # of 'if_op' (note that all intermediate values are outputs). + assert t.graph == cond_graph + output_idx = cond_graph.outputs.index(t) + inputs.append(if_op.outputs[output_idx]) + + return inputs + + +def _create_new_tf_function(func_graph): + """Converts func_graph to a TF_Function and adds it to the current graph. + + Args: + func_graph: function._FuncGraph + + Returns: + The name of the new TF_Function. + """ + c_func = c_api.TF_GraphToFunction_wrapper( + func_graph._c_graph, + compat.as_str(func_graph.name), + False, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in func_graph.inputs], + [t._as_tf_output() for t in func_graph.outputs], + [], + None, # opts + None) # description + _ = c_api_util.ScopedTFFunction(c_func) + + # TODO(b/109833212): this sucks, we're serializing the TF_Function*, + # deserializing it into a Python FunctionDef, then reserializing it to create + # a new TF_Function that we add to the graph. + fdef = _function.function_def_from_tf_function(c_func) + defined_func = _function._from_definition(fdef) + defined_func.add_to_graph(ops.get_default_graph()) + + return func_graph.name + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that aren't inputs or outputs.""" + intermediates = [] + for op in func_graph.get_operations(): + for t in op.outputs: + if t in func_graph.inputs: continue + if t in func_graph.outputs: continue + intermediates.append(t) + return intermediates + + +def _separate_unique_inputs(true_inputs, false_inputs): + """Separates tensors appearing only in true_inputs or false_inputs, or both. + + Args: + true_inputs: list of Tensors + false_inputs: list of Tensors + + Returns: + Three lists of Tensors: + 1. The tensors that appear in both true_inputs and false_inputs + 2. The tensors that only appear in true_inputs + 3. The tensors that only appear in false_inputs + """ + true_inputs = set(true_inputs) + false_inputs = set(false_inputs) + + shared_inputs = true_inputs.intersection(false_inputs) + true_only_inputs = true_inputs - false_inputs + false_only_inputs = false_inputs - true_inputs + + return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) + + +def _pad_params(true_graph, false_graph, true_params, false_params): + """Returns new param lists that have matching signatures. + + This is done by mirroring each param list in the other using dummy params. + There is no merging of params. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_params: a list of Tensors from true_graph + false_params: a list of Tensors from false_graph + + Returns: + A new list of Tensors in true_graph and a new list of Tensors in + false_graph. The two lists have the same number of Tensors, with matching + types and shapes across the lists. + """ + new_true_params = (true_params + + _create_dummy_params(true_graph, false_params)) + new_false_inputs = (_create_dummy_params(false_graph, true_params) + + false_params) + return new_true_params, new_false_inputs + + +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): + """Modifies true_graph and false_graph so they have the same input signature. + + This method reorders and/or adds parameters to true_graph and false_graph so + they have the same input signature, and updates the 'inputs', 'extra_inputs', + and '_captured' fields of both graphs accordingly. It uses the input tensors + from the outer graph to avoid duplicating shared arguments. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_inputs: a list of Tensors in the outer graph. The inputs for + true_graph. + false_inputs: a list of Tensors in the outer graph. The inputs for + false_graph. + + Returns: + A new list of Tensors from the outer graph that are the new inputs for both + true_graph and false_graph. This is a deduped version of true_inputs + + false_inputs. + """ + shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( + true_inputs, false_inputs) + + new_inputs = shared_inputs + true_only_inputs + false_only_inputs + + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + + true_graph.inputs = ( + [true_input_to_param[t] for t in shared_inputs] + + [true_input_to_param[t] for t in true_only_inputs] + + _create_dummy_params(true_graph, false_only_inputs)) + + false_graph.inputs = ( + [false_input_to_param[t] for t in shared_inputs] + + _create_dummy_params(false_graph, true_only_inputs) + + [false_input_to_param[t] for t in false_only_inputs]) + + # Rewrite the _FuncGraphs' state to reflect the new inputs. + true_graph.extra_inputs = new_inputs + false_graph.extra_inputs = new_inputs + + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) + false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + + return new_inputs + + +def _create_dummy_params(func_graph, template_tensors): + """Creates tensors in func_graph to represent template_tensors. + + Args: + func_graph: function._FuncGraph. + template_tensors: a list of tensors in the outer graph. + + Returns: + A list of tensors in func_graph. + """ + with func_graph.as_default(): + return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) + for t in template_tensors] + + +def _get_grad_fn_name(func_graph): + """Returns a unique name to use for the grad function of `func_graph`.""" + name = "%s_grad" % func_graph.name + + base_name = name + counter = 1 + if ops.get_default_graph()._is_function(name): + name = "%s_%s" % (base_name, counter) + counter += 1 + + return name + + +def _check_same_outputs(true_graph, false_graph): + """Raises an error if true_graph and false_graph have different outputs.""" + true_output_types = [t.dtype for t in true_graph.outputs] + false_output_types = [t.dtype for t in false_graph.outputs] + if (len(true_graph.outputs) != len(false_graph.outputs) or + true_output_types != false_output_types): + raise ValueError( + "true_fn() and false_fn() must return the same number and type of " + "arguments, got:\n" + " true_fn: %s\n" + " false_fn: %s" % (true_output_types, false_output_types)) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 5ebdb190791efffae2695b2c26ea00d5a3510878..04545cceb7e166d227a46974ba3602e3cfd36512 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -24,6 +24,7 @@ from __future__ import print_function import abc import collections import functools +import os import six @@ -38,6 +39,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_util as util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops @@ -57,6 +59,10 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export + +_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0" + + # We override the 'tuple' for a control flow op, so we keep python's # existing 'tuple' for later use in this module. _basetuple = tuple @@ -596,7 +602,6 @@ def _EnforceShapeInvariant(merge_var, next_var): enter = merge_var.op.inputs[0].op assert util.IsLoopEnter(enter) input_t = enter.inputs[0] - assert input_t.shape == m_shape raise ValueError( "Input tensor '%s' enters the loop with shape %s, but has shape %s " "after one iteration. To allow the shape to vary across iterations, " @@ -1192,20 +1197,18 @@ class ControlFlowState(object): to backprop. """ loop_exits = [] - for _, grad_state in self._map.items(): - # pylint: disable=protected-access + for grad_state in self._map.values(): for y in grad_state.forward_loop_exits: - if pending_count[y.op._id] == 0: + if pending_count[y.op] == 0: grad_state.pending_exits_count -= 1 - if y.op._id not in to_ops_set: + if y.op not in to_ops_set: grad_state.unused_exits.append(y) if grad_state.pending_exits_count == 0: loop_exits.extend(grad_state.unused_exits) # Need to include Enters in backprop for higher-order gradients. for y in grad_state.forward_context.loop_enters: - if pending_count[y.op._id] == 0: - pending_count[y.op._id] = 1 - # pylint: enable=protected-access + if pending_count[y.op] == 0: + pending_count[y.op] = 1 return loop_exits def EnterGradWhileContext(self, op, before): @@ -1243,8 +1246,8 @@ class ControlFlowState(object): # We need to include all exits of a loop for backprop. for loop_exit in grad_state.forward_loop_exits: - if not between_ops[loop_exit.op._id]: - between_ops[loop_exit.op._id] = True + if loop_exit.op not in between_ops: + between_ops.add(loop_exit.op) between_op_list.append(loop_exit.op) def ZerosLikeForExit(self, val): @@ -1996,6 +1999,9 @@ def cond(pred, ``` """ + if _ENABLE_COND_V2: + return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name) + # We needed to make true_fn/false_fn keyword arguments for # backwards-compatibility. This check exists so that we can convert back to # having them be positional arguments. @@ -2731,7 +2737,8 @@ class WhileContext(ControlFlowContext): self.outer_context.Exit() else: shape_acc = array_ops.zeros_like( - array_ops.shape_internal(op.inputs[0], optimize=False), + array_ops.shape_internal(op.inputs[0], optimize=False, + out_type=dense_shape.dtype), optimize=False) if self.outer_context: @@ -2925,7 +2932,8 @@ class WhileContext(ControlFlowContext): return original_body_result, exit_vars - def BuildLoop(self, pred, body, loop_vars, shape_invariants): + def BuildLoop(self, pred, body, loop_vars, shape_invariants, + return_same_structure): """Add the loop termination condition and body to the graph.""" # Keep original_loop_vars to identify which are TensorArrays @@ -2936,9 +2944,10 @@ class WhileContext(ControlFlowContext): loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) try: self.Enter() - # _BuildLoop calls _update_input in several places. _lock ensures a - # Session.run call cannot occur between creating and mutating new ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # _BuildLoop calls _update_input in several places. _mutation_lock() + # ensures a Session.run call cannot occur between creating and mutating + # new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access original_body_result, exit_vars = self._BuildLoop( pred, body, original_loop_vars, loop_vars, shape_invariants) finally: @@ -2952,7 +2961,11 @@ class WhileContext(ControlFlowContext): packed_exit_vars = nest.pack_sequence_as( structure=original_body_result, flat_sequence=exit_vars_with_tensor_arrays) - return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars + + if return_same_structure: + return packed_exit_vars + else: + return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() @@ -2992,7 +3005,8 @@ def while_loop(cond, back_prop=True, swap_memory=False, name=None, - maximum_iterations=None): + maximum_iterations=None, + return_same_structure=False): """Repeat `body` while the condition `cond` is true. `cond` is a callable returning a boolean scalar tensor. `body` is a callable @@ -3068,11 +3082,16 @@ def while_loop(cond, to run. If provided, the `cond` output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than `maximum_iterations`. + return_same_structure: If True, output has same structure as `loop_vars`. If + eager execution is enabled, this is ignored (and always treated as True). Returns: - The output tensors for the loop variables after the loop. When the length - of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when - the length of `loop_vars` is greater than 1 it returns a list. + The output tensors for the loop variables after the loop. + If `return_same_structure` is True, the return value has the same + structure as `loop_vars`. + If `return_same_structure` is False, the return value is a Tensor, + TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list + otherwise. Raises: TypeError: if `cond` or `body` is not callable. @@ -3127,6 +3146,7 @@ def while_loop(cond, happen is that the thread updating `x` can never get ahead of the counter thread because the thread incrementing `x` depends on the value of the counter. + ```python import tensorflow as tf @@ -3208,7 +3228,8 @@ def while_loop(cond, # be encapsulated in the root context. if loop_context.outer_context is None: ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) - result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) + result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, + return_same_structure) if maximum_iterations is not None: return result[1] else: @@ -3341,12 +3362,6 @@ def group(*inputs, **kwargs): if not hasattr(inp, "device"): raise TypeError("Expected tf.group() expected Tensor arguments not " "'%s' with type '%s'" % (inp, type(inp))) - if not hasattr(inp, "device"): - if isinstance(inp, list): - raise TypeError("To call tf.group() with a list, use " - "tf.group(*[...]) not tf.group([...]).") - raise TypeError("Expected tf.group() expected Tensor arguments not " - "'%s' with type '%s'" % (inp, type(inp))) dev = inp.device if dev in ops_on_device: ops_on_device[dev].append(inp) diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 59bb925df0f25b3bf88112bc3eb1b13b21ace414..153548ae92cfecfe5c750746b1425abcf3747b1b 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -939,7 +939,7 @@ class CaseTest(test_util.TensorFlowTestCase): class WhileLoopTestCase(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWhileLoopWithSingleVariable(self): i = constant_op.constant(0) c = lambda i: math_ops.less(i, 10) @@ -948,7 +948,7 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase): self.assertEqual(self.evaluate(r), 10) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self): i = constant_op.constant(0) c = lambda i: math_ops.less(i, 10) @@ -958,6 +958,28 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase): # Expect a tuple since that is what the body returns. self.assertEqual(self.evaluate(r), (10,)) + def testWhileLoopSameReturnShape_False(self): + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + + # Body returns a [tensor, []] + b = lambda i, _: [math_ops.add(i, 1), []] + + # Should only return the tensor. + r = control_flow_ops.while_loop(c, b, [i, []]) + self.assertEqual(self.evaluate(r), 10) + + def testWhileLoopSameReturnShape_True(self): + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + + # Body returns a [tensor, []] + b = lambda i, _: [math_ops.add(i, 1), []] + + # Should only return the original structure. + r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True) + self.assertEqual(self.evaluate(r), [10, []]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py index 907df85cd954d2a897ba9a0c4b21be8586859380..aacdaa7ad019d8aae2d0b533cde8412ab0f0fa22 100644 --- a/tensorflow/python/ops/conv2d_benchmark.py +++ b/tensorflow/python/ops/conv2d_benchmark.py @@ -21,6 +21,8 @@ from __future__ import print_function import itertools import time +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,22 +30,32 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test +FLAGS = flags.FLAGS -def build_graph(device, input_shape, filter_shape, strides, padding, dtype, - num_iters, warmup_iters): +flags.DEFINE_boolean( + "enable_layout_optimizer", False, + "If true, enables layout optimizer to update input data format for faster " + "execution of convolution ops.") + + +def build_graph(device, dtype, data_format, input_shape, filter_shape, strides, + padding, num_iters, warmup_iters): """builds a graph containing a sequence of conv2d operations. Args: device: String, the device to run on. + dtype: Data type for the convolution. + data_format: A string from: "NHWC" or "NCHW". Data format for input and + output data. input_shape: Shape of the input tensor. filter_shape: Shape of the filter tensor. strides: A list of ints. 1-D of length 4. The stride of sliding window for each dimension of input. padding: A string from: "SAME", "VALID". The type of padding algorithm to use. - dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. warmup_iters: number of iterations for warmup runs. @@ -57,22 +69,23 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype, random_ops.truncated_normal(filter_shape, dtype=dtype)) outputs = [] - conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC") + conv2d_op = nn_ops.conv2d( + inp, filt, strides, padding, data_format=data_format) outputs.append(conv2d_op) for _ in range(1, num_iters): with ops.control_dependencies([conv2d_op]): conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) outputs.append(conv2d_op) warmup_groups = [] warmup_conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) warmup_groups.append(warmup_conv2d_op) for _ in range(1, warmup_iters): with ops.control_dependencies([warmup_conv2d_op]): warmup_conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) warmup_groups.append(warmup_conv2d_op) return control_flow_ops.group(*warmup_groups), control_flow_ops.group( *outputs) @@ -81,12 +94,15 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype, class Conv2DBenchmark(test.Benchmark): """Benchmark conv2d!""" - def _run_graph(self, device, input_shape, filter_shape, strides, padding, - dtype, num_iters, warmup_iters): + def _run_graph(self, device, dtype, data_format, input_shape, filter_shape, + strides, padding, num_iters, warmup_iters): """runs the graph and print its execution time. Args: device: String, the device to run on. + dtype: Data type for the convolution. + data_format: A string from: "NHWC" or "NCHW". Data format for input and + output data. input_shape: Shape of the input tensor. filter_shape: Shape of the filter tensor. strides: A list of ints. 1-D of length 4. The stride of sliding @@ -94,7 +110,6 @@ class Conv2DBenchmark(test.Benchmark): padding: A string from: "SAME", "VALID". The type of padding algorithm to use. num_iters: Number of iterations to run the benchmark. - dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. warmup_iters: number of iterations for warmup runs. @@ -103,10 +118,27 @@ class Conv2DBenchmark(test.Benchmark): """ graph = ops.Graph() with graph.as_default(): - warmup_outputs, outputs = build_graph(device, input_shape, filter_shape, - strides, padding, dtype, num_iters, - warmup_iters) - with session_lib.Session(graph=graph) as session: + warmup_outputs, outputs = build_graph(device, dtype, data_format, + input_shape, filter_shape, strides, + padding, num_iters, warmup_iters) + + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.opt_level = -1 + rewrite_options = config.graph_options.rewrite_options + + # Disable layout optimizer to not change input data_format. + rewrite_options.layout_optimizer = ( + rewriter_config_pb2.RewriterConfig.ON if FLAGS.enable_layout_optimizer + else rewriter_config_pb2.RewriterConfig.OFF) + # Convolution ops are effectively noop in the test graph as we are not + # fetching the convolution outputs. Disable dependency optimizer to not + # remove the conv ops. + rewrite_options.dependency_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + + with session_lib.Session(graph=graph, config=config) as session: + # TODO(hinsu): Use run_op_benchmark method from test.Benchmark to run + # benchmark along with warmup. variables.global_variables_initializer().run() # warmup runs session.run(warmup_outputs) @@ -114,20 +146,21 @@ class Conv2DBenchmark(test.Benchmark): start_time = time.time() session.run(outputs) duration = (time.time() - start_time) / num_iters - print("%s %s inputshape:%s filtershape:%s strides:%s padding:%s " + print("%s %s %s inputshape:%s filtershape:%s strides:%s padding:%s " "%d iters: %.8f sec" % - (device, str(dtype), str(input_shape).replace(" ", ""), - str(filter_shape).replace(" ", ""), + (device, str(dtype), data_format, str(input_shape).replace( + " ", ""), str(filter_shape).replace(" ", ""), str(strides).replace(" ", ""), padding, num_iters, duration)) name_template = ( - "conv2d_{device}_{datatype}_input_shape_{inputshape}_" + "conv2d_{device}_{datatype}_{data_format}_input_shape_{inputshape}_" "filter_shape_{filtershape}_strides_{strides}_padding_{padding}") self.report_benchmark( name=name_template.format( device=device, datatype=str(dtype), + data_format=str(data_format), inputshape=str(input_shape).replace(" ", ""), filtershape=str(filter_shape).replace(" ", ""), strides=str(strides).replace(" ", ""), @@ -140,24 +173,37 @@ class Conv2DBenchmark(test.Benchmark): def benchmark_conv2d(self): print("conv2d benchmark:") - h = 500 - w = 500 - fh = 3 - fw = 3 - input_shapes = [] - filter_shapes = [] data_types = [dtypes.float32, dtypes.float16] - for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]): - input_shapes += [[b, h, w, c]] - filter_shapes += [[fh, fw, c, b]] - strides = [[1, 2, 2, 1]] + data_formats = ["NHWC", "NCHW"] + in_channels = list(range(3, 16)) + out_channels = [4, 16, 32] + hw_strides = [[2, 2]] paddings = ["VALID", "SAME"] - for ishape, fshape in zip(input_shapes, filter_shapes): - for dtype in data_types: - for stride in strides: - for padding in paddings: - self._run_graph("gpu", ishape, fshape, stride, padding, dtype, 80, - 2) + + args_lists = [ + data_types, data_formats, in_channels, out_channels, hw_strides, + paddings + ] + for args in itertools.product(*args_lists): + dtype, data_format, in_channel, out_channel, hw_stride, padding = args + + # Keep batch size same as out channels just to reduce the number of + # different configurations to benchmark. + batch_size = out_channel + h, w, fh, fw = 500, 500, 3, 3 + if data_format == "NHWC": + ishape = [batch_size, h, w, in_channel] + stride = [1] + hw_stride + [1] + elif data_format == "NCHW": + ishape = [batch_size, in_channel, h, w] + stride = [1, 1] + hw_stride + else: + raise ValueError("Unknown data_format: " + str(data_format)) + fshape = [fh, fw, in_channel, out_channel] + num_iters = 80 + warmup_iters = 2 + self._run_graph("gpu", dtype, data_format, ishape, fshape, stride, + padding, num_iters, warmup_iters) if __name__ == "__main__": diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index d934f27cb96f4a65e2adf860e0c5e08b7bd0b7d4..ca24f11054039472baaefd301e45f57c9444f60d 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -89,7 +89,7 @@ def custom_gradient(f): operations in `f` to `x`. - `grad_fn` is a function with the signature `g(*grad_ys)` which returns a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect - to the `Tensor`s in `x. `grad_ys` is a `Tensor` or sequence of + to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of `Tensor`s the same size as `y` holding the initial value gradients for each `Tensor` in `y`. If `f` uses `Variable`s (that are not part of the inputs), i.e. through `get_variable`, then `grad_fn` should have diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 62c5adc385a2e87d27298c72f8dd2f67303119df..abf597ca55c647cca3f6012ed602a815298e1ed3 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * @@ -129,11 +130,6 @@ class QueueBase(object): @{tf.RandomShuffleQueue} for concrete implementations of this class, and instructions on how to create them. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, dtypes, shapes, names, queue_ref): @@ -157,12 +153,7 @@ class QueueBase(object): Raises: ValueError: If one of the arguments is invalid. - RuntimeError: If eager execution is enabled. """ - if context.executing_eagerly(): - raise RuntimeError( - "Queues are not supported when eager execution is enabled. " - "Instead, please use tf.data to get data into your model.") self._dtypes = dtypes if shapes is not None: if len(shapes) != len(dtypes): @@ -179,6 +170,8 @@ class QueueBase(object): self._queue_ref = queue_ref if context.executing_eagerly(): self._name = context.context().scope_name + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + queue_ref, None) else: self._name = self._queue_ref.op.name.split("/")[-1] @@ -605,6 +598,11 @@ class QueueBase(object): else: return gen_data_flow_ops.queue_size(self._queue_ref, name=name) +def _shared_name(shared_name): + if context.executing_eagerly(): + return str(ops.uid()) + return shared_name + @tf_export("RandomShuffleQueue") class RandomShuffleQueue(QueueBase): @@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase): min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -702,11 +695,6 @@ class FIFOQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -752,7 +740,7 @@ class FIFOQueue(QueueBase): component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase): component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -843,11 +826,6 @@ class PriorityQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -899,7 +877,7 @@ class PriorityQueue(QueueBase): component_types=types, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) priority_dtypes = [_dtypes.int64] + types diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py index d7fb3f1f783cceef280e07b6110098a80011b19f..84d9d40a35095643dc29946f8827cfd486a9fd9a 100644 --- a/tensorflow/python/ops/distributions/bernoulli.py +++ b/tensorflow/python/ops/distributions/bernoulli.py @@ -71,7 +71,7 @@ class Bernoulli(distribution.Distribution): Raises: ValueError: If p and logits are passed, or if neither are passed. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py index b6978486004affbf97ee0e5da15fee7ab092eb32..99d30b0bd112b62c625a94b43da589f9717d0774 100644 --- a/tensorflow/python/ops/distributions/beta.py +++ b/tensorflow/python/ops/distributions/beta.py @@ -84,13 +84,24 @@ class Beta(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. + Warning: The samples can be zero due to finite precision. + This happens more often when some of the concentrations are very small. + Make sure to round the samples to `np.finfo(dtype).tiny` before computing the + density. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples ```python # Create a batch of three Beta distributions. alpha = [1, 2, 3] beta = [1, 2, 3] - dist = Beta(alpha, beta) + dist = tf.distributions.Beta(alpha, beta) dist.sample([4, 5]) # Shape [4, 5, 3] @@ -106,7 +117,7 @@ class Beta(distribution.Distribution): # Create batch_shape=[2, 3] via parameter broadcast: alpha = [[1.], [2]] # Shape [2, 1] beta = [3., 4, 5] # Shape [3] - dist = Beta(alpha, beta) + dist = tf.distributions.Beta(alpha, beta) # alpha broadcast as: [[1., 1, 1,], # [2, 2, 2]] @@ -122,6 +133,18 @@ class Beta(distribution.Distribution): dist.prob(x) # Shape [2, 3] ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + alpha = tf.constant(1.0) + beta = tf.constant(2.0) + dist = tf.distributions.Beta(alpha, beta) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [alpha, beta]) + ``` + """ def __init__(self, @@ -150,7 +173,7 @@ class Beta(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration1, concentration0]) as name: self._concentration1 = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration1, name="concentration1"), @@ -165,7 +188,7 @@ class Beta(distribution.Distribution): dtype=self._total_concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration1, self._concentration0, @@ -321,7 +344,7 @@ class BetaWithSoftplusConcentration(Beta): validate_args=False, allow_nan_stats=True, name="BetaWithSoftplusConcentration"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration1, concentration0]) as name: super(BetaWithSoftplusConcentration, self).__init__( diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index caceadf53a0a0816379b0d75808ec756b558e861..b65e64d401b800884bdee7ad883bd7cc41dcfd20 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -160,13 +160,20 @@ class Bijector(object): 3. `log_det_jacobian(x)` - "The log of the determinant of the matrix of all first-order partial - derivatives of the inverse function." + "The log of the absolute value of the determinant of the matrix of all + first-order partial derivatives of the inverse function." Useful for inverting a transformation to compute one probability in terms of another. Geometrically, the Jacobian determinant is the volume of the transformation and is used to scale the probability. + We take the absolute value of the determinant before log to avoid NaN + values. Geometrically, a negative determinant corresponds to an + orientation-reversing transformation. It is ok for us to discard the sign + of the determinant because we only integrate everywhere-nonnegative + functions (probability densities) and the correct orientation is always the + one that produces a nonnegative integrand. + By convention, transformations of random variables are named in terms of the forward transformation. The forward transformation creates samples, the inverse is useful for computing probabilities. @@ -1021,7 +1028,7 @@ class Bijector(object): axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. - event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if (event_ndims_ is not None and y.shape.ndims is not None and ildj.shape.ndims is not None): @@ -1036,7 +1043,7 @@ class Bijector(object): def _get_event_reduce_dims(self, min_event_ndims, event_ndims): """Compute the reduction dimensions given event_ndims.""" - event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None: return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)] @@ -1046,9 +1053,18 @@ class Bijector(object): def _check_valid_event_ndims(self, min_event_ndims, event_ndims): """Check whether event_ndims is atleast min_event_ndims.""" - event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_ = tensor_util.constant_value(event_ndims) assertions = [] + + if not event_ndims.dtype.is_integer: + raise ValueError("Expected integer dtype, got dtype {}".format( + event_ndims.dtype)) + if event_ndims_ is not None: + if event_ndims.shape.ndims != 0: + raise ValueError("Expected scalar event_ndims, got shape {}".format( + event_ndims.shape)) if min_event_ndims > event_ndims_: raise ValueError("event_ndims ({}) must be larger than " "min_event_ndims ({})".format( @@ -1056,17 +1072,29 @@ class Bijector(object): elif self.validate_args: assertions += [ check_ops.assert_greater_equal(event_ndims, min_event_ndims)] + + if event_ndims.shape.is_fully_defined(): + if event_ndims.shape.ndims != 0: + raise ValueError("Expected scalar shape, got ndims {}".format( + event_ndims.shape.ndims)) + + elif self.validate_args: + assertions += [ + check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")] return assertions - def _maybe_get_event_ndims_statically(self, event_ndims): + def _maybe_get_static_event_ndims(self, event_ndims): """Helper which returns tries to return an integer static value.""" event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if isinstance(event_ndims_, np.ndarray): - if (event_ndims_.dtype not in (np.int32, np.int64) or - len(event_ndims_.shape)): + if isinstance(event_ndims_, (np.generic, np.ndarray)): + if event_ndims_.dtype not in (np.int32, np.int64): + raise ValueError("Expected integer dtype, got dtype {}".format( + event_ndims_.dtype)) + + if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape): raise ValueError("Expected a scalar integer, got {}".format( event_ndims_)) - event_ndims_ = event_ndims_.tolist() + event_ndims_ = int(event_ndims_) return event_ndims_ diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index bbdc8c455af66c7c6ad1866302e84f50cf221f9b..dd25fce2ec860456fdbbad903032cf4bcda9daba 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export -def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): +def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" - if event.shape.ndims is None: - raise NotImplementedError( - "Cannot broadcast with an event tensor of unknown rank.") - if event.dtype.is_integer: pass elif event.dtype.is_floating: @@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): else: raise TypeError("`value` should have integer `dtype` or " "`self.dtype` ({})".format(base_dtype)) - - if params.get_shape()[:-1] == event.get_shape(): - params = params - else: - params *= array_ops.ones_like( - array_ops.expand_dims(event, -1), dtype=params.dtype) + shape_known_statically = ( + params.shape.ndims is not None and + params.shape[:-1].is_fully_defined() and + event.shape.is_fully_defined()) + if not shape_known_statically or params.shape[:-1] != event.shape: + params *= array_ops.ones_like(event[..., array_ops.newaxis], + dtype=params.dtype) params_shape = array_ops.shape(params)[:-1] event *= array_ops.ones(params_shape, dtype=event.dtype) - event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1])) + if params.shape.ndims is not None: + event.set_shape(tensor_shape.TensorShape(params.shape[:-1])) + return event, params @@ -182,7 +181,7 @@ class Categorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py index 8d0d1d860bf4a7efadd8b6bf101708974b87e5d1..9104a1d071af3d7b7d40838148f2e49301fa39ba 100644 --- a/tensorflow/python/ops/distributions/dirichlet.py +++ b/tensorflow/python/ops/distributions/dirichlet.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -89,13 +90,24 @@ class Dirichlet(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. + Warning: Some components of the samples can be zero due to finite precision. + This happens more often when some of the concentrations are very small. + Make sure to round the samples to `np.finfo(dtype).tiny` before computing the + density. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples ```python # Create a single trivariate Dirichlet, with the 3rd class being three times # more frequent than the first. I.e., batch_shape=[], event_shape=[3]. alpha = [1., 2, 3] - dist = Dirichlet(alpha) + dist = tf.distributions.Dirichlet(alpha) dist.sample([4, 5]) # shape: [4, 5, 3] @@ -117,7 +129,7 @@ class Dirichlet(distribution.Distribution): # Create batch_shape=[2], event_shape=[3]: alpha = [[1., 2, 3], [4, 5, 6]] # shape: [2, 3] - dist = Dirichlet(alpha) + dist = tf.distributions.Dirichlet(alpha) dist.sample([4, 5]) # shape: [4, 5, 2, 3] @@ -128,6 +140,17 @@ class Dirichlet(distribution.Distribution): dist.prob(x) # shape: [2] ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + alpha = tf.constant([1.0, 2.0, 3.0]) + dist = tf.distributions.Dirichlet(alpha) + samples = dist.sample(5) # Shape [5, 3] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, alpha) + ``` + """ def __init__(self, @@ -154,7 +177,7 @@ class Dirichlet(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration]) as name: self._concentration = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration, name="concentration"), @@ -164,7 +187,7 @@ class Dirichlet(distribution.Distribution): dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._total_concentration], @@ -289,11 +312,86 @@ class Dirichlet(distribution.Distribution): if not self.validate_args: return x return control_flow_ops.with_dependencies([ - check_ops.assert_positive( - x, - message="samples must be positive"), - distribution_util.assert_close( + check_ops.assert_positive(x, message="samples must be positive"), + check_ops.assert_near( array_ops.ones([], dtype=self.dtype), math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), ], x) + + +@kullback_leibler.RegisterKL(Dirichlet, Dirichlet) +def _kl_dirichlet_dirichlet(d1, d2, name=None): + """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet. + + Args: + d1: instance of a Dirichlet distribution object. + d2: instance of a Dirichlet distribution object. + name: (optional) Name to use for created operations. + default is "kl_dirichlet_dirichlet". + + Returns: + Batchwise KL(d1 || d2) + """ + with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[ + d1.concentration, d2.concentration]): + # The KL between Dirichlet distributions can be derived as follows. We have + # + # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)] + # + # where B(a) is the multivariate Beta function: + # + # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n]) + # + # The KL is + # + # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))} + # + # so we'll need to know the log density of the Dirichlet. This is + # + # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a) + # + # The only term that matters for the expectations is the log(x[i]). To + # compute the expectation of this term over the Dirichlet density, we can + # use the following facts about the Dirichlet in exponential family form: + # 1. log(x[i]) is a sufficient statistic + # 2. expected sufficient statistics (of any exp family distribution) are + # equal to derivatives of the log normalizer with respect to + # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i]) + # + # To proceed, we can rewrite the Dirichlet density in exponential family + # form as follows: + # + # Dir(x; a) = exp{eta(a) . T(x) - A(a)} + # + # where '.' is the dot product of vectors eta and T, and A is a scalar: + # + # eta[i](a) = a[i] - 1 + # T[i](x) = log(x[i]) + # A(a) = log B(a) + # + # Now, we can use fact (2) above to write + # + # E_Dir(x; a)[log(x[i])] + # = dA(a) / da[i] + # = d/da[i] log B(a) + # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j]) + # = digamma(a[i])) - digamma(sum_j a[j]) + # + # Putting it all together, we have + # + # KL[Dir(x; a) || Dir(x; b)] + # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)} + # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b)) + # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b) + # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))] + # - lbeta(a) + lbeta(b)) + + digamma_sum_d1 = math_ops.digamma( + math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True)) + digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1 + concentration_diff = d1.concentration - d2.concentration + + return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) - + special_math_ops.lbeta(d1.concentration) + + special_math_ops.lbeta(d2.concentration)) diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py index 3a35e0caa0f411dbb413aa5e8fe68143e0914db9..5350c8284704a15f71bce7be0c44e298067d3692 100644 --- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py +++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py @@ -191,7 +191,7 @@ class DirichletMultinomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, concentration]) as name: # Broadcasting works because: # * The broadcasting convention is to prepend dimensions of size [1], and diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index a6579e3246d4c2c6bf266754868cba43515e3256..c03ef967e68474b0313de01d48252c8274e37a21 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -212,7 +212,7 @@ class ReparameterizationType(object): reparameterized, and straight-through gradients are either partially unsupported or are not supported at all. In this case, for purposes of e.g. RL or variational inference, it is generally safest to wrap the - sample results in a `stop_gradients` call and instead use policy + sample results in a `stop_gradients` call and use policy gradients / surrogate loss instead. """ @@ -525,7 +525,7 @@ class Distribution(_BaseDistribution): """Dictionary of parameters used to instantiate this `Distribution`.""" # Remove "self", "__class__", or other special variables. These can appear # if the subclass used: - # `parameters = distribution_util.parent_frame_arguments()`. + # `parameters = dict(locals())`. return dict((k, v) for k, v in self._parameters.items() if not k.startswith("__") and k != "self") @@ -722,11 +722,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._log_prob(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.log(self._prob(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.log(self._prob(value, **kwargs)) def log_prob(self, value, name="log_prob"): """Log probability density/mass function. @@ -749,11 +746,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._prob(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.exp(self._log_prob(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.exp(self._log_prob(value, **kwargs)) def prob(self, value, name="prob"): """Probability density/mass function. @@ -776,11 +770,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._log_cdf(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.log(self._cdf(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.log(self._cdf(value, **kwargs)) def log_cdf(self, value, name="log_cdf"): """Log cumulative distribution function. @@ -813,11 +804,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._cdf(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.exp(self._log_cdf(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.exp(self._log_cdf(value, **kwargs)) def cdf(self, value, name="cdf"): """Cumulative distribution function. @@ -846,11 +834,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._log_survival_function(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.log1p(-self.cdf(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.log1p(-self.cdf(value, **kwargs)) def log_survival_function(self, value, name="log_survival_function"): """Log survival function. @@ -884,11 +869,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._survival_function(value, **kwargs) - except NotImplementedError as original_exception: - try: - return 1. - self.cdf(value, **kwargs) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return 1. - self.cdf(value, **kwargs) def survival_function(self, value, name="survival_function"): """Survival function. @@ -933,10 +915,7 @@ class Distribution(_BaseDistribution): def _call_quantile(self, value, name, **kwargs): with self._name_scope(name, values=[value]): value = ops.convert_to_tensor(value, name="value") - try: - return self._quantile(value, **kwargs) - except NotImplementedError as original_exception: - raise original_exception + return self._quantile(value, **kwargs) def quantile(self, value, name="quantile"): """Quantile function. Aka "inverse cdf" or "percent point function". @@ -982,11 +961,8 @@ class Distribution(_BaseDistribution): with self._name_scope(name): try: return self._variance() - except NotImplementedError as original_exception: - try: - return math_ops.square(self._stddev()) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.square(self._stddev()) def _stddev(self): raise NotImplementedError("stddev is not implemented") @@ -1014,11 +990,8 @@ class Distribution(_BaseDistribution): with self._name_scope(name): try: return self._stddev() - except NotImplementedError as original_exception: - try: - return math_ops.sqrt(self._variance()) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.sqrt(self._variance()) def _covariance(self): raise NotImplementedError("covariance is not implemented") diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py index 1e08f48d529b164ddbb77d4d36ba0bd3390475b7..4325a14449dd9a13dabb65a240ede452544c761a 100644 --- a/tensorflow/python/ops/distributions/exponential.py +++ b/tensorflow/python/ops/distributions/exponential.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import gamma -from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -91,7 +90,7 @@ class Exponential(gamma.Gamma): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) # Even though all statistics of are defined for valid inputs, this is not # true in the parent class "Gamma." Therefore, passing # allow_nan_stats=True @@ -104,9 +103,6 @@ class Exponential(gamma.Gamma): allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) - # While the Gamma distribution is not reparameterizable, the exponential - # distribution is. - self._reparameterization_type = True self._parameters = parameters self._graph_parents += [self._rate] @@ -144,7 +140,7 @@ class ExponentialWithSoftplusRate(Exponential): validate_args=False, allow_nan_stats=True, name="ExponentialWithSoftplusRate"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[rate]) as name: super(ExponentialWithSoftplusRate, self).__init__( rate=nn.softplus(rate, name="softplus_rate"), diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py index 7ca690d9d2f8348a103adb57abe57f2ce058f17c..b631f0247c59e518fbd4925065d33345d4ea8e47 100644 --- a/tensorflow/python/ops/distributions/gamma.py +++ b/tensorflow/python/ops/distributions/gamma.py @@ -55,7 +55,7 @@ class Gamma(distribution.Distribution): ```none pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z - Z = Gamma(alpha) beta**alpha + Z = Gamma(alpha) beta**(-alpha) ``` where: @@ -85,14 +85,35 @@ class Gamma(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. - WARNING: This distribution may draw 0-valued samples for small `concentration` - values. See note in `tf.random_gamma` docstring. + Warning: The samples of this distribution are always non-negative. However, + the samples that are smaller than `np.finfo(dtype).tiny` are rounded + to this value, so it appears more often than it should. + This should only be noticeable when the `concentration` is very small, or the + `rate` is very large. See note in `tf.random_gamma` docstring. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) #### Examples ```python - dist = Gamma(concentration=3.0, rate=2.0) - dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + dist = tf.distributions.Gamma(concentration=3.0, rate=2.0) + dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + ``` + + Compute the gradients of samples w.r.t. the parameters: + + ```python + concentration = tf.constant(3.0) + rate = tf.constant(2.0) + dist = tf.distributions.Gamma(concentration, rate) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [concentration, rate]) ``` """ @@ -126,7 +147,7 @@ class Gamma(distribution.Distribution): Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), @@ -141,7 +162,7 @@ class Gamma(distribution.Distribution): dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._rate], @@ -261,7 +282,7 @@ class GammaWithSoftplusConcentrationRate(Gamma): validate_args=False, allow_nan_stats=True, name="GammaWithSoftplusConcentrationRate"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: super(GammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py index ee3a6a40ff78fb95545c3af4eff85e84a1abd515..be17cf2527eacb7930333cbb03f1dba860b19dea 100644 --- a/tensorflow/python/ops/distributions/laplace.py +++ b/tensorflow/python/ops/distributions/laplace.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math -from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -101,7 +100,7 @@ class Laplace(distribution.Distribution): Raises: TypeError: if `loc` and `scale` are of different dtype. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): @@ -218,7 +217,7 @@ class LaplaceWithSoftplusScale(Laplace): validate_args=False, allow_nan_stats=True, name="LaplaceWithSoftplusScale"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: super(LaplaceWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py index 036ba45cccf49913c339d305e110960c988f8ccf..d0943e8eee69a5ef23d1829651801513e4bc1d69 100644 --- a/tensorflow/python/ops/distributions/multinomial.py +++ b/tensorflow/python/ops/distributions/multinomial.py @@ -182,7 +182,7 @@ class Multinomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = ops.convert_to_tensor(total_count, name="total_count") if validate_args: diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py index 0620aae10d0d3bb1f902171ed5c033630520dd7a..d0a987ba7ccf0b7ef2f38c2fc868540f6b5ff005 100644 --- a/tensorflow/python/ops/distributions/normal.py +++ b/tensorflow/python/ops/distributions/normal.py @@ -32,7 +32,6 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import special_math -from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -132,7 +131,7 @@ class Normal(distribution.Distribution): Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): @@ -244,7 +243,7 @@ class NormalWithSoftplusScale(Normal): validate_args=False, allow_nan_stats=True, name="NormalWithSoftplusScale"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: super(NormalWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py index 9330b930b5140b50e2b8f060c70702bd47b2c4ae..e0cf6f86f10eec76bf94cd74f64202c452425886 100644 --- a/tensorflow/python/ops/distributions/student_t.py +++ b/tensorflow/python/ops/distributions/student_t.py @@ -80,6 +80,12 @@ class StudentT(distribution.Distribution): variance. However it is not actually the std. deviation; the Student's t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`. + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples Examples of initialization of one or a batch of distributions. @@ -118,6 +124,19 @@ class StudentT(distribution.Distribution): dist.prob(3.0) ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + df = tf.constant(2.0) + loc = tf.constant(2.0) + scale = tf.constant(11.0) + dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [df, loc, scale]) + ``` + """ # pylint: enable=line-too-long @@ -157,7 +176,7 @@ class StudentT(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[df, loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(df)] if validate_args else []): @@ -168,7 +187,7 @@ class StudentT(distribution.Distribution): (self._df, self._loc, self._scale)) super(StudentT, self).__init__( dtype=self._scale.dtype, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, @@ -349,7 +368,7 @@ class StudentTWithAbsDfSoftplusScale(StudentT): validate_args=False, allow_nan_stats=True, name="StudentTWithAbsDfSoftplusScale"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[df, scale]) as name: super(StudentTWithAbsDfSoftplusScale, self).__init__( df=math_ops.floor(math_ops.abs(df)), diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py index 9392464ec11613cf318a9a86124bdba6f46ba595..e80bf9ee4272c832f1ab53e6bf2b1f0cf6faafee 100644 --- a/tensorflow/python/ops/distributions/transformed_distribution.py +++ b/tensorflow/python/ops/distributions/transformed_distribution.py @@ -252,7 +252,7 @@ class TransformedDistribution(distribution_lib.Distribution): name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) name = name or (("" if bijector is None else bijector.name) + distribution.name) with ops.name_scope(name, values=[event_shape, batch_shape]) as name: @@ -416,7 +416,7 @@ class TransformedDistribution(distribution_lib.Distribution): # For caching to work, it is imperative that the bijector is the first to # modify the input. x = self.bijector.inverse(y) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) if self.bijector._is_injective: # pylint: disable=protected-access @@ -435,13 +435,15 @@ class TransformedDistribution(distribution_lib.Distribution): log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) log_prob += math_ops.cast(ildj, log_prob.dtype) if self._is_maybe_event_override and isinstance(event_ndims, int): - log_prob.set_shape(array_ops.broadcast_static_shape( - x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) + log_prob.set_shape( + array_ops.broadcast_static_shape( + y.get_shape().with_rank_at_least(1)[:-event_ndims], + self.batch_shape)) return log_prob def _prob(self, y): x = self.bijector.inverse(y) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims) @@ -459,8 +461,10 @@ class TransformedDistribution(distribution_lib.Distribution): prob = math_ops.reduce_prod(prob, self._reduce_event_indices) prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): - prob.set_shape(array_ops.broadcast_static_shape( - y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) + prob.set_shape( + array_ops.broadcast_static_shape( + y.get_shape().with_rank_at_least(1)[:-event_ndims], + self.batch_shape)) return prob def _log_cdf(self, y): @@ -618,15 +622,14 @@ class TransformedDistribution(distribution_lib.Distribution): return array_ops.transpose( x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n))) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - static_event_ndims = tensor_util.constant_value(event_ndims) - - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py index dfa10331e3e9d64c5d3cae2c52373edfbcda723a..e66c4a37e7c6330656bdefb6639670896a488fc7 100644 --- a/tensorflow/python/ops/distributions/uniform.py +++ b/tensorflow/python/ops/distributions/uniform.py @@ -29,7 +29,6 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -103,7 +102,7 @@ class Uniform(distribution.Distribution): Raises: InvalidArgumentError: if `low >= high` and `validate_args=False`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[low, high]) as name: with ops.control_dependencies([ check_ops.assert_less( diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 59c89d21f9142fbbaf52b1d273873194a30a0f7f..3e480a79f52b178789a2d34e98c6af31048c07b1 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -36,43 +36,6 @@ from tensorflow.python.ops import nn from tensorflow.python.util import tf_inspect -def assert_close( - x, y, data=None, summarize=None, message=None, name="assert_close"): - """Assert that x and y are within machine epsilon of each other. - - Args: - x: Floating-point `Tensor` - y: Floating-point `Tensor` - data: The tensors to print out if the condition is `False`. Defaults to - error message and first few entries of `x` and `y`. - summarize: Print this many entries of each tensor. - message: A string to prefix to the default message. - name: A name for this operation (optional). - - Returns: - Op raising `InvalidArgumentError` if |x - y| > machine epsilon. - """ - message = message or "" - x = ops.convert_to_tensor(x, name="x") - y = ops.convert_to_tensor(y, name="y") - - if data is None: - data = [ - message, - "Condition x ~= y did not hold element-wise: x = ", x, "y = ", y - ] - - if x.dtype.is_integer: - return check_ops.assert_equal( - x, y, data=data, summarize=summarize, message=message, name=name) - - with ops.name_scope(name, "assert_close", [x, y, data]): - tol = np.finfo(x.dtype.as_numpy_dtype).eps - condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol)) - return control_flow_ops.Assert( - condition, data, summarize=summarize) - - def assert_integer_form( x, data=None, summarize=None, message=None, int_dtype=None, name="assert_integer_form"): @@ -179,6 +142,7 @@ def maybe_get_static_value(x, dtype=None): if x is None: return x try: + # This returns an np.ndarray. x_ = tensor_util.constant_value(x) except TypeError: x_ = x @@ -240,8 +204,12 @@ def get_logits_and_probs(logits=None, dependencies = [check_ops.assert_non_negative(probs)] if multidimensional: probs = embed_check_categorical_event_shape(probs) - dependencies += [assert_close(math_ops.reduce_sum(probs, -1), one, - message="probs does not sum to 1.")] + dependencies += [ + check_ops.assert_near( + math_ops.reduce_sum(probs, -1), + one, + message="probs does not sum to 1.") + ] else: dependencies += [check_ops.assert_less_equal( probs, one, message="probs has components greater than 1.")] @@ -823,8 +791,8 @@ def fill_triangular(x, upper=False, name=None): Triangular matrix elements are filled in a clockwise spiral. See example, below. - If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1, - b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., + If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is + `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: @@ -913,10 +881,11 @@ def fill_triangular(x, upper=False, name=None): # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n + ndims = prefer_static_rank(x) if upper: - x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])] + x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] else: - x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])] + x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] new_shape = ( static_final_shape.as_list() if static_final_shape.is_fully_defined() @@ -930,6 +899,74 @@ def fill_triangular(x, upper=False, name=None): return x +def fill_triangular_inverse(x, upper=False, name=None): + """Creates a vector from a (batch of) triangular matrix. + + The vector is created from the lower-triangular or upper-triangular portion + depending on the value of the parameter `upper`. + + If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is + `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. + + Example: + + ```python + fill_triangular_inverse( + [[4, 0, 0], + [6, 5, 0], + [3, 2, 1]]) + + # ==> [1, 2, 3, 4, 5, 6] + + fill_triangular_inverse( + [[1, 2, 3], + [0, 5, 6], + [0, 0, 4]], upper=True) + + # ==> [1, 2, 3, 4, 5, 6] + ``` + + Args: + x: `Tensor` representing lower (or upper) triangular elements. + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + name: Python `str`. The name to give this op. + + Returns: + flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower + (or upper) triangular elements from `x`. + """ + + with ops.name_scope(name, "fill_triangular_inverse", values=[x]): + x = ops.convert_to_tensor(x, name="x") + if x.shape.with_rank_at_least(2)[-1].value is not None: + n = np.int32(x.shape[-1].value) + m = np.int32((n * (n + 1)) // 2) + static_final_shape = x.shape[:-2].concatenate([m]) + else: + n = array_ops.shape(x)[-1] + m = (n * (n + 1)) // 2 + static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( + [None]) + ndims = prefer_static_rank(x) + if upper: + initial_elements = x[..., 0, :] + triangular_portion = x[..., 1:, :] + else: + initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) + triangular_portion = x[..., :-1, :] + rotated_triangular_portion = array_ops.reverse( + array_ops.reverse(triangular_portion, axis=[ndims - 1]), + axis=[ndims - 2]) + consolidated_matrix = triangular_portion + rotated_triangular_portion + end_sequence = array_ops.reshape( + consolidated_matrix, + array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) + y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) + y.set_shape(static_final_shape) + return y + + def tridiag(below=None, diag=None, above=None, name=None): """Creates a matrix with values set above, below, and on the diagonal. diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index bcc717b043f226a18344de31b36f09d5064f25a3..27c2fa701760f000db2463aaba0b496b3550ddff 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops # Imports gradient definitions. @@ -30,6 +31,7 @@ from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-impor from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -43,8 +45,8 @@ def _clip(params, ids, max_norm): Args: params: A `Tensor` of embeddings retrieved by `gather`. ids: The `ids` argument that was passed to `gather`. - max_norm: If provided, the embeddings are l2-normalized to the value of - max_norm. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value. Returns: A `Tensor` with the same type as `params`. @@ -290,8 +292,8 @@ def embedding_lookup( in `indices` are always validated to be within range. If assigned to GPU, out-of-bound indices result in safe but unspecified behavior, which may include raising an error. - max_norm: If provided, embedding values are l2-normalized to the value of - max_norm. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value. Returns: A `Tensor` with the same type as the tensors in `params`. @@ -346,8 +348,8 @@ def embedding_lookup_sparse(params, "mean" is the weighted sum divided by the total weight. "sqrtn" is the weighted sum divided by the square root of the sum of the squares of the weights. - max_norm: If provided, each embedding is normalized to have l2 norm equal - to max_norm before combining. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value, before combining. Returns: A dense tensor representing the combined embeddings for the @@ -479,3 +481,158 @@ def embedding_lookup_sparse(params, assert False, "Unrecognized combiner" return embeddings + + +@tf_export("nn.safe_embedding_lookup_sparse") +def safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + sparse_weights=None, + combiner='mean', + default_id=None, + name=None, + partition_strategy='div', + max_norm=None): + """Lookup embedding results, accounting for invalid IDs and empty features. + + The partitioned embedding in `embedding_weights` must all be the same shape + except for the first dimension. The first dimension is allowed to vary as the + vocabulary size is not necessarily a multiple of `P`. `embedding_weights` + may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a + partitioner. + + Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs + with non-positive weight. For an entry with no features, the embedding vector + for `default_id` is returned, or the 0-vector if `default_id` is not supplied. + + The ids and weights may be multi-dimensional. Embeddings are always aggregated + along the last dimension. + + Args: + embedding_weights: A list of `P` float `Tensor`s or values representing + partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` + created by partitioning along dimension 0. The total unpartitioned + shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the + vocab size and `e_1, ..., e_m` are the embedding dimensions. + sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the + ids. `d_0` is typically batch size. + sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing + float weights corresponding to `sparse_ids`, or `None` if all weights + are be assumed to be 1.0. + combiner: A string specifying how to combine embedding results for each + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" + the default. + default_id: The id to use for an entry with no features. + name: A name for this operation (optional). + partition_strategy: A string specifying the partitioning strategy. + Currently `"div"` and `"mod"` are supported. Default is `"div"`. + max_norm: If not `None`, all embeddings are l2-normalized to max_norm before + combining. + + + Returns: + Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. + + Raises: + ValueError: if `embedding_weights` is empty. + """ + if embedding_weights is None: + raise ValueError('Missing embedding_weights %s.' % embedding_weights) + if isinstance(embedding_weights, variables.PartitionedVariable): + embedding_weights = list(embedding_weights) # get underlying Variables. + if not isinstance(embedding_weights, list): + embedding_weights = [embedding_weights] + if len(embedding_weights) < 1: + raise ValueError('Missing embedding_weights %s.' % embedding_weights) + + dtype = sparse_weights.dtype if sparse_weights is not None else None + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights + ] + + with ops.name_scope(name, 'embedding_lookup', + embedding_weights + [sparse_ids, + sparse_weights]) as scope: + # Reshape higher-rank sparse ids and weights to linear segment ids. + original_shape = sparse_ids.dense_shape + original_rank_dim = sparse_ids.dense_shape.get_shape()[0] + original_rank = ( + array_ops.size(original_shape) + if original_rank_dim.value is None + else original_rank_dim.value) + sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ + math_ops.reduce_prod( + array_ops.slice(original_shape, [0], [original_rank - 1])), + array_ops.gather(original_shape, original_rank - 1)]) + if sparse_weights is not None: + sparse_weights = sparse_tensor.SparseTensor( + sparse_ids.indices, + sparse_weights.values, sparse_ids.dense_shape) + + # Prune invalid ids and weights. + sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) + if combiner != 'sum': + sparse_ids, sparse_weights = _prune_invalid_weights( + sparse_ids, sparse_weights) + + # Fill in dummy values for empty features, if necessary. + sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, + default_id or + 0) + if sparse_weights is not None: + sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) + + result = embedding_lookup_sparse( + embedding_weights, + sparse_ids, + sparse_weights, + combiner=combiner, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + max_norm=max_norm) + + if default_id is None: + # Broadcast is_row_empty to the same shape as embedding_lookup_result, + # for use in Select. + is_row_empty = array_ops.tile( + array_ops.reshape(is_row_empty, [-1, 1]), + array_ops.stack([1, array_ops.shape(result)[1]])) + + result = array_ops.where(is_row_empty, + array_ops.zeros_like(result), + result, + name=scope) + + # Reshape back from linear ids back into higher-dimensional dense result. + final_result = array_ops.reshape( + result, + array_ops.concat([ + array_ops.slice( + math_ops.cast(original_shape, dtypes.int32), [0], + [original_rank - 1]), + array_ops.slice(array_ops.shape(result), [1], [-1]) + ], 0)) + final_result.set_shape(tensor_shape.unknown_shape( + (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) + return final_result + + +def _prune_invalid_ids(sparse_ids, sparse_weights): + """Prune invalid IDs (< 0) from the input ids and weights.""" + is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) + if sparse_weights is not None: + is_id_valid = math_ops.logical_and( + is_id_valid, + array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) + if sparse_weights is not None: + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) + return sparse_ids, sparse_weights + + +def _prune_invalid_weights(sparse_ids, sparse_weights): + """Prune invalid weights (< 0) from the input ids and weights.""" + if sparse_weights is not None: + is_weights_valid = math_ops.greater(sparse_weights.values, 0) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) + return sparse_ids, sparse_weights diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 394ad0b1a2284ac147a09f165fb1f50d24f4cedc..53ae6d843fecf3be93ab35a73e890da3962c5aea 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -455,7 +455,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, lambda i, _: i < n, compute, (i, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, - swap_memory=swap_memory) + swap_memory=swap_memory, + maximum_iterations=n) results_flat = [r.stack() for r in r_a] n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0] @@ -944,6 +945,61 @@ def For(start, # pylint: enable=invalid-name,protected-access -def partitioned_call(args, f): - return gen_functional_ops.partitioned_call( - args=args, Tout=[o.type for o in f.definition.signature.output_arg], f=f) +def partitioned_call(args, f, tout=None, executing_eagerly=None): + """Executes a function while respecting device annotations. + + Currently, only those functions that execute within the same address space + can be executed. + + Args: + args: The arguments of the function, including captured inputs. + f: The function to execute; an instance of `_DefinedFunction` or + `_EagerDefinedFunction`. + tout: a list containing the output dtypes enums; if `None`, inferred from + the signature of `f`. + executing_eagerly: (Optional) A boolean indicating whether the context is + executing eagerly. If `None`, fetched from the global context. + + Returns: + The list of `Tensor`s returned by invoking `f(args)`. If the function does + not return anything, then returns `None` if eager execution is enabled, or + the `Operation` if not. + """ + + if tout is None: + tout = tuple(x.type for x in f.definition.signature.output_arg) + + if executing_eagerly is None: + executing_eagerly = context.executing_eagerly() + + if executing_eagerly or len(tout): + if f.stateful_ops: + outputs = gen_functional_ops.stateful_partitioned_call( + args=args, Tout=tout, f=f) + else: + outputs = gen_functional_ops.partitioned_call(args=args, Tout=tout, f=f) + return outputs if outputs else None + + # The generated binding returns an empty list for functions that don't + # return any Tensors, hence the need to use `create_op` directly. + args = [ops.internal_convert_to_tensor(x) for x in args] + tin_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + type=[x.dtype.as_datatype_enum for x in args])) + tout_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue(type=tout)) + func_attr = attr_value_pb2.AttrValue( + func=attr_value_pb2.NameAttrList(name=f.name)) + + graph = ops.get_default_graph() + f.add_to_graph(graph) + op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" + op = graph.create_op( + op_name, + args, + tout, + compute_shapes=False, + name="PartitionedFunctionCall", + attrs={"Tin": tin_attr, "Tout": tout_attr, "f": func_attr}) + outputs = op.outputs + return outputs if outputs else op diff --git a/tensorflow/python/ops/gradient_checker.py b/tensorflow/python/ops/gradient_checker.py index 12afcd0b517d5e85112c067ccaca5693e5a4e231..94c8d7933523a315523cf7b2d34d8263785b6eeb 100644 --- a/tensorflow/python/ops/gradient_checker.py +++ b/tensorflow/python/ops/gradient_checker.py @@ -283,10 +283,10 @@ def compute_gradient(x, numbers. For example, if `x` is complex with shape `[m]` and `y` is complex with shape `[n]`, each Jacobian `J` will have shape `[m * 2, n * 2]` with - J[:m, :n] = d(Re y)/d(Re x) - J[:m, n:] = d(Im y)/d(Re x) - J[m:, :n] = d(Re y)/d(Im x) - J[m:, n:] = d(Im y)/d(Im x) + J[::2, ::2] = d(Re y)/d(Re x) + J[::2, 1::2] = d(Im y)/d(Re x) + J[1::2, ::2] = d(Re y)/d(Im x) + J[1::2, 1::2] = d(Im y)/d(Im x) Args: x: a tensor or list of tensors diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 069b5a43086100c76089ac2b5023d004280e9d8d..b64a66be03ba09e0660b7067420b61f91cf191a3 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import contextlib +import sys import warnings import numpy as np @@ -30,12 +31,14 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops # pylint: disable=unused-import +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -47,12 +50,17 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import from tensorflow.python.ops import manip_grad # pylint: disable=unused-import from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_grad # pylint: disable=unused-import from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._gradients_impl = sys.modules[__name__] # pylint: disable=protected-access + # Warn the user if we convert a sparse representation to dense with at # least this number of elements. _LARGE_SPARSE_NUM_ELEMENTS = 100000000 @@ -107,93 +115,74 @@ ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor) -def _MarkReachedOps(from_ops, reached_ops): +def _MarkReachedOps(from_ops, reached_ops, func_graphs): """Mark all ops reached from "from_ops". Args: from_ops: list of Operations. - reached_ops: list of booleans, indexed by operation id. + reached_ops: set of Operations. + func_graphs: list of function._FuncGraphs. This method will traverse through + these functions if they capture from_ops or any reachable ops. """ queue = collections.deque() queue.extend(from_ops) while queue: op = queue.popleft() - if not reached_ops[op._id]: - reached_ops[op._id] = True + if op not in reached_ops: + reached_ops.add(op) for output in op.outputs: if _IsBackpropagatable(output): - queue.extend(output.consumers()) - - -def _GatherInputs(to_ops, reached_ops): - """List all inputs of to_ops that are in reached_ops. - - Args: - to_ops: list of Operations. - reached_ops: list of booleans, indexed by operation id. - - Returns: - The list of all inputs of to_ops that are in reached_ops. - That list includes all elements of to_ops. - """ - inputs = [] - queue = collections.deque() - queue.extend(to_ops) - while queue: - op = queue.popleft() - # We are interested in this op. - if reached_ops[op._id]: - inputs.append(op) - # Clear the boolean so we won't add the inputs again. - reached_ops[op._id] = False - for inp in op.inputs: - queue.append(inp.op) - return inputs + queue.extend(_Consumers(output, func_graphs)) -def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): +def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, + xs): """Initialize the pending count for ops between two lists of Operations. - 'pending_count[op._id]' indicates the number of backprop inputs + 'pending_count[op]' indicates the number of backprop inputs to this operation. Args: - graph: a Graph. to_ops: list of Operations. from_ops: list of Operations. colocate_gradients_with_ops: Python bool. See docstring of gradients(). + func_graphs: list of function._FuncGraphs. This method will traverse through + these functions if they capture from_ops or any reachable ops. This is + useful if to_ops occur in a function and from_ops are in an outer function + or graph. + xs: list of Tensors. Returns: - A tuple containing: (1) the subset of to_ops ids reachable from from_ops - by a path of zero or more backpropagatable tensors, (2) a list of integers - indexed by operation id, indicating the number of backprop inputs to this - operation, and (3) a ControlFlowState object which is not None if the ops - between from_ops and to_ops contain control flow loops. + A tuple containing: (1) the subset of to_ops reachable from from_ops by a + path of zero or more backpropagatable tensors, (2) a mapping from operation + to the number of backprop inputs to that op, and (3) a ControlFlowState + object which is not None if the ops between from_ops and to_ops contain + control flow loops. """ # Mark reachable ops from from_ops. - reached_ops = [False] * (graph._last_id + 1) - _MarkReachedOps(from_ops, reached_ops) - # reached_ops[X] iff X is reachable from from_ops by a path of zero or more + reached_ops = set() + _MarkReachedOps(from_ops, reached_ops, func_graphs) + # X in reached_ops iff X is reachable from from_ops by a path of zero or more # backpropagatable tensors. - reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id]) # pylint: disable=protected-access + reachable_to_ops = set(op for op in to_ops if op in reached_ops) # Mark between ops. - between_ops = [False] * (graph._last_id + 1) + between_ops = set() between_op_list = [] queue = collections.deque() queue.extend(to_ops) while queue: op = queue.popleft() # We are interested in this op. - if reached_ops[op._id]: - between_ops[op._id] = True + if op in reached_ops: + between_ops.add(op) between_op_list.append(op) # Clear the boolean so we won't add the inputs again. - reached_ops[op._id] = False - for inp in op.inputs: + reached_ops.remove(op) + for inp in _Inputs(op, xs): queue.append(inp.op) - # between_ops[X] iff X is on a path of zero or more backpropagatable tensors + # X in between_ops iff X is on a path of zero or more backpropagatable tensors # between from_ops and to_ops # 'loop_state' is None if there are no while loops. @@ -201,11 +190,11 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): between_op_list, between_ops, colocate_gradients_with_ops) # Initialize pending count for between ops. - pending_count = [0] * (graph._last_id + 1) + pending_count = collections.defaultdict(int) for op in between_op_list: - for x in op.inputs: - if between_ops[x.op._id]: - pending_count[x.op._id] += 1 + for x in _Inputs(op, xs): + if x.op in between_ops: + pending_count[x.op] += 1 return reachable_to_ops, pending_count, loop_state @@ -324,22 +313,23 @@ def _VerifyGeneratedGradients(grads, op): "inputs %d" % (len(grads), op.node_def, len(op.inputs))) -def _StopOps(from_ops, stop_gradient_ops, pending_count): +def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): """The set of ops that terminate the gradient computation. This computes the frontier of the forward graph *before* which backprop should stop. Operations in the returned set will not be differentiated. This set is defined as the subset of `from_ops` containing ops that have no predecessor in `from_ops`. `pending_count` is the result of - `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops` - iff pending_count[op._id] > 0. + `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` + iff pending_count[op] > 0. In addition, none of `stop_gradient_ops` will be differentiated. Args: from_ops: list of Operations. stop_gradient_ops: list of Operations never to backprop through. - pending_count: List of integers, indexed by operation id. + pending_count: mapping from operation to number of backprop inputs. + xs: list of Tensors. Returns: The set of operations. @@ -347,13 +337,13 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count): stop_ops = set() for op in from_ops: is_stop_op = True - for inp in op.inputs: - if pending_count[inp.op._id] > 0: + for inp in _Inputs(op, xs): + if pending_count[inp.op] > 0: is_stop_op = False break if is_stop_op: - stop_ops.add(op._id) - stop_ops.update(op._id for op in stop_gradient_ops) # pylint: disable=protected-access + stop_ops.add(op) + stop_ops.update(op for op in stop_gradient_ops) return stop_ops @@ -367,17 +357,26 @@ def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pyli yield -def _SymGrad(op, out_grads): +def _IsPartitionedCall(op): + return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" + + +def _SymGrad(op, out_grads, xs): """Backprop through a function call node op given its outputs' gradients.""" - f_in = [x for x in op.inputs] + out_grads - f_types = [x.dtype for x in op.inputs] + f_in = [x for x in _Inputs(op, xs)] + out_grads + f_types = [x.dtype for x in _Inputs(op, xs)] f = attr_value_pb2.NameAttrList() - f.name = op.type + if _IsPartitionedCall(op): + f.name = op.get_attr("f").name + else: + f.name = op.type for k in op.node_def.attr: f.attr[k].CopyFrom(op.node_def.attr[k]) - # pylint: disable=protected-access - in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) - # pylint: enable=protected-access + # TODO(apassos) use a better dtype here + in_grads = functional_ops.symbolic_gradient( + input=f_in, + Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types], + f=f) return in_grads @@ -418,7 +417,7 @@ def _MaybeCompile(scope, op, func, grad_fn): return grad_fn() -def _RaiseNoGradWrtInitialLoopValError(op, from_ops): +def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): """Raises an error if we backprop through a loop var.""" # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error # message. @@ -432,7 +431,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops): if curr_op in from_ops: target_op = curr_op break - queue.extend(t.op for t in curr_op.inputs) + queue.extend(t.op for t in _Inputs(curr_op, xs)) assert target_op raise ValueError( "Cannot compute gradient inside while loop with respect to op '%s'. " @@ -442,6 +441,68 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops): % target_op.name) +def _MaybeCaptured(t): + """If t is a captured value placeholder, returns the original captured value. + + Args: + t: Tensor + + Returns: + A tensor, potentially from a different Graph/function._FuncGraph. + """ + # pylint: disable=protected-access + if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder": + for input_t, placeholder_t in t.op.graph._captured.items(): + if t == placeholder_t: + return _MaybeCaptured(input_t) + # pylint: enable=protected-access + return t + + +# TODO(skyewm): plumbing xs through everywhere is ugly, consider making +# _GradientsHelper a class with xs as a member variable. +def _Inputs(op, xs): + """Returns the inputs of op, crossing closure boundaries where necessary. + + Args: + op: Operation + xs: list of Tensors we are differentiating w.r.t. + + Returns: + A list of tensors. The tensors may be from multiple + Graph/function._FuncGraphs if op is in a function._FuncGraph and has + captured inputs. + """ + if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access + # If we're differentiating w.r.t. `t`, do not attempt to traverse through it + # to a captured value. The algorithm needs to "see" `t` in this case, even + # if it's a function input for a captured value, whereas usually we'd like + # to traverse through these closures as if the captured value was the direct + # input to op. + return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs] + else: + return op.inputs + + +def _Consumers(t, func_graphs): + """Returns the consumers of t, crossing closure boundaries where necessary. + + Args: + t: Tensor + func_graphs: a list of function._FuncGraphs that may have captured t. + + Returns: + A list of tensors. The tensors will be from the current graph and/or + func_graphs. + """ + consumers = t.consumers() + for func in func_graphs: + for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access + if input_t == t: + consumers.extend(_Consumers(placeholder, func_graphs)) + return consumers + + @tf_export("gradients") def gradients(ys, xs, @@ -527,21 +588,38 @@ def gradients(ys, RuntimeError: if called in Eager mode. """ - # Creating the gradient graph for control flow mutates Operations. _lock - # ensures a Session.run call cannot occur between creating and mutating new - # ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # Creating the gradient graph for control flow mutates Operations. + # _mutation_lock ensures a Session.run call cannot occur between creating and + # mutating new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients) -def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, - gate_gradients, aggregation_method, stop_gradients): +def _GradientsHelper(ys, + xs, + grad_ys=None, + name="gradients", + colocate_gradients_with_ops=False, + gate_gradients=False, + aggregation_method=None, + stop_gradients=None, + src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): - raise RuntimeError("tf.gradients not supported when eager execution " - "is enabled. Use tf.contrib.eager.GradientTape " - "instead.") + raise RuntimeError("tf.gradients is not supported when eager execution " + "is enabled. Use tf.GradientTape instead.") + if src_graph is None: + src_graph = ops.get_default_graph() + + # If src_graph is a _FuncGraph (i.e. a function body), gather it and all + # ancestor graphs. This is necessary for correctly handling captured values. + func_graphs = [] + curr_graph = src_graph + while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access + func_graphs.append(curr_graph) + curr_graph = curr_graph._outer_graph # pylint: disable=protected-access + ys = _AsList(ys) xs = _AsList(xs) stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) @@ -576,12 +654,13 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, # Initialize the pending count for ops in the connected subgraph from ys # to the xs. if len(ys) > 1: - ys = [array_ops.identity(y) if y.consumers() else y for y in ys] + ys = [array_ops.identity(y) if _Consumers(y, func_graphs) else y + for y in ys] to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] stop_gradient_ops = [t.op for t in stop_gradients] reachable_to_ops, pending_count, loop_state = _PendingCount( - ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops) + to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs) # Iterate over the collected ops. # @@ -603,12 +682,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, for op in to_ops: # 'ready' handles the case where one output gradient relies on # another output's gradient. - # pylint: disable=protected-access - ready = (pending_count[op._id] == 0) - if ready and op._id not in to_ops_set and op._id in reachable_to_ops: - to_ops_set.add(op._id) + ready = (pending_count[op] == 0) + if ready and op not in to_ops_set and op in reachable_to_ops: + to_ops_set.add(op) queue.append(op) - # pylint: enable=protected-access if loop_state: loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) @@ -617,7 +694,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) queue.append(y.op) - stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count) + stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs) while queue: # generate gradient subgraph for op. op = queue.popleft() @@ -631,13 +708,19 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, grad_fn = None func_call = None + is_partitioned_call = _IsPartitionedCall(op) # pylint: disable=protected-access - is_func_call = ops.get_default_graph()._is_function(op.type) + is_func_call = ( + src_graph._is_function(op.type) or is_partitioned_call) # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) - if has_out_grads and (op._id not in stop_ops): + if has_out_grads and (op not in stop_ops): if is_func_call: - func_call = ops.get_default_graph()._get_function(op.type) + if is_partitioned_call: + func_call = src_graph._get_function( # pylint: disable=protected-access + compat.as_bytes(op.get_attr("f").name)) + else: + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access # Note that __defun is not set if the graph is # imported. If it's set, we prefer to access the original # defun. @@ -666,7 +749,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, op._control_flow_context.IsWhileContext() and op._control_flow_context == ops.get_default_graph()._get_control_flow_context()): - _RaiseNoGradWrtInitialLoopValError(op, from_ops) + _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs) # pylint: enable=protected-access if (grad_fn or is_func_call) and has_out_grads: @@ -687,7 +770,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i) with ops.name_scope(op.name + "_grad"): # pylint: disable=protected-access - with ops.get_default_graph()._original_op(op): + with src_graph._original_op(op): # pylint: enable=protected-access if grad_fn: # If grad_fn was found, do not use SymbolicGradient even for @@ -698,7 +781,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, # For function call ops, we add a 'SymbolicGradient' # node to the graph to compute gradients. in_grads = _MaybeCompile(grad_scope, op, func_call, - lambda: _SymGrad(op, out_grads)) + lambda: _SymGrad(op, out_grads, xs)) in_grads = _AsList(in_grads) _VerifyGeneratedGradients(in_grads, op) if gate_gradients and len([x for x in in_grads @@ -713,8 +796,8 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, else: # If no grad_fn is defined or none of out_grads is available, # just propagate a list of None backwards. - in_grads = [None] * len(op.inputs) - for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)): + in_grads = [None] * len(_Inputs(op, xs)) + for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)): if in_grad is not None: if (isinstance(in_grad, ops.Tensor) and t_in.dtype != dtypes.resource): @@ -732,7 +815,8 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, loop_state.ExitGradWhileContext(op, before=False) # Update pending count for the inputs of op and enqueue ready ops. - _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state) + _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, + xs) if loop_state: loop_state.PostProcessing() @@ -751,16 +835,14 @@ def _HasAnyNotNoneGrads(grads, op): return False -def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state): +def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, + xs): """Update pending count for the inputs of op and enqueue ready ops.""" - for x in op.inputs: - # pylint: disable=protected-access - pending_count[x.op._id] -= 1 - ready = (pending_count[x.op._id] == 0) + for x in _Inputs(op, xs): + pending_count[x.op] -= 1 + ready = (pending_count[x.op] == 0) if loop_state and not ready: - ready = ( - pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op)) - # pylint: enable=protected-access + ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) if ready: if control_flow_util.IsLoopExit(x.op): # if x is an exit without real gradient, defer processing them. @@ -1004,21 +1086,32 @@ def _AggregatedGrads(grads, logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), tensor_shape, used) else: - out_grad = math_ops._as_indexed_slices_list( - [g for g in out_grad if g is not None]) - out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad] - # Form IndexedSlices out of the concatenated values and - # indices. - out_grads[i] = ops.IndexedSlices( - array_ops.concat([x.values for x in out_grad], 0), - array_ops.concat([x.indices for x in out_grad], 0), - out_grad[0].dense_shape) + out_grads[i] = _AggregateIndexedSlicesGradients(out_grad) else: # not out_grad # out_grads[i] is [], thus its aggregation is simply None. out_grads[i] = None return out_grads +def _AggregateIndexedSlicesGradients(grads): + """Aggregates gradients of type `IndexedSlices` by concatenation.""" + if len(grads) < 1: + return None + elif len(grads) == 1: + return grads[0] + else: + grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access + [g for g in grads if g is not None]) + grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access + # Form IndexedSlices out of the concatenated values and indices. + concat_grad = ops.IndexedSlices( + array_ops.concat([x.values for x in grads], axis=0), + array_ops.concat([x.indices for x in grads], axis=0), + grads[0].dense_shape) + + return concat_grad + + # TODO(vrv): Make this available when we want to make it public. def _hessian_vector_product(ys, xs, v): """Multiply the Hessian of `ys` wrt `xs` by `v`. diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 096d0ce794e785716b71bd275f3c1fd19cf345f8..d02fcf4ee27c180003e5b026e486a4ec0ad11e7d 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -57,93 +57,8 @@ from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest -def _OpsBetween(graph, to_ops, from_ops): - """Build the list of operations between two lists of Operations. - - Args: - graph: a Graph. - to_ops: list of Operations. - from_ops: list of Operations. - - Returns: - The list of operations between "from_ops" and "to_ops", sorted by - decreasing operation id. This list contains all elements of to_ops. - - TODO(touts): Think about returning an empty list if from_ops are not - reachable from to_ops. Presently it returns to_ops in that case. - """ - # List of booleans, indexed by operation id, indicating if - # an op is reached from the output of "input_ops". - reached_ops = [False] * (graph._last_id + 1) - # We only care to reach up to "output_ops" so we mark the - # output ops as reached to avoid recursing past them. - for op in to_ops: - reached_ops[op._id] = True - gradients_impl._MarkReachedOps(from_ops, reached_ops) - between_ops = gradients_impl._GatherInputs(to_ops, reached_ops) - between_ops.sort(key=lambda x: -x._id) - return between_ops - - -@test_util.with_c_api class GradientsTest(test_util.TensorFlowTestCase): - def _OpNames(self, op_list): - return ["%s/%d" % (str(op.name), op._id) for op in op_list] - - def _assertOpListEqual(self, ops1, ops2): - self.assertEquals(self._OpNames(ops1), self._OpNames(ops2)) - - def testOpsBetweenSimple(self): - with ops.Graph().as_default() as g: - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - # Full graph - self._assertOpListEqual([t3.op, t2.op, t1.op], - _OpsBetween(g, [t3.op], [t1.op, t2.op])) - # Only t1, t3. - self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op])) - - def testOpsBetweenUnreachable(self): - with ops.Graph().as_default() as g: - t1 = constant(1.0) - t2 = constant(2.0) - _ = array_ops.stack([t1, t2]) - t4 = constant(1.0) - t5 = constant(2.0) - t6 = array_ops.stack([t4, t5]) - # Elements of to_ops are always listed. - self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op])) - - def testOpsBetweenCut(self): - with ops.Graph().as_default() as g: - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - t4 = constant([1.0]) - t5 = array_ops.concat([t4, t3], 0) - t6 = constant([2.0]) - t7 = array_ops.concat([t5, t6], 0) - self._assertOpListEqual([t7.op, t5.op, t4.op], - _OpsBetween(g, [t7.op], [t4.op])) - - def testOpsBetweenCycle(self): - with ops.Graph().as_default() as g: - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - t4 = array_ops.concat([t3, t3, t3], 0) - t5 = constant([1.0]) - t6 = array_ops.concat([t4, t5], 0) - t7 = array_ops.concat([t6, t3], 0) - self._assertOpListEqual([t6.op, t4.op, t3.op], - _OpsBetween(g, [t6.op], [t3.op])) - self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op], - _OpsBetween(g, [t7.op], [t1.op, t5.op])) - self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op], - _OpsBetween(g, [t6.op], [t2.op, t5.op])) - def testGradients(self): with ops.Graph().as_default(): inp = constant(1.0, shape=[32, 100], name="in") @@ -522,6 +437,96 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): grad_func=grad_func, python_grad_func=self._PythonGradient) f.add_to_graph(ops.Graph()) + def testGradientWrtCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.Defun() + def Foo(): + y = math_ops.multiply(x, 2.0, name="y") + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.test_session() as sess: + self.assertEqual(sess.run(f), 2.0) + + def testGradientOfCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + y = math_ops.multiply(x, 2.0, name="y") + + @function.Defun() + def Foo(): + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.test_session() as sess: + self.assertEqual(sess.run(f), 2.0) + + def testCapturedResourceVariable(self): + with ops.Graph().as_default(): + var = resource_variable_ops.ResourceVariable(1.0, name="var") + + @function.Defun() + def Foo(): + y = math_ops.multiply(var, 2.0, name="y") + g = gradients_impl.gradients(y, var) + return g[0] + + f = Foo() + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertEqual(sess.run(f), 2.0) + + def testCapturedNested(self): + with ops.Graph().as_default(): + x1 = constant_op.constant(1.0, name="x1") + x2 = constant_op.constant(2.0, name="x2") + x3 = math_ops.multiply(x1, x2, name="x3") + + @function.Defun() + def Outer(): + outer1 = array_ops.identity(x1, name="outer1") + + @function.Defun() + def Inner(): + inner1 = array_ops.identity(outer1, name="inner1") + inner2 = array_ops.identity(x2, name="inner2") + inner3 = array_ops.identity(x3, name="inner3") + return gradients_impl.gradients([inner1, inner2, inner3, x1], + [x1, x2]) + + return Inner() + + x1_grad, x2_grad = Outer() + with self.test_session() as sess: + # 1.0 + None + 2.0 + 1.0 = 4.0 + self.assertEqual(sess.run(x1_grad), 4.0) + # None + 1.0 + 1.0 + None = 2.0 + self.assertEqual(sess.run(x2_grad), 2.0) + + def testCapturedFromFunction(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.Defun() + def Outer(): + y = math_ops.multiply(x, 2.0, name="y") + + @function.Defun() + def Inner(): + z = math_ops.multiply(y, 3.0, name="z") + g = gradients_impl.gradients(z, y) + return g[0] + + return Inner() + + z_grad = Outer() + with self.test_session() as sess: + self.assertEqual(sess.run(z_grad), 3.0) + class StopGradientTest(test_util.TensorFlowTestCase): @@ -948,5 +953,53 @@ class CustomGradientTest(test_util.TensorFlowTestCase): self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) +class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase): + + def _assert_indexed_slices_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def testNoGradients(self): + self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([])) + + def testOneGradient(self): + t = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + result = gradients_impl._AggregateIndexedSlicesGradients([t]) + self._assert_indexed_slices_equal(t, result) + + def testMultipleGradients(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices(constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) + self._assert_indexed_slices_equal(total, result) + + def testMultipleGradientsWithNones(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices(constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]])) + t3 = None + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3]) + self._assert_indexed_slices_equal(total, result) + + def testMixedTensorAndIndexedSlices(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) + self._assert_indexed_slices_equal(total, result) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 54e27b87dfb55a366ecb8f05d164f724bd5ee70a..a2eae452ae551eb1792e5b21477d31c55d64fd79 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -56,6 +57,7 @@ ops.NotDifferentiable('NonMaxSuppression') ops.NotDifferentiable('NonMaxSuppressionV2') +# pylint: disable=invalid-name def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. @@ -258,14 +260,14 @@ def random_flip_up_down(image, seed=None): dimension, which is `height`. Otherwise output the image as-is. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. seed: A Python integer. Used to create a random seed. See @{tf.set_random_seed} for behavior. Returns: - A 3-D tensor of the same type and shape as `image`. - + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ @@ -280,13 +282,14 @@ def random_flip_left_right(image, seed=None): second dimension, which is `width`. Otherwise output the image as-is. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. seed: A Python integer. Used to create a random seed. See @{tf.set_random_seed} for behavior. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. @@ -297,7 +300,8 @@ def random_flip_left_right(image, seed=None): def _random_flip(image, flip_index, seed, scope_name): """Randomly (50% chance) flip an image along axis `flip_index`. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. flip_index: The dimension along which to flip the image. Vertical: 0, Horizontal: 1 seed: A Python integer. Used to create a random seed. See @@ -306,22 +310,37 @@ def _random_flip(image, flip_index, seed, scope_name): scope_name: Name of the scope in which the ops are added. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ with ops.name_scope(None, scope_name, [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) - mirror_cond = math_ops.less(uniform_random, .5) - result = control_flow_ops.cond( - mirror_cond, - lambda: array_ops.reverse(image, [flip_index]), - lambda: image, - name=scope) - return fix_image_flip_shape(image, result) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) + mirror_cond = math_ops.less(uniform_random, .5) + result = control_flow_ops.cond( + mirror_cond, + lambda: array_ops.reverse(image, [flip_index]), + lambda: image, + name=scope + ) + return fix_image_flip_shape(image, result) + elif shape.ndims == 4: + uniform_random = random_ops.random_uniform( + [array_ops.shape(image)[0]], 0, 1.0, seed=seed + ) + mirror_cond = math_ops.less(uniform_random, .5) + return array_ops.where( + mirror_cond, + image, + functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype) + ) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.flip_left_right') @@ -523,7 +542,7 @@ def transpose_image(image): @tf_export('image.central_crop') def central_crop(image, central_fraction): - """Crop the central region of the image. + """Crop the central region of the image(s). Remove the outer parts of an image but retain the central region of the image along each dimension. If we specify central_fraction = 0.5, this function @@ -536,15 +555,19 @@ def central_crop(image, central_fraction): | | where "X" is the central 50% of the image. -------- + This function works on either a single image (`image` is a 3-D Tensor), or a + batch of images (`image` is a 4-D Tensor). + Args: - image: 3-D float Tensor of shape [height, width, depth] + image: Either a 3-D float Tensor of shape [height, width, depth], or a 4-D + Tensor of shape [batch_size, height, width, depth]. central_fraction: float (0, 1], fraction of size to crop Raises: ValueError: if central_crop_fraction is not within (0, 1]. Returns: - 3-D float Tensor + 3-D / 4-D float Tensor, as per the input. """ with ops.name_scope(None, 'central_crop', [image]): image = ops.convert_to_tensor(image, name='image') @@ -553,24 +576,75 @@ def central_crop(image, central_fraction): if central_fraction == 1.0: return image - image = _Assert3DImage(image) + _AssertAtLeast3DImage(image) + rank = image.get_shape().ndims + if rank != 3 and rank != 4: + raise ValueError('`image` should either be a Tensor with rank = 3 or ' + 'rank = 4. Had rank = {}.'.format(rank)) + + # Helper method to return the `idx`-th dimension of `tensor`, along with + # a boolean signifying if the dimension is dynamic. + def _get_dim(tensor, idx): + static_shape = tensor.get_shape()[idx].value + if static_shape is not None: + return static_shape, False + return array_ops.shape(tensor)[idx], True + + # Get the height, width, depth (and batch size, if the image is a 4-D + # tensor). + if rank == 3: + img_h, dynamic_h = _get_dim(image, 0) + img_w, dynamic_w = _get_dim(image, 1) + img_d = image.get_shape()[2] + else: + img_bs = image.get_shape()[0] + img_h, dynamic_h = _get_dim(image, 1) + img_w, dynamic_w = _get_dim(image, 2) + img_d = image.get_shape()[3] + + # Compute the bounding boxes for the crop. The type and value of the + # bounding boxes depend on the `image` tensor's rank and whether / not the + # dimensions are statically defined. + if dynamic_h: + img_hd = math_ops.to_double(img_h) + bbox_h_start = math_ops.to_int32((img_hd - img_hd * central_fraction) / 2) + else: + img_hd = float(img_h) + bbox_h_start = int((img_hd - img_hd * central_fraction) / 2) - img_shape = array_ops.shape(image) - depth = image.get_shape()[2] - img_h = math_ops.to_double(img_shape[0]) - img_w = math_ops.to_double(img_shape[1]) - bbox_h_start = math_ops.to_int32((img_h - img_h * central_fraction) / 2) - bbox_w_start = math_ops.to_int32((img_w - img_w * central_fraction) / 2) + if dynamic_w: + img_wd = math_ops.to_double(img_w) + bbox_w_start = math_ops.to_int32((img_wd - img_wd * central_fraction) / 2) + else: + img_wd = float(img_w) + bbox_w_start = int((img_wd - img_wd * central_fraction) / 2) - bbox_h_size = img_shape[0] - bbox_h_start * 2 - bbox_w_size = img_shape[1] - bbox_w_start * 2 + bbox_h_size = img_h - bbox_h_start * 2 + bbox_w_size = img_w - bbox_w_start * 2 + + if rank == 3: + bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0]) + bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1]) + else: + bbox_begin = array_ops.stack([0, bbox_h_start, bbox_w_start, 0]) + bbox_size = array_ops.stack([-1, bbox_h_size, bbox_w_size, -1]) - bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0]) - bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1]) image = array_ops.slice(image, bbox_begin, bbox_size) - # The first two dimensions are dynamic and unknown. - image.set_shape([None, None, depth]) + # Reshape the `image` tensor to the desired size. + if rank == 3: + image.set_shape([ + None if dynamic_h else bbox_h_size, + None if dynamic_w else bbox_w_size, + img_d + ]) + else: + image.set_shape([ + img_bs, + None if dynamic_h else bbox_h_size, + None if dynamic_w else bbox_w_size, + img_d + ]) return image @@ -866,12 +940,13 @@ class ResizeMethod(object): def resize_images(images, size, method=ResizeMethod.BILINEAR, - align_corners=False): + align_corners=False, + preserve_aspect_ratio=False): """Resize `images` to `size` using the specified `method`. Resized images will be distorted if their original aspect ratio is not the same as `size`. To avoid distortions see - @{tf.image.resize_image_with_crop_or_pad}. + @{tf.image.resize_image_with_pad}. `method` can be one of: @@ -898,6 +973,10 @@ def resize_images(images, align_corners: bool. If True, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels. Defaults to `False`. + preserve_aspect_ratio: Whether to preserve the aspect ratio. If this is set, + then `images` will be resized to a size that fits in `size` while + preserving the aspect ratio of the original image. Scales up the image if + `size` is bigger than the current size of the `image`. Defaults to False. Raises: ValueError: if the shape of `images` is incompatible with the @@ -936,6 +1015,28 @@ def resize_images(images, new_height_const = size_const_as_shape[0].value new_width_const = size_const_as_shape[1].value + if preserve_aspect_ratio: + # Get the current shapes of the image, even if dynamic. + _, current_height, current_width, _ = _ImageDimensions(images, rank=4) + + # do the computation to find the right scale and height/width. + scale_factor_height = (math_ops.to_float(new_height_const) / + math_ops.to_float(current_height)) + scale_factor_width = (math_ops.to_float(new_width_const) / + math_ops.to_float(current_width)) + scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width) + scaled_height_const = math_ops.to_int32(scale_factor * + math_ops.to_float(current_height)) + scaled_width_const = math_ops.to_int32(scale_factor * + math_ops.to_float(current_width)) + + # NOTE: Reset the size and other constants used later. + size = ops.convert_to_tensor([scaled_height_const, scaled_width_const], + dtypes.int32, name='size') + size_const_as_shape = tensor_util.constant_value_as_shape(size) + new_height_const = size_const_as_shape[0].value + new_width_const = size_const_as_shape[1].value + # If we can determine that the height and width will be unmodified by this # transformation, we avoid performing the resize. if all(x is not None @@ -969,6 +1070,106 @@ def resize_images(images, return images +@tf_export('image.resize_image_with_pad') +def resize_image_with_pad(image, + target_height, + target_width, + method=ResizeMethod.BILINEAR): + """Resizes and pads an image to a target width and height. + + Resizes an image to a target width and height by keeping + the aspect ratio the same without distortion. If the target + dimensions don't match the image dimensions, the image + is resized and then padded with zeroes to match requested + dimensions. + + Args: + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. + target_height: Target height. + target_width: Target width. + method: Method to use for resizing image. See `resize_images()` + + Raises: + ValueError: if `target_height` or `target_width` are zero or negative. + + Returns: + Resized and padded image. + If `images` was 4-D, a 4-D float Tensor of shape + `[batch, new_height, new_width, channels]`. + If `images` was 3-D, a 3-D float Tensor of shape + `[new_height, new_width, channels]`. + """ + with ops.name_scope(None, 'resize_image_with_pad', [image]): + image = ops.convert_to_tensor(image, name='image') + image_shape = image.get_shape() + is_batch = True + if image_shape.ndims == 3: + is_batch = False + image = array_ops.expand_dims(image, 0) + elif image_shape.ndims is None: + is_batch = False + image = array_ops.expand_dims(image, 0) + image.set_shape([None] * 4) + elif image_shape.ndims != 4: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') + + assert_ops = _CheckAtLeast3DImage(image, require_static=False) + assert_ops += _assert(target_width > 0, ValueError, + 'target_width must be > 0.') + assert_ops += _assert(target_height > 0, ValueError, + 'target_height must be > 0.') + + image = control_flow_ops.with_dependencies(assert_ops, image) + + def max_(x, y): + if _is_tensor(x) or _is_tensor(y): + return math_ops.maximum(x, y) + else: + return max(x, y) + + _, height, width, _ = _ImageDimensions(image, rank=4) + + # convert values to float, to ease divisions + f_height = math_ops.cast(height, dtype=dtypes.float64) + f_width = math_ops.cast(width, dtype=dtypes.float64) + f_target_height = math_ops.cast(target_height, dtype=dtypes.float64) + f_target_width = math_ops.cast(target_width, dtype=dtypes.float64) + + # Find the ratio by which the image must be adjusted + # to fit within the target + ratio = max_(f_width / f_target_width, f_height / f_target_height) + resized_height_float = f_height / ratio + resized_width_float = f_width / ratio + resized_height = math_ops.cast( + math_ops.floor(resized_height_float), dtype=dtypes.int32) + resized_width = math_ops.cast( + math_ops.floor(resized_width_float), dtype=dtypes.int32) + + padding_height = (f_target_height - resized_height_float) / 2 + padding_width = (f_target_width - resized_width_float) / 2 + f_padding_height = math_ops.floor(padding_height) + f_padding_width = math_ops.floor(padding_width) + p_height = max_(0, math_ops.cast(f_padding_height, dtype=dtypes.int32)) + p_width = max_(0, math_ops.cast(f_padding_width, dtype=dtypes.int32)) + + # Resize first, then pad to meet requested dimensions + resized = resize_images(image, [resized_height, resized_width], method) + + padded = pad_to_bounding_box(resized, p_height, p_width, target_height, + target_width) + + if padded.get_shape().ndims is None: + raise ValueError('padded contains no shape.') + + _ImageDimensions(padded, rank=4) + + if not is_batch: + padded = array_ops.squeeze(padded, squeeze_dims=[0]) + + return padded + + @tf_export('image.per_image_standardization') def per_image_standardization(image): """Linearly scales `image` to have zero mean and unit norm. @@ -1396,6 +1597,75 @@ def adjust_hue(image, delta, name=None): return convert_image_dtype(rgb_altered, orig_dtype) +# pylint: disable=invalid-name +@tf_export('image.random_jpeg_quality') +def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None): + """Randomly changes jpeg encoding quality for inducing jpeg noise. + + `min_jpeg_quality` must be in the interval `[0, 100]` and less than + `max_jpeg_quality`. + `max_jpeg_quality` must be in the interval `[0, 100]`. + + Args: + image: RGB image or images. Size of the last dimension must be 3. + min_jpeg_quality: Minimum jpeg encoding quality to use. + max_jpeg_quality: Maximum jpeg encoding quality to use. + seed: An operation-specific seed. It will be used in conjunction + with the graph-level seed to determine the real seeds that will be + used in this operation. Please see the documentation of + set_random_seed for its interaction with the graph-level random seed. + + Returns: + Adjusted image(s), same shape and DType as `image`. + + Raises: + ValueError: if `min_jpeg_quality` or `max_jpeg_quality` is invalid. + """ + if (min_jpeg_quality < 0 or max_jpeg_quality < 0 or + min_jpeg_quality > 100 or max_jpeg_quality > 100): + raise ValueError('jpeg encoding range must be between 0 and 100.') + + if min_jpeg_quality >= max_jpeg_quality: + raise ValueError('`min_jpeg_quality` must be less than `max_jpeg_quality`.') + + np.random.seed(seed) + jpeg_quality = np.random.randint(min_jpeg_quality, max_jpeg_quality) + return adjust_jpeg_quality(image, jpeg_quality) + + +@tf_export('image.adjust_jpeg_quality') +def adjust_jpeg_quality(image, jpeg_quality, name=None): + """Adjust jpeg encoding quality of an RGB image. + + This is a convenience method that adjusts jpeg encoding quality of an + RGB image. + + `image` is an RGB image. The image's encoding quality is adjusted + to `jpeg_quality`. + `jpeg_quality` must be in the interval `[0, 100]`. + + Args: + image: RGB image or images. Size of the last dimension must be 3. + jpeg_quality: int. jpeg encoding quality. + name: A name for this operation (optional). + + Returns: + Adjusted image(s), same shape and DType as `image`. + """ + with ops.name_scope(name, 'adjust_jpeg_quality', [image]) as name: + image = ops.convert_to_tensor(image, name='image') + # Remember original dtype to so we can convert back if needed + orig_dtype = image.dtype + # Convert to uint8 + image = convert_image_dtype(image, dtypes.uint8) + # Encode image to jpeg with given jpeg quality + image = gen_image_ops.encode_jpeg(image, quality=jpeg_quality) + # Decode jpeg image + image = gen_image_ops.decode_jpeg(image) + # Convert back to original dtype and return + return convert_image_dtype(image, orig_dtype) + + @tf_export('image.random_saturation') def random_saturation(image, lower, upper, seed=None): """Adjust the saturation of an RGB image by a random factor. @@ -1483,13 +1753,13 @@ def is_jpeg(contents, name=None): @tf_export('image.decode_image') -def decode_image(contents, channels=None, name=None): +def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, and `decode_png`. Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the - appropriate operation to convert the input bytes `string` into a `Tensor` of - type `uint8`. + appropriate operation to convert the input bytes `string` into a `Tensor` + of type `dtype`. Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D @@ -1501,10 +1771,11 @@ def decode_image(contents, channels=None, name=None): contents: 0-D `string`. The encoded image bytes. channels: An optional `int`. Defaults to `0`. Number of color channels for the decoded image. + dtype: The desired DType of the returned `Tensor`. name: A name for the operation (optional) Returns: - `Tensor` with type `uint8` with shape `[height, width, num_channels]` for + `Tensor` with type `dtype` and shape `[height, width, num_channels]` for BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for GIF images. @@ -1528,7 +1799,7 @@ def decode_image(contents, channels=None, name=None): channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_decode, assert_channels]): - return gen_image_ops.decode_bmp(contents) + return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype) def _gif(): # Create assert to make sure that channels is not set to 1 @@ -1541,7 +1812,7 @@ def decode_image(contents, channels=None, name=None): channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_channels]): - return gen_image_ops.decode_gif(contents) + return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype) def check_gif(): # Create assert op to check that bytes are GIF decodable @@ -1550,7 +1821,11 @@ def decode_image(contents, channels=None, name=None): def _png(): """Decodes a PNG image.""" - return gen_image_ops.decode_png(contents, channels) + return convert_image_dtype( + gen_image_ops.decode_png(contents, channels, + dtype=dtypes.uint8 + if dtype == dtypes.uint8 + else dtypes.uint16), dtype) def check_png(): """Checks if an image is PNG.""" @@ -1566,7 +1841,8 @@ def decode_image(contents, channels=None, name=None): 'images') assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_channels]): - return gen_image_ops.decode_jpeg(contents, channels) + return convert_image_dtype( + gen_image_ops.decode_jpeg(contents, channels), dtype) # Decode normal JPEG images (start with \xff\xd8\xff\xe0) # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1). @@ -1727,7 +2003,7 @@ def sample_distorted_bounding_box(image_size, width / height within this range. area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. The cropped area of the image must contain a fraction of the - supplied image within in this range. + supplied image within this range. max_attempts: An optional `int`. Defaults to `100`. Number of attempts at generating a cropped region of the image of the specified constraints. After `max_attempts` failures, return the @@ -1772,7 +2048,7 @@ def non_max_suppression(boxes, scores, max_output_size, iou_threshold=0.5, - score_threshold=0.0, + score_threshold=float('-inf'), name=None): """Greedily selects a subset of bounding boxes in descending order of score. diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index c437c12c2744792eaee197bf7d2a5f2b75d280bf..cf9761803bf9654e21ec12e1f1c7193b3e88c020 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -533,6 +533,37 @@ class FlipImageBenchmark(test.Benchmark): iters=benchmark_rounds, wall_time=step_time) + def _benchmarkBatchedRandomFlipLeftRight(self, device, cpu_count): + image_shape = [16, 299, 299, 3] + warmup_rounds = 100 + benchmark_rounds = 1000 + config = config_pb2.ConfigProto() + if cpu_count is not None: + config.inter_op_parallelism_threads = 1 + config.intra_op_parallelism_threads = cpu_count + with session.Session("", graph=ops.Graph(), config=config) as sess: + with ops.device(device): + inputs = variables.Variable( + random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255, + trainable=False, + dtype=dtypes.float32) + run_op = image_ops.random_flip_left_right(inputs) + sess.run(variables.global_variables_initializer()) + for i in xrange(warmup_rounds + benchmark_rounds): + if i == warmup_rounds: + start = time.time() + sess.run(run_op) + end = time.time() + step_time = (end - start) / benchmark_rounds + tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all") + print("benchmarkBatchedRandomFlipLeftRight_16_299_299_3_%s step_time: " + "%.2f us" % + (tag, step_time * 1e6)) + self.report_benchmark( + name="benchmarkBatchedRandomFlipLeftRight_16_299_299_3_%s" % (tag), + iters=benchmark_rounds, + wall_time=step_time) + def benchmarkFlipLeftRightCpu1(self): self._benchmarkFlipLeftRight("/cpu:0", 1) @@ -551,6 +582,15 @@ class FlipImageBenchmark(test.Benchmark): def benchmarkRandomFlipLeftRightGpu(self): self._benchmarkRandomFlipLeftRight(test.gpu_device_name(), None) + def benchmarkBatchedRandomFlipLeftRightCpu1(self): + self._benchmarkBatchedRandomFlipLeftRight("/cpu:0", 1) + + def benchmarkBatchedRandomFlipLeftRightCpuAll(self): + self._benchmarkBatchedRandomFlipLeftRight("/cpu:0", None) + + def benchmarkBatchedRandomFlipLeftRightGpu(self): + self._benchmarkBatchedRandomFlipLeftRight(test.gpu_device_name(), None) + class AdjustHueBenchmark(test.Benchmark): @@ -987,7 +1027,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) - y = image_ops.random_flip_left_right(x_tf) + y = image_ops.random_flip_left_right(x_tf, seed=seed) self.assertTrue(y.op.name.startswith("random_flip_left_right")) count_flipped = 0 @@ -1008,6 +1048,50 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): self.assertGreaterEqual(count_flipped, 20) self.assertGreaterEqual(count_unflipped, 20) + def testRandomFlipLeftRightWithBatch(self): + batch_size = 16 + seed = 42 + + # create single item of test data + x_np_raw = np.array( + [[1, 2, 3], [1, 2, 3]], dtype=np.uint8 + ).reshape([1, 2, 3, 1]) + y_np_raw = np.array( + [[3, 2, 1], [3, 2, 1]], dtype=np.uint8 + ).reshape([1, 2, 3, 1]) + + # create batched test data + x_np = np.vstack([x_np_raw for _ in range(batch_size)]) + y_np = np.vstack([y_np_raw for _ in range(batch_size)]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.random_flip_left_right(x_tf, seed=seed) + self.assertTrue(y.op.name.startswith("random_flip_left_right")) + + count_flipped = 0 + count_unflipped = 0 + for _ in range(100): + y_tf = y.eval() + + # check every element of the batch + for i in range(batch_size): + if y_tf[i][0][0] == 1: + self.assertAllEqual(y_tf[i], x_np[i]) + count_unflipped += 1 + else: + self.assertAllEqual(y_tf[i], y_np[i]) + count_flipped += 1 + + # 100 trials, each containing batch_size elements + # Mean: 50 * batch_size + # Std Dev: ~5 * sqrt(batch_size) + # Six Sigma: 50 * batch_size - (5 * 6 * sqrt(batch_size)) + # = 50 * batch_size - 30 * sqrt(batch_size) = 800 - 30 * 4 = 680 + six_sigma = 50 * batch_size - 30 * np.sqrt(batch_size) + self.assertGreaterEqual(count_flipped, six_sigma) + self.assertGreaterEqual(count_unflipped, six_sigma) + def testInvolutionUpDown(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) @@ -1057,9 +1141,11 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) + seed = 42 + with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) - y = image_ops.random_flip_up_down(x_tf, seed=42) + y = image_ops.random_flip_up_down(x_tf, seed=seed) self.assertTrue(y.op.name.startswith("random_flip_up_down")) count_flipped = 0 count_unflipped = 0 @@ -1079,6 +1165,50 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): self.assertGreaterEqual(count_flipped, 20) self.assertGreaterEqual(count_unflipped, 20) + def testRandomFlipUpDownWithBatch(self): + batch_size = 16 + seed = 42 + + # create single item of test data + x_np_raw = np.array( + [[1, 2, 3], [4, 5, 6]], dtype=np.uint8 + ).reshape([1, 2, 3, 1]) + y_np_raw = np.array( + [[4, 5, 6], [1, 2, 3]], dtype=np.uint8 + ).reshape([1, 2, 3, 1]) + + # create batched test data + x_np = np.vstack([x_np_raw for _ in range(batch_size)]) + y_np = np.vstack([y_np_raw for _ in range(batch_size)]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.random_flip_up_down(x_tf, seed=seed) + self.assertTrue(y.op.name.startswith("random_flip_up_down")) + + count_flipped = 0 + count_unflipped = 0 + for _ in range(100): + y_tf = y.eval() + + # check every element of the batch + for i in range(batch_size): + if y_tf[i][0][0] == 1: + self.assertAllEqual(y_tf[i], x_np[i]) + count_unflipped += 1 + else: + self.assertAllEqual(y_tf[i], y_np[i]) + count_flipped += 1 + + # 100 trials, each containing batch_size elements + # Mean: 50 * batch_size + # Std Dev: ~5 * sqrt(batch_size) + # Six Sigma: 50 * batch_size - (5 * 6 * sqrt(batch_size)) + # = 50 * batch_size - 30 * sqrt(batch_size) = 800 - 30 * 4 = 680 + six_sigma = 50 * batch_size - 30 * np.sqrt(batch_size) + self.assertGreaterEqual(count_flipped, six_sigma) + self.assertGreaterEqual(count_unflipped, six_sigma) + def testInvolutionTranspose(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) @@ -1156,6 +1286,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): #Ops that support 4D input for op in [ image_ops.flip_left_right, image_ops.flip_up_down, + image_ops.random_flip_left_right, image_ops.random_flip_up_down, image_ops.transpose_image, image_ops.rot90 ]: transformed_unknown_dims_4 = op(p_unknown_dims_4) @@ -1166,14 +1297,6 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): "must be at least three-dimensional"): op(p_wrong_rank) - for op in [ - image_ops.random_flip_left_right, - image_ops.random_flip_up_down, - ]: - with self.assertRaisesRegexp(ValueError, "must be three-dimensional"): - op(p_wrong_rank) - - def testRot90GroupOrder(self): image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3]) with self.test_session(use_gpu=True): @@ -1208,41 +1331,6 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_np = np.rot90(image, k=k, axes=(1, 2)) self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k})) -class RandomFlipTest(test_util.TensorFlowTestCase): - - def testRandomLeftRight(self): - x_np = np.array([0, 1], dtype=np.uint8).reshape([1, 2, 1]) - num_iterations = 500 - - hist = [0, 0] - with self.test_session(use_gpu=True): - x_tf = constant_op.constant(x_np, shape=x_np.shape) - y = image_ops.random_flip_left_right(x_tf) - for _ in xrange(num_iterations): - y_np = y.eval().flatten()[0] - hist[y_np] += 1 - - # Ensure that each entry is observed within 4 standard deviations. - four_stddev = 4.0 * np.sqrt(num_iterations / 2.0) - self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev) - - def testRandomUpDown(self): - x_np = np.array([0, 1], dtype=np.uint8).reshape([2, 1, 1]) - num_iterations = 500 - - hist = [0, 0] - with self.test_session(use_gpu=True): - x_tf = constant_op.constant(x_np, shape=x_np.shape) - y = image_ops.random_flip_up_down(x_tf) - for _ in xrange(num_iterations): - y_np = y.eval().flatten()[0] - hist[y_np] += 1 - - # Ensure that each entry is observed within 4 standard deviations. - four_stddev = 4.0 * np.sqrt(num_iterations / 2.0) - self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev) - - class AdjustContrastTest(test_util.TensorFlowTestCase): def _testContrast(self, x_np, y_np, contrast_factor): @@ -1585,14 +1673,16 @@ class CentralCropTest(test_util.TensorFlowTestCase): self.assertEqual(y.get_shape().as_list(), post_shape) def testNoOp(self): - x_shape = [13, 9, 3] - x_np = np.ones(x_shape, dtype=np.float32) - with self.test_session(use_gpu=True): - x = constant_op.constant(x_np, shape=x_shape) - y = image_ops.central_crop(x, 1.0) - y_tf = y.eval() - self.assertAllEqual(y_tf, x_np) - self.assertEqual(y.op.name, x.op.name) + x_shapes = [[13, 9, 3], [5, 13, 9, 3]] + for x_shape in x_shapes: + x_np = np.ones(x_shape, dtype=np.float32) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + x = constant_op.constant(x_np, shape=x_shape) + y = image_ops.central_crop(x, 1.0) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + self.assertEqual(y.op.name, x.op.name) def testCropping(self): x_shape = [4, 8, 1] @@ -1601,6 +1691,23 @@ class CentralCropTest(test_util.TensorFlowTestCase): [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8]], dtype=np.int32).reshape(x_shape) y_np = np.array([[3, 4, 5, 6], [3, 4, 5, 6]]).reshape([2, 4, 1]) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + x = constant_op.constant(x_np, shape=x_shape) + y = image_ops.central_crop(x, 0.5) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + self.assertAllEqual(y_tf.shape, y_np.shape) + + x_shape = [2, 4, 8, 1] + x_np = np.array( + [[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8], + [8, 7, 6, 5, 4, 3, 2, 1], [8, 7, 6, 5, 4, 3, 2, 1], + [8, 7, 6, 5, 4, 3, 2, 1], [8, 7, 6, 5, 4, 3, 2, 1]], + dtype=np.int32).reshape(x_shape) + y_np = np.array([[[3, 4, 5, 6], [3, 4, 5, 6]], + [[6, 5, 4, 3], [6, 5, 4, 3]]]).reshape([2, 2, 4, 1]) with self.test_session(use_gpu=True): x = constant_op.constant(x_np, shape=x_shape) y = image_ops.central_crop(x, 0.5) @@ -1610,52 +1717,87 @@ class CentralCropTest(test_util.TensorFlowTestCase): def testCropping2(self): # Test case for 10315 - x_shape = [240, 320, 3] - x_np = np.zeros(x_shape, dtype=np.int32) - y_np = np.zeros([80, 106, 3], dtype=np.int32) - with self.test_session(use_gpu=True): - x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32) - y = image_ops.central_crop(x, 0.33) - y_tf = y.eval(feed_dict={x: x_np}) - self.assertAllEqual(y_tf, y_np) - self.assertAllEqual(y_tf.shape, y_np.shape) + x_shapes = [[240, 320, 3], [5, 240, 320, 3]] + expected_y_shapes = [[80, 106, 3], [5, 80, 106, 3]] + + for x_shape, y_shape in zip(x_shapes, expected_y_shapes): + x_np = np.zeros(x_shape, dtype=np.int32) + y_np = np.zeros(y_shape, dtype=np.int32) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32) + y = image_ops.central_crop(x, 0.33) + y_tf = y.eval(feed_dict={x: x_np}) + self.assertAllEqual(y_tf, y_np) + self.assertAllEqual(y_tf.shape, y_np.shape) def testShapeInference(self): - # Test no-op fraction=1.0 + # Test no-op fraction=1.0, with 3-D tensors. self._assertShapeInference([50, 60, 3], 1.0, [50, 60, 3]) self._assertShapeInference([None, 60, 3], 1.0, [None, 60, 3]) self._assertShapeInference([50, None, 3], 1.0, [50, None, 3]) self._assertShapeInference([None, None, 3], 1.0, [None, None, 3]) self._assertShapeInference([50, 60, None], 1.0, [50, 60, None]) self._assertShapeInference([None, None, None], 1.0, [None, None, None]) - self._assertShapeInference(None, 1.0, None) - # TODO(toddw): Currently central_crop() doesn't infer the result shape even - # when it's possible. If we change it to do so, we can test as follows: - # - # self._assertShapeInference([50, 60, 3], 0.5, [25, 30, 3]) - # self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3]) - # self._assertShapeInference([50, None, 3], 0.5, [25, None, 3]) - # self._assertShapeInference([None, None, 3], 0.5, [None, None, 3]) - # self._assertShapeInference([50, 60, None], 0.5, [25, 30, None]) - # self._assertShapeInference([None, None, None], 0.5, [None, None, None]) - # self._assertShapeInference(None, 0.5, None) - def testError(self): + # Test no-op fraction=0.5, with 3-D tensors. + self._assertShapeInference([50, 60, 3], 0.5, [26, 30, 3]) + self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3]) + self._assertShapeInference([50, None, 3], 0.5, [26, None, 3]) + self._assertShapeInference([None, None, 3], 0.5, [None, None, 3]) + self._assertShapeInference([50, 60, None], 0.5, [26, 30, None]) + self._assertShapeInference([None, None, None], 0.5, [None, None, None]) + + # Test no-op fraction=1.0, with 4-D tensors. + self._assertShapeInference([5, 50, 60, 3], 1.0, [5, 50, 60, 3]) + self._assertShapeInference([5, None, 60, 3], 1.0, [5, None, 60, 3]) + self._assertShapeInference([5, 50, None, 3], 1.0, [5, 50, None, 3]) + self._assertShapeInference([5, None, None, 3], 1.0, [5, None, None, 3]) + self._assertShapeInference([5, 50, 60, None], 1.0, [5, 50, 60, None]) + self._assertShapeInference([5, None, None, None], 1.0, + [5, None, None, None]) + self._assertShapeInference([None, None, None, None], 1.0, + [None, None, None, None]) + + # Test no-op fraction=0.5, with 4-D tensors. + self._assertShapeInference([5, 50, 60, 3], 0.5, [5, 26, 30, 3]) + self._assertShapeInference([5, None, 60, 3], 0.5, [5, None, 30, 3]) + self._assertShapeInference([5, 50, None, 3], 0.5, [5, 26, None, 3]) + self._assertShapeInference([5, None, None, 3], 0.5, [5, None, None, 3]) + self._assertShapeInference([5, 50, 60, None], 0.5, [5, 26, 30, None]) + self._assertShapeInference([5, None, None, None], 0.5, + [5, None, None, None]) + self._assertShapeInference([None, None, None, None], 0.5, + [None, None, None, None]) + + def testErrorOnInvalidCentralCropFractionValues(self): x_shape = [13, 9, 3] x_np = np.ones(x_shape, dtype=np.float32) - with self.test_session(use_gpu=True): - x = constant_op.constant(x_np, shape=x_shape) - with self.assertRaises(ValueError): - _ = image_ops.central_crop(x, 0.0) - with self.assertRaises(ValueError): - _ = image_ops.central_crop(x, 1.01) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + x = constant_op.constant(x_np, shape=x_shape) + with self.assertRaises(ValueError): + _ = image_ops.central_crop(x, 0.0) + with self.assertRaises(ValueError): + _ = image_ops.central_crop(x, 1.01) + + def testErrorOnInvalidShapes(self): + x_shapes = [None, [], [3], [3, 9], [3, 9, 3, 9, 3]] + for x_shape in x_shapes: + x_np = np.ones(x_shape, dtype=np.float32) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + x = constant_op.constant(x_np, shape=x_shape) + with self.assertRaises(ValueError): + _ = image_ops.central_crop(x, 0.5) def testNameScope(self): x_shape = [13, 9, 3] x_np = np.ones(x_shape, dtype=np.float32) - with self.test_session(use_gpu=True): - y = image_ops.central_crop(x_np, 1.0) - self.assertTrue(y.op.name.startswith("central_crop")) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + y = image_ops.central_crop(x_np, 1.0) + self.assertTrue(y.op.name.startswith("central_crop")) class PadToBoundingBoxTest(test_util.TensorFlowTestCase): @@ -2457,6 +2599,182 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): y = image_ops.resize_images(single_image, [55, 66]) self.assertTrue(y.op.name.startswith("resize_images")) + def _ResizeImageCall(self, x, max_h, max_w, preserve_aspect_ratio, + use_tensor_inputs): + if use_tensor_inputs: + target_max = ops.convert_to_tensor([max_h, max_w]) + x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim) + feed_dict = {x_tensor: x} + else: + target_max = [max_h, max_w] + x_tensor = x + feed_dict = {} + + y = image_ops.resize_images(x_tensor, target_max, + preserve_aspect_ratio=preserve_aspect_ratio) + + with self.test_session(use_gpu=True): + return y.eval(feed_dict=feed_dict) + + def _assertResizeEqual(self, x, x_shape, y, y_shape, + preserve_aspect_ratio=True, + use_tensor_inputs_options=None): + use_tensor_inputs_options = use_tensor_inputs_options or [False, True] + target_height, target_width, _ = y_shape + x = np.array(x).reshape(x_shape) + y = np.array(y).reshape(y_shape) + + for use_tensor_inputs in use_tensor_inputs_options: + y_tf = self._ResizeImageCall(x, target_height, target_width, + preserve_aspect_ratio, use_tensor_inputs) + self.assertAllClose(y, y_tf) + + def _assertResizeCheckShape(self, x, x_shape, target_shape, + y_shape, preserve_aspect_ratio=True, + use_tensor_inputs_options=None): + use_tensor_inputs_options = use_tensor_inputs_options or [False, True] + target_height, target_width = target_shape + x = np.array(x).reshape(x_shape) + y = np.zeros(y_shape) + + for use_tensor_inputs in use_tensor_inputs_options: + y_tf = self._ResizeImageCall(x, target_height, target_width, + preserve_aspect_ratio, use_tensor_inputs) + self.assertShapeEqual(y, ops.convert_to_tensor(y_tf)) + + def testPreserveAspectRatioMultipleImages(self): + x_shape = [10, 100, 100, 10] + x = np.random.uniform(size=x_shape) + + self._assertResizeCheckShape(x, x_shape, [250, 250], [10, 250, 250, 10], + preserve_aspect_ratio=False) + + def testPreserveAspectRatioNoOp(self): + x_shape = [10, 10, 10] + x = np.random.uniform(size=x_shape) + + self._assertResizeEqual(x, x_shape, x, x_shape) + + def testPreserveAspectRatioSmaller(self): + x_shape = [100, 100, 10] + x = np.random.uniform(size=x_shape) + + self._assertResizeCheckShape(x, x_shape, [75, 50], [50, 50, 10]) + + def testPreserveAspectRatioSmallerMultipleImages(self): + x_shape = [10, 100, 100, 10] + x = np.random.uniform(size=x_shape) + + self._assertResizeCheckShape(x, x_shape, [75, 50], [10, 50, 50, 10]) + + def testPreserveAspectRatioLarger(self): + x_shape = [100, 100, 10] + x = np.random.uniform(size=x_shape) + + self._assertResizeCheckShape(x, x_shape, [150, 200], [150, 150, 10]) + + def testPreserveAspectRatioSameRatio(self): + x_shape = [1920, 1080, 3] + x = np.random.uniform(size=x_shape) + + self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3]) + + +class ResizeImageWithPadTest(test_util.TensorFlowTestCase): + + def _ResizeImageWithPad(self, x, target_height, target_width, + use_tensor_inputs): + if use_tensor_inputs: + target_height = ops.convert_to_tensor(target_height) + target_width = ops.convert_to_tensor(target_width) + x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim) + feed_dict = {x_tensor: x} + else: + x_tensor = x + feed_dict = {} + + y = image_ops.resize_image_with_pad(x_tensor, target_height, + target_width) + if not use_tensor_inputs: + self.assertTrue(y.get_shape().is_fully_defined()) + + with self.test_session(use_gpu=True): + return y.eval(feed_dict=feed_dict) + + def _assertReturns(self, + x, + x_shape, + y, + y_shape, + use_tensor_inputs_options=None): + use_tensor_inputs_options = use_tensor_inputs_options or [False, True] + target_height, target_width, _ = y_shape + x = np.array(x).reshape(x_shape) + y = np.array(y).reshape(y_shape) + + for use_tensor_inputs in use_tensor_inputs_options: + y_tf = self._ResizeImageWithPad(x, target_height, target_width, + use_tensor_inputs) + self.assertAllClose(y, y_tf) + + def _assertRaises(self, + x, + x_shape, + target_height, + target_width, + err_msg, + use_tensor_inputs_options=None): + use_tensor_inputs_options = use_tensor_inputs_options or [False, True] + x = np.array(x).reshape(x_shape) + + for use_tensor_inputs in use_tensor_inputs_options: + try: + self._ResizeImageWithPad(x, target_height, target_width, + use_tensor_inputs) + except Exception as e: # pylint: disable=broad-except + if err_msg not in str(e): + raise + else: + raise AssertionError("Exception not raised: %s" % err_msg) + + def _assertShapeInference(self, pre_shape, height, width, post_shape): + image = array_ops.placeholder(dtypes.float32, shape=pre_shape) + y = image_ops.resize_image_with_pad(image, height, width) + self.assertEqual(y.get_shape().as_list(), post_shape) + + def testNoOp(self): + x_shape = [10, 10, 10] + x = np.random.uniform(size=x_shape) + + self._assertReturns(x, x_shape, x, x_shape) + + def testPad(self): + # Reduce vertical dimension + x = [1, 2, 3, 4, 5, 6, 7, 8] + x_shape = [2, 4, 1] + + y = [0, 1, 3, 0] + y_shape = [1, 4, 1] + + self._assertReturns(x, x_shape, y, y_shape) + + # Reduce horizontal dimension + x = [1, 2, 3, 4, 5, 6, 7, 8] + x_shape = [2, 4, 1] + + y = [1, 3, 0, 0] + y_shape = [2, 2, 1] + + self._assertReturns(x, x_shape, y, y_shape) + + x = [1, 2, 3, 4, 5, 6, 7, 8] + x_shape = [2, 4, 1] + + y = [1, 3] + y_shape = [1, 2, 1] + + self._assertReturns(x, x_shape, y, y_shape) + class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase): @@ -3746,5 +4064,88 @@ class SobelEdgesTest(test_util.TensorFlowTestCase): self.assertAllClose(expected_batch, actual_sobel) +class DecodeImageTest(test_util.TensorFlowTestCase): + + def testJpegUint16(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/jpeg/testdata" + jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg")) + image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0), + dtypes.uint16) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testPngUint16(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/png/testdata" + png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png")) + image0 = image_ops.decode_image(png0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype( + image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testGifUint16(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/gif/testdata" + gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) + image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0), + dtypes.uint16) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testBmpUint16(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/bmp/testdata" + bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp")) + image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0), + dtypes.uint16) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testJpegFloat32(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/jpeg/testdata" + jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg")) + image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0), + dtypes.float32) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testPngFloat32(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/png/testdata" + png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png")) + image0 = image_ops.decode_image(png0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype( + image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testGifFloat32(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/gif/testdata" + gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) + image0 = image_ops.decode_image(gif0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0), + dtypes.float32) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + def testBmpFloat32(self): + with self.test_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/bmp/testdata" + bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp")) + image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0), + dtypes.float32) + image0, image1 = sess.run([image0, image1]) + self.assertAllEqual(image0, image1) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 1f8d8dc4f3e7b84cea9850f5da08d8c5a189e096..5bfc5ce2a7a1913b097ee67d1b18d684b5ebcaa5 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -43,7 +43,8 @@ from tensorflow.python.ops import linalg_ops_impl from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import ( + deprecated, deprecated_arg_values) from tensorflow.python.util.tf_export import tf_export @@ -86,7 +87,7 @@ class Initializer(object): @tf_export("keras.initializers.Zeros", "initializers.zeros", - "zeros_initializer") + "zeros_initializer", "keras.initializers.zeros") class Zeros(Initializer): """Initializer that generates tensors initialized to 0.""" @@ -102,7 +103,8 @@ class Zeros(Initializer): return {"dtype": self.dtype.name} -@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer") +@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer", + "keras.initializers.ones") class Ones(Initializer): """Initializer that generates tensors initialized to 1.""" @@ -119,7 +121,7 @@ class Ones(Initializer): @tf_export("keras.initializers.Constant", "initializers.constant", - "constant_initializer") + "constant_initializer", "keras.initializers.constant") class Constant(Initializer): """Initializer that generates tensors with constant values. @@ -225,7 +227,8 @@ class Constant(Initializer): @tf_export("keras.initializers.RandomUniform", "initializers.random_uniform", - "random_uniform_initializer") + "random_uniform_initializer", "keras.initializers.uniform", + "keras.initializers.random_uniform") class RandomUniform(Initializer): """Initializer that generates tensors with a uniform distribution. @@ -262,7 +265,8 @@ class RandomUniform(Initializer): @tf_export("keras.initializers.RandomNormal", "initializers.random_normal", - "random_normal_initializer") + "random_normal_initializer", "keras.initializers.normal", + "keras.initializers.random_normal") class RandomNormal(Initializer): """Initializer that generates tensors with a normal distribution. @@ -299,7 +303,8 @@ class RandomNormal(Initializer): @tf_export("keras.initializers.TruncatedNormal", - "initializers.truncated_normal", "truncated_normal_initializer") + "initializers.truncated_normal", "truncated_normal_initializer", + "keras.initializers.truncated_normal") class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. @@ -405,8 +410,10 @@ class UniformUnitScaling(Initializer): class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. - With `distribution="normal"`, samples are drawn from a truncated normal - distribution centered on zero, with `stddev = sqrt(scale / n)` + With `distribution="truncated_normal" or "untruncated_normal"`, + samples are drawn from a truncated/untruncated normal + distribution with a mean of zero and a standard deviation (after truncation, + if used) `stddev = sqrt(scale / n)` where n is: - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" @@ -429,10 +436,14 @@ class VarianceScaling(Initializer): "distribution" arguments. """ + @deprecated_arg_values( + None, + "`normal` is a deprecated alias for `truncated_normal`", + distribution="normal") def __init__(self, scale=1.0, mode="fan_in", - distribution="normal", + distribution="truncated_normal", seed=None, dtype=dtypes.float32): if scale <= 0.: @@ -440,7 +451,8 @@ class VarianceScaling(Initializer): if mode not in {"fan_in", "fan_out", "fan_avg"}: raise ValueError("Invalid `mode` argument:", mode) distribution = distribution.lower() - if distribution not in {"normal", "uniform"}: + if distribution not in {"normal", "uniform", + "truncated_normal", "untruncated_normal"}: raise ValueError("Invalid `distribution` argument:", distribution) self.scale = scale self.mode = mode @@ -462,10 +474,15 @@ class VarianceScaling(Initializer): scale /= max(1., fan_out) else: scale /= max(1., (fan_in + fan_out) / 2.) - if self.distribution == "normal": - stddev = math.sqrt(scale) + if self.distribution == "normal" or self.distribution == "truncated_normal": + # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + stddev = math.sqrt(scale) / .87962566103423978 return random_ops.truncated_normal( shape, 0.0, stddev, dtype, seed=self.seed) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + return random_ops.random_normal( + shape, 0.0, stddev, dtype, seed=self.seed) else: limit = math.sqrt(3.0 * scale) return random_ops.random_uniform( @@ -482,7 +499,7 @@ class VarianceScaling(Initializer): @tf_export("keras.initializers.Orthogonal", "initializers.orthogonal", - "orthogonal_initializer") + "orthogonal_initializer", "keras.initializers.orthogonal") class Orthogonal(Initializer): """Initializer that generates an orthogonal matrix. @@ -546,7 +563,9 @@ class ConvolutionDeltaOrthogonal(Initializer): The shape of the tensor must have length 3, 4 or 5. The number of input filters must not exceed the number of output filters. The center pixels of the - tensor form an orthogonal matrix. Other pixels are set to be zero. + tensor form an orthogonal matrix. Other pixels are set to be zero. See + algorithm 2 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 + Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -667,6 +686,7 @@ class ConvolutionOrthogonal2D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -802,6 +822,7 @@ class ConvolutionOrthogonal1D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -918,6 +939,7 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 [Xiao et al., 2018] in: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -1062,7 +1084,8 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal): return self._dict_to_tensor(p, ksize, ksize, ksize) -@tf_export("keras.initializers.Identity", "initializers.identity") +@tf_export("keras.initializers.Identity", "initializers.identity", + "keras.initializers.identity") class Identity(Initializer): """Initializer that generates the identity matrix. diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 8cfe964b1c0a572f43a14c66885e74ea105b0916..20c46fbb82b0671c6cc586eafdd7fa346d8b4e6d 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -42,7 +42,7 @@ __all__ = ["LinearOperator"] class LinearOperator(object): """Base class defining a [batch of] linear operator[s]. - Subclasses of `LinearOperator` provide a access to common methods on a + Subclasses of `LinearOperator` provide access to common methods on a (batch) matrix, without the need to materialize the matrix. This allows: * Matrix free computations @@ -69,11 +69,11 @@ class LinearOperator(object): #### Shape compatibility - `LinearOperator` sub classes should operate on a [batch] matrix with + `LinearOperator` subclasses should operate on a [batch] matrix with compatible shape. Class docstrings should define what is meant by compatible - shape. Some sub-classes may not support batching. + shape. Some subclasses may not support batching. - An example is: + Examples: `x` is a batch matrix with compatible shape for `matmul` if diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 1b5bb9470c4406ad075f2f6d5c38661311472727..78c85db557047ebcc3dd655deae62acbcef929c7 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -102,7 +102,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): raise NotImplementedError("operator_build_infos has not been implemented.") @abc.abstractmethod - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): """Build a batch matrix and an Operator that should have similar behavior. Every operator acts like a (batch) matrix. This method returns both @@ -118,9 +118,6 @@ class LinearOperatorDerivedClassTest(test.TestCase): Returns: operator: `LinearOperator` subclass instance. mat: `Tensor` representing operator. - feed_dict: Dictionary. - If placholder is True, this must contains everything needed to be fed - to sess.run calls at runtime to make the operator work. """ # Create a matrix as a numpy array with desired shape/dtype. # Create a LinearOperator that should have the same behavior as the matrix. @@ -189,12 +186,12 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_dense = operator.to_dense() if not use_placeholder: self.assertAllEqual(build_info.shape, op_dense.get_shape()) - op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict) + op_dense_v, mat_v = sess.run([op_dense, mat]) self.assertAC(op_dense_v, mat_v) def test_det(self): @@ -204,14 +201,13 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_det = operator.determinant() if not use_placeholder: self.assertAllEqual(build_info.shape[:-2], op_det.get_shape()) op_det_v, mat_det_v = sess.run( - [op_det, linalg_ops.matrix_determinant(mat)], - feed_dict=feed_dict) + [op_det, linalg_ops.matrix_determinant(mat)]) self.assertAC(op_det_v, mat_det_v) def test_log_abs_det(self): @@ -221,7 +217,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_log_abs_det = operator.log_abs_determinant() _, mat_log_abs_det = linalg.slogdet(mat) @@ -229,7 +225,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): self.assertAllEqual( build_info.shape[:-2], op_log_abs_det.get_shape()) op_log_abs_det_v, mat_log_abs_det_v = sess.run( - [op_log_abs_det, mat_log_abs_det], feed_dict=feed_dict) + [op_log_abs_det, mat_log_abs_det]) self.assertAC(op_log_abs_det_v, mat_log_abs_det_v) def _test_matmul(self, with_batch): @@ -246,7 +242,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) x = self._make_x( operator, adjoint=adjoint, with_batch=with_batch) @@ -264,7 +260,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( - [op_matmul, mat_matmul], feed_dict=feed_dict) + [op_matmul, mat_matmul]) self.assertAC(op_matmul_v, mat_matmul_v) def test_matmul(self): @@ -289,7 +285,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs( operator, adjoint=adjoint, with_batch=with_batch) @@ -307,8 +303,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) - op_solve_v, mat_solve_v = sess.run( - [op_solve, mat_solve], feed_dict=feed_dict) + op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) self.assertAC(op_solve_v, mat_solve_v) def test_solve(self): @@ -326,14 +321,13 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) - op_trace_v, mat_trace_v = sess.run( - [op_trace, mat_trace], feed_dict=feed_dict) + op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace]) self.assertAC(op_trace_v, mat_trace_v) def test_add_to_tensor(self): @@ -343,15 +337,14 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_plus_2mat = operator.add_to_tensor(2 * mat) if not use_placeholder: self.assertAllEqual(build_info.shape, op_plus_2mat.get_shape()) - op_plus_2mat_v, mat_v = sess.run( - [op_plus_2mat, mat], feed_dict=feed_dict) + op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat]) self.assertAC(op_plus_2mat_v, 3 * mat_v) @@ -362,7 +355,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_diag_part = operator.diag_part() mat_diag_part = array_ops.matrix_diag_part(mat) @@ -372,7 +365,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): op_diag_part.get_shape()) op_diag_part_, mat_diag_part_ = sess.run( - [op_diag_part, mat_diag_part], feed_dict=feed_dict) + [op_diag_part, mat_diag_part]) self.assertAC(op_diag_part_, mat_diag_part_) diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 3cbbf3412a2a1bd974354a5819d410b4074ab47d..b6b98d5c86fd3285b35377c9158dcdb649b88a83 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -55,6 +55,17 @@ def _MatrixDeterminantGrad(op, grad): return multipliers * a_adj_inv +@ops.RegisterGradient("LogMatrixDeterminant") +def _LogMatrixDeterminantGrad(op, _, grad_b): + """Gradient for LogMatrixDeterminant.""" + a = op.inputs[0] + c = op.outputs[1] + a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) + multipliers = array_ops.reshape( + grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) + return multipliers * a_adj_inv + + @ops.RegisterGradient("Cholesky") def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 0e547689cc51857adb77791bfb94c2527cdffef2..fb51fbc6264d3b797d134005cbf1e700d0a9990c 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -366,6 +366,10 @@ class KeyValueTensorInitializer(TableInitializerBase): with ops.name_scope( self._name, values=(table.table_ref, self._keys, self._values)) as scope: + if context.executing_eagerly(): + # Ensure a unique name when eager execution is enabled to avoid spurious + # sharing issues. + scope += str(ops.uid()) init_op = gen_lookup_ops.initialize_table_v2( table.table_ref, self._keys, self._values, name=scope) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) @@ -1108,6 +1112,10 @@ def index_table_from_tensor(vocabulary_list, shared_name = "" with ops.name_scope(None, "hash_table") as hash_table_scope: + if context.executing_eagerly(): + # Ensure a unique name when eager execution is enabled to avoid spurious + # sharing issues. + shared_name += str(ops.uid()) table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys init = KeyValueTensorInitializer( table_keys, diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 9fc545c9678e7eb33a7ad35e2a84f890885e09af..66633c8b12f60c86760f906aa8e4312c7394e796 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -192,6 +192,11 @@ def compute_weighted_loss( on some model parameters but you do not want this to affect the loss gradient, you need to apply @{tf.stop_gradient} to `weights` before passing them to `compute_weighted_loss`. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): @@ -260,6 +265,11 @@ def absolute_difference( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid or if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -306,6 +316,11 @@ def cosine_distance( Raises: ValueError: If `predictions` shape doesn't match `labels` shape, or `axis`, `labels`, `predictions` or `weights` is `None`. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) if axis is None: @@ -334,8 +349,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, Args: labels: The ground truth output tensor. Its shape should match the shape of - logits. The values of the tensor are expected to be 0.0 or 1.0. - logits: The logits, a float tensor. + logits. The values of the tensor are expected to be 0.0 or 1.0. Internally + the {0,1} labels are converted to {-1,1} when calculating the hinge loss. + logits: The logits, a float tensor. Note that logits are assumed to be + unbounded and 0-centered. A value > 0 (resp. < 0) is considered a positive + (resp. negative) binary prediction. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `losses` dimension). @@ -350,6 +368,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, Raises: ValueError: If the shapes of `logits` and `labels` don't match or if `labels` or `logits` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -413,6 +436,11 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -474,6 +502,11 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -537,6 +570,11 @@ def mean_pairwise_squared_error( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -615,6 +653,11 @@ def mean_squared_error( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -667,6 +710,11 @@ def sigmoid_cross_entropy( ValueError: If the shape of `logits` doesn't match that of `multi_class_labels` or if the shape of `weights` is invalid, or if `weights` is None. Also if `multi_class_labels` or `logits` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if multi_class_labels is None: raise ValueError("multi_class_labels must not be None.") @@ -728,6 +776,11 @@ def softmax_cross_entropy( ValueError: If the shape of `logits` doesn't match that of `onehot_labels` or if the shape of `weights` is invalid or if `weights` is None. Also if `onehot_labels` or `logits` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if onehot_labels is None: raise ValueError("onehot_labels must not be None.") @@ -825,7 +878,8 @@ def sparse_softmax_cross_entropy( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or + `float64`. weights: Coefficients for the loss. This must be scalar or broadcastable to `labels` (i.e. same rank and each dimension is either 1 or the same). scope: the scope for the operations performed in computing the loss. @@ -839,6 +893,11 @@ def sparse_softmax_cross_entropy( Raises: ValueError: If the shapes of `logits`, `labels`, and `weights` are incompatible, or if any of them are None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py index 10646af8a983f149cf0620bf355cf0bc1fa697fb..97bba46661d056fd336c68988e3bc17ef4232487 100644 --- a/tensorflow/python/ops/losses/util.py +++ b/tensorflow/python/ops/losses/util.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -32,7 +33,10 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): loss: A loss `Tensor`. loss_collection: Optional collection to add the loss to. """ - if loss_collection: + # Since we have no way of figuring out when a training iteration starts or + # ends, holding on to a loss when executing eagerly is indistingishable from + # leaking memory. We instead leave the collection empty. + if loss_collection and not context.executing_eagerly(): ops.add_to_collection(loss_collection, loss) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 563c0b3ab3f6316b89f5ea76f5d075d9f4b77eea..f0c6bd532fcdb76922ce4d5aa7fa13936db81b2f 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -620,29 +620,59 @@ def _DigammaGrad(op, grad): return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) +@ops.RegisterGradient("BesselI0e") +def _BesselI0eGrad(op, grad): + """Compute gradient of bessel_i0e(x) with respect to its argument.""" + x = op.inputs[0] + y = op.outputs[0] + with ops.control_dependencies([grad]): + return grad * (math_ops.bessel_i1e(x) - math_ops.sign(x) * y) + + +@ops.RegisterGradient("BesselI1e") +def _BesselI1eGrad(op, grad): + """Compute gradient of bessel_i1e(x) with respect to its argument.""" + x = op.inputs[0] + y = op.outputs[0] + with ops.control_dependencies([grad]): + # For x = 0, the correct gradient is 0.5. + # However, the main branch gives NaN because of the division by x, so + # we impute the gradient manually. + # An alternative solution is to express the gradient via bessel_i0e and + # bessel_i2e, but the latter is not yet implemented in Eigen. + eps = np.finfo(x.dtype.as_numpy_dtype).eps + zeros = array_ops.zeros_like(x) + x_is_not_tiny = math_ops.abs(x) > eps + safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros) + dy_dx = math_ops.bessel_i0e(safe_x) - y * ( + math_ops.sign(safe_x) + math_ops.reciprocal(safe_x)) + return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros) + + @ops.RegisterGradient("Igamma") def _IgammaGrad(op, grad): - """Returns gradient of igamma(a, x) with respect to x.""" - # TODO(ebrevdo): Perhaps add the derivative w.r.t. a + """Returns gradient of igamma(a, x) with respect to a and x.""" a = op.inputs[0] x = op.inputs[1] sa = array_ops.shape(a) sx = array_ops.shape(x) - unused_ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) + ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) - # Perform operations in log space before summing, because Gamma(a) - # and Gamma'(a) can grow large. - partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) - # TODO(b/36815900): Mark None return values as NotImplemented - return (None, array_ops.reshape( - math_ops.reduce_sum(partial_x * grad, rx), sx)) + with ops.control_dependencies([grad]): + partial_a = gen_math_ops.igamma_grad_a(a, x) + # Perform operations in log space before summing, because Gamma(a) + # and Gamma'(a) can grow large. + partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) + - math_ops.lgamma(a)) + return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Igammac") def _IgammacGrad(op, grad): - """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x.""" - _, igamma_grad_x = _IgammaGrad(op, grad) - return None, -igamma_grad_x + """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" + igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) + return (-igamma_grad_a, -igamma_grad_x) @ops.RegisterGradient("Betainc") diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 3a31ef7f8814906883665a55f7b90710fd0baf2f..cdb6dc8f22919420ff44e217578315d17cb93d8c 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -125,8 +125,8 @@ def abs(x, name=None): # pylint: disable=redefined-builtin ``` Args: - x: A `Tensor` or `SparseTensor` of type `float32`, `float64`, `int32`, - `int64`, `complex64` or `complex128`. + x: A `Tensor` or `SparseTensor` of type `float16`, `float32`, `float64`, + `int32`, `int64`, `complex64` or `complex128`. name: A name for the operation (optional). Returns: @@ -370,7 +370,7 @@ def erf(x, name=None): """Computes the Gauss error function of `x` element-wise. Args: - x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`, + x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, `float32`, `float64`. name: A name for the operation (optional). @@ -430,10 +430,10 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin ``` Args: - x: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`, - or `complex128`. - y: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`, - or `complex128`. + x: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, `int64`, + `complex64`, or `complex128`. + y: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, `int64`, + `complex64`, or `complex128`. name: A name for the operation (optional). Returns: @@ -600,7 +600,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin ``` Args: - x: A `Tensor` of type `float32` or `float64`. + x: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, or `int64`. name: A name for the operation (optional). Returns: @@ -1257,7 +1257,7 @@ def reduce_sum(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. For example: @@ -1397,7 +1397,7 @@ def reduce_mean(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. For example: @@ -1469,7 +1469,7 @@ def reduce_prod(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. Args: @@ -1519,7 +1519,7 @@ def reduce_min(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. Args: @@ -1568,7 +1568,7 @@ def reduce_max(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. Args: @@ -1617,7 +1617,7 @@ def reduce_all(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. For example: @@ -1675,7 +1675,7 @@ def reduce_any(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. For example: @@ -1990,7 +1990,7 @@ def matmul(a, sparse_matmul_types = [dtypes.bfloat16, dtypes.float32] use_sparse_matmul = ( a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types) - if (a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16 and + if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and a.dtype != b.dtype): # matmul currently doesn't handle mixed-precision inputs. use_sparse_matmul = True @@ -2225,8 +2225,8 @@ def sigmoid(x, name=None): Returns: A Tensor with the same type as `x`. - @compatibility(numpy) - Equivalent to np.scipy.special.expit + @compatibility(scipy) + Equivalent to scipy.special.expit @end_compatibility """ with ops.name_scope(name, "Sigmoid", [x]) as name: @@ -2954,6 +2954,67 @@ def polyval(coeffs, x, name=None): p = c + p * x return p + +@tf_export("math.bessel_i0e") +def bessel_i0e(x, name=None): + """Computes the Bessel i0e function of `x` element-wise. + + Exponentially scaled modified Bessel function of order 0 defined as + `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. + + This function is faster and numerically stabler than `bessel_i0(x)`. + + Args: + x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + + @compatibility(scipy) + Equivalent to scipy.special.i0e + @end_compatibility + """ + with ops.name_scope(name, "bessel_i0e", [x]) as name: + if isinstance(x, sparse_tensor.SparseTensor): + x_i0e = gen_math_ops.bessel_i0e(x.values, name=name) + return sparse_tensor.SparseTensor( + indices=x.indices, values=x_i0e, dense_shape=x.dense_shape) + else: + return gen_math_ops.bessel_i0e(x, name=name) + + +@tf_export("math.bessel_i1e") +def bessel_i1e(x, name=None): + """Computes the Bessel i1e function of `x` element-wise. + + Exponentially scaled modified Bessel function of order 1 defined as + `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. + + This function is faster and numerically stabler than `bessel_i1(x)`. + + Args: + x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + + @compatibility(scipy) + Equivalent to scipy.special.i1e + @end_compatibility + """ + with ops.name_scope(name, "bessel_i1e", [x]) as name: + if isinstance(x, sparse_tensor.SparseTensor): + x_i1e = gen_math_ops.bessel_i1e(x.values, name=name) + return sparse_tensor.SparseTensor( + indices=x.indices, values=x_i1e, dense_shape=x.dense_shape) + else: + return gen_math_ops.bessel_i1e(x, name=name) + + # FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow # 1.0 API so we leave these here for backwards compatibility. fft = gen_spectral_ops.fft diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 980c92b0d592bccc34e1fbee636ebdd39056f2fc..6b709e5e7faf0a74f966f446ba9d33ee1087908a 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -37,14 +37,14 @@ log = np.log class ReduceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceAllDims(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): y_tf = self.evaluate(math_ops.reduce_sum(x)) self.assertEqual(y_tf, 21) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceExplicitAxes(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): @@ -57,7 +57,7 @@ class ReduceTest(test_util.TensorFlowTestCase): for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)): self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceInvalidAxis(self): if context.executing_eagerly(): # The shape check is in run a graph construction time. In eager mode, @@ -150,7 +150,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase): class RoundTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRounding(self): x = np.arange(-5.0, 5.0, .25) for dtype in [np.float32, np.double, np.int32]: @@ -194,7 +194,7 @@ class ModTest(test_util.TensorFlowTestCase): class SquaredDifferenceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSquaredDifference(self): for dtype in [np.int32, np.float16]: x = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) @@ -207,7 +207,7 @@ class SquaredDifferenceTest(test_util.TensorFlowTestCase): class ApproximateEqualTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testApproximateEqual(self): for dtype in [np.float32, np.double]: x = dtype(1) @@ -235,10 +235,19 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001)) self.assertAllEqual(z, z_tf) + def testApproximateEqualShape(self): + for dtype in [np.float32, np.double]: + x = np.array([1, 2], dtype=dtype) + y = np.array([[1, 2]], dtype=dtype) + # The inputs 'x' and 'y' must have the same shape. + with self.assertRaisesRegexp( + ValueError, "Shapes must be equal rank, but are 1 and 2"): + math_ops.approximate_equal(x, y) + class ScalarMulTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsRefs(self): if context.executing_eagerly(): var = resource_variable_ops.ResourceVariable(10, name="var") @@ -250,14 +259,14 @@ class ScalarMulTest(test_util.TensorFlowTestCase): self.evaluate(init) self.assertEqual(30, self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsConstant(self): const = constant_op.constant(10) result = math_ops.scalar_mul(3, const) with test_util.device(use_gpu=True): self.assertEqual(30, self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsTensor(self): tensor = array_ops.ones([10, 10]) result = math_ops.scalar_mul(3, tensor) @@ -266,7 +275,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase): with test_util.device(use_gpu=True): self.assertAllEqual(self.evaluate(expected), self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsIndexedSlices(self): values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2]) indices = constant_op.constant([0, 2, 5]) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 47eea6ef6b58abd4819544e29783048964104922..bfd225b0d837783fc854835f862fb4a12550fffc 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,21 +34,55 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export def metric_variable(shape, dtype, validate_shape=True, name=None): - """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" - - return variable_scope.variable( - lambda: array_ops.zeros(shape, dtype), - trainable=False, - collections=[ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ], - validate_shape=validate_shape, - name=name) + """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. + + If running in a `DistributionStrategy` context, the variable will be + "tower local". This means: + + * The returned object will be a container with separate variables + per replica/tower of the model. + + * When writing to the variable, e.g. using `assign_add` in a metric + update, the update will be applied to the variable local to the + replica/tower. + + * To get a metric's result value, we need to sum the variable values + across the replicas/towers before computing the final answer. + Furthermore, the final answer should be computed once instead of + in every replica/tower. Both of these are accomplished by + running the computation of the final result value inside + `tf.contrib.distribute.get_tower_context().merge_call(fn)`. + Inside the `merge_call()`, ops are only added to the graph once + and access to a tower-local variable in a computation returns + the sum across all replicas/towers. + + Args: + shape: Shape of the created variable. + dtype: Type of the created variable. + validate_shape: (Optional) Whether shape validation is enabled for + the created variable. + name: (Optional) String name of the created variable. + + Returns: + A (non-trainable) variable initialized to zero, or if inside a + `DistributionStrategy` scope a tower-local variable container. + """ + with distribute_lib.get_tower_context().tower_local_var_scope( + variable_scope.VariableAggregation.SUM): + # Note that "tower local" implies trainable=False. + return variable_scope.variable( + lambda: array_ops.zeros(shape, dtype), + collections=[ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ], + validate_shape=validate_shape, + name=name) def _remove_squeezable_dimensions(predictions, labels, weights): @@ -333,11 +367,15 @@ def mean(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - mean_t = _safe_div(total, count, 'value') - update_op = _safe_div(update_total_op, update_count_op, 'update_op') + def aggregate_across_towers(_, t, c): + mean_t = _safe_div(t, c, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_t) + return mean_t - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) + mean_t = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, total, count) + update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -572,6 +610,17 @@ def _confusion_matrix_at_thresholds(labels, return values, update_ops +def _aggregate_variable(v, collections): + + def f(distribution, value): + value = distribution.read_var(value) + if collections: + ops.add_to_collections(collections, value) + return value + + return distribute_lib.get_tower_context().merge_call(f, v) + + @tf_export('metrics.auc') def auc(labels, predictions, @@ -757,14 +806,18 @@ def auc(labels, raise ValueError('Invalid summation_method: %s' % summation_method) # sum up the areas of all the trapeziums - auc_value = compute_auc(values['tp'], values['fn'], values['tn'], - values['fp'], 'value') + def aggregate_auc(_, values): + auc_value = compute_auc(values['tp'], values['fn'], values['tn'], + values['fp'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, auc_value) + return auc_value + + auc_value = distribute_lib.get_tower_context().merge_call( + aggregate_auc, values) update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') - if metrics_collections: - ops.add_to_collections(metrics_collections, auc_value) - if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -992,15 +1045,18 @@ def mean_per_class_accuracy(labels, update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) - per_class_accuracy = _safe_div(count, total, None) + def aggregate_mean_accuracy(_, count, total): + per_class_accuracy = _safe_div(count, total, None) + mean_accuracy_v = math_ops.reduce_mean( + per_class_accuracy, name='mean_accuracy') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_accuracy_v) + return mean_accuracy_v - mean_accuracy_v = math_ops.reduce_mean( - per_class_accuracy, name='mean_accuracy') - update_op = _safe_div(update_count_op, update_total_op, name='update_op') - - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_accuracy_v) + mean_accuracy_v = distribute_lib.get_tower_context().merge_call( + aggregate_mean_accuracy, count, total) + update_op = _safe_div(update_count_op, update_total_op, name='update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1071,7 +1127,7 @@ def mean_iou(labels, total_cm, update_op = _streaming_confusion_matrix(labels, predictions, num_classes, weights) - def compute_mean_iou(name): + def compute_mean_iou(total_cm, name): """Compute the mean intersection-over-union via the confusion matrix.""" sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) @@ -1098,10 +1154,14 @@ def mean_iou(labels, math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0) return result - mean_iou_v = compute_mean_iou('mean_iou') + def mean_iou_across_towers(_, v): + mean_iou_v = compute_mean_iou(v, 'mean_iou') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_iou_v) + return mean_iou_v - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_iou_v) + mean_iou_v = distribute_lib.get_tower_context().merge_call( + mean_iou_across_towers, total_cm) if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1310,12 +1370,16 @@ def mean_tensor(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - mean_t = _safe_div(total, count, 'value') - update_op = _safe_div(update_total_op, update_count_op, 'update_op') + def aggregate_across_towers(_, t, c): + mean_t = _safe_div(t, c, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_t) + return mean_t - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) + mean_t = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, total, count) + update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1413,12 +1477,9 @@ def _count_condition(values, weights = math_ops.to_float(weights) values = math_ops.multiply(values, weights) - value_tensor = array_ops.identity(count) - update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) - - if metrics_collections: - ops.add_to_collections(metrics_collections, value_tensor) + value_tensor = _aggregate_variable(count, metrics_collections) + update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1525,13 +1586,12 @@ def false_negatives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('fn',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['fn']) + fn_value = _aggregate_variable(values['fn'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['fn']) - return values['fn'], update_ops['fn'] + return fn_value, update_ops['fn'] @tf_export('metrics.false_positives') @@ -1635,13 +1695,12 @@ def false_positives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('fp',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['fp']) + fp_value = _aggregate_variable(values['fp'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['fp']) - return values['fp'], update_ops['fp'] + return fp_value, update_ops['fp'] @tf_export('metrics.true_negatives') @@ -1745,13 +1804,12 @@ def true_negatives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('tn',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['tn']) + tn_value = _aggregate_variable(values['tn'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['tn']) - return values['tn'], update_ops['tn'] + return tn_value, update_ops['tn'] @tf_export('metrics.true_positives') @@ -1855,13 +1913,12 @@ def true_positives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('tp',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['tp']) + tp_value = _aggregate_variable(values['tp'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['tp']) - return values['tp'], update_ops['tp'] + return tp_value, update_ops['tp'] @tf_export('metrics.precision') @@ -1945,13 +2002,17 @@ def precision(labels, return array_ops.where( math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) - p = compute_precision(true_p, false_p, 'value') - update_op = compute_precision(true_positives_update_op, - false_positives_update_op, 'update_op') + def once_across_towers(_, true_p, false_p): + p = compute_precision(true_p, false_p, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, p) + return p - if metrics_collections: - ops.add_to_collections(metrics_collections, p) + p = distribute_lib.get_tower_context().merge_call( + once_across_towers, true_p, false_p) + update_op = compute_precision(true_positives_update_op, + false_positives_update_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2025,13 +2086,17 @@ def precision_at_thresholds(labels, def compute_precision(tp, fp, name): return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) - prec = compute_precision(values['tp'], values['fp'], 'value') - update_op = compute_precision(update_ops['tp'], update_ops['fp'], - 'update_op') + def precision_across_towers(_, values): + prec = compute_precision(values['tp'], values['fp'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, prec) + return prec - if metrics_collections: - ops.add_to_collections(metrics_collections, prec) + prec = distribute_lib.get_tower_context().merge_call( + precision_across_towers, values) + update_op = compute_precision(update_ops['tp'], update_ops['fp'], + 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2050,7 +2115,7 @@ def recall(labels, The `recall` function creates two local variables, `true_positives` and `false_negatives`, that are used to compute the recall. This value is ultimately returned as `recall`, an idempotent operation that simply divides - `true_positives` by the sum of `true_positives` and `false_negatives`. + `true_positives` by the sum of `true_positives` and `false_negatives`. For estimation of the metric over a stream of data, the function creates an `update_op` that updates these variables and returns the `recall`. `update_op` @@ -2117,13 +2182,17 @@ def recall(labels, math_ops.greater(true_p + false_n, 0), math_ops.div(true_p, true_p + false_n), 0, name) - rec = compute_recall(true_p, false_n, 'value') - update_op = compute_recall(true_positives_update_op, - false_negatives_update_op, 'update_op') + def once_across_towers(_, true_p, false_n): + rec = compute_recall(true_p, false_n, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, rec) + return rec - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) + rec = distribute_lib.get_tower_context().merge_call( + once_across_towers, true_p, false_n) + update_op = compute_recall(true_positives_update_op, + false_negatives_update_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2552,11 +2621,17 @@ def recall_at_top_k(labels, class_id=class_id, weights=weights) - metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) + def aggregate_across_towers(_, tp, fn): + metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) + if metrics_collections: + ops.add_to_collections(metrics_collections, metric) + return metric + + metric = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, tp, fn) + update = math_ops.div( tp_update, math_ops.add(tp_update, fn_update), name='update') - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) if updates_collections: ops.add_to_collections(updates_collections, update) return metric, update @@ -2627,12 +2702,16 @@ def recall_at_thresholds(labels, def compute_recall(tp, fn, name): return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) - rec = compute_recall(values['tp'], values['fn'], 'value') - update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') + def recall_across_towers(_, values): + rec = compute_recall(values['tp'], values['fn'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, rec) + return rec - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) + rec = distribute_lib.get_tower_context().merge_call( + recall_across_towers, values) + update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2698,13 +2777,16 @@ def root_mean_squared_error(labels, mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, None, name or 'root_mean_squared_error') + def once_across_towers(_, mse): + rmse = math_ops.sqrt(mse) + if metrics_collections: + ops.add_to_collections(metrics_collections, rmse) + return rmse - rmse = math_ops.sqrt(mse) - update_rmse_op = math_ops.sqrt(update_mse_op) - - if metrics_collections: - ops.add_to_collections(metrics_collections, rmse) + rmse = distribute_lib.get_tower_context().merge_call( + once_across_towers, mse) + update_rmse_op = math_ops.sqrt(update_mse_op) if updates_collections: ops.add_to_collections(updates_collections, update_rmse_op) @@ -2797,15 +2879,19 @@ def sensitivity_at_specificity(labels, return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, name) - sensitivity = compute_sensitivity_at_specificity( - values['tp'], values['tn'], values['fp'], values['fn'], 'value') + def aggregate_across_towers(_, values): + sensitivity = compute_sensitivity_at_specificity( + values['tp'], values['tn'], values['fp'], values['fn'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, sensitivity) + return sensitivity + + sensitivity = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, values) + update_op = compute_sensitivity_at_specificity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 'update_op') - - if metrics_collections: - ops.add_to_collections(metrics_collections, sensitivity) - if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -3070,11 +3156,16 @@ def _streaming_sparse_average_precision_at_top_k(labels, total_update = state_ops.assign_add(total_var, batch_total, name='update') # Divide total by max to get mean, for both vars and the update ops. - mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') - update = _safe_scalar_div(total_update, max_update, name=scope) + def aggregate_across_towers(_, total_var, max_var): + mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_average_precision) + return mean_average_precision - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_average_precision) + mean_average_precision = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, total_var, max_var) + + update = _safe_scalar_div(total_update, max_update, name=scope) if updates_collections: ops.add_to_collections(updates_collections, update) @@ -3351,11 +3442,17 @@ def precision_at_top_k(labels, class_id=class_id, weights=weights) - metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) + def aggregate_across_towers(_, tp, fp): + metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) + if metrics_collections: + ops.add_to_collections(metrics_collections, metric) + return metric + + metric = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, tp, fp) + update = math_ops.div( tp_update, math_ops.add(tp_update, fp_update), name='update') - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) if updates_collections: ops.add_to_collections(updates_collections, update) return metric, update @@ -3583,15 +3680,19 @@ def specificity_at_sensitivity(labels, return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, name) - specificity = compute_specificity_at_sensitivity( - values['tp'], values['tn'], values['fp'], values['fn'], 'value') + def aggregate_across_towers(_, values): + specificity = compute_specificity_at_sensitivity( + values['tp'], values['tn'], values['fp'], values['fn'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, specificity) + return specificity + + specificity = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, values) + update_op = compute_specificity_at_sensitivity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 'update_op') - - if metrics_collections: - ops.add_to_collections(metrics_collections, specificity) - if updates_collections: ops.add_to_collections(updates_collections, update_op) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 783d4858925d3e2b1ca210a8162a2b4df07d3089..f47f38e29e328ea92bfc494d60673c70a58274d3 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -621,7 +621,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. Args: - counts: A `Tensor` containing a the total count of the data (one value). + counts: A `Tensor` containing the total count of the data (one value). mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly shifted) sum of the elements to average over. variance_ss: A `Tensor` containing the variance sufficient statistics: the @@ -689,6 +689,9 @@ def moments( # Compute true mean while keeping the dims for proper broadcasting. mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean") # sample variance, not unbiased variance + # Note: stop_gradient does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero variance = math_ops.reduce_mean( math_ops.squared_difference(y, array_ops.stop_gradient(mean)), axes, diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 54b08a564b1d7da3d4e30d593c18ee3811cff3d4..41d54a6c2f9d8cd961cea398da679fd81361b848 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2009,7 +2009,8 @@ def sparse_softmax_cross_entropy_with_logits( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or + `float64`. name: A name for the operation (optional). Returns: @@ -2166,7 +2167,7 @@ def _calc_conv_flops(graph, node): filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) filter_in_depth = int(filter_shape[2]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats( "flops", (output_count * filter_in_depth * filter_height * filter_width * 2)) @@ -2184,7 +2185,7 @@ def _calc_depthwise_conv_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) @@ -2311,13 +2312,22 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) - keep_prob = ops.convert_to_tensor( - keep_prob, dtype=x.dtype, name="keep_prob") - keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) - # Do nothing if we know keep_prob == 1 - if tensor_util.constant_value(keep_prob) == 1: + # Early return if nothing needs to be dropped. + if isinstance(keep_prob, float) and keep_prob == 1: return x + if context.executing_eagerly(): + if isinstance(keep_prob, ops.EagerTensor): + if keep_prob.numpy() == 1: + return x + else: + keep_prob = ops.convert_to_tensor( + keep_prob, dtype=x.dtype, name="keep_prob") + keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + + # Do nothing if we know keep_prob == 1 + if tensor_util.constant_value(keep_prob) == 1: + return x noise_shape = _get_noise_shape(x, noise_shape) @@ -2585,7 +2595,7 @@ def _calc_dilation2d_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 035b4735affbd37f9de94057eed6f7b5d9aadd6e..ae24ca0552e7ba2823ec9404ecc848f510cce464 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -76,7 +76,7 @@ class SoftmaxTest(test_lib.TestCase): z = u.sum(1)[:, np.newaxis] return u / z - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSoftmax(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -123,7 +123,7 @@ class LogPoissonLossTest(test_lib.TestCase): lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) return lpl - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogPoissonLoss(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -164,7 +164,7 @@ class LogSoftmaxTest(test_lib.TestCase): u = x - m return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogSoftmax(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -201,7 +201,7 @@ class LogSoftmaxTest(test_lib.TestCase): class L2LossTest(test_lib.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2Loss(self): for dtype in [dtypes.float32, dtypes.float64]: x = constant_op.constant( @@ -235,7 +235,7 @@ class L2NormalizeTest(test_lib.TestCase): norm = np.apply_along_axis(np.linalg.norm, dim, x) return x / np.expand_dims(norm, dim) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2Normalize(self): x_shape = [20, 7, 3] np.random.seed(1) @@ -246,7 +246,7 @@ class L2NormalizeTest(test_lib.TestCase): y_tf = nn_impl.l2_normalize(x_tf, dim) self.assertAllClose(y_np, self.evaluate(y_tf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2NormalizeDimArray(self): x_shape = [20, 7, 3] np.random.seed(1) diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..065c2caedc9d334543512941f3513e45360b460f --- /dev/null +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -0,0 +1,129 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "parallel_for", + srcs = [ + "__init__.py", + "control_flow_ops.py", + "gradients.py", + "pfor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":gradients", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "@absl_py//absl/flags", + ], +) + +py_library( + name = "pfor_lib", + srcs = ["pfor.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "@absl_py//absl/flags", + ], +) + +py_library( + name = "control_flow_ops", + srcs = ["control_flow_ops.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pfor_lib", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "control_flow_ops_test", + srcs = ["control_flow_ops_test.py"], + additional_deps = [ + ":control_flow_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:gradients", + "//tensorflow/python:logging_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:tensor_array_grad", + "//tensorflow/python:random_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "gradients", + srcs = ["gradients.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "gradients_test", + size = "large", + srcs = ["gradients_test.py"], + additional_deps = [ + ":control_flow_ops", + ":gradients", + "//third_party/py/numpy", + "//tensorflow/python:layers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:random_ops", + "//tensorflow/python/ops/losses", + ], + tags = ["no_gpu"], # TODO(b/80127739): test is flaky +) diff --git a/tensorflow/python/ops/parallel_for/__init__.py b/tensorflow/python/ops/parallel_for/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b49d865968b0bab02380cb934431f4933590570e --- /dev/null +++ b/tensorflow/python/ops/parallel_for/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for pfor, for_loop, jacobian.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops.parallel_for import * # pylint: disable=wildcard-import +from tensorflow.python.ops.parallel_for.control_flow_ops import for_loop +from tensorflow.python.ops.parallel_for.control_flow_ops import pfor +from tensorflow.python.ops.parallel_for.gradients import batch_jacobian +from tensorflow.python.ops.parallel_for.gradients import jacobian +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'pfor', + 'for_loop', + 'jacobian', + 'batch_jacobian', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf2eb82146969532c84b7d56d40974e94337507 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""for_loop and pfor ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops.parallel_for.pfor import PFor +from tensorflow.python.util import nest + + +def for_loop(loop_fn, loop_fn_dtypes, iters): + """Runs `loop_fn` `iters` times and stacks the outputs. + + + Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and + stacks corresponding outputs of the different runs. + + Args: + loop_fn: A function that takes an int32 scalar tf.Tensor object representing + the iteration number, and returns a possibly nested structure of tensor + objects. The shape of these outputs should not depend on the input. + loop_fn_dtypes: dtypes for the outputs of loop_fn. + iters: Number of iterations for which to run loop_fn. + + Returns: + Returns a nested structure of stacked output tensor objects with the same + nested structure as the output of `loop_fn`. + """ + + flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) + + def while_body(i, *ta_list): + """Body of while loop.""" + fn_output = nest.flatten(loop_fn(i)) + if len(fn_output) != len(flat_loop_fn_dtypes): + raise ValueError( + "Number of expected outputs, %d, does not match the number of " + "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes), + len(fn_output))) + outputs = [] + for out, ta in zip(fn_output, ta_list): + # TODO(agarwal): support returning Operation objects from loop_fn. + assert isinstance(out, ops.Tensor) + outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) + return tuple([i + 1] + outputs) + + ta_list = control_flow_ops.while_loop( + lambda i, *ta: i < iters, while_body, [0] + [ + tensor_array_ops.TensorArray(dtype, iters) + for dtype in flat_loop_fn_dtypes + ])[1:] + + # TODO(rachelim): enable this for sparse tensors + return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list]) + + +def pfor(loop_fn, iters): + """Equivalent to running `loop_fn` `iters` times and stacking the outputs. + + `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` + times, with input from 0 to `iters - 1`, and stacking corresponding output of + each iteration. However the implementation does not use a tf.while_loop. + Instead it adds new operations to the graph that collectively compute the same + value as what running `loop_fn` in a loop would compute. + + + This is an experimental feature and currently has a lot of limitations: + - There should be no data depenendency between the different iterations. For + example, a future iteration should not depend on a value or side-effect of + a previous iteration. + - Stateful kernels may mostly not be supported since these often imply a + data dependency or ordering of the iterations. We do support a limited set + of such stateful kernels though (like RandomFoo, Variable operations like + reads, etc). + - Conversion works only on a limited set of kernels for which a converter + has been registered. + - loop_fn cannot currently contain control flow operations like + tf.while_loop or tf.cond. + - `loop_fn` should return nested structure of Tensors or Operations. However + if an Operation is returned, it should have zero outputs. + - The shape and dtype of `loop_fn` outputs should not depend on the input + to loop_fn. + + Args: + loop_fn: A function that takes an int32 scalar tf.Tensor object representing + the iteration number, and returns a possibly nested structure of Tensor or + Operation objects. + iters: Number of iterations for which to run loop_fn. + + Returns: + Returns a nested structure of stacked tensor objects with the same nested + structure as the output of `loop_fn`. + """ + existing_ops = set(ops.get_default_graph().get_operations()) + with ops.name_scope("loop_body"): + loop_var = array_ops.placeholder(dtypes.int32, shape=[]) + loop_fn_outputs = loop_fn(loop_var) + new_ops = set(ops.get_default_graph().get_operations()) - existing_ops + iters = ops.convert_to_tensor(iters) + with ops.name_scope("pfor"): + converter = PFor(loop_var, iters, new_ops) + outputs = [] + for loop_fn_output in nest.flatten(loop_fn_outputs): + outputs.append(converter.convert(loop_fn_output)) + return nest.pack_sequence_as(loop_fn_outputs, outputs) diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e66cb0b874b183d53cc34dbb3aa3d182e255a4 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -0,0 +1,1404 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for pfor and for_loop.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import numpy as np + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class PForTest(test.TestCase): + + def _run_targets(self, targets1, targets2=None, run_init=True): + targets1 = nest.flatten(targets1) + targets2 = ([] if targets2 is None else nest.flatten(targets2)) + assert len(targets1) == len(targets2) or not targets2 + if run_init: + init = variables.global_variables_initializer() + self.evaluate(init) + return self.evaluate(targets1 + targets2) + + def run_and_assert_equal(self, targets1, targets2): + outputs = self._run_targets(targets1, targets2) + outputs = nest.flatten(outputs) # flatten SparseTensorValues + n = len(outputs) // 2 + for i in range(n): + if outputs[i + n].dtype != np.object: + self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-5) + else: + self.assertAllEqual(outputs[i + n], outputs[i]) + + def _test_loop_fn(self, loop_fn, iters, loop_fn_dtypes=dtypes.float32): + t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters) + t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters) + self.run_and_assert_equal(t1, t2) + + def test_op_conversion_fallback_to_while_loop(self): + # Note that we used top_k op for this test. If a converter gets defined for + # it, we will need to find another op for which a converter has not been + # defined. + x = random_ops.random_uniform([3, 2, 4]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return nn.top_k(x_i) + + with self.assertRaisesRegexp(ValueError, "No converter defined"): + self._test_loop_fn( + loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32]) + flags.FLAGS.op_conversion_fallback_to_while_loop = True + self._test_loop_fn( + loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32]) + flags.FLAGS.op_conversion_fallback_to_while_loop = False + + +class ArrayTest(PForTest): + + def test_gather(self): + x = random_ops.random_uniform([3, 3, 3]) + + def loop_fn(i): + outputs = [] + x_i = array_ops.gather(x, i) + for y in [x, x_i]: + axes = [0, 2, -1] if y == x else [0] + for axis in axes: + outputs.append(array_ops.gather(y, 2, axis=axis)) + outputs.append(array_ops.gather(y, i, axis=axis)) + outputs.append(array_ops.gather(y, [i], axis=axis)) + outputs.append(array_ops.gather(y, [i, 2], axis=axis)) + outputs.append(array_ops.gather(y, [[2, i], [i, 1]], axis=axis)) + return outputs + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 20) + + def test_shape(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.shape(x_i), array_ops.shape(x_i, out_type=dtypes.int64) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int64]) + + def test_size(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.size(x_i), array_ops.size(x_i, out_type=dtypes.int64) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int64]) + + def test_rank(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.rank(x_i) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_shape_n(self): + x = random_ops.random_uniform([3, 2, 3]) + y = random_ops.random_uniform([3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + y_i = array_ops.gather(y, i) + return array_ops.shape_n([x_i, x, y, y_i]), array_ops.shape_n( + [x_i, x, y, y_i], out_type=dtypes.int64) + + self._test_loop_fn( + loop_fn, 3, loop_fn_dtypes=[dtypes.int32] * 4 + [dtypes.int64] * 4) + + def test_reshape(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.reshape(x1, [-1]), array_ops.reshape(x1, [1, 3, 1, -1]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_expand_dims(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.expand_dims( + x1, axis=-1), array_ops.expand_dims( + x1, axis=1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_slice(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.slice(x1, begin=(0, 1), size=(2, 1)) + + self._test_loop_fn(loop_fn, 3) + + def test_tile(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.tile(x1, [2, 1]) + + self._test_loop_fn(loop_fn, 3) + + def test_tile_loop_dependent(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.tile(x1, [i, 1]) + + with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"): + pfor_control_flow_ops.pfor(loop_fn, 2) + + def test_pack(self): + x = random_ops.random_uniform([3, 2, 3]) + y = random_ops.random_uniform([2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.stack([x1, y], axis=-1) + + self._test_loop_fn(loop_fn, 1) + + def test_unpack(self): + x = random_ops.random_uniform([3, 2, 3, 4]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.unstack( + x_i, 4, axis=-1), array_ops.unstack( + x_i, 3, axis=1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 7) + + def test_pad(self): + x = random_ops.random_uniform([3, 2, 3]) + padding = constant_op.constant([[1, 2], [3, 4]]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.pad(x1, padding, mode="CONSTANT") + + self._test_loop_fn(loop_fn, 3) + + def test_split(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.split(x1, 2, axis=0), array_ops.split(x1, 3, axis=-1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 5) + + def test_transpose(self): + x = random_ops.random_uniform([3, 2, 3, 4]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.transpose(x1, [2, 1, 0]) + + self._test_loop_fn(loop_fn, 3) + + def test_zeros_like(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + z = array_ops.zeros_like(x1), + return z, z + x1 + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_concat_v2(self): + x = random_ops.random_uniform([3, 2, 3]) + y = random_ops.random_uniform([2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.concat( + [x1, x1, y], axis=0), array_ops.concat( + [x1, x1, y], axis=-1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_unary_cwise_ops(self): + for op in [array_ops.identity, array_ops.stop_gradient]: + x = random_ops.random_uniform([3, 5]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + y = op(x1) + x1 + loss = nn.l2_loss(y) + return op(x), y, gradient_ops.gradients(loss, x1) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3) + + def test_strided_slice(self): + x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + y = x_i[:2, ::2, 1::3, ..., array_ops.newaxis, 1] + loss = nn.l2_loss(y) + return y, gradient_ops.gradients(loss, x_i) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + +class MathTest(PForTest): + + def test_unary_cwise_ops(self): + for op in [ + math_ops.tanh, nn.relu, math_ops.sigmoid, math_ops.negative, + math_ops.square + ]: + x = random_ops.random_uniform([3, 5]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + y = op(x1) + loss = math_ops.reduce_sum(y * y) + return op(x), y, gradient_ops.gradients(loss, x1) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3) + + def test_unary_cwise_no_grad(self): + for op in [math_ops.ceil, math_ops.floor, math_ops.logical_not]: + x = random_ops.random_uniform([3, 5]) + if op == math_ops.logical_not: + x = x > 0 + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + return op(array_ops.gather(x, i)) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=x.dtype) + + def test_binary_cwise_ops(self): + logical_ops = [ + math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor + ] + bool_ops = [ + math_ops.less, math_ops.less_equal, math_ops.greater, + math_ops.greater_equal, math_ops.equal, math_ops.not_equal + ] + float_ops = [ + math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.divide, + math_ops.maximum, math_ops.minimum + ] + for op in logical_ops + bool_ops + float_ops: + x = random_ops.random_uniform([7, 3, 5]) + y = random_ops.random_uniform([3, 5]) + if op in logical_ops: + x = x > 0 + y = y > 0 + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + y1 = array_ops.gather(y, i) + return op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1) + + # pylint: enable=cell-var-from-loop + + dtype = dtypes.float32 if op in float_ops else dtypes.bool + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtype] * 5) + + def test_addn(self): + x = random_ops.random_uniform([2, 3, 5]) + y = random_ops.random_uniform([3, 5]) + z = random_ops.random_uniform([3, 5]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return math_ops.add_n([x1, y, z]) + + self._test_loop_fn(loop_fn, 2) + + def test_matmul(self): + for tr_a in (True, False): + for tr_b in (True, False): + for stack_a in (True, False): + for stack_b in (True, False): + shape_a = (5, 3) if tr_a else (3, 5) + if stack_a: + shape_a = (2,) + shape_a + shape_b = (7, 5) if tr_b else (5, 7) + if stack_b: + shape_b = (2,) + shape_b + + x = random_ops.random_uniform(shape_a) + y = random_ops.random_uniform(shape_b) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) if stack_a else x + b = array_ops.gather(y, i) if stack_b else y + return math_ops.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_batch_matmul(self): + for tr_a in (True, False): + for tr_b in (True, False): + for stack_a in (True, False): + for stack_b in (True, False): + shape_a = (4, 5, 3) if tr_a else (4, 3, 5) + if stack_a: + shape_a = (2,) + shape_a + shape_b = (4, 7, 5) if tr_b else (4, 5, 7) + if stack_b: + shape_b = (2,) + shape_b + + x = random_ops.random_uniform(shape_a) + y = random_ops.random_uniform(shape_b) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) if stack_a else x + b = array_ops.gather(y, i) if stack_b else y + return math_ops.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_reduction(self): + x = random_ops.random_uniform([2, 3, 4, 5]) + for op in [ + math_ops.reduce_sum, math_ops.reduce_prod, math_ops.reduce_max, + math_ops.reduce_min + ]: + for axis in ([1], None, [0, 2]): + for keepdims in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + return op(a, axis=axis, keepdims=keepdims) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_cum_sum(self): + x = random_ops.random_uniform([2, 3, 4, 5]) + for axis in (1, -2): + for exclusive in (True, False): + for reverse in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + return math_ops.cumsum( + a, axis=axis, exclusive=exclusive, reverse=reverse) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_cum_prod(self): + x = random_ops.random_uniform([2, 3, 4, 5]) + for axis in (1, -2): + for exclusive in (True, False): + for reverse in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + return math_ops.cumprod( + a, axis=axis, exclusive=exclusive, reverse=reverse) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_bias_add(self): + x_shape = [2, 3, 4, 5, 6] + x = random_ops.random_uniform(x_shape) + for data_format in ("NCHW", "NHWC"): + bias_dim = 2 if data_format == "NCHW" else -1 + bias_shape = x_shape[bias_dim] + bias = random_ops.random_uniform([bias_shape]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + y = nn.bias_add(a, bias, data_format=data_format) + loss = math_ops.reduce_sum(y * y) + return y, gradient_ops.gradients(loss, bias) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn( + loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.float32]) + + def test_unsorted_segment_sum(self): + t = random_ops.random_uniform([3, 3, 2]) + segment_ids = constant_op.constant([[0, 0, 2], [0, 1, 2], [2, 2, 2]]) + num_segments = 3 + + def loop_fn(i): + data = array_ops.gather(t, i) + data_0 = array_ops.gather(t, 0) + seg_ids = array_ops.gather(segment_ids, i) + return (math_ops.unsorted_segment_sum(data, seg_ids, num_segments), + math_ops.unsorted_segment_sum(data_0, seg_ids, num_segments)) + + self._test_loop_fn(loop_fn, 3, [dtypes.float32] * 2) + + def test_cast(self): + x = constant_op.constant([[1], [2]]) + y = constant_op.constant([[1.0], [2.0]]) + + def loop_fn(i): + return (math_ops.cast(array_ops.gather(x, i), dtypes.float32), + math_ops.cast(array_ops.gather(y, i), dtypes.int32)) + + self._test_loop_fn( + loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.int32]) + + def test_tanh_axpy(self): + a = constant_op.constant(3.) + x = random_ops.random_uniform([4, 5]) + y = random_ops.random_uniform([6, 5]) + n = x.shape[0] + + def loop_fn(i): + return math_ops.tanh(a * array_ops.gather(x, i) + array_ops.gather(y, i)) + + self._test_loop_fn(loop_fn, n) + + def test_select(self): + cond = constant_op.constant([True, False]) + a = random_ops.random_uniform([2, 3, 5]) + b = random_ops.random_uniform([2, 3, 5]) + for cond_shape in [2], [2, 3], [2, 3, 5]: + cond = random_ops.random_uniform(cond_shape) > 0.5 + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a_i = array_ops.gather(a, i) + b_i = array_ops.gather(b, i) + cond_i = array_ops.gather(cond, i) + return array_ops.where(cond_i, a_i, b_i) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + +class NNTest(PForTest): + + def test_conv2d(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + filt = random_ops.random_uniform([3, 3, 3, 7]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return nn.conv2d( + x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") + + self._test_loop_fn(loop_fn, 3) + + def test_conv2d_backprop_input(self): + x_shape = [2, 12, 12, 3] + filt = random_ops.random_uniform([3, 3, 3, 7]) + grad = random_ops.random_uniform([3, 2, 5, 5, 7]) + + def loop_fn(i): + grad1 = array_ops.gather(grad, i) + return nn.conv2d_backprop_input( + x_shape, + filt, + grad1, + strides=[1, 2, 2, 1], + padding="VALID", + data_format="NHWC") + + self._test_loop_fn(loop_fn, 3) + + def test_conv2d_backprop_filter(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + x_0 = array_ops.gather(x, 0) + filter_sizes = [3, 3, 3, 7] + grad = random_ops.random_uniform([3, 2, 5, 5, 7]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + grad_i = array_ops.gather(grad, i) + return [ + nn.conv2d_backprop_filter( + inp, + filter_sizes, + grad_i, + strides=[1, 2, 2, 1], + padding="VALID", + data_format="NHWC") for inp in [x_i, x_0] + ] + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_avg_pool(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + ksize = [1, 3, 3, 1] + + def loop_fn(i): + x1 = array_ops.gather(x, i) + output = nn.avg_pool( + x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") + loss = nn.l2_loss(output) + return output, gradient_ops.gradients(loss, x1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_max_pool(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + ksize = [1, 3, 3, 1] + + def loop_fn(i): + x1 = array_ops.gather(x, i) + output = nn.max_pool( + x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") + loss = nn.l2_loss(output) + return output, gradient_ops.gradients(loss, x1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_fused_batch_norm(self): + data_formats = ["NHWC"] + if test.is_gpu_available(): + data_formats.append("NCHW") + for is_training in (True, False): + for data_format in data_formats: + if data_format == "NCHW": + x = random_ops.random_uniform([3, 1, 2, 5, 5]) + else: + x = random_ops.random_uniform([3, 1, 5, 5, 2]) + scale = random_ops.random_uniform([2]) + offset = random_ops.random_uniform([2]) + mean = None if is_training else random_ops.random_uniform([2]) + variance = None if is_training else random_ops.random_uniform([2]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + outputs = nn.fused_batch_norm( + x1, + scale, + offset, + mean=mean, + variance=variance, + epsilon=0.01, + data_format=data_format, + is_training=is_training) + outputs = list(outputs) + # We only test the first value of outputs when is_training is False. + # It looks like CPU and GPU have different outputs for batch_mean and + # batch_variance for this case. + if not is_training: + outputs[1] = constant_op.constant(0.) + outputs[2] = constant_op.constant(0.) + loss = nn.l2_loss(outputs[0]) + gradients = gradient_ops.gradients(loss, [x1, scale, offset]) + return outputs + gradients + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 6) + + def test_softmax_cross_entropy_with_logits(self): + logits = random_ops.random_uniform([3, 2, 4]) + labels = random_ops.random_uniform([3, 2, 4]) + labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True) + + def loop_fn(i): + logits_i = array_ops.gather(logits, i) + labels_i = array_ops.gather(labels, i) + loss = nn.softmax_cross_entropy_with_logits( + labels=labels_i, logits=logits_i) + return loss, gradient_ops.gradients(math_ops.reduce_sum(loss), logits_i) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + +class RandomTest(PForTest): + + # The random values generated in the two implementations are not guaranteed to + # match. So we only check the returned shapes. + def run_and_assert_equal(self, targets1, targets2): + outputs = self._run_targets(targets1, targets2) + n = len(outputs) // 2 + for i in range(n): + self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) + + def test_random_uniform(self): + + def loop_fn(_): + return random_ops.random_uniform([3]) + + self._test_loop_fn(loop_fn, 5) + + def test_random_uniform_int(self): + + def loop_fn(_): + return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32) + + self._test_loop_fn(loop_fn, 5, loop_fn_dtypes=dtypes.int32) + + def test_random_standard_normal(self): + + def loop_fn(_): + return random_ops.random_normal([3]) + + self._test_loop_fn(loop_fn, 5) + + def test_truncated_normal(self): + + def loop_fn(_): + return random_ops.truncated_normal([3]) + + self._test_loop_fn(loop_fn, 5) + + def test_random_gamma(self): + + def loop_fn(_): + return random_ops.random_gamma([3], alpha=[0.5]) + + self._test_loop_fn(loop_fn, 5) + + def test_random_poisson_v2(self): + + def loop_fn(_): + return random_ops.random_poisson(lam=[1.3], shape=[3]) + + self._test_loop_fn(loop_fn, 5) + + +class LoggingTest(PForTest): + + def test_print(self): + x = random_ops.random_uniform([3, 5]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return logging_ops.Print( + x1, [x1, "x1", array_ops.shape(x1)], summarize=10) + + self._test_loop_fn(loop_fn, 3) + + def test_assert(self): + + def loop_fn(i): + return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) + + # TODO(agarwal): make this work with for_loop. + with session.Session() as sess: + sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)) + + +class TensorArrayTest(PForTest): + + def test_create_outside_and_read(self): + + ta = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) + + def loop_fn(i): + return ta.read(i), ta.read(0) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_create_outside_and_gather(self): + + ta = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) + + def loop_fn(i): + return ta.gather([i]), ta.gather([0, 1]) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_create_outside_and_write_and_scatter(self): + + t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False) + handle = t.handle + + def loop_fn(i): + ta = t.write(i + 2, 2 * i).write(i, 5) + ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i]) + return ta.flow + + t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2) + out1 = tensor_array_ops.TensorArray( + dtypes.int32, handle=handle, flow=t1[-1]).stack() + output1 = self._run_targets(out1) + + t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2) + out2 = tensor_array_ops.TensorArray( + dtypes.int32, handle=handle, flow=t2[-1]).stack() + output2 = self._run_targets(out2) + self.assertAllClose(output2, output1) + + def test_create_inside_and_write(self): + + def loop_fn(i): + # TODO(agarwal): switching the order of writes to ta1 does not work. + ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0, i).write( + 1, 1) + ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1) + return ta1.stack(), ta2.stack() + + self._test_loop_fn(loop_fn, 3, [dtypes.int32] * 2) + + def test_create_inside_and_scatter(self): + + def loop_fn(i): + # TODO(agarwal): switching the order of scatter to ta1 does not work. + ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).scatter( + [0], [[i, 2]]).scatter([1], [[1, 2]]) + ta2 = tensor_array_ops.TensorArray(dtypes.int32, + 2).scatter([0], [3]).scatter([1], [4]) + return ta1.stack(), ta2.stack() + + self._test_loop_fn(loop_fn, 3, [dtypes.int32] * 2) + + def test_create_inside_and_read(self): + + def loop_fn(i): + ta1 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) + ta2 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) + # TODO(agarwal): ta1.read(i) currently is not supported. + return ta1.read(0), ta2.read(0), ta2.read(i) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 3) + + def test_create_inside_and_gather(self): + + def loop_fn(i): + ta1 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) + ta2 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) + # TODO(agarwal): ta1.read(i) currently is not supported. + return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i]) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 3) + + def test_grad(self): + x = random_ops.random_uniform([3, 2]) + ta = tensor_array_ops.TensorArray( + dtypes.float32, 3, clear_after_read=False).unstack(x) + y = math_ops.square(ta.stack()) + + def loop_fn(i): + y_i = array_ops.gather(y, i) + grad = gradient_ops.gradients(y_i, x)[0] + return array_ops.gather(grad, i) + + t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3) + # y = x * x. Hence dy/dx = 2 * x. + actual_grad = 2.0 * x + with session.Session() as sess: + actual_grad, computed_grad = sess.run([t1, actual_grad]) + self.assertAllClose(actual_grad, computed_grad) + + +class StackTest(PForTest): + + def test_stack_inside_loop_invariant(self): + + def loop_fn(_): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + op1 = data_flow_ops.stack_push_v2(s, 1) + with ops.control_dependencies([op1]): + op2 = data_flow_ops.stack_push_v2(s, 2) + with ops.control_dependencies([op2]): + e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + with ops.control_dependencies([e2]): + e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + return e1, e2 + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_stack_inside_push_loop_dependent(self): + + def loop_fn(i): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + op1 = data_flow_ops.stack_push_v2(s, i) + with ops.control_dependencies([op1]): + op2 = data_flow_ops.stack_push_v2(s, 2) + with ops.control_dependencies([op2]): + e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + with ops.control_dependencies([e2]): + e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + return e1, e2 + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_stack_outside_pop(self): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + op = data_flow_ops.stack_push_v2(s, 5) + with ops.control_dependencies([op]): + op = data_flow_ops.stack_push_v2(s, 6) + with ops.control_dependencies([op]): + op = data_flow_ops.stack_push_v2(s, 7) + + def loop_fn(_): + e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + with ops.control_dependencies([e1]): + e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + return e1, e2 + + with ops.control_dependencies([op]): + e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2) + with ops.control_dependencies([e1, e2]): + e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False) + self.assertAllEqual([7, 7], v1) + self.assertAllEqual([6, 6], v2) + self.assertAllEqual(5, v3) + + def test_stack_outside_push(self): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + + def loop_fn(_): + return data_flow_ops.stack_push_v2(s, 7) + + with self.assertRaisesRegexp(ValueError, "StackPushV2 not allowed.*"): + pfor_control_flow_ops.pfor(loop_fn, iters=2) + + +# TODO(agarwal): test nested while_loops. This currently requires converting a +# tf.cond. +class ControlFlowTest(PForTest): + + def test_while_outside_loop(self): + + x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) + + def loop_fn(i): + return x + i + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_invariant_while(self): + + def loop_fn(_): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_invariant_while_with_control_dependency(self): + + def loop_fn(i): + with ops.control_dependencies([i]): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, + [0]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_while_with_stateful_ops(self): + + def loop_fn(_): + return control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0] + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_while_unstacked_condition(self): + + def loop_fn(i): + return control_flow_ops.while_loop(lambda j, x: j < 4, + lambda j, x: (j + 1, x + i), [0, 0]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int32]) + + def test_while(self): + x = random_ops.random_uniform([3, 5]) + lengths = constant_op.constant([4, 0, 2]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + lengths_i = array_ops.gather(lengths, i) + + _, total = control_flow_ops.while_loop( + lambda j, _: j < lengths_i, + lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) + return total + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32]) + + def test_while_jacobian(self): + x = random_ops.random_uniform([1, 3]) + y = random_ops.random_uniform([3, 3]) + + # out = x @ y @ y @ y @ y, where @ is matmul operator. + _, out = control_flow_ops.while_loop( + lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), + [0, x]) + + def loop_fn(i): + out_i = array_ops.gather(out, i, axis=1) + return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1]) + + out = pfor_control_flow_ops.pfor(loop_fn, iters=3) + + # The above code does not work with tf.while_loop instead of pfor. So we + # manually compute the expected output here. + # Note that gradient of output w.r.t is (y @ y @ y @ y)^T. + expected_output = y + for _ in range(3): + expected_output = math_ops.matmul(expected_output, y) + expected_output = array_ops.transpose(expected_output, [1, 0]) + + with session.Session() as sess: + out, expected = sess.run([out, expected_output]) + self.assertAllClose(expected, out) + + def test_tensor_array_as_loop_variable(self): + + def loop_fn(i): + + def body(j, ta): + ta = ta.write(j, i + j * j) + return j + 1, ta + + _, ta = control_flow_ops.while_loop( + lambda j, _: j < 4, body, + (0, tensor_array_ops.TensorArray(dtypes.int32, size=4))) + return ta.stack() + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_read_tensor_array_partitioned_indices(self): + # Note that tensor array values are pfor loop dependent, and the while loop + # termination condition is also dependent on pfor iteration. + def loop_fn(i): + ta = tensor_array_ops.TensorArray(dtypes.int32, size=6) + ta = ta.unstack(i + list(range(5))) + + def body(j, s): + return j + 1, s + ta.read(j) + + _, s = control_flow_ops.while_loop(lambda j, _: j < i, + body, + (0, 0)) + return s + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_external_while_loop_grad(self): + # Here we test that external while_loops that are extended from inside pfor + # (due to gradient calls) are not actually converted. If the below was + # converted all pfor iterations would write to the same tensor array + # indices. + x = constant_op.constant(1.) + + def body(j, ta): + ta = ta.write(j, x) + return j + 1, ta + + _, ta = control_flow_ops.while_loop( + lambda j, _: j < 4, body, + (0, tensor_array_ops.TensorArray(dtypes.float32, size=4))) + out = ta.stack() + + def loop_fn(i): + out_i = array_ops.gather(out, i) + return gradient_ops.gradients(out_i, x)[0] + + with session.Session() as sess: + # out is [x, x, x]. Hence the gradients should be [1, 1, 1]. + self.assertAllEqual([1, 1, 1], + sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))) + + def test_tensor_array_grad(self): + inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32) + ta = tensor_array_ops.TensorArray(dtypes.float32, size=3) + ta = ta.unstack(inp) + + def loop_fn(i): + + def body(j, x): + value = ta.gather([j]) + value = array_ops.gather(array_ops.reshape(value, [4, 2]), i) + return j + 1, x + value + + _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body, + (0, array_ops.zeros([2]))) + out = math_ops.reduce_prod(out) + return out, gradient_ops.gradients(out, inp)[0] + + pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4) + # Note that tf.while_loop does not work in the setup above. So we manually + # construct the equivalent computation of the above loops here. + real_out = math_ops.reduce_sum(inp, reduction_indices=[0]) + real_out = math_ops.reduce_prod(real_out, reduction_indices=[1]) + # Note that gradients of real_out will accumulate the gradients across the + # output value. Hence we do the same aggregation on pfor_out_grad. + real_out_grad = gradient_ops.gradients(real_out, inp)[0] + sum_pfor_out_grad = math_ops.reduce_sum( + pfor_out_grad, reduction_indices=[0]) + + with session.Session() as sess: + v1, v2, v1_grad, v2_grad = sess.run( + [pfor_out, real_out, sum_pfor_out_grad, real_out_grad]) + self.assertAllClose(v1, v2) + self.assertAllClose(v1_grad, v2_grad) + + +def dynamic_lstm_input_fn(batch_size, state_size, max_steps): + # We make inputs and sequence_length constant so that multiple session.run + # calls produce the same result. + inputs = constant_op.constant( + np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) + sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1) + sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32) + return inputs, sequence_length + + +def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps): + cell = cell_fn(state_size) + inputs, sequence_length = dynamic_lstm_input_fn(batch_size, + state_size, + max_steps) + inputs_ta = tensor_array_ops.TensorArray( + dtypes.float32, size=max_steps, element_shape=[batch_size, state_size]) + inputs_time_major = array_ops.transpose(inputs, [1, 0, 2]) + inputs_ta = inputs_ta.unstack(inputs_time_major) + zeros = array_ops.zeros([state_size]) + + def loop_fn(i): + sequence_length_i = array_ops.gather(sequence_length, i) + + def body_fn(t, state, ta): + inputs_t = array_ops.expand_dims( + array_ops.gather(inputs_ta.read(t), i), 0) + output, new_state = cell(inputs_t, state) + output = array_ops.reshape(output, [-1]) + # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the + # array_ops.where when t < min(sequence_length). Doing that requires + # supporting tf.cond pfor conversion. + done = t >= sequence_length_i + output = array_ops.where(done, zeros, output) + ta = ta.write(t, output) + new_state = [array_ops.where(done, s, ns) for s, ns in + zip(nest.flatten(state), nest.flatten(new_state))] + new_state = nest.pack_sequence_as(state, new_state) + return t + 1, new_state, ta + + def condition_fn(t, _, unused): + del unused + return t < max_steps + + initial_state = cell.zero_state(1, dtypes.float32) + _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [ + 0, initial_state, + tensor_array_ops.TensorArray(dtypes.float32, max_steps) + ]) + + new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)] + new_state = nest.pack_sequence_as(initial_state, new_state) + return ta.stack(), new_state + + pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size) + tf_output = rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length, + initial_state=cell.zero_state(batch_size, dtypes.float32)) + return pfor_output, tf_output + + +class RNNTest(PForTest): + + def test_dynamic_rnn(self): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, + 3, 5, 7) + self.run_and_assert_equal(pfor_outputs, tf_outputs) + + def test_dynamic_lstm(self): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, + 3, 5, 7) + self.run_and_assert_equal(pfor_outputs, tf_outputs) + + +# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop +# conversion don't look good. Some of it seems like lot of copies between host +# and device. Optimize that. +class Benchmarks(test.Benchmark): + + def _run(self, targets, iters, name=None): + + def _done(t): + # Note that we don't use tf.control_dependencies since that will not make + # sure that the computation on GPU has actually finished. So we fetch the + # first element of the output, and assume that this will not be called on + # empty tensors. + return array_ops.gather(array_ops.reshape(t, [-1]), 0) + + targets = [_done(x) for x in nest.flatten(targets)] + sess = session.Session() + with sess: + init = variables.global_variables_initializer() + sess.run(init) + sess.run(targets) + begin = time.time() + for _ in range(iters): + sess.run(targets) + end = time.time() + avg_time_ms = 1000 * (end - begin) / iters + self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) + return avg_time_ms + + def benchmark_basic_while(self): + with ops.Graph().as_default(): + + def loop_fn(i): + _, s = control_flow_ops.while_loop( + lambda t, x: t < i, + lambda t, x: (t + 1, x + i), + [0, 0]) + return s + + iters = 50 + pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters) + for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32, + iters) + self._run(pfor_output, 100, name="pfor_basic") + self._run(for_loop_output, 100, name="for_loop_basic") + + def benchmark_dynamic_rnn(self): + with ops.Graph().as_default(): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, + 128, 512, 16) + self._run(pfor_outputs, 100, name="pfor_rnn") + self._run(tf_outputs, 100, name="tf_rnn") + + def benchmark_dynamic_lstm(self): + with ops.Graph().as_default(): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, + 128, 512, 16) + self._run(pfor_outputs, 100, name="pfor_lstm") + self._run(tf_outputs, 100, name="tf_lstm") + + +class SparseTest(PForTest): + + def test_var_loop_len(self): + num_iters = array_ops.placeholder(dtypes.int32) + + def loop_fn(_): + return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], + [3]) # [0, 2, 0] + + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + with self.test_session() as sess: + sess.run(pfor, feed_dict={num_iters: 3}) + + def test_sparse_result_none_stacked(self): + num_iters = 10 + + def loop_fn(_): + return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], + [3]) # [0, 2, 0] + + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + + indices = [[i, j] for i in range(num_iters) for j in range(3)] + values = [4, 5, 6] * num_iters + dense_shapes = [num_iters, 3] + # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...] + manual = sparse_tensor.SparseTensor(indices, values, dense_shapes) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_all_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + indices = array_ops.expand_dims(i, 0) + return sparse_tensor.SparseTensor(indices, i, i + 1) # [0, ..., 0, i] + + # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], + list(range(num_iters)), + (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_indices_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + indices = array_ops.expand_dims(i, 0) + return sparse_tensor.SparseTensor(indices, [1], [num_iters]) + + # Expected result: identity matrix size num_iters * num_iters + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], + [1] * num_iters, (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_values_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + return sparse_tensor.SparseTensor([[0]], i, [num_iters]) # [i, 0, ..., 0] + + # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], + list(range(num_iters)), + (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_shapes_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + return sparse_tensor.SparseTensor([[0]], [1], i + 1) # [1, 0, ..., 0] + + # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], + [1] * num_iters, (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_shapes_stacked_2D(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0) + shape = array_ops.concat([i, i], 0) + return sparse_tensor.SparseTensor([[0, 0]], [1], shape) # [1, 0, ..., 0] + + # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)], + [1] * num_iters, + (num_iters, num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + +class ParsingTest(PForTest): + + def test_decode_csv(self): + csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]]) + kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"} + + def loop_fn(i): + line = array_ops.gather(csv_tensor, i) + return parsing_ops.decode_csv(line, **kwargs) + + self._test_loop_fn(loop_fn, iters=3, loop_fn_dtypes=[dtypes.int32] * 3) + + def test_parse_single_example(self): + + def _int64_feature(*values): + return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values)) + + def _bytes_feature(*values): + return feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[v.encode("utf-8") for v in values])) + + examples = constant_op.constant([ + example_pb2.Example( + features=feature_pb2.Features( + feature={ + "dense_int": _int64_feature(i), + "dense_str": _bytes_feature(str(i)), + "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8), + "sparse_str": _bytes_feature(*["abc"] * i) + })).SerializeToString() for i in range(10) + ]) + + features = { + "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), + "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), + "sparse_int": parsing_ops.VarLenFeature(dtypes.int64), + "sparse_str": parsing_ops.VarLenFeature(dtypes.string), + } + + def loop_fn(i): + example_proto = array_ops.gather(examples, i) + f = parsing_ops.parse_single_example(example_proto, features) + return f + + pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10) + manual = parsing_ops.parse_example(examples, features) + self.run_and_assert_equal(pfor, manual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3d5c9b86ed186f76e113351646b3dda153e72b --- /dev/null +++ b/tensorflow/python/ops/parallel_for/gradients.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Jacobian ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops.parallel_for import control_flow_ops +from tensorflow.python.util import nest + + +def jacobian(output, inputs, use_pfor=True): + """Computes jacobian of `output` w.r.t. `inputs`. + + Args: + output: A tensor. + inputs: A tensor or a nested structure of tensor objects. + use_pfor: If true, uses pfor for computing the jacobian. Else uses + tf.while_loop. + + Returns: + A tensor or a nested strucutre of tensors with the same structure as + `inputs`. Each entry is the jacobian of `output` w.rt. to the corresponding + value in `inputs`. If output has shape [y_1, ..., y_n] and inputs_i has + shape [x_1, ..., x_m], the corresponding jacobian has shape + [y_1, ..., y_n, x_1, ..., x_m]. + """ + flat_inputs = nest.flatten(inputs) + output_shape = array_ops.shape(output) + output = array_ops.reshape(output, [-1]) + + def loop_fn(i): + y = array_ops.gather(output, i) + return gradient_ops.gradients(y, flat_inputs) + + try: + output_size = int(output.shape[0]) + except TypeError: + output_size = array_ops.shape(output)[0] + + if use_pfor: + pfor_outputs = control_flow_ops.pfor(loop_fn, output_size) + else: + pfor_outputs = control_flow_ops.for_loop( + loop_fn, [output.dtype] * len(flat_inputs), output_size) + + for i, out in enumerate(pfor_outputs): + new_shape = array_ops.concat( + [output_shape, array_ops.shape(out)[1:]], axis=0) + out = array_ops.reshape(out, new_shape) + pfor_outputs[i] = out + + return nest.pack_sequence_as(inputs, pfor_outputs) + + +def batch_jacobian(output, inp, use_pfor=True): + """Computes and stacks jacobians of `output[i,...]` w.r.t. `input[i,...]`. + + e.g. + x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) + y = x * x + jacobian = batch_jacobian(y, x) + # => [[[2, 0], [0, 4]], [[6, 0], [0, 8]]] + + Args: + output: A tensor with shape [b, y1, ..., y_n]. `output[i,...]` should + only depend on `inp[i,...]`. + inp: A tensor with shape [b, x1, ..., x_m] + use_pfor: If true, uses pfor for computing the Jacobian. Else uses a + tf.while_loop. + + Returns: + A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]` + is the jacobian of `output[i, ...]` w.r.t. `inp[i, ...]`, i.e. stacked + per-example jacobians. + + Raises: + ValueError: if first dimension of `output` and `inp` do not match. + """ + output_shape = output.shape + if not output_shape[0].is_compatible_with(inp.shape[0]): + raise ValueError("Need first dimension of output shape (%s) and inp shape " + "(%s) to match." % (output.shape, inp.shape)) + if output_shape.is_fully_defined(): + batch_size = int(output_shape[0]) + output_row_size = output_shape.num_elements() // batch_size + else: + output_shape = array_ops.shape(output) + batch_size = output_shape[0] + output_row_size = array_ops.size(output) // batch_size + inp_shape = array_ops.shape(inp) + # Flatten output to 2-D. + with ops.control_dependencies( + [check_ops.assert_equal(batch_size, inp_shape[0])]): + output = array_ops.reshape(output, [batch_size, output_row_size]) + + def loop_fn(i): + y = array_ops.gather(output, i, axis=1) + return gradient_ops.gradients(y, inp)[0] + + if use_pfor: + pfor_output = control_flow_ops.pfor(loop_fn, output_row_size) + else: + pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype, + output_row_size) + pfor_output = array_ops.reshape(pfor_output, + [output_row_size, batch_size, -1]) + output = array_ops.transpose(pfor_output, [1, 0, 2]) + new_shape = array_ops.concat([output_shape, inp_shape[1:]], axis=0) + return array_ops.reshape(output, new_shape) diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py new file mode 100644 index 0000000000000000000000000000000000000000..310a2154f71c29702de1d43d8fc4af931b3217eb --- /dev/null +++ b/tensorflow/python/ops/parallel_for/gradients_test.py @@ -0,0 +1,568 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jacobian and batch_jacobian ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import time + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import training as keras_training +from tensorflow.python.layers import layers as tf_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.parallel_for import control_flow_ops +from tensorflow.python.ops.parallel_for import gradients +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class FullyConnectedModel(object): + + def __init__(self, activation_size, num_layers): + self._layers = [ + tf_layers.Dense(activation_size, activation=nn.relu) + for _ in range(num_layers) + ] + + def __call__(self, inp): + activation = inp + for layer in self._layers: + activation = layer(activation) + return activation + + +def fully_connected_model_fn(batch_size, activation_size, num_layers): + model = FullyConnectedModel(activation_size, num_layers) + inp = random_ops.random_normal([batch_size, activation_size]) + return inp, model(inp) + + +def lstm_model_fn(batch_size, state_size, steps): + inputs = [ + random_ops.random_normal([batch_size, state_size]) for _ in range(steps) + ] + cell = rnn_cell.BasicLSTMCell(state_size) + init_state = cell.zero_state(batch_size, dtypes.float32) + state = init_state + for inp in inputs: + _, state = cell(inp, state) + return init_state.c, state.c + + +def dynamic_lstm_model_fn(batch_size, state_size, max_steps): + # We make inputs and sequence_length constant so that multiple session.run + # calls produce the same result. + inputs = constant_op.constant( + np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) + sequence_length = constant_op.constant( + np.random.randint(0, size=[batch_size], high=max_steps + 1), + dtype=dtypes.int32) + + cell = rnn_cell.BasicLSTMCell(state_size) + initial_state = cell.zero_state(batch_size, dtypes.float32) + return inputs, rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length, + initial_state=initial_state) + + +def create_fc_batch_jacobian(batch_size, activation_size, num_layers): + inp, output = fully_connected_model_fn(batch_size, activation_size, + num_layers) + pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True) + while_jacobian = gradients.batch_jacobian(output, inp, use_pfor=False) + return pfor_jacobian, while_jacobian + + +def create_lstm_batch_jacobian(batch_size, state_size, steps): + inp, output = lstm_model_fn(batch_size, state_size, steps) + pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True) + while_jacobian = gradients.batch_jacobian(output, inp, use_pfor=False) + return pfor_jacobian, while_jacobian + + +def create_dynamic_lstm_batch_jacobian(batch_size, state_size, max_steps): + inp, (_, final_state) = dynamic_lstm_model_fn(batch_size, state_size, + max_steps) + pfor_jacobian = gradients.batch_jacobian(final_state.c, inp, use_pfor=True) + # Note that use_pfor=False does not work above given the current limitations + # on implementation of while_loop. So we statically unroll the looping in the + # jacobian computation. + while_gradients = [ + gradient_ops.gradients(array_ops.gather(final_state.c, i, axis=1), inp)[0] + for i in range(state_size) + ] + return pfor_jacobian, while_gradients + + +def create_lstm_batch_hessian(batch_size, state_size, steps): + inp, output = lstm_model_fn(batch_size, state_size, steps) + pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True) + pfor_jacobian = array_ops.reshape(pfor_jacobian, [batch_size, -1]) + pfor_hessian = gradients.batch_jacobian(pfor_jacobian, inp, use_pfor=True) + # TODO(agarwal): using two nested while_loop doesn't seem to work here. + # Hence we use pfor_jacobian for computing while_hessian. + while_jacobian = pfor_jacobian + while_hessian = gradients.batch_jacobian(while_jacobian, inp, use_pfor=False) + return pfor_hessian, while_hessian + + +def create_lstm_hessian(batch_size, state_size, steps): + _, output = lstm_model_fn(batch_size, state_size, steps) + weights = variables.trainable_variables() + pfor_jacobians = gradients.jacobian(output, weights, use_pfor=True) + pfor_hessians = [ + gradients.jacobian(x, weights, use_pfor=True) for x in pfor_jacobians + ] + # TODO(agarwal): using two nested while_loop doesn't seem to work here. + # Hence we use pfor_jacobians for computing while_hessians. + while_jacobians = pfor_jacobians + while_hessians = [ + gradients.jacobian(x, weights, use_pfor=False) for x in while_jacobians + ] + return pfor_hessians, while_hessians + + +def create_fc_per_eg_grad(batch_size, activation_size, num_layers): + inp = random_ops.random_normal([batch_size, activation_size]) + layers = [ + tf_layers.Dense(activation_size, activation=nn.relu) + for _ in range(num_layers) + ] + projection = tf_layers.Dense(1) + + def model_fn(activation): + for layer in layers: + activation = layer(activation) + activation = projection(activation) + activation = nn.l2_loss(activation) + return gradient_ops.gradients(activation, variables.trainable_variables()) + + def loop_fn(i): + return model_fn(array_ops.expand_dims(array_ops.gather(inp, i), 0)) + + pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size) + loop_fn_dtypes = [x.dtype for x in variables.trainable_variables()] + while_outputs = control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, batch_size) + return pfor_outputs, while_outputs + + +def create_lstm_per_eg_grad(batch_size, state_size, steps): + inputs = [ + random_ops.random_normal([batch_size, state_size]) for _ in range(steps) + ] + cell = rnn_cell.BasicLSTMCell(state_size) + init_state = cell.zero_state(batch_size, dtypes.float32) + + def model_fn(inps, init_state): + state = init_state + for inp in inps: + _, state = cell(inp, state) + output = nn.l2_loss(state.c) + return gradient_ops.gradients(output, variables.trainable_variables()) + + def loop_fn(i): + loop_inputs = [ + array_ops.expand_dims(array_ops.gather(x, i), 0) for x in inputs + ] + loop_init_state = rnn_cell.LSTMStateTuple( + *[array_ops.expand_dims(array_ops.gather(x, i), 0) for x in init_state]) + return model_fn(loop_inputs, loop_init_state) + + pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size) + loop_fn_dtypes = [x.dtype for x in variables.trainable_variables()] + while_outputs = control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, batch_size) + return pfor_outputs, while_outputs + + +# Importing the code from tensorflow_models seems to cause errors. Hence we +# duplicate the model definition here. +# TODO(agarwal): Use the version in tensorflow_models/official instead. +class Mnist(keras_training.Model): + + def __init__(self, data_format): + """Creates a model for classifying a hand-written digit. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + """ + super(Mnist, self).__init__() + if data_format == "channels_first": + self._input_shape = [-1, 1, 28, 28] + else: + assert data_format == "channels_last" + self._input_shape = [-1, 28, 28, 1] + + self.conv1 = tf_layers.Conv2D( + 32, 5, padding="same", data_format=data_format, activation=nn.relu) + self.conv2 = tf_layers.Conv2D( + 64, 5, padding="same", data_format=data_format, activation=nn.relu) + self.fc1 = tf_layers.Dense(1024, activation=nn.relu) + self.fc2 = tf_layers.Dense(10) + self.dropout = tf_layers.Dropout(0.4) + self.max_pool2d = tf_layers.MaxPooling2D( + (2, 2), (2, 2), padding="same", data_format=data_format) + + def __call__(self, inputs, training): + """Add operations to classify a batch of input images. + + Args: + inputs: A Tensor representing a batch of input images. + training: A boolean. Set to True to add operations required only when + training the classifier. + + Returns: + A logits Tensor with shape [, 10]. + """ + y = array_ops.reshape(inputs, self._input_shape) + y = self.conv1(y) + y = self.max_pool2d(y) + y = self.conv2(y) + y = self.max_pool2d(y) + y = tf_layers.flatten(y) + y = self.fc1(y) + y = self.dropout(y, training=training) + return self.fc2(y) + + +def create_mnist_per_eg_grad(batch_size, data_format, training): + images = random_ops.random_uniform([batch_size, 28, 28]) + sparse_labels = np.random.randint( + low=0, high=10, size=[batch_size]).astype(np.int32) + labels = np.zeros((batch_size, 10)).astype(np.float32) + labels[np.arange(batch_size), sparse_labels] = 1. + model = Mnist(data_format) + + def loop_fn(i): + image = array_ops.gather(images, i) + label = array_ops.gather(labels, i) + logits = array_ops.reshape(model(image, training=training), [-1]) + loss = losses.softmax_cross_entropy( + logits=logits, onehot_labels=label, reduction=losses.Reduction.NONE) + return gradient_ops.gradients(loss, variables.trainable_variables()) + + pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size) + while_outputs = control_flow_ops.for_loop( + loop_fn, [dtypes.float32] * len(variables.trainable_variables()), + batch_size) + return pfor_outputs, while_outputs + + +def create_mnist_per_eg_jacobian(batch_size, data_format, training): + images = random_ops.random_uniform([batch_size, 28, 28]) + model = Mnist(data_format) + + def loop_fn(i, use_pfor): + image = array_ops.gather(images, i) + logits = array_ops.reshape(model(image, training=training), [-1]) + return gradients.jacobian( + logits, variables.trainable_variables(), use_pfor=use_pfor) + + pfor_outputs = control_flow_ops.pfor( + functools.partial(loop_fn, use_pfor=True), + batch_size) + while_outputs = control_flow_ops.for_loop( + functools.partial(loop_fn, use_pfor=False), + [dtypes.float32] * len(variables.trainable_variables()), batch_size) + return pfor_outputs, while_outputs + + +def create_fc_per_eg_jacobians(batch_size, activation_size, num_layers): + model = FullyConnectedModel(activation_size=activation_size, + num_layers=num_layers) + inp = random_ops.random_normal([batch_size, activation_size]) + output = model(inp) + jacobians = gradients.jacobian(output, variables.trainable_variables()) + + def loop_fn(i, use_pfor): + inp_i = array_ops.expand_dims(array_ops.gather(inp, i), 0) + output = array_ops.reshape(model(inp_i), [-1]) + return gradients.jacobian( + output, variables.trainable_variables(), use_pfor=use_pfor) + + per_eg_jacobians_pfor = control_flow_ops.pfor( + functools.partial(loop_fn, use_pfor=True), + batch_size) + per_eg_jacobians_while = control_flow_ops.for_loop( + functools.partial(loop_fn, use_pfor=False), + [dtypes.float32] * len(variables.trainable_variables()), batch_size) + return jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while + + +class GradientsTest(test.TestCase): + + def run_and_assert_equal(self, targets1, targets2, atol=1e-4, rtol=1e-4): + targets1 = nest.flatten(targets1) + targets2 = nest.flatten(targets2) + assert len(targets1) == len(targets2) + init = variables.global_variables_initializer() + self.evaluate(init) + outputs = self.evaluate(targets1 + targets2) + n = len(outputs) // 2 + for i in range(n): + self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol) + + def test_jacobian_fixed_shape(self): + x = random_ops.random_uniform([2, 2]) + y = math_ops.matmul(x, x, transpose_a=True) + jacobian_pfor = gradients.jacobian(y, x, use_pfor=True) + jacobian_while = gradients.jacobian(y, x, use_pfor=False) + answer = ops.convert_to_tensor([[ + gradient_ops.gradients(y[0][0], x)[0], + gradient_ops.gradients(y[0][1], x)[0] + ], [ + gradient_ops.gradients(y[1][0], x)[0], + gradient_ops.gradients(y[1][1], x)[0] + ]]) + self.run_and_assert_equal(answer, jacobian_pfor) + self.run_and_assert_equal(answer, jacobian_while) + + def test_jacobian_unknown_shape(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=[None, None]) + y = math_ops.matmul(x, x, transpose_a=True) + jacobian_pfor = gradients.jacobian(y, x, use_pfor=True) + jacobian_while = gradients.jacobian(y, x, use_pfor=False) + answer = ops.convert_to_tensor([[ + gradient_ops.gradients(y[0][0], x)[0], + gradient_ops.gradients(y[0][1], x)[0] + ], [ + gradient_ops.gradients(y[1][0], x)[0], + gradient_ops.gradients(y[1][1], x)[0] + ]]) + ans, pfor_value, while_value = sess.run( + [answer, jacobian_pfor, jacobian_while], + feed_dict={x: [[1, 2], [3, 4]]}) + self.assertAllClose(ans, pfor_value) + self.assertAllClose(ans, while_value) + + def test_batch_jacobian_bad_shapes(self): + x = random_ops.random_uniform([2, 2]) + y = random_ops.random_uniform([3, 2]) + with self.assertRaisesRegexp(ValueError, "Need first dimension of output"): + gradients.batch_jacobian(y, x, use_pfor=True) + + def test_batch_jacobian_bad_unknown_shapes(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32) + y = array_ops.concat([x, x], axis=0) + jacobian = gradients.batch_jacobian(y, x) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "assertion failed"): + sess.run(jacobian, feed_dict={x: [[1, 2], [3, 4]]}) + + def test_batch_jacobian_fixed_shape(self): + x = random_ops.random_uniform([2, 3, 5]) + y = x * x + batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True) + batch_jacobian_while = gradients.batch_jacobian(y, x, use_pfor=False) + two_x = 2 * x + answer = array_ops.stack( + [array_ops.diag(two_x[0]), + array_ops.diag(two_x[1])]) + self.run_and_assert_equal(answer, batch_jacobian_pfor) + self.run_and_assert_equal(answer, batch_jacobian_while) + + def test_batch_jacobian_unknown_shape(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32) + y = x * x + batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True) + batch_jacobian_while = gradients.batch_jacobian(y, x, use_pfor=False) + two_x = 2 * x + answer = array_ops.stack( + [array_ops.diag(two_x[0]), + array_ops.diag(two_x[1])]) + ans, pfor_value, while_value = sess.run( + [answer, batch_jacobian_pfor, batch_jacobian_while], + feed_dict={x: [[1, 2], [3, 4]]}) + self.assertAllClose(ans, pfor_value) + self.assertAllClose(ans, while_value) + + def test_fc_batch_jacobian(self): + pfor_jacobian, while_jacobian = create_fc_batch_jacobian(8, 4, 2) + self.run_and_assert_equal(pfor_jacobian, while_jacobian) + + def test_lstm_batch_jacobian(self): + pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(8, 4, 2) + self.run_and_assert_equal(pfor_jacobian, while_jacobian) + + def test_dynamic_lstm_batch_jacobian(self): + pfor_jacobian, while_gradients = create_dynamic_lstm_batch_jacobian(8, 4, 3) + with session.Session() as sess: + init = variables.global_variables_initializer() + sess.run(init) + pfor = sess.run(pfor_jacobian) + for i in range(4): + while_i = sess.run(while_gradients[i]) + self.assertAllClose(while_i, pfor[:, i, ...]) + + def test_lstm_hessian(self): + pfor_hessian, while_hessian = create_lstm_hessian(2, 2, 2) + self.run_and_assert_equal(pfor_hessian, while_hessian) + + def test_lstm_batch_hessian(self): + pfor_hessian, while_hessian = create_lstm_batch_hessian(2, 2, 2) + self.run_and_assert_equal(pfor_hessian, while_hessian) + + def test_fc_per_eg_grad(self): + pfor_outputs, while_outputs = create_fc_per_eg_grad(8, 4, 2) + self.run_and_assert_equal(pfor_outputs, while_outputs) + + def test_lstm_per_eg_grad(self): + pfor_outputs, while_outputs = create_lstm_per_eg_grad(8, 4, 2) + self.run_and_assert_equal(pfor_outputs, while_outputs) + + def test_mnist_per_eg_grad(self): + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + # Note that we we are setting training=False here so that dropout produces + # the same result with pfor and with while_loop. + pfor_outputs, while_outputs = create_mnist_per_eg_grad( + 4, data_format, training=False) + self.run_and_assert_equal(pfor_outputs, while_outputs, rtol=1e-3) + + def test_mnist_per_eg_jacobian(self): + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + # Note that we we are setting training=False here so that dropout produces + # the same result with pfor and with while_loop. + pfor_outputs, while_outputs = create_mnist_per_eg_jacobian( + 2, data_format, training=False) + self.run_and_assert_equal(pfor_outputs, while_outputs, rtol=1e-3) + + def test_fc_jacobian(self): + jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while = ( + create_fc_per_eg_jacobians(batch_size=8, + activation_size=4, + num_layers=2)) + self.run_and_assert_equal(jacobians, per_eg_jacobians_pfor, + rtol=2e-3, atol=1e-3) + self.run_and_assert_equal(jacobians, per_eg_jacobians_while, + rtol=2e-3, atol=1e-3) + + +class GradientsBenchmarks(test.Benchmark): + + def _run(self, targets, iters, name=None): + + def _done(t): + # Note that we don't use tf.control_dependencies since that will not make + # sure that the computation on GPU has actually finished. So we fetch the + # first element of the output, and assume that this will not be called on + # empty tensors. + return array_ops.gather(array_ops.reshape(t, [-1]), 0) + + targets = [_done(x) for x in nest.flatten(targets)] + sess = session.Session() + with sess: + init = variables.global_variables_initializer() + sess.run(init) + sess.run(targets) + begin = time.time() + for _ in range(iters): + sess.run(targets) + end = time.time() + avg_time_ms = 1000 * (end - begin) / iters + self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) + return avg_time_ms + + def benchmark_fc_batch_jacobian(self): + with ops.Graph().as_default(): + pfor_jacobian, while_jacobian = create_fc_batch_jacobian(100, 32, 20) + self._run(pfor_jacobian, 100, name="fc_batch_jacobian_pfor") + self._run(while_jacobian, 20, name="fc_batch_jacobian_while") + + def benchmark_lstm_batch_jacobian(self): + with ops.Graph().as_default(): + pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(100, 32, 8) + self._run(pfor_jacobian, 100, name="lstm_batch_jacobian_pfor") + self._run(while_jacobian, 20, name="lstm_batch_jacobian_while") + + def benchmark_lstm_hessian(self): + with ops.Graph().as_default(): + pfor_hessian, while_hessian = create_lstm_hessian(2, 2, 10) + self._run(pfor_hessian, 20, name="lstm_hessian_pfor") + self._run(while_hessian, 3, name="lstm_hessian_while_pfor") + + def benchmark_lstm_batch_hessian(self): + with ops.Graph().as_default(): + pfor_hessian, while_hessian = create_lstm_batch_hessian(4, 4, 10) + self._run(pfor_hessian, 100, name="lstm_batch_hessian_pfor") + self._run(while_hessian, 20, name="lstm_batch_hessian_while_pfor") + + def benchmark_fc_per_eg_grad(self): + with ops.Graph().as_default(): + pfor_outputs, while_outputs = create_fc_per_eg_grad(100, 32, 3) + self._run(pfor_outputs, 100, name="fc_per_eg_grad_pfor") + self._run(while_outputs, 20, name="fc_per_eg_grad_while") + + def benchmark_lstm_per_eg_grad(self): + with ops.Graph().as_default(): + pfor_outputs, while_outputs = create_lstm_per_eg_grad(100, 32, 8) + self._run(pfor_outputs, 100, name="lstm_per_eg_grad_pfor") + self._run(while_outputs, 20, name="lstm_per_eg_grad_while") + + def benchmark_mnist_per_eg_grad(self): + with ops.Graph().as_default(): + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + pfor_outputs, while_outputs = create_mnist_per_eg_grad( + 128, data_format, training=True) + self._run(pfor_outputs, 20, name="mnist_per_eg_grad_pfor") + self._run(while_outputs, 20, name="mnist_per_eg_grad_while") + + def benchmark_mnist_per_eg_jacobian(self): + with ops.Graph().as_default(): + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + pfor_outputs, while_outputs = create_mnist_per_eg_jacobian( + 16, data_format, training=True) + self._run(pfor_outputs, 20, name="mnist_per_eg_jacobian_pfor") + self._run(while_outputs, 20, name="mnist_per_eg_jacobian_while") + + def benchmark_fc_per_eg_jacobian(self): + with ops.Graph().as_default(): + jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while = ( + create_fc_per_eg_jacobians(batch_size=128, + activation_size=32, + num_layers=3)) + self._run(jacobians, 30, name="fc_jacobians_pfor") + self._run(per_eg_jacobians_pfor, 100, + name="fc_per_eg_jacobians_pfor") + self._run(per_eg_jacobians_while, 10, + name="fc_per_eg_jacobians_while") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4ef0f1ab58750502d76a8de120cc3c5ea16c99 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -0,0 +1,2552 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Compiled parallel-for loop.""" +# pylint: disable=missing-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from absl import flags + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + +flags.DEFINE_bool( + "op_conversion_fallback_to_while_loop", False, + "If true, falls back to using a while loop for ops for " + "which a converter is not defined.") + + +def _stack(t, length): + """stacks `t` `length` times.""" + ones = array_ops.ones_like(array_ops.shape(t)) + multiples = array_ops.concat([length, ones], 0) + t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) + return wrap(t, True) + + +# The following stateful ops can be safely called once, and with the same +# signature as the unconverted version, if their inputs are loop invariant. +# TODO(agarwal): implement a strategy for converting Variable reads/writes. The +# plan is to map each read/write in the loop_fn to a corresponding merged +# read/write in the converted graph. Writes need to be mergeable (e.g. +# AssignAdd) to be used in `pfor`. Given a certain read/write order in the +# loop_fn, doing a one-to-one conversion will simulate executing such +# instructions in lock-step across all iterations. +passthrough_stateful_ops = set([ + "VariableV2", + "VarHandleOp", + "ReadVariableOp", + "StackV2", + "TensorArrayWriteV3", + "TensorArrayReadV3", + "TensorArraySizeV3", +]) + + +def _is_stateful_pfor_op(op): + if isinstance(op, WhileOp): + return op.is_stateful + if op.type == "Const": + # Const didn't have an op_def. + return False + if op.type in passthrough_stateful_ops: + return False + assert hasattr(op, "op_def") and op.op_def is not None, op + return op.op_def.is_stateful + + +# pylint: disable=protected-access +class WhileOp(object): + """Object for storing state for converting the outputs of a while_loop.""" + + def __init__(self, exit_node, pfor_ops): + """Initializer. + + Args: + exit_node: A tensor output from the while_loop. + pfor_ops: list of ops inside the current pfor loop. + """ + self._pfor_ops = set(pfor_ops) + self._pfor_op_ids = set([x._id for x in pfor_ops]) + assert isinstance(exit_node, ops.Tensor) + self._while_context = exit_node.op._get_control_flow_context() + assert isinstance(self._while_context, control_flow_ops.WhileContext) + self._context_name = self._while_context.name + self._condition = self._while_context.pivot.op.inputs[0] + # Parts of an external while_loop could be created inside a pfor loop. + # However for the purpose here, we declare such loops to be external. Also + # note that we check if the condition was created inside or outside to + # determine if the while_loop was first created inside or outside. + # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. + self._is_inside_loop = self.op_is_inside_loop(self._condition.op) + if self._is_inside_loop: + for e in self._while_context.loop_exits: + assert self.op_is_inside_loop(e.op) + + # Note the code below tries to reverse engineer an existing while_loop graph + # by assuming the following pattern of nodes. + # + # NextIteration <---- Body <--- Enter + # | ^ + # V ___| Y + # Enter -> Merge -> Switch___ + # ^ | N + # | V + # LoopCond Exit + + # Node that elements in the list below correspond one-to-one with each + # other. i.e. these lists are the same size, and the i_th entry corresponds + # to different Operations/Tensors of a single cycle as illustrated above. + # List of Switch ops (ops.Operation) that feed into an Exit Node. + self._exit_switches = [] + # List of inputs (ops.Tensor) to NextIteration. + self._body_outputs = [] + # List of list of control inputs of the NextIteration nodes. + self._next_iter_control_inputs = [] + # List of Merge ops (ops.Operation). + self._enter_merges = [] + # List of output (ops.Tensor) of Exit nodes. + self._outputs = [] + + # List of Enter Tensors. + # There are two types of Enter nodes: + # - The Enter nodes that are used in the `loop_vars` argument to + # `while_loop` (see + # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect + # these Enter nodes immediately below by tracing backwards from the Exit + # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the + # diagram above. This allows us to have a 1:1 correspondence between the + # self._outputs and the first elements in self._enters. + # - The Enter nodes that are used only by the body. They don't appear in the + # `loop_vars` and are not returned from the `while_loop`. In Python code, + # they are usually captured by the body lambda. We collect them below by + # iterating over all the ops in the graph. They are appended to the end of + # self._enters or self._direct_enters, and don't correspond to any outputs + # in self._outputs. Note that we keep the resource/variant Enter nodes in + # self._direct_enters and the constructed while_loop's body uses them + # directly as opposed to passing them as loop variables. This is done + # because the while_body cannot partition the resource/variant Tensors, so + # it has to leave them unchanged. + self._enters = [] + self._direct_enters = [] + + for e in self._while_context.loop_exits: + self._outputs.append(e.op.outputs[0]) + switch = e.op.inputs[0].op + assert switch.type == "Switch", switch + self._exit_switches.append(switch) + merge = switch.inputs[0].op + assert merge.type == "Merge", merge + self._enter_merges.append(merge) + enter = merge.inputs[0].op + assert enter.type == "Enter", enter + self._enters.append(enter.outputs[0]) + next_iter = merge.inputs[1].op + assert next_iter.type == "NextIteration", next_iter + self._body_outputs.append(next_iter.inputs[0]) + self._next_iter_control_inputs.append(next_iter.control_inputs) + + # Collect all the Enter nodes that are not part of `loop_vars`, the second + # category described above. + # Also track whether the loop body has any stateful ops. + self._is_stateful = False + for op in ops.get_default_graph().get_operations(): + # TODO(agarwal): make sure this works with nested case. + control_flow_context = op._get_control_flow_context() + if control_flow_context is None: + continue + if control_flow_context.name == self._context_name: + self._is_stateful |= _is_stateful_pfor_op(op) + if op.type == "Enter": + output = op.outputs[0] + if output not in self._enters: + if output.dtype in (dtypes.resource, dtypes.variant): + if output not in self._direct_enters: + self._direct_enters.append(output) + else: + self._enters.append(output) + + def __str__(self): + """String representation.""" + return "while_loop(%s)" % self.name + + @property + def inputs(self): + """Input to all the Enter nodes.""" + return [x.op.inputs[0] for x in self._enters + self._direct_enters] + + @property + def control_inputs(self): + """Control input to all the Enter nodes.""" + control_inputs = [] + for x in self._enters + self._direct_enters: + control_inputs.extend(x.op.control_inputs) + return control_inputs + + @property + def outputs(self): + """Outputs of all the Exit nodes.""" + return self._outputs + + @property + def name(self): + """Context name for the while loop.""" + return self._context_name + + @property + def is_inside_loop(self): + """Returns true if the while_loop was created inside the pfor.""" + return self._is_inside_loop + + def op_is_inside_loop(self, op): + """True if op was created inside the pfor loop body.""" + assert isinstance(op, ops.Operation) + # Note that we use self._pfor_op_ids for the check and not self._pfor_ops + # since it appears there tensorflow API could return different python + # objects representing the same Operation node. + return op._id in self._pfor_op_ids + + @property + def is_stateful(self): + return self._is_stateful + + @property + def pfor_converter(self): + """Return a converter for the while loop.""" + return self + + def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, + inputs_stacked): + """Create a PFor object for converting parts of the while_loop. + + Args: + parent_pfor: PFor object being used for converting the while_loop. + indices: int32 Tensor of ids for the iterations that are still active + (i.e. did not exit the while_loop). + cond_stacked: True if the while_loop condition is stacked. + inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note + that these Tensors are a subset of the loop variables for the generated + while_loop. + inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, + indicating if the value is stacked or not. + + Returns: + A PFor instance. The instance is initialized by adding conversion mappings + of nodes that will be external to the conversion that the returned + instance will be used for. e.g. Enter nodes as well as Merge and Switch + outputs are mapped to converted values. + """ + num_outputs = len(self._outputs) + assert len(inputs) == len(self._enters) + assert len(inputs_stacked) == len(self._enters) + loop_var = parent_pfor.loop_var + loop_len = array_ops.size(indices) + pfor = PFor( + loop_var, + loop_len, + pfor_ops=self._pfor_ops, + all_indices=indices, + all_indices_partitioned=cond_stacked) + # Map all inputs of Enter nodes in self._direct_enters to their converted + # values. + for enter in self._direct_enters: + enter_input = enter.op.inputs[0] + converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( + enter_input) + # Since these are resources / variants, they should be unstacked. + assert not stacked and not is_sparse_stacked, (enter, converted_enter) + pfor._add_conversion(enter, wrap(converted_enter, False)) + + # Map all Enter nodes to the inputs. + for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): + pfor._add_conversion(enter, wrap(inp, stacked)) + # Map outputs of Switch and Merge. + for i in range(num_outputs): + wrapped_inp = wrap(inputs[i], inputs_stacked[i]) + merge = self._enter_merges[i] + pfor._add_conversion(merge.outputs[0], wrapped_inp) + # Note that second output of Merge is typically not used, except possibly + # as a control dependency. To avoid trying to output the correct value, we + # employ a hack here. We output a dummy invalid value with an incorrect + # dtype. This will allow control dependency to work but if using it as an + # input, it should typically lead to errors during graph construction due + # to dtype mismatch. + # TODO(agarwal): Check in the original graph to see if there are any + # consumers of this Tensor that use it as an input. + pfor._add_conversion(merge.outputs[1], + wrap(constant_op.constant(-1.0), False)) + switch = self._exit_switches[i] + # Don't need to worry about switch.output[0] which will feed to Exit node. + pfor._add_conversion(switch.outputs[1], wrapped_inp) + return pfor + + def _convert_enter(self, parent_pfor, enter): + """Converts an Enter node.""" + inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) + control_inputs = [ + parent_pfor._convert_helper(x).t for x in enter.op.control_inputs + ] + if control_inputs: + with ops.control_dependencies(control_inputs): + inp = array_ops.identity(inp) + return inp, stacked + + def _maybe_stacked(self, cache, inp): + """Heuristic to figue out if the coverting inp leads to a stacked value. + + + Args: + cache: map from Tensor to boolean indicating stacked/unstacked. + inp: input Tensor. + + Returns: + True if `inp` could get stacked. If the function returns False, the + converted value should be guaranteed to be unstacked. If returning True, + it may or may not be stacked. + """ + if inp in cache: + return cache[inp] + if not self.op_is_inside_loop(inp.op): + return False + op = inp.op + output = False + if op.type in [ + "Shape", + "Rank" + "ShapeN", + "ZerosLike", + "TensorArrayV3", + "TensorArraySizeV3", + ]: + output = False + elif _is_stateful_pfor_op(op): + # This may be fairly aggressive. + output = True + elif op.type == "Exit": + # This may be fairly aggressive. + output = True + else: + for t in op.inputs: + if self._maybe_stacked(cache, t): + output = True + break + cache[inp] = output + return output + + def _create_init_values(self, pfor_input): + """Create arguments passed to converted while_loop.""" + with ops.name_scope("while_init"): + loop_len_vector = pfor_input.pfor.loop_len_vector + loop_len = loop_len_vector[0] + num_outputs = len(self._outputs) + + inputs = [] + maybe_stacked_cache = {} + # Convert all the Enters. Need to do this before checking for stacking + # below. + for i, enter in enumerate(self._enters): + inp, stacked = self._convert_enter(pfor_input.pfor, enter) + inputs.append(inp) + maybe_stacked_cache[enter] = stacked + # Since this enter node is part of the `loop_vars`, it corresponds to an + # output and its preceding switch. We mark this switch's output the same + # stackness, to act at the base case for the logic below. Below, we will + # be going through the body figuring out which inputs might need to be + # stacked and which inputs can safely remain unstacked. + if i < num_outputs: + maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked + + # Shape invariants for init_values corresponding to self._enters. + input_shape_invariants = [] + # TensorArrays for outputs of converted while loop + output_tas = [] + # Shape invariants for output TensorArrays. + ta_shape_invariants = [] + # List of booleans indicating stackness of inputs, i.e. tensors + # corresponding to self._enters. + inputs_stacked = [] + for i, inp in enumerate(inputs): + enter = self._enters[i] + inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) + # Note that even when an input is unstacked, the body could make it + # stacked. we use a heuristic below to figure out if body may be making + # it stacked. + if i < num_outputs: + body_output = self._body_outputs[i] + if enter.op in self._pfor_ops: + body_output_stacked = self._maybe_stacked(maybe_stacked_cache, + body_output) + else: + # If constructed outside of pfor loop, then the output would not be + # stacked. + body_output_stacked = False + if body_output_stacked and not inp_stacked: + inp = _stack(inp, loop_len_vector).t + inputs[i] = inp + inp_stacked = True + # TODO(agarwal): other attributes for the TensorArray ? + output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) + ta_shape_invariants.append(tensor_shape.TensorShape(None)) + + inputs_stacked.append(inp_stacked) + input_shape_invariants.append(tensor_shape.TensorShape(None)) + + # See documentation for __call__ for the structure of init_values. + init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas + # TODO(agarwal): try stricter shape invariants + shape_invariants = ( + [tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None) + ] + input_shape_invariants + ta_shape_invariants) + + return init_values, inputs_stacked, shape_invariants + + def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): + """Handles case when condition is unstacked. + + Note that all iterations end together. So we don't need to partition the + inputs. When all iterations are done, we write the inputs to the + TensorArrays. Note that we only write to index 0 of output_tas. Since all + iterations end together, they can all be output together. + """ + not_all_done = array_ops.reshape(conditions, []) + new_output_tas = [] + # pylint: disable=cell-var-from-loop + for i, out_ta in enumerate(output_tas): + inp = inputs[i] + new_output_tas.append( + control_flow_ops.cond(not_all_done, + lambda: out_ta, + lambda: out_ta.write(0, inp))) + # pylint: enable=cell-var-from-loop + return not_all_done, indices, inputs, new_output_tas + + def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, + output_tas): + num_outputs = len(self._outputs) + # Compute if all iterations are done. + not_all_done = math_ops.reduce_any(conditions) + conditions_int = math_ops.cast(conditions, dtypes.int32) + # Partition the indices. + done_indices, new_indices = data_flow_ops.dynamic_partition( + indices, conditions_int, 2) + + new_inputs = [] + new_output_tas = [] + for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): + # Partition the inputs. + if stacked: + done_inp, new_inp = data_flow_ops.dynamic_partition( + inp, conditions_int, 2) + else: + # TODO(agarwal): avoid this stacking. See TODO earlier in + # _process_cond_unstacked. + done_inp = _stack(inp, [array_ops.size(done_indices)]).t + new_inp = inp + new_inputs.append(new_inp) + # For iterations that are done, write them to TensorArrays. + if i < num_outputs: + out_ta = output_tas[i] + # Note that done_indices can be empty. done_inp should also be empty in + # that case. + new_output_tas.append(out_ta.scatter(done_indices, done_inp)) + return not_all_done, new_indices, new_inputs, new_output_tas + + def _process_body(self, pfor_input, inputs_stacked, + new_indices, cond_stacked, new_inputs, + not_all_done): + """Convert the body function.""" + + def true_fn(control_inputs, body_pfor, body_output, stacked): + """Converts the body function for all but last iteration. + + This essentially converts body_output. Additionally, it needs to handle + any control dependencies on the NextIteration node. So it creates another + Identity node with the converted dependencies. + """ + converted_control_inp = [] + for x in control_inputs: + for t in x.outputs: + converted_control_inp.append(body_pfor._convert_helper(t).t) + if stacked: + # Note convert always does the stacking. + output = body_pfor.convert(body_output) + else: + output, convert_stacked, _ = body_pfor._convert_helper(body_output) + assert convert_stacked == stacked, body_output + with ops.control_dependencies(converted_control_inp): + return array_ops.identity(output) + + body_pfor = self._init_pfor(pfor_input.pfor, new_indices, + cond_stacked, new_inputs, + inputs_stacked) + new_outputs = [] + + for i, (body_output, stacked) in enumerate( + zip(self._body_outputs, inputs_stacked)): + control_inp = self._next_iter_control_inputs[i] + out_dtype = body_output.dtype + # Note that we want to run the body only if not all pfor iterations are + # done. If all are done, we return empty tensors since these values will + # not be used. Notice that the value returned by the loop is based on + # TensorArrays and not directly on these returned values. + # pylint: disable=cell-var-from-loop + new_output = control_flow_ops.cond( + not_all_done, + lambda: true_fn(control_inp, body_pfor, body_output, stacked), + lambda: constant_op.constant([], dtype=out_dtype)) + # pylint: enable=cell-var-from-loop + new_outputs.append(new_output) + return new_outputs + + def __call__(self, pfor_input): + """Converter for the while_loop. + + The conversion of a while_loop is another while_loop. + + The arguments to this converted while_loop are as follows: + not_all_done: Boolean scalar Tensor indicating if all the pfor iterations + are done. + indices: int32 1-D Tensor storing the id of the iterations that are not + done. + args: Remaining arguments. These can be divided into 3 categories: + - First set of arguments are the tensors that correspond to the initial + elements of self._enters. The elements that appear in original while + loop's `loop_vars`. + - The second set of arguments are the tensors that correspond to the + remaining elements of self._enters. These are the tensors that directly + enter the original while loop body. + - Finally, the last set of arguments are TensorArrays. These TensorArrays + correspond to the outputs of the original while_loop, i.e. to the + elements in self._outputs. Each TensorArray has `PFor.loop_len` + elements, i.e. the number of pfor iterations. At the end, the i'th + element of each TensorArray will contain the output computed by the + i'th iteration of pfor. Note that elements can be written into these + tensors arrays in any order, depending on when the corresponding pfor + iteration is done. + If the original while_loop had `k` tensors in its `loop_vars` and its body + directly captured `m` tensors, the `args` will contain `2 * k + m` values. + + In each iteration, the while_loop body recomputes the condition for all + active pfor iterations to see which of them are now done. It then partitions + all the inputs and passes them along to the converted body. Values for all + the iterations that are done are written to TensorArrays indexed by the pfor + iteration number. When all iterations are done, the TensorArrays are stacked + to get the final value. + + Args: + pfor_input: A PForInput object corresponding to the output of any Exit + node from this while loop. + + Returns: + List of converted outputs. + """ + # Create init_values that will be passed to the while_loop. + init_values, inputs_stacked, shape_invariants = self._create_init_values( + pfor_input) + # Note that we use a list as a hack since we need the nested function body + # to set the value of cond_is_stacked. python2.x doesn't support nonlocal + # variables. + cond_is_stacked = [None] + + def cond(not_all_done, *_): + return not_all_done + + def body(not_all_done, indices, *args): + # See documentatin for __call__ for the structure of *args. + num_enters = len(self._enters) + inputs = args[:num_enters] + output_tas = args[num_enters:] + # TODO(agarwal): see which outputs have consumers and only populate the + # TensorArrays corresonding to those. Or do those paths get trimmed out + # from inside the while_loop body? + assert len(inputs) >= len(output_tas) + assert len(inputs) == len(inputs_stacked) + + # Convert condition + with ops.name_scope("while_cond"): + # Note that we set cond_stacked to True here. At this point we don't + # know if it could be loop invariant, hence the conservative value is + # to assume stacked. + cond_pfor = self._init_pfor(pfor_input.pfor, indices, + cond_stacked=True, + inputs=inputs, + inputs_stacked=inputs_stacked) + conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) + cond_is_stacked[0] = cond_stacked + + # Recompute the new condition, write outputs of done iterations, and + # partition the inputs if needed. + if not cond_stacked: + (not_all_done, new_indices, + new_inputs, new_output_tas) = self._process_cond_unstacked( + conditions, indices, inputs, output_tas) + else: + (not_all_done, new_indices, + new_inputs, new_output_tas) = self._process_cond_stacked( + conditions, indices, inputs, inputs_stacked, output_tas) + + # Convert body + with ops.name_scope("while_body"): + # Compute the outputs from the body. + new_outputs = self._process_body(pfor_input, inputs_stacked, + new_indices, cond_stacked, new_inputs, + not_all_done) + + # Note that the first num_outputs new values of inputs are computed using + # the body. Rest of them were direct Enters into the condition/body and + # the partitioning done earlier is sufficient to give the new value. + num_outputs = len(self._outputs) + new_args = ([not_all_done, new_indices] + new_outputs + list( + new_inputs[num_outputs:]) + new_output_tas) + return tuple(new_args) + + while_outputs = control_flow_ops.while_loop( + cond, body, init_values, shape_invariants=shape_invariants) + output_tas = while_outputs[-len(self._outputs):] + outputs = [] + assert cond_is_stacked[0] is not None + for inp_stacked, ta in zip(inputs_stacked, output_tas): + if cond_is_stacked[0]: + outputs.append(wrap(ta.stack(), True)) + else: + # Note that if while_loop condition is unstacked, all iterations exit at + # the same time and we wrote those outputs in index 0 of the tensor + # array. + outputs.append(wrap(ta.read(0), inp_stacked)) + return outputs + + +class _PforInput(object): + """Input object passed to registered pfor converters.""" + + def __init__(self, pfor, op, inputs): + """Creates a _PforInput object. + + Args: + pfor: PFor converter object. + op: the Operation object that is being converted. + inputs: list of WrappedTensor objects representing converted values of the + inputs of `op`. + """ + self.pfor = pfor + self._op = op + self._inputs = inputs + + def stack_inputs(self, stack_indices=None): + """Stacks unstacked inputs at `stack_indices`. + + Args: + stack_indices: indices of inputs at which stacking is done. If None, + stacking is done at all indices. + """ + if stack_indices is None: + stack_indices = range(len(self._inputs)) + length = self.pfor.loop_len_vector + for i in stack_indices: + inp = self._inputs[i] + if not inp.is_stacked: + self._inputs[i] = _stack(inp.t, length) + + def expanddim_inputs_for_broadcast(self): + """Reshapes stacked inputs to prepare them for broadcast. + + Since stacked inputs have an extra leading dimension, automatic broadcasting + rules could incorrectly try to expand dimensions before that leading + dimension. To avoid that, we reshape these stacked inputs to the maximum + rank they will need to be broadcasted to. + """ + if not self._inputs: + return + + # Find max rank + def _get_rank(x): + rank = array_ops.rank(x.t) + if not x.is_stacked: + rank += 1 + return rank + + ranks = [_get_rank(x) for x in self._inputs] + max_rank = ranks[0] + for rank in ranks[1:]: + max_rank = math_ops.maximum(rank, max_rank) + + for i, inp in enumerate(self._inputs): + if inp.is_stacked: + shape = array_ops.shape(inp.t) + rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) + ones = array_ops.tile([1], rank_diff) + new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) + self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) + + @property + def inputs(self): + return self._inputs + + @property + def num_inputs(self): + return len(self._inputs) + + def input(self, index): + assert len(self._inputs) > index, (index, self._inputs) + return self._inputs[index] + + def stacked_input(self, index): + t, is_stacked, _ = self.input(index) + if not is_stacked: + op_type = self.op_type + op_def = getattr(self._op, "op_def", None) + if op_def is None: + input_name = "at index %d" % index + else: + input_name = "\"%s\"" % op_def.input_arg[index].name + raise ValueError("Input %s of op \"%s\" expected to be not loop invariant" + ".\nError while converting op %s" + "with converted inputs\n%s" % (input_name, op_type, + self._op, self.inputs)) + return t + + def unstacked_input(self, index): + t, is_stacked, _ = self.input(index) + if is_stacked: + op_type = self.op_type + op_def = getattr(self._op, "op_def", None) + if op_def is None: + input_name = "at index %d" % index + else: + input_name = "\"%s\"" % op_def.input_arg[index].name + raise ValueError("Input %s of op \"%s\" expected to be loop invariant" + ".\nError while converting op %s" + "with converted inputs\n%s" % (input_name, op_type, + self._op, self.inputs)) + return t + + @property + def op(self): + return self._op + + @property + def op_type(self): + return self._op.type + + def get_attr(self, attr): + return self._op.get_attr(attr) + + @property + def outputs(self): + return self._op.outputs + + def output(self, index): + assert index < len(self._op.outputs) + return self._op.outputs[index] + + +_pfor_converter_registry = {} + + +class RegisterPFor(object): + """Utility to register converters for pfor. + + Usage: + @RegisterPFor(foo_op_type) + def _foo_converter(pfor_input): + ... + + The above will register conversion function `_foo_converter` for handling + conversion of `foo_op_type`. During conversion, the registered functin will be + called with a single argument of type `PForInput` which will contain state + needed for the conversion. This registered function should output a list of + WrappedTensor object with the same length as the number of outputs of op being + converted. If the op had zero outputs, then it should return a ops.Operation + object. + """ + + def __init__(self, op_type): + """Creates an object to register a converter for op with type `op_type`.""" + self.op_type = op_type + + def __call__(self, converter): + name = self.op_type + assert name not in _pfor_converter_registry, "Re-registering %s " % name + _pfor_converter_registry[name] = converter + return converter + + +class RegisterPForWithArgs(RegisterPFor): + """Utility to register converters for pfor. + + Usage: + @RegisteRPFor(foo_op_type, foo=value, ....) + def _foo_converter(pfor_input, foo=None, ....): + ... + + See RegisterPFor for details on the conversion function. + `RegisterPForWithArgs` allows binding extra arguments to the + conversion function at registration time. + """ + + def __init__(self, op_type, *args, **kw_args): + super(RegisterPForWithArgs, self).__init__(op_type) + self._args = args + self._kw_args = kw_args + + def __call__(self, converter): + + def _f(pfor_input): + return converter(pfor_input, self.op_type, *self._args, **self._kw_args) + + super(RegisterPForWithArgs, self).__call__(_f) + return converter + + +def _create_op(op_type, inputs, op_dtypes, attrs=None): + """Utility to create an op.""" + return ops.get_default_graph().create_op( + op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) + + +WrappedTensor = collections.namedtuple("WrappedTensor", + ["t", "is_stacked", "is_sparse_stacked"]) +"""Wrapper around the result of a Tensor conversion. + +The additional fields are useful for keeping track of the conversion state as +data flows through the ops in the loop body. For every op whose output is a +Tensor, its converter should return either a WrappedTensor or a list of +WrappedTensors. + +Args: + t: The converted tensor + is_stacked: True if the tensor is stacked, i.e. represents the results of all + the iterations of the loop, where each row i of the tensor corresponds to + that op's output on iteration i of the loop. False if the tensor is not + stacked, i.e. represents the result of the op on of a single iteration of + the loop, where the result does not vary between iterations. + is_sparse_stacked: True if the tensor corresponds to a component tensor + (indices, values, or dense_shape) of a sparse tensor, and has been logically + stacked via a sparse conversion. +""" + + +def wrap(tensor, is_stacked=True, is_sparse_stacked=False): + """Helper to create a WrappedTensor object.""" + assert isinstance(is_stacked, bool) + assert isinstance(is_sparse_stacked, bool) + assert isinstance(tensor, ops.Tensor) + assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " + "stacked via a sparse " + "conversion, it must also be " + "stacked.") + return WrappedTensor(tensor, is_stacked, is_sparse_stacked) + + +def _fallback_converter(pfor_input): + logging.warn("Using a while_loop for converting %s", pfor_input.op_type) + output_dtypes = [x.dtype for x in pfor_input.outputs] + iters = pfor_input.pfor.loop_len_vector[0] + + def while_body(i, *ta_list): + """Body of while loop.""" + inputs = [ + x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs + ] + op_outputs = _create_op( + pfor_input.op_type, + inputs, + output_dtypes, + attrs=pfor_input.op.node_def.attr).outputs + + outputs = [] + for out, ta in zip(op_outputs, ta_list): + assert isinstance(out, ops.Tensor) + outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) + return tuple([i + 1] + outputs) + + ta_list = control_flow_ops.while_loop( + lambda i, *ta: i < iters, while_body, [0] + [ + tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes + ])[1:] + return tuple([wrap(ta.concat(), True) for ta in ta_list]) + + +class PFor(object): + """Implementation of rewrite of parallel-for loops. + + This class takes a DAG or a set of DAGs representing the body of a + parallel-for loop, and adds new operations to the graph that implements + functionality equivalent to running that loop body for a specified number of + iterations. This new set of nodes may or may not use a tensorflow loop + construct. + + The process of conversion does not delete or change any existing operations. + It only adds operations that efficiently implement the equivalent + functionality. We refer to the added ops as "converted ops". + + The conversion process uses a simple greedy heuristic. It walks the loop body + and tries to express the functionality of running each node in a loop with a + new set of nodes. When converting an op several cases are possible: + - The op is not inside the loop body. Hence it can be used as is. + - The op does not depend on the iteration number and is stateless. In this + case, it can be used as is. + - The op is not stateful, and depends on iteration number only through control + dependencies. In this case, we can create a single op with same inputs and + attributes, but with "converted" control dependencies. + - The op is not stateful, and all its inputs are loop invariant. In this + case, similar to above, we can create a single op with same inputs and + attributes, but with "converted" control dependencies. + - The op is stateful or at least one of the inputs is not loop invariant. In + this case, we run the registered converter for that op to create a set of + converted ops. All nodes in the set will have converted control dependencies + corresponding to control dependencies of the original op. If the op returned + multiple outputs, "converted outputs" could be produced by different ops in + this set. + """ + + def __init__(self, + loop_var, + loop_len, + pfor_ops, + all_indices=None, + all_indices_partitioned=False): + """Creates an object to rewrite a parallel-for loop. + + Args: + loop_var: ops.Tensor output of a Placeholder operation. The value should + be an int32 scalar representing the loop iteration number. + loop_len: A scalar or scalar Tensor representing the number of iterations + the loop is run for. + pfor_ops: List of all ops inside the loop body. + all_indices: If not None, an int32 vector with size `loop_len` + representing the iteration ids that are still active. These values + should be unique and sorted. However they may not be contiguous. This is + typically the case when inside a control flow construct which has + partitioned the indices of the iterations that are being converted. + all_indices_partitioned: If True, this object is being constructed from a + control flow construct where not all the pfor iterations are guaranteed + to be active. + """ + assert isinstance(loop_var, ops.Tensor) + assert loop_var.op.type == "Placeholder" + self._loop_var = loop_var + loop_len_value = tensor_util.constant_value(loop_len) + if loop_len_value is not None: + loop_len = loop_len_value + self._loop_len_vector = array_ops.reshape(loop_len, [1]) + self._all_indices_partitioned = all_indices_partitioned + if all_indices_partitioned: + assert all_indices is not None + self.all_indices = ( + math_ops.range(loop_len) if all_indices is None else all_indices) + + self._conversion_map = {} + self._conversion_map[loop_var] = wrap(self.all_indices, True) + self._pfor_ops = set(pfor_ops) + self._pfor_op_ids = set([x._id for x in pfor_ops]) + + def op_is_inside_loop(self, op): + """True if op was created inside the pfor loop body.""" + assert isinstance(op, ops.Operation) + # Note that we use self._pfor_op_ids for the check and not self._pfor_ops + # since it appears there tensorflow API could return different python + # objects representing the same Operation node. + return op._id in self._pfor_op_ids + + def _convert_sparse(self, y): + """Returns the converted value corresponding to SparseTensor y. + + For SparseTensors, instead of stacking the component tensors separately, + resulting in component tensors with shapes (N, m, rank), (N, m), and (N, + rank) respectively for indices, values, and dense_shape (where N is the loop + length and m is the number of sparse tensor values per loop iter), we want + to logically stack the SparseTensors, to create a SparseTensor whose + components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) + respectively. + + Here, we try to get the conversion of each component tensor. + If the tensors are stacked via a sparse conversion, return the resulting + SparseTensor composed of the converted components. Otherwise, the component + tensors are either unstacked or stacked naively. In the latter case, we + unstack the component tensors to reform loop_len SparseTensor elements, + then correctly batch them. + + The unstacked tensors must have the same rank. Each dimension of each + SparseTensor will expand to be the largest among all SparseTensor elements + for that dimension. For example, if there are N SparseTensors of rank 3 + being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), + the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). + + Args: + y: A tf.SparseTensor. + + Returns: + A tf.SparseTensor that is the converted value corresponding to y. + """ + outputs = [ + self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) + ] + assert all(isinstance(o, WrappedTensor) for o in outputs) + + if all(w.is_sparse_stacked for w in outputs): + return sparse_tensor.SparseTensor(*[w.t for w in outputs]) + + assert not any(w.is_sparse_stacked for w in outputs), ( + "Error converting SparseTensor. All components should be logically " + "stacked, or none.") + + # If component tensors were not sparsely stacked, they are either unstacked + # or stacked without knowledge that they are components of sparse tensors. + # In this case, we have to restack them. + return self._restack_sparse_tensor_logically( + *[self._unwrap_or_tile(w) for w in outputs]) + + def _restack_sparse_tensor_logically(self, indices, values, shape): + sparse_tensor_rank = indices.get_shape()[-1].value + if sparse_tensor_rank is not None: + sparse_tensor_rank += 1 + + def map_fn(args): + res = gen_sparse_ops.serialize_sparse( + args[0], args[1], args[2], out_type=dtypes.variant) + return res + + # Applies a map function to the component tensors to serialize each + # sparse tensor element and batch them all, then deserializes the batch. + # TODO(rachelim): Try to do this without map_fn -- add the right offsets + # to shape and indices tensors instead. + result = functional_ops.map_fn( + map_fn, [indices, values, shape], dtype=dtypes.variant) + return sparse_ops.deserialize_sparse( + result, dtype=values.dtype, rank=sparse_tensor_rank) + + def _unwrap_or_tile(self, wrapped_tensor): + """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" + output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked + if is_stacked: + return output + else: + return _stack(output, self._loop_len_vector).t + + def convert(self, y): + """Returns the converted value corresponding to y. + + Args: + y: A ops.Tensor or a ops.Operation object. If latter, y should not have + any outputs. + + Returns: + If y does not need to be converted, it returns y as is. Else it returns + the "converted value" corresponding to y. + """ + if isinstance(y, sparse_tensor.SparseTensor): + return self._convert_sparse(y) + output = self._convert_helper(y) + if isinstance(output, WrappedTensor): + assert isinstance(y, ops.Tensor) + return self._unwrap_or_tile(output) + else: + assert isinstance(y, ops.Operation) + assert not y.outputs + assert isinstance(output, ops.Operation) + return output + + def _was_converted(self, t): + """True if t is not a conversion of itself.""" + converted_t = self._conversion_map[t] + return converted_t.t is not t + + def _add_conversion(self, old_output, new_output): + self._conversion_map[old_output] = new_output + + def _convert_helper(self, op_or_tensor): + stack = [op_or_tensor] + while stack: + y = stack[0] + if y in self._conversion_map: + assert isinstance(self._conversion_map[y], + (WrappedTensor, ops.Operation)) + stack.pop(0) + continue + if isinstance(y, ops.Operation): + assert not y.outputs, ( + "We only support converting Operation objects with no outputs. " + "Got %s", y) + y_op = y + else: + assert isinstance(y, ops.Tensor), y + y_op = y.op + + is_while_loop = y_op.type == "Exit" + if is_while_loop: + while_op = WhileOp(y, pfor_ops=self._pfor_ops) + is_inside_loop = while_op.is_inside_loop + # If all nodes in the while_loop graph were created inside the pfor, we + # treat the whole loop subgraph as a single op (y_op) and try to convert + # it. For while_loops that are created completely or partially outside, + # we treat them as external and should be able to simply return the Exit + # node output as is without needing any conversion. Note that for + # while_loops that are partially constructed inside, we assume they will + # be loop invariant. If that is not the case, it will create runtime + # errors since the converted graph would depend on the self._loop_var + # placeholder. + if is_inside_loop: + y_op = while_op + else: + is_inside_loop = self.op_is_inside_loop(y_op) + + # If this op was not created inside the loop body, we will return as is. + # 1. Convert inputs and control inputs. + + def _add_to_stack(x): + if x not in self._conversion_map: + stack.insert(0, x) + return True + else: + return False + + if is_inside_loop: + added_to_stack = False + for inp in y_op.inputs: + added_to_stack |= _add_to_stack(inp) + for cinp in y_op.control_inputs: + if cinp.outputs: + for t in cinp.outputs: + added_to_stack |= _add_to_stack(t) + else: + added_to_stack |= _add_to_stack(cinp) + if added_to_stack: + continue + + converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] + some_input_converted = any( + [self._was_converted(x) for x in y_op.inputs]) + some_input_stacked = any([x.is_stacked for x in converted_inputs]) + + converted_control_ops = set() + some_control_input_converted = False + for cinp in y_op.control_inputs: + if cinp.outputs: + for t in cinp.outputs: + converted_t = self._conversion_map[t] + if self._was_converted(t): + some_control_input_converted = True + converted_control_ops.add(converted_t.t.op) + else: + converted_cinp = self._conversion_map[cinp] + assert isinstance(converted_cinp, ops.Operation) + if converted_cinp != cinp: + some_control_input_converted = True + converted_control_ops.add(converted_cinp) + converted_control_ops = list(converted_control_ops) + is_stateful = _is_stateful_pfor_op(y_op) + else: + converted_inputs = [] + converted_control_ops = [] + logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, + converted_inputs, converted_control_ops) + + # 2. Convert y_op + # If converting a while_loop, we let the while_loop convertor deal with + # putting the control dependencies appropriately. + control_dependencies = [] if is_while_loop else converted_control_ops + with ops.control_dependencies(control_dependencies), ops.name_scope( + y_op.name + "/pfor/"): + # None of the inputs and control inputs were converted. + if (not is_inside_loop or + (not is_stateful and not some_input_converted and + not some_control_input_converted)): + if y == y_op: + assert not isinstance(y_op, WhileOp) + new_outputs = y_op + else: + new_outputs = [wrap(x, False) for x in y_op.outputs] + elif not (is_stateful or is_while_loop or some_input_stacked): + # All inputs are unstacked or uncoverted but some control inputs are + # converted. + # TODO(rachelim): Handle the case where some inputs are sparsely + # stacked (i.e. any([x.is_sparse_stacked for x in converted_inputs])) + new_op = _create_op(y_op.type, [x.t for x in converted_inputs], + [x.dtype for x in y_op.outputs], + y_op.node_def.attr) + if y == y_op: + new_outputs = new_op + else: + new_outputs = [wrap(x, False) for x in new_op.outputs] + else: + # Either some inputs are not loop invariant or op is stateful. + if hasattr(y_op, "pfor_converter"): + converter = y_op.pfor_converter + else: + converter = _pfor_converter_registry.get(y_op.type, None) + if converter is None: + if flags.FLAGS.op_conversion_fallback_to_while_loop: + converter = _fallback_converter + else: + raise ValueError( + "No converter defined for %s\n%s\ninputs: %s. " + "\nEither add a converter or set " + "--op_conversion_fallback_to_while_loop=True, " + "which may run slower" % (y_op.type, y_op, converted_inputs)) + # TODO(rachelim): Handle the case where some inputs are sparsely + # stacked. We should only call the converter if it supports handling + # those inputs. + new_outputs = converter(_PforInput(self, y_op, converted_inputs)) + if isinstance(new_outputs, WrappedTensor): + new_outputs = [new_outputs] + assert isinstance(new_outputs, + (list, tuple, ops.Operation)), new_outputs + logging.vlog(2, "converted %s %s", y_op, new_outputs) + + # Insert into self._conversion_map + if y == y_op: + assert isinstance(new_outputs, ops.Operation) + self._add_conversion(y_op, new_outputs) + else: + for old_output, new_output in zip(y_op.outputs, new_outputs): + assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) + self._add_conversion(old_output, new_output) + stack.pop(0) + + return self._conversion_map[op_or_tensor] + + @property + def loop_len_vector(self): + """Returns a single element vector whose value is number of iterations.""" + return self._loop_len_vector + + @property + def loop_var(self): + """Returns placeholder loop variable.""" + return self._loop_var + + @property + def pfor_ops(self): + return self._pfor_ops + + @property + def all_indices_partitioned(self): + """all_indices_partitioned property. + + Returns: + True if we are inside a control flow construct and not all pfor iterations + may be active. + """ + return self._all_indices_partitioned + +# nn_ops + + +def _flatten_first_two_dims(x): + """Merges first two dimensions.""" + old_shape = array_ops.shape(x) + new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) + return array_ops.reshape(x, new_shape) + + +def _unflatten_first_dim(x, first_dim): + """Splits first dimension into [first_dim, -1].""" + old_shape = array_ops.shape(x) + new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) + return array_ops.reshape(x, new_shape) + + +def _inputs_with_flattening(pfor_input, input_indices): + """Stacks and flattens first dim of inputs at indices `input_indices`.""" + if input_indices is None: + input_indices = [] + pfor_input.stack_inputs(stack_indices=input_indices) + inputs = [] + for i in range(pfor_input.num_inputs): + if i in input_indices: + inp = pfor_input.stacked_input(i) + inp = _flatten_first_two_dims(inp) + else: + inp = pfor_input.unstacked_input(i) + inputs.append(inp) + return inputs + + +@RegisterPForWithArgs("Conv2D", dims=[0]) +@RegisterPForWithArgs("AvgPool", dims=[0]) +@RegisterPForWithArgs("MaxPool", dims=[0]) +@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) +def _convert_flatten_batch(pfor_input, op_type, dims): + del op_type + inputs = _inputs_with_flattening(pfor_input, dims) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs] + return [wrap(x, True) for x in outputs] + + +_channel_flatten_input_cache = {} + + +def _channel_flatten_input(x, data_format): + """Merge the stack dimension with the channel dimension. + + If S is pfor's stacking dimension, then, + - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose + should be cheap. + - for SNHWC, we transpose to NHWCS. + We then merge the S and C dimension. + + Args: + x: ops.Tensor to transform. + data_format: "NCHW" or "NHWC". + + Returns: + A 3-element tuple with the transformed value, along with the shape for + reshape and order for transpose required to transform back. + """ + + graph = ops.get_default_graph() + cache_key = (graph, x, data_format) + if cache_key not in _channel_flatten_input_cache: + x_shape = array_ops.shape(x) + if data_format == b"NCHW": + order = [1, 0, 2, 3, 4] + shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) + reverse_order = order + else: + order = [1, 2, 3, 0, 4] + shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) + reverse_order = [3, 0, 1, 2, 4] + # Move S dimension next to C dimension. + x = array_ops.transpose(x, order) + reverse_shape = array_ops.shape(x) + # Reshape to merge the S and C dimension. + x = array_ops.reshape(x, shape) + outputs = x, reverse_order, reverse_shape + _channel_flatten_input_cache[cache_key] = outputs + else: + outputs = _channel_flatten_input_cache[cache_key] + return outputs + + +# Note that with training=True, running FusedBatchNorm on individual examples +# is very different from running FusedBatchNorm on a batch of those examples. +# This is because, for the latter case, the operation can be considered as first +# computing the mean and variance over all the examples and then using these +# to scale all those examples. This creates a data dependency between these +# different "iterations" since the inputs to the scaling step depends on the +# statistics coming from all these inputs. +# As with other kernels, the conversion here effectively runs the kernel +# independently for each iteration, and returns outputs by stacking outputs from +# each of those iterations. +@RegisterPFor("FusedBatchNorm") +def _convert_fused_batch_norm(pfor_input): + is_training = pfor_input.get_attr("is_training") + # When BatchNorm is used with training=False, mean and variance are provided + # externally and used as is by the op. Thus, we can merge the S and N + # dimensions as we do for regular operations. + # When BatchNorm is used with training=True, mean and variance are computed + # for each channel across the batch dimension (first one). If we merge S and N + # dimensions, mean and variances will be computed over a larger set. So, we + # merge the S and C dimensions instead. + if not is_training: + # We return zeros for batch_mean and batch_variance output. Note that CPU + # and GPU seem to have different behavior for those two outputs. CPU outputs + # zero because these values are not used during inference. GPU outputs + # something, probably real means and variances. + inputs = _inputs_with_flattening(pfor_input, [0]) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + y = outputs[0] + n = pfor_input.pfor.loop_len_vector + y = _unflatten_first_dim(y, n) + mean = pfor_input.unstacked_input(3) + zeros = array_ops.zeros_like(mean) + return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)] + + pfor_input.stack_inputs() + data_format = pfor_input.get_attr("data_format") + # We merge the first dimension with the "C" dimension, run FusedBatchNorm, and + # then transpose back. + x = pfor_input.stacked_input(0) + x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) + # Note that we stack all the other inputs as well so that they are the same + # size as the new size of the channel dimension. + inputs = [x] + [ + array_ops.reshape(pfor_input.stacked_input(i), [-1]) + for i in range(1, pfor_input.num_inputs) + ] + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + y = outputs[0] + y = array_ops.reshape(y, reverse_shape) + y = array_ops.transpose(y, reverse_order) + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] + outputs = [y] + outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("FusedBatchNormGrad") +def _convert_fused_batch_norm_grad(pfor_input): + pfor_input.stack_inputs() + data_format = pfor_input.get_attr("data_format") + y_backprop = pfor_input.stacked_input(0) + y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) + x = pfor_input.stacked_input(1) + x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) + inputs = [y_backprop, x] + [ + array_ops.reshape(pfor_input.stacked_input(i), [-1]) + for i in range(2, pfor_input.num_inputs) + ] + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + x_backprop = outputs[0] + x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) + x_backprop = array_ops.transpose(x_backprop, x_reverse_order) + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] + outputs = [x_backprop] + outputs + return [wrap(output, True) for output in outputs] + + +@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) +@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) +def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, + shape_dim): + del op_type + inputs = _inputs_with_flattening(pfor_input, flatten_dims) + n = pfor_input.pfor.loop_len_vector + # Adjust the `input_sizes` input. + ones = array_ops.ones( + [array_ops.shape(inputs[shape_dim])[0] - 1], dtype=n.dtype) + inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + outputs = [_unflatten_first_dim(x, n) for x in outputs] + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Conv2DBackpropFilter") +def _convert_conv2d_backprop_filter(pfor_input): + pfor_input.stack_inputs(stack_indices=[2]) + inputs, inputs_stacked, _ = pfor_input.input(0) + filter_sizes = pfor_input.unstacked_input(1) + grads = pfor_input.stacked_input(2) + strides = pfor_input.get_attr("strides") + padding = pfor_input.get_attr("padding") + use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") + data_format = pfor_input.get_attr("data_format") + dilations = pfor_input.get_attr("dilations") + if inputs_stacked: + # TODO(agarwal): Implement this efficiently. + logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!") + + def while_body(i, ta): + inp_i = inputs[i, ...] + grad_i = grads[i, ...] + output = nn_ops.conv2d_backprop_filter( + inp_i, + filter_sizes, + grad_i, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations) + return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) + + n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) + _, ta = control_flow_ops.while_loop( + lambda i, ta: i < n, while_body, + (0, tensor_array_ops.TensorArray(inputs.dtype, n))) + output = ta.concat() + return wrap(output, True) + else: + # We merge the stack dimension with the channel dimension of the gradients + # and pretend we had a larger filter (see change to filter_sizes below). + # Once the filter backprop is computed, we reshape and transpose back + # appropriately. + grads, _, _ = _channel_flatten_input(grads, data_format) + n = pfor_input.pfor.loop_len_vector + old_filter_sizes = filter_sizes + filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) + output = nn_ops.conv2d_backprop_filter( + inputs, + filter_sizes, + grads, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations) + new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) + output = array_ops.reshape(output, new_filter_shape) + output = array_ops.transpose(output, [3, 0, 1, 2, 4]) + return wrap(output, True) + + +# array_ops + + +@RegisterPForWithArgs("Identity", array_ops.identity) +@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) +def _convert_identity(pfor_input, op_type, op_func): + del op_type + return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("Reshape") +def _convert_reshape(pfor_input): + t = pfor_input.stacked_input(0) + shape = pfor_input.unstacked_input(1) + new_dim = array_ops.shape(t)[:1] + new_shape = array_ops.concat([new_dim, shape], axis=0) + return wrap(array_ops.reshape(t, new_shape), True) + + +@RegisterPFor("ExpandDims") +def _convert_expanddims(pfor_input): + t = pfor_input.stacked_input(0) + dim = pfor_input.unstacked_input(1) + dim += math_ops.cast(dim >= 0, dtypes.int32) + return wrap(array_ops.expand_dims(t, axis=dim), True) + + +@RegisterPFor("Slice") +def _convert_slice(pfor_input): + t = pfor_input.stacked_input(0) + begin = pfor_input.unstacked_input(1) + size = pfor_input.unstacked_input(2) + begin = array_ops.concat([[0], begin], axis=0) + size = array_ops.concat([[-1], size], axis=0) + return wrap(array_ops.slice(t, begin, size), True) + + +@RegisterPFor("Tile") +def _convert_tile(pfor_input): + t = pfor_input.stacked_input(0) + multiples = pfor_input.unstacked_input(1) + multiples = array_ops.concat([[1], multiples], 0) + return wrap(array_ops.tile(t, multiples), True) + + +@RegisterPFor("Pack") +def _convert_pack(pfor_input): + pfor_input.stack_inputs() + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + return wrap( + array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) + + +@RegisterPFor("Unpack") +def _convert_unpack(pfor_input): + value = pfor_input.stacked_input(0) + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + num = pfor_input.get_attr("num") + return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] + + +@RegisterPFor("Pad") +def _convert_pad(pfor_input): + t = pfor_input.stacked_input(0) + paddings = pfor_input.unstacked_input(1) + paddings = array_ops.concat([[[0, 0]], paddings], 0) + return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) + + +@RegisterPFor("Split") +def _convert_split(pfor_input): + split_dim = pfor_input.unstacked_input(0) + t = pfor_input.stacked_input(1) + num_split = pfor_input.get_attr("num_split") + split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) + return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] + + +@RegisterPFor("Transpose") +def _convert_transpose(pfor_input): + t = pfor_input.stacked_input(0) + perm = pfor_input.unstacked_input(1) + new_perm = array_ops.concat([[0], perm + 1], axis=0) + return wrap(array_ops.transpose(t, new_perm), True) + + +@RegisterPFor("ZerosLike") +def _convert_zeroslike(pfor_input): + t = pfor_input.stacked_input(0) + shape = array_ops.shape(t)[1:] + return wrap(array_ops.zeros(shape, dtype=t.dtype), False) + + +@RegisterPFor("Gather") +@RegisterPFor("GatherV2") +def _convert_gather(pfor_input): + param, param_stacked, _ = pfor_input.input(0) + indices, indices_stacked, _ = pfor_input.input(1) + op_type = pfor_input.op_type + if op_type == "Gather": + validate_indices = pfor_input.get_attr("validate_indices") + axis = 0 + else: + validate_indices = None + axis = pfor_input.unstacked_input(2) + axis_value = tensor_util.constant_value(axis) + if axis_value is not None: + axis = axis_value + if indices_stacked and not param_stacked: + if indices == pfor_input.pfor.all_indices and axis == 0: + param_shape0 = param.shape[0].value + indices_shape0 = indices.shape[0].value + if param_shape0 is not None and indices_shape0 == param_shape0: + # Note that with loops and conditionals, indices may not be contiguous. + # However they will be sorted and unique. So if the shape matches, then + # it must be picking up all the rows of param. + return wrap(param, True) + # TODO(agarwal): use array_ops.slice here. + output = array_ops.gather( + param, indices, validate_indices=validate_indices, axis=axis) + if axis != 0: + axis = control_flow_ops.cond( + axis < 0, lambda: axis + array_ops.rank(param), lambda: axis) + order = array_ops.concat( + [[axis], + math_ops.range(axis), + math_ops.range(axis + 1, array_ops.rank(output))], + axis=0) + output = control_flow_ops.cond( + math_ops.equal(axis, 0), lambda: output, + lambda: array_ops.transpose(output, order)) + return wrap(output, True) + if param_stacked: + loop_len_vector = pfor_input.pfor.loop_len_vector + pfor_input.stack_inputs(stack_indices=[1]) + indices = pfor_input.stacked_input(1) + param_flat = _flatten_first_two_dims(param) + + # Recompute indices to handle stacked param. + indices_offset = math_ops.range( + loop_len_vector[0]) * array_ops.shape(param)[1] + # Reshape indices_offset to allow broadcast addition + ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32) + new_shape = array_ops.concat([loop_len_vector, ones], axis=0) + indices_offset = array_ops.reshape(indices_offset, new_shape) + indices += indices_offset + + # TODO(agarwal): handle axis != 0. May need to transpose param or + # array_ops.gather_nd. + if isinstance(axis, ops.Tensor): + axis_value = tensor_util.constant_value(axis) + else: + try: + axis_value = int(axis) + except TypeError: + axis_value = None + msg = ("Gather, where indices and param are both loop dependent, currently " + "requires axis=0") + if axis_value is not None and axis_value != 0: + raise ValueError("Error while converting %s. %s. Got axis=%d" % + (pfor_input.op, msg, axis)) + with ops.control_dependencies( + [check_ops.assert_equal(axis, 0, message=msg)]): + output = array_ops.gather(param_flat, indices) + return wrap(output, True) + + +@RegisterPFor("ConcatV2") +def _convert_concatv2(pfor_input): + n = pfor_input.num_inputs + pfor_input.stack_inputs(stack_indices=range(n - 1)) + axis = pfor_input.unstacked_input(n - 1) + axis += math_ops.cast(axis >= 0, axis.dtype) + return wrap( + array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), + True) + + +@RegisterPFor("StridedSlice") +def _convert_strided_slice(pfor_input): + inp = pfor_input.stacked_input(0) + begin = pfor_input.unstacked_input(1) + end = pfor_input.unstacked_input(2) + strides = pfor_input.unstacked_input(3) + begin_mask = pfor_input.get_attr("begin_mask") + end_mask = pfor_input.get_attr("end_mask") + ellipsis_mask = pfor_input.get_attr("ellipsis_mask") + new_axis_mask = pfor_input.get_attr("new_axis_mask") + shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") + + begin = array_ops.concat([[0], begin], axis=0) + end = array_ops.concat([[0], end], axis=0) + strides = array_ops.concat([[1], strides], axis=0) + begin_mask = begin_mask << 1 | 1 + end_mask = end_mask << 1 | 1 + ellipsis_mask <<= 1 + new_axis_mask <<= 1 + shrink_axis_mask <<= 1 + return wrap( + array_ops.strided_slice( + inp, + begin, + end, + strides, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask), True) + + +@RegisterPFor("StridedSliceGrad") +def _convert_strided_slice_grad(pfor_input): + shape = pfor_input.unstacked_input(0) + begin = pfor_input.unstacked_input(1) + end = pfor_input.unstacked_input(2) + strides = pfor_input.unstacked_input(3) + dy = pfor_input.stacked_input(4) + begin_mask = pfor_input.get_attr("begin_mask") + end_mask = pfor_input.get_attr("end_mask") + ellipsis_mask = pfor_input.get_attr("ellipsis_mask") + new_axis_mask = pfor_input.get_attr("new_axis_mask") + shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") + + shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) + begin = array_ops.concat([[0], begin], axis=0) + end = array_ops.concat([[0], end], axis=0) + strides = array_ops.concat([[1], strides], axis=0) + begin_mask = begin_mask << 1 | 1 + end_mask = end_mask << 1 | 1 + ellipsis_mask <<= 1 + new_axis_mask <<= 1 + shrink_axis_mask <<= 1 + return wrap( + array_ops.strided_slice_grad( + shape, + begin, + end, + strides, + dy, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask), True) + + +# math_ops + + +@RegisterPFor("MatMul") +def _convert_matmul(pfor_input): + # TODO(agarwal): Check if tiling is faster than two transposes. + a, a_stacked, _ = pfor_input.input(0) + b, b_stacked, _ = pfor_input.input(1) + tr_a = pfor_input.get_attr("transpose_a") + tr_b = pfor_input.get_attr("transpose_b") + if a_stacked and b_stacked: + output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) + return output + elif a_stacked: + if tr_a: + a = array_ops.transpose(a, [0, 2, 1]) + if a.shape.is_fully_defined(): + x, y, z = a.shape + else: + x, y, z = [ + array_ops.reshape(i, []) + for i in array_ops.split(array_ops.shape(a), 3) + ] + a = array_ops.reshape(a, [x * y, z]) + prod = math_ops.matmul(a, b, transpose_b=tr_b) + return wrap(array_ops.reshape(prod, [x, y, -1]), True) + else: + assert b_stacked + if tr_b: + perm = [2, 0, 1] + b = array_ops.transpose(b, perm) + else: + # As an optimization, if one of the first two dimensions is 1, then we can + # reshape instead of transpose. + # TODO(agarwal): This check can be done inside Transpose kernel. + b_shape = array_ops.shape(b) + min_dim = math_ops.minimum(b_shape[0], b_shape[1]) + perm = control_flow_ops.cond( + math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2]) + new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) + b = array_ops.transpose(b, perm) + b = array_ops.reshape(b, new_shape) + + if b.shape.is_fully_defined(): + x, y, z = b.shape + else: + x, y, z = [ + array_ops.reshape(i, []) + for i in array_ops.split(array_ops.shape(b), 3) + ] + b = array_ops.reshape(b, [x, y * z]) + prod = math_ops.matmul(a, b, transpose_a=tr_a) + prod = array_ops.reshape(prod, [-1, y, z]) + prod = array_ops.transpose(prod, [1, 0, 2]) + return wrap(prod, True) + + +@RegisterPFor("BatchMatMul") +def _convert_batch_mat_mul(pfor_input): + # TODO(agarwal): There may be a more efficient way to do this instead of + # stacking the inputs. + pfor_input.stack_inputs() + x = pfor_input.stacked_input(0) + y = pfor_input.stacked_input(1) + adj_x = pfor_input.get_attr("adj_x") + adj_y = pfor_input.get_attr("adj_y") + + x = _flatten_first_two_dims(x) + y = _flatten_first_two_dims(y) + output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) + output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) + return wrap(output, True) + + +@RegisterPForWithArgs("Sum", math_ops.reduce_sum) +@RegisterPForWithArgs("Prod", math_ops.reduce_prod) +@RegisterPForWithArgs("Max", math_ops.reduce_max) +@RegisterPForWithArgs("Min", math_ops.reduce_min) +def _convert_reduction(pfor_input, _, op_func): + t = pfor_input.stacked_input(0) + indices = pfor_input.unstacked_input(1) + # Shift positive indices by one to account for the extra dimension. + indices += math_ops.cast(indices >= 0, dtypes.int32) + keep_dims = pfor_input.get_attr("keep_dims") + return wrap(op_func(t, indices, keepdims=keep_dims), True) + + +@RegisterPForWithArgs("Cumsum", math_ops.cumsum) +@RegisterPForWithArgs("Cumprod", math_ops.cumprod) +def _convert_cumfoo(pfor_input, _, op_func): + t = pfor_input.stacked_input(0) + axis = pfor_input.unstacked_input(1) + # Shift positive indices by one to account for the extra dimension. + axis += math_ops.cast(axis >= 0, dtypes.int32) + exclusive = pfor_input.get_attr("exclusive") + reverse = pfor_input.get_attr("reverse") + return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) + + +@RegisterPFor("BiasAdd") +def _convert_biasadd(pfor_input): + t = pfor_input.stacked_input(0) + bias = pfor_input.unstacked_input(1) + data_format = pfor_input.get_attr("data_format") + if data_format != b"NCHW": + return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) + shape = array_ops.shape(t) + flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) + t = array_ops.reshape(t, flattened_shape) + t = nn_ops.bias_add(t, bias, data_format=b"NCHW") + t = array_ops.reshape(t, shape) + return wrap(t, True) + + +@RegisterPFor("UnsortedSegmentSum") +def _convert_unsortedsegmentsum(pfor_input): + data, data_stacked, _ = pfor_input.input(0) + # TODO(agarwal): handle unstacked? + segment_ids = pfor_input.stacked_input(1) + # TODO(agarwal): handle stacked? + num_segments = pfor_input.unstacked_input(2) + if not data_stacked: + data = _stack(data, pfor_input.pfor.loop_len_vector).t + segment_shape = array_ops.shape(segment_ids) + n = segment_shape[0] + ones = array_ops.ones_like(segment_shape)[1:] + segment_offset = num_segments * math_ops.range(n) + segment_offset = array_ops.reshape(segment_offset, + array_ops.concat([[n], ones], axis=0)) + segment_ids += segment_offset + num_segments *= n + output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments) + new_output_shape = array_ops.concat( + [[n, -1], array_ops.shape(output)[1:]], axis=0) + output = array_ops.reshape(output, new_output_shape) + return wrap(output, True) + + +@RegisterPFor("Cast") +def _convert_cast(pfor_input): + inp = pfor_input.stacked_input(0) + dtype = pfor_input.get_attr("DstT") + return wrap(math_ops.cast(inp, dtype), True) + + +# Note that ops handled here do not have attributes except "T", and hence don't +# need extra arguments passed to the cwise_op call below. +@RegisterPForWithArgs("Add", math_ops.add) +@RegisterPForWithArgs("Ceil", math_ops.ceil) +@RegisterPForWithArgs("Equal", math_ops.equal) +@RegisterPForWithArgs("NotEqual", math_ops.not_equal) +@RegisterPForWithArgs("Floor", math_ops.floor) +@RegisterPForWithArgs("Greater", math_ops.greater) +@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal) +@RegisterPForWithArgs("Less", math_ops.less) +@RegisterPForWithArgs("LessEqual", math_ops.less_equal) +@RegisterPForWithArgs("LogicalOr", math_ops.logical_or) +@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and) +@RegisterPForWithArgs("LogicalNot", math_ops.logical_not) +@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor) +@RegisterPForWithArgs("Maximum", math_ops.maximum) +@RegisterPForWithArgs("Minimum", math_ops.minimum) +@RegisterPForWithArgs("Mul", math_ops.multiply) +@RegisterPForWithArgs("Neg", math_ops.negative) +@RegisterPForWithArgs("RealDiv", math_ops.divide) +@RegisterPForWithArgs("Relu", nn_ops.relu) +@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid) +@RegisterPForWithArgs("Square", math_ops.square) +@RegisterPForWithArgs("Sub", math_ops.subtract) +@RegisterPForWithArgs("Tanh", math_ops.tanh) +def _convert_cwise(pfor_input, op_type, op_func): + del op_type + pfor_input.expanddim_inputs_for_broadcast() + return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("Shape") +def _convert_shape(pfor_input): + out_type = pfor_input.get_attr("out_type") + return wrap( + array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], + False) + + +@RegisterPFor("ShapeN") +def _convert_shape_n(pfor_input): + out_type = pfor_input.get_attr("out_type") + shapes = [ + array_ops.shape(x, out_type=out_type)[1:] + if stacked else array_ops.shape(x) for x, stacked, _ in pfor_input.inputs + ] + return [wrap(x, False) for x in shapes] + + +@RegisterPFor("Size") +def _convert_size(pfor_input): + out_type = pfor_input.get_attr("out_type") + n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) + return wrap( + array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, + False) + + +@RegisterPFor("Rank") +def _convert_rank(pfor_input): + return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) + + +@RegisterPFor("AddN") +def _convert_addn(pfor_input): + # AddN does not support broadcasting. + pfor_input.stack_inputs() + return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("BiasAddGrad") +def _convert_biasaddgrad(pfor_input): + grad = pfor_input.stacked_input(0) + fmt = pfor_input.get_attr("data_format") + if fmt == b"NCHW": + output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) + else: + grad_shape = array_ops.shape(grad) + last_dim_shape = grad_shape[-1] + first_dim_shape = grad_shape[0] + output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) + output = math_ops.reduce_sum(output, axis=[1], keepdims=False) + return wrap(output, True) + + +# Some required ops are not exposed under the tf namespace. Hence relying on +# _create_op to create them. +@RegisterPForWithArgs("ReluGrad") +@RegisterPForWithArgs("TanhGrad") +@RegisterPForWithArgs("SigmoidGrad") +def _convert_grads(pfor_input, op_type, *args, **kw_args): + del args + del kw_args + # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we + # have to use tiling here. + pfor_input.stack_inputs() + outputs = _create_op( + op_type, [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Select") +def _convert_select(pfor_input): + pfor_input.stack_inputs() + cond = pfor_input.stacked_input(0) + t = pfor_input.stacked_input(1) + e = pfor_input.stacked_input(2) + cond_rank = array_ops.rank(cond) + cond, t, e = control_flow_ops.cond( + cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), + lambda: [cond, t, e]) + outputs = _create_op( + pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + n = pfor_input.pfor.loop_len_vector + out = control_flow_ops.cond(cond_rank > 1, + lambda: _unflatten_first_dim(outputs[0], n), + lambda: outputs[0]) + return [wrap(out, True) for x in outputs] + + +# random_ops + + +@RegisterPForWithArgs("RandomUniform") +@RegisterPForWithArgs("RandomUniformInt") +@RegisterPForWithArgs("RandomStandardNormal") +@RegisterPForWithArgs("TruncatedNormal") +@RegisterPForWithArgs("RandomGamma") +@RegisterPForWithArgs("RandomPoissonV2") +def _convert_random(pfor_input, op_type, *args, **kw_args): + del args + del kw_args + inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] + # inputs[0] is "shape" + inputs[0] = array_ops.concat( + [pfor_input.pfor.loop_len_vector, inputs[0]], axis=0) + logging.warning( + "Note that %s inside pfor op may not give same output as " + "inside a sequential loop.", op_type) + outputs = _create_op( + op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +# logging_ops + + +@RegisterPFor("Assert") +def _convert_assert(pfor_input): + cond, cond_stacked, _ = pfor_input.input(0) + if cond_stacked: + cond = math_ops.reduce_all(cond) + + data_list = [x.t for x in pfor_input.inputs][1:] + return _create_op("Assert", [cond] + data_list, [], + attrs=pfor_input.op.node_def.attr) + + +@RegisterPFor("Print") +def _convert_print(pfor_input): + # Note that we don't stack all the inputs. Hence unstacked values are printed + # once here vs multiple times in a while_loop. + pfor_input.stack_inputs([0]) + outputs = _create_op( + "Print", [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +# data_flow_ops + +# TensorArray conversion is tricky since we don't support arrays of +# TensorArrays. For converting them, we consider two distinct cases: +# +# 1. The array is constructed outside the pfor call, and read/written inside the +# loop. +# This is an easier case since we don't need to make an array of TensorArrays. +# A correctness requirement is that these parallel iterations shouldn't attempt +# to write to the same location. Hence at conversion time we disallow indices to +# be loop-invariant as that would guarantee a collision. Even if the indices are +# not loop-invariant, they could conflict and that shall trigger runtime errors. +# +# 2. The array is constructed and used entirely inside each pfor iteration. +# For simplicity, here we require that the indices used for write/scatter are +# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in +# different pfor iterations. We consider two sub_cases: +# +# 2a Elements written to the array are "stacked" +# To simulate multiple TensorArrays, we may increase the dimension of each +# element of the array. i.e. the i_th row of the j_th entry of the converted +# TensorArray corresponds to to the j_th entry of the TensorArray in the i_th +# pfor iteration. +# +# 2b Elements written to the array are "unstacked" +# In this case we don't increase the dimensions to avoid redundant tiling. Each +# iteration is trying to write the same value. So we convert that to a single +# write. +# +# Here are some tricks used to implement the above: +# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of +# trying to trace whether future writes are stacked or unstacked in order to set +# this attr, we set it to correspond to unknown shape. +# - We use the "flow" output of the different ops to track whether the array +# elements are stacked or unstacked. If a stacked write/scatter is done, we make +# the flow stacked as well. +# - We use some heuristic traversal of the graph to track whether the +# TensorArray handle was created inside or outside the pfor loop. + + +@RegisterPFor("TensorArrayV3") +def _convert_tensor_array_v3(pfor_input): + size = pfor_input.unstacked_input(0) + dtype = pfor_input.get_attr("dtype") + dynamic_size = pfor_input.get_attr("dynamic_size") + clear_after_read = pfor_input.get_attr("clear_after_read") + identical_element_shapes = pfor_input.get_attr("identical_element_shapes") + tensor_array_name = pfor_input.get_attr("tensor_array_name") + handle, flow = data_flow_ops.tensor_array_v3( + size, + dtype=dtype, + # We don't set element shape since we don't know if writes are stacked or + # not yet. + element_shape=None, + dynamic_size=dynamic_size, + clear_after_read=clear_after_read, + identical_element_shapes=identical_element_shapes, + tensor_array_name=tensor_array_name) + # Note we keep flow unstacked for now since we don't know if writes will be + # stacked or not. + return wrap(handle, False), wrap(flow, False) + + +@RegisterPFor("TensorArraySizeV3") +def _convert_tensor_array_size_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + flow, flow_stacked, _ = pfor_input.input(1) + if flow_stacked: + flow = _unstack_flow(flow) + size = data_flow_ops.tensor_array_size_v3(handle, flow) + return wrap(size, False) + + +def _handle_inside_pfor(pfor_input, handle): + """Returns True if handle was created inside the pfor loop.""" + # We use some heuristic to find the original TensorArray creation op. + # The logic should handle the common cases (except cond based subgraphs). + # In theory the user could perform different operations on the handle (like + # Reshape, stack multiple handles, etc) which could break this logic. + # TODO(agarwal): handle Switch/Merge. + while handle.op.type in ("Enter", "Identity"): + handle = handle.op.inputs[0] + if handle.op.type not in [ + "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"]: + raise ValueError("Unable to find source for handle %s" % handle) + else: + return pfor_input.pfor.op_is_inside_loop(handle.op) + + +def _unstack_flow(value): + # TODO(agarwal): consider looking if this is a Tile op then get its input. + # This may avoid running the Tile operations. + return array_ops.gather(value, 0) + + +@RegisterPFor("TensorArrayReadV3") +def _convert_tensor_array_read_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + index, index_stacked, _ = pfor_input.input(1) + dtype = pfor_input.get_attr("dtype") + flow, flow_stacked, _ = pfor_input.input(2) + if flow_stacked: + flow = _unstack_flow(flow) + + is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside_pfor: + # Note that if we are inside a control flow construct inside the pfor, and + # only some of the iterations are doing the read (i.e. + # `all_indices_partitioned` is True), then the read operation should only + # return values for the currently active pfor iterations (`all_indices` + # below). Hence, whenever the returned value is stacked (i.e. `flow` is + # stacked), we may need to do an extra gather after reading the values. Also + # note that if `is_inside` is false, then values in the tensor array are + # unstacked. So the check is only needed in this branch. + all_indices = pfor_input.pfor.all_indices + all_indices_partitioned = pfor_input.pfor.all_indices_partitioned + # Note: flow_stacked indicates if values in the TensorArray are stacked or + # not. + if index_stacked: + if flow_stacked: + raise ValueError( + "It looks like TensorArrayReadV3 was called on a TensorArray whose" + " values are not loop-invariant, and the read indices were also" + " not loop invariant. This is currently unsupported.") + value = data_flow_ops.tensor_array_gather_v3( + handle, index, flow, dtype=dtype) + return wrap(value, True) + value = data_flow_ops.tensor_array_read_v3( + handle, index, flow, dtype=dtype) + if flow_stacked and all_indices_partitioned: + value = array_ops.gather(value, all_indices) + return wrap(value, flow_stacked) + # Values in the TensorArray should be unstacked (since different iterations + # couldn't write to the same location). So whether output is stacked or not + # depends on index_stacked. + if index_stacked: + value = data_flow_ops.tensor_array_gather_v3( + handle, index, flow, dtype=dtype) + else: + value = data_flow_ops.tensor_array_read_v3( + handle, index, flow, dtype=dtype) + return wrap(value, index_stacked) + + +@RegisterPFor("TensorArrayWriteV3") +def _convert_tensor_array_write_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + index, index_stacked, _ = pfor_input.input(1) + value, value_stacked, _ = pfor_input.input(2) + flow, flow_stacked, _ = pfor_input.input(3) + if value_stacked and pfor_input.pfor.all_indices_partitioned: + # Looks like we are in a control flow in a pfor where not all iterations are + # active now. We don't allow that since that could lead to different indices + # having different shapes which will be hard to merge later. + raise ValueError("Writing non loop invariant values to TensorArray from " + "inside a while_loop/cond not supported.") + if flow_stacked: + flow = _unstack_flow(flow) + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + if index_stacked: + raise ValueError("Need indices for %s to be loop invariant" % handle) + if not flow_stacked and not value_stacked: + flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) + return wrap(flow_out, False) + else: + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + # TODO(agarwal): Note that if flow is unstacked and value is stacked, then + # this may or may not be a safe situation. flow is unstacked both for a + # freshly created TensorArray, as well as after unstacked values are + # written to it. If it is the latter, then we cannot write a stacked value + # now since that may cause runtime errors due to different shapes in the + # array. At the moment we are not able to handle this gracefully and + # distinguish between the two cases. That would require some heuristic + # traversal of the graph to figure out whether all the writes are + # unstacked or not. + flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + else: + if not index_stacked: + raise ValueError("Need indices for %s to be not loop invariant" % handle) + # Note that even when index_stacked is true, actual values in index may + # still not be unique. However that will cause runtime error when executing + # the scatter operation below. + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + + +def _transpose_first_two_dims(value): + # TODO(agarwal): optimize if one of the dims == 1. + value_shape = array_ops.shape(value) + v0 = value_shape[0] + v1 = value_shape[1] + value = array_ops.reshape(value, [v0, v1, -1]) + value = array_ops.transpose(value, [1, 0, 2]) + new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) + return array_ops.reshape(value, new_shape) + + +@RegisterPFor("TensorArrayGatherV3") +def _convert_tensor_array_gather_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + indices, indices_stacked, _ = pfor_input.input(1) + indices = array_ops.reshape(indices, [-1]) + flow, flow_stacked, _ = pfor_input.input(2) + if flow_stacked: + flow = _unstack_flow(flow) + dtype = pfor_input.get_attr("dtype") + # TODO(agarwal): support element_shape attr? + + n = pfor_input.pfor.loop_len_vector + value = data_flow_ops.tensor_array_gather_v3( + handle, indices, flow, dtype=dtype) + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + # flow_stacked indicates if values in the TensorArray are stacked or not. + if indices_stacked: + if flow_stacked: + raise ValueError( + "It looks like TensorArrayGatherV3 was called on a TensorArray " + "whose values are not loop-invariant, and the indices were also " + "not loop invariant. This is currently unsupported.") + else: + value = _unflatten_first_dim(value, n) + return wrap(value, True) + else: + if flow_stacked: + # Since elements in this array are stacked and `value` was produced by + # gather, its first two dims are "gathered elements" and "stack + # dimension". Our semantics require these two to be flipped. + value = _transpose_first_two_dims(value) + return wrap(value, flow_stacked) + else: + # Values in the TensorArray should be unstacked (since different iterations + # couldn't write to the same location). So whether output is stacked or not + # depends on indices_stacked. + if indices_stacked: + value = _unflatten_first_dim(value, n) + return wrap(value, indices_stacked) + + +@RegisterPFor("TensorArrayScatterV3") +def _convert_tensor_array_scatter_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + indices, indices_stacked, _ = pfor_input.input(1) + indices = array_ops.reshape(indices, [-1]) + value, value_stacked, _ = pfor_input.input(2) + flow, flow_stacked, _ = pfor_input.input(3) + + if flow_stacked: + flow = _unstack_flow(flow) + + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + if indices_stacked: + raise ValueError("Need indices for %s to be loop invariant" % handle) + # Note that flow_stacked indicates if existing values in the array are + # stacked or not. + if not flow_stacked and not value_stacked: + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return wrap(flow_out, False) + if not value_stacked: + # TODO(agarwal): tile in the second dimension directly instead of + # transposing below. + value = _stack(value, pfor_input.pfor.loop_len_vector).t + + value = _transpose_first_two_dims(value) + # TODO(agarwal): Note that if a previous write was unstacked, flow will be + # unstacked, and a stacked value may be written here which may cause + # runtime error due to different elements having different shape. We do + # not try to prevent that. + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + if not indices_stacked: + raise ValueError("Need indices for %s to be not loop invariant" % handle) + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + value = _flatten_first_two_dims(value) + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + + +@RegisterPFor("TensorArrayGradV3") +def _convert_tensor_array_grad_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + flow, flow_stacked, _ = pfor_input.input(1) + if flow_stacked: + flow = _unstack_flow(flow) + source = pfor_input.get_attr("source") + # TODO(agarwal): For now, we assume that gradients are stacked if the + # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong + # will give runtime error due to incorrect shape being written to the + # accumulator. It is difficult to know in advance if gradients written will be + # stacked or not. Note that flow being stacked is not indicative of the + # gradient being stacked or not. Revisit this later. + shape_to_prepend = pfor_input.pfor.loop_len_vector + grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( + handle=handle, + flow_in=flow, + shape_to_prepend=shape_to_prepend, + source=source) + flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t + return [wrap(grad_handle, False), wrap(flow_out, True)] + + +# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar +# to TensorArrays, we convert them by changing the dimension of the elements +# inside the stack. +# +# We consider two cases: +# +# 1. StackV2 is constructed and used entirely inside the pfor loop. +# We keep a single Stack and perform the push/pop operations of all the +# iterations in lock-step. We also assume that all the iterations perform these +# operations. In case of dynamic control flow, if only some of the iterations +# try to perform a push/pop, then the conversion may not work correctly and may +# cause undefined behavior. +# TODO(agarwal): test StackV2 with dynamic control flow. +# +# 2. StackV2 is constructed outside the pfor loop. +# Performing stack push/pop in a parallel fashion is ill-defined. However given +# that reading stacks created externally is a common operation when computing +# jacobians, we provide some special semantics here as follows. +# - disallow push operations to the stack +# - pop operations are performed in lock step by all iterations, similar to the +# case when the stack is created inside. A single value is popped during the +# lock-step operation and broadcast to all the iterations. Values in the stack +# are assumed to be loop-invariant. +# +# Some other implementation details: +# We use an ugly logic to find whether values in Stack data structure are +# loop invariant or not. When converting push/pop operations, we keep track of +# whether the last conversion used a stacked value or not (see _stack_cache +# below). As a result if an unstacked value is written first, subsequent stacked +# writes are disallowed when they could have been allowed in theory. + +# Map from cache key based on StackV2 handle to a bool indicating whether values +# are stacked or not. +# TODO(agarwal): move _stack_cache inside pfor? +_stack_cache = {} + + +def _stack_cache_key(pfor_input): + """Create cache key corresponding to a stack handle.""" + op_type = pfor_input.op_type + assert op_type in ["StackPushV2", "StackPopV2"], op_type + orig_handle = pfor_input.op.inputs[0] + while orig_handle.op.type in ["Identity", "Enter"]: + orig_handle = orig_handle.op.inputs[0] + assert orig_handle.op.type == "StackV2", orig_handle.op + return ops.get_default_graph(), pfor_input.pfor, orig_handle + + +def _stack_handle_inside_pfor(handle, pfor_input): + while handle.op.type in ["Identity", "Enter"]: + handle = handle.op.inputs[0] + assert handle.op.type == "StackV2", ( + "Unable to find StackV2 op. Got %s" % handle.op) + return pfor_input.pfor.op_is_inside_loop(handle.op) + + +@RegisterPFor("StackPushV2") +def _convert_stack_push_v2(pfor_input): + handle = pfor_input.unstacked_input(0) + elem, elem_stacked, _ = pfor_input.input(1) + swap_memory = pfor_input.get_attr("swap_memory") + + if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): + raise ValueError("StackPushV2 not allowed on stacks created outside pfor") + stack_cache_key = _stack_cache_key(pfor_input) + stacked = _stack_cache.get(stack_cache_key, None) + if stacked is None: + stacked = elem_stacked + _stack_cache[stack_cache_key] = stacked + else: + # If we previously made it unstacked then we can't revert to being stacked. + if not stacked and elem_stacked: + raise ValueError( + "It looks like the stack was previously determined to be loop" + " invariant, but we are now trying to push a loop dependent value" + " to it. This is currently unsupported.") + if stacked and not elem_stacked: + elem = _stack(elem, pfor_input.pfor.loop_len_vector).t + out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) + return wrap(out, stacked) + + +# Note that inputs to this convertor will be unstacked. However it should get +# called since it is a stateful op. +@RegisterPFor("StackPopV2") +def _convert_stack_pop_v2(pfor_input): + handle = pfor_input.unstacked_input(0) + stack_cache_key = _stack_cache_key(pfor_input) + stacked = _stack_cache.get(stack_cache_key, None) + # If a StackPushV2 has not been converted yet, we default to unstacked since + # the push could be outside of pfor, or the covertor may not be called if the + # inputs are unconverted. + if stacked is None: + stacked = False + _stack_cache[stack_cache_key] = False + elem_type = pfor_input.get_attr("elem_type") + out = data_flow_ops.stack_pop_v2(handle, elem_type) + return wrap(out, stacked) + + +# parsing_ops + + +@RegisterPFor("DecodeCSV") +def _convert_decode_csv(pfor_input): + lines = pfor_input.stacked_input(0) + record_defaults = [ + pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) + ] + field_delim = pfor_input.get_attr("field_delim") + use_quote_delim = pfor_input.get_attr("use_quote_delim") + select_cols = pfor_input.get_attr("select_cols") + if not select_cols: + select_cols = None + return [ + wrap(t, True) for t in parsing_ops.decode_csv( + lines, + record_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + select_cols=select_cols) + ] + + +@RegisterPFor("ParseSingleExample") +def _convert_parse_single_example(pfor_input): + serialized = pfor_input.stacked_input(0) + dense_defaults = [ + pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) + ] + sparse_keys = pfor_input.get_attr("sparse_keys") + dense_keys = pfor_input.get_attr("dense_keys") + sparse_types = pfor_input.get_attr("sparse_types") + dense_shapes = pfor_input.get_attr("dense_shapes") + output = gen_parsing_ops.parse_example( + serialized=serialized, + names=[], + dense_defaults=dense_defaults, + sparse_keys=sparse_keys, + dense_keys=dense_keys, + sparse_types=sparse_types, + dense_shapes=dense_shapes) + return [wrap(t, True, True) for t in nest.flatten(output)] diff --git a/tensorflow/python/ops/random_grad.py b/tensorflow/python/ops/random_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..baa8e2e2cd33d37312b5b14bea3c248c06ff2e50 --- /dev/null +++ b/tensorflow/python/ops/random_grad.py @@ -0,0 +1,65 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradients for operators defined in random_ops.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import math_ops + + +def add_leading_unit_dimensions(x, num_dimensions): + new_shape = array_ops.concat( + [array_ops.ones([num_dimensions], dtype=dtypes.int32), + array_ops.shape(x)], axis=0) + return array_ops.reshape(x, new_shape) + + +@ops.RegisterGradient("RandomGamma") +def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name + """Returns the gradient of a Gamma sample w.r.t. alpha. + + The gradient is computed using implicit differentiation, see + "Implicit Reparameterization Gradients" (https://arxiv.org/abs/1805.08498). + + Args: + op: A `RandomGamma` operation. We assume that the inputs to the operation + are `shape` and `alpha` tensors, and the output is the `sample` tensor. + grad: The incoming gradient `dloss / dsample` of the same shape as + `op.outputs[0]`. + + Returns: + A `Tensor` with derivatives `dloss / dalpha` + """ + shape = op.inputs[0] + alpha = op.inputs[1] + sample = op.outputs[0] + + with ops.control_dependencies([grad]): + # Make the parameters alpha broadcastable with samples by appending + # unit dimensions. + num_sample_dimensions = array_ops.shape(shape)[0] + alpha_broadcastable = add_leading_unit_dimensions( + alpha, num_sample_dimensions) + partial_a = gen_random_ops.random_gamma_grad(alpha_broadcastable, sample) + + # The first input is shape; the second input is alpha. + return (None, math_ops.reduce_sum( + grad * partial_a, axis=math_ops.range(num_sample_dimensions))) diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 6a2dd3f1cd55eea1d3b652a31cd2784c411c2ce0..b8738adf66e6ff51962ed44dce7cd4b95544e271 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -368,25 +368,41 @@ def random_gamma(shape, `alpha` is the shape parameter describing the distribution(s), and `beta` is the inverse scale parameter(s). - Example: + Note: Because internal calculations are done using `float64` and casting has + `floor` semantics, we must manually map zero outcomes to the smallest + possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This + means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise + should. This bias can only happen for small values of `alpha`, i.e., + `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. - samples = tf.random_gamma([10], [0.5, 1.5]) - # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents - # the samples drawn from each distribution + The samples are differentiable w.r.t. alpha and beta. + The derivatives are computed using the approach described in the paper - samples = tf.random_gamma([7, 5], [0.5, 1.5]) - # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] - # represents the 7x5 samples drawn from each of the two distributions + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) - samples = tf.random_gamma([30], [[1.],[3.],[5.]], beta=[[3., 4.]]) - # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. + Example: - Note: Because internal calculations are done using `float64` and casting has - `floor` semantics, we must manually map zero outcomes to the smallest - possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This - means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise - should. This bias can only happen for small values of `alpha`, i.e., - `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. + ```python + samples = tf.random_gamma([10], [0.5, 1.5]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution + + samples = tf.random_gamma([7, 5], [0.5, 1.5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + + alpha = tf.constant([[1.],[3.],[5.]]) + beta = tf.constant([[3., 4.]]) + samples = tf.random_gamma([30], alpha=alpha, beta=beta) + # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. + + loss = tf.reduce_mean(tf.square(samples)) + dloss_dalpha, dloss_dbeta = tf.gradients(loss, [alpha, beta]) + # unbiased stochastic derivatives of the loss function + alpha.shape == dloss_dalpha.shape # True + beta.shape == dloss_dbeta.shape # True + ``` Args: shape: A 1-D integer Tensor or Python array. The shape of the output samples @@ -406,8 +422,9 @@ def random_gamma(shape, name: Optional name for the operation. Returns: - samples: a `Tensor` of shape `tf.concat(shape, tf.shape(alpha + beta))` - with values of type `dtype`. + samples: a `Tensor` of shape + `tf.concat([shape, tf.shape(alpha + beta)], axis=0)` with values of type + `dtype`. """ with ops.name_scope(name, "random_gamma", [shape, alpha, beta]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) @@ -421,8 +438,6 @@ def random_gamma(shape, gen_random_ops.random_gamma( shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta) -ops.NotDifferentiable("RandomGamma") - @tf_export("random_poisson") def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): @@ -432,13 +447,15 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): Example: - samples = tf.random_poisson([0.5, 1.5], [10]) - # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents - # the samples drawn from each distribution + ```python + samples = tf.random_poisson([0.5, 1.5], [10]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution - samples = tf.random_poisson([12.2, 3.3], [7, 5]) - # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] - # represents the 7x5 samples drawn from each of the two distributions + samples = tf.random_poisson([12.2, 3.3], [7, 5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + ``` Args: lam: A Tensor or Python value or N-D array of type `dtype`. @@ -455,8 +472,8 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): name: Optional name for the operation. Returns: - samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with - values of type `dtype`. + samples: a `Tensor` of shape `tf.concat([shape, tf.shape(lam)], axis=0)` + with values of type `dtype`. """ with ops.name_scope(name, "random_poisson", [lam, shape]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 288006fad7c57adf4e845e8946d5a58039d25dd1..15cafbbde50335de0dc0cd8849425c07b4ac81d3 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib + from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.python import pywrap_tensorflow @@ -115,6 +117,18 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): return handle +@contextlib.contextmanager +def _handle_graph(handle): + # Note: might have an eager tensor but not be executing eagerly when building + # functions. + if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) + or ops.has_default_graph()): + yield + else: + with handle.graph.as_default(): + yield + + class EagerResourceDeleter(object): """An object which cleans up a resource handle. @@ -159,7 +173,8 @@ class EagerResourceDeleter(object): def shape_safe_assign_variable_handle(handle, shape, value, name=None): """Helper that checks shape compatibility and assigns variable.""" - value_tensor = ops.convert_to_tensor(value) + with _handle_graph(handle): + value_tensor = ops.convert_to_tensor(value) shape.assert_is_compatible_with(value_tensor.shape) return gen_resource_variable_ops.assign_variable_op(handle, value_tensor, @@ -492,6 +507,9 @@ class ResourceVariable(variables.Variable): else: self._cached_value = None if not context.executing_eagerly(): + # Eager variables are only added to collections if they are part of an + # eager variable store (otherwise in an interactive session they would + # hog memory and cause OOM). This is done in ops/variable_scope.py. ops.add_to_collections(collections, self) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) @@ -536,6 +554,7 @@ class ResourceVariable(variables.Variable): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: snapshot = g.as_graph_element( ops.prepend_name_scope( @@ -561,6 +580,21 @@ class ResourceVariable(variables.Variable): self._constraint = None self._cached_shape_as_list = None + @contextlib.contextmanager + def _assign_dependencies(self): + """Makes assignments depend on the cached value, if any. + + This prevents undefined behavior with reads not ordered wrt writes. + + Yields: + None. + """ + if self._cached_value is not None: + with ops.control_dependencies([self._cached_value]): + yield + else: + yield + def __nonzero__(self): return self.__bool__() @@ -705,7 +739,7 @@ class ResourceVariable(variables.Variable): return self._save_slice_info def _read_variable_op(self): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) @@ -730,7 +764,7 @@ class ResourceVariable(variables.Variable): def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: - if self._trainable: + if self.trainable: tape.watch_variable(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) @@ -771,6 +805,7 @@ class ResourceVariable(variables.Variable): var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, export_scope) var_def.is_resource = True + var_def.trainable = self.trainable if self._save_slice_info: var_def.save_slice_info_def.MergeFrom( self._save_slice_info.to_proto(export_scope=export_scope)) @@ -816,14 +851,15 @@ class ResourceVariable(variables.Variable): operator: string. The operator name. """ + tensor_oper = getattr(ops.Tensor, operator) def _run_op(a, *args): # pylint: disable=protected-access value = a._AsTensor() - return getattr(ops.Tensor, operator)(value, *args) + return tensor_oper(value, *args) # Propagate __doc__ to wrapper try: - _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ + _run_op.__doc__ = tensor_oper.__doc__ except AttributeError: pass @@ -850,8 +886,10 @@ class ResourceVariable(variables.Variable): # TODO(apassos): this here and below is not atomic. Consider making it # atomic if there's a way to do so without a performance cost for those who # don't need it. - assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( - self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) + with _handle_graph(self.handle), self._assign_dependencies(): + assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( + self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) if read_value: return self._lazy_read(assign_sub_op) return assign_sub_op @@ -872,14 +910,16 @@ class ResourceVariable(variables.Variable): it will return the `Operation` that does the assignment, and when in eager mode it will return `None`. """ - assign_add_op = gen_resource_variable_ops.assign_add_variable_op( - self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) + with _handle_graph(self.handle), self._assign_dependencies(): + assign_add_op = gen_resource_variable_ops.assign_add_variable_op( + self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) if read_value: return self._lazy_read(assign_add_op) return assign_add_op def _lazy_read(self, op): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return _UnreadVariable( self._handle, self.dtype, self._shape, self._in_graph_mode, @@ -902,30 +942,34 @@ class ResourceVariable(variables.Variable): it will return the `Operation` that does the assignment, and when in eager mode it will return `None`. """ - value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) - self._shape.assert_is_compatible_with(value_tensor.shape) - assign_op = gen_resource_variable_ops.assign_variable_op( - self.handle, value_tensor, name=name) - if read_value: - return self._lazy_read(assign_op) + # Note: not depending on the cached value here since this can used to + # initialize the variable. + with _handle_graph(self.handle): + value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) + self._shape.assert_is_compatible_with(value_tensor.shape) + assign_op = gen_resource_variable_ops.assign_variable_op( + self.handle, value_tensor, name=name) + if read_value: + return self._lazy_read(assign_op) return assign_op def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): - return self._lazy_read( - gen_array_ops.resource_strided_slice_assign( - ref=self.handle, - begin=begin, - end=end, - strides=strides, - value=value, - name=name, - begin_mask=begin_mask, - end_mask=end_mask, - ellipsis_mask=ellipsis_mask, - new_axis_mask=new_axis_mask, - shrink_axis_mask=shrink_axis_mask)) + with _handle_graph(self.handle), self._assign_dependencies(): + return self._lazy_read( + gen_array_ops.resource_strided_slice_assign( + ref=self.handle, + begin=begin, + end=end, + strides=strides, + value=ops.convert_to_tensor(value, dtype=self.dtype), + name=name, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask)) def __int__(self): if self.dtype != dtypes.int32 and self.dtype != dtypes.int64: @@ -955,32 +999,28 @@ class ResourceVariable(variables.Variable): def __imul__(self, unused_other): raise RuntimeError("Variable *= value not supported. Use " - "variable.assign_mul(value) to modify the variable " - "value and variable = variable * value to get a new " - "Tensor object.") + "`var.assign(var * value)` to modify the variable or " + "`var = var * value` to get a new Tensor object.") def __idiv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __itruediv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __irealdiv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __ipow__(self, unused_other): raise RuntimeError("Variable **= value not supported. Use " - "value and variable = variable ** value to get a new " - "Tensor object.") + "`var.assign(var ** value)` to modify the variable or " + "`var = var ** value` to get a new Tensor object.") pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable) @@ -1024,6 +1064,10 @@ class _UnreadVariable(ResourceVariable): self._graph_element = self.read_value() self._handle_deleter = deleter + @property + def name(self): + return self._parent_op.name + def value(self): return self._read_variable_op() diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 10d576c95bc4fd3147da44ee1522dc829bcab83d..deba133fb9910f28c7f902f334174734c3c742f7 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops @@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape): return shape +def _should_cache(): + """Returns True if a default caching device should be set, otherwise False.""" + if context.executing_eagerly(): + return False + # Don't set a caching device when running in a loop, since it is possible that + # train steps could be wrapped in a tf.while_loop. In that scenario caching + # prevents forward computations in loop iterations from re-reading the + # updated weights. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingWhileContext(ctxt) is None + + # pylint: disable=unused-argument def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, @@ -558,7 +571,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -828,7 +841,8 @@ def _dynamic_rnn_loop(cell, final_outputs = nest.pack_sequence_as( structure=cell.output_size, flat_sequence=final_outputs) if not in_graph_mode: - final_outputs = array_ops.stack(final_outputs, axis=0) + final_outputs = nest.map_structure_up_to( + cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) return (final_outputs, final_state) @@ -1014,7 +1028,7 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1227,7 +1241,7 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index e9a2d2d0f19c409eb5578efbfd62fa766796230c..82a044a0d4c8710f5ade0aa460f4354a0dd35deb 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -979,6 +980,7 @@ class DropoutWrapper(RNNCell): but not `callable`. ValueError: if any of the keep_probs are not between 0 and 1. """ + super(DropoutWrapper, self).__init__() assert_like_rnncell("cell", cell) if (dropout_state_filter_visitor is not None @@ -1153,6 +1155,7 @@ class ResidualWrapper(RNNCell): Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs and outputs. """ + super(ResidualWrapper, self).__init__() self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") @@ -1210,6 +1213,7 @@ class DeviceWrapper(RNNCell): cell: An instance of `RNNCell`. device: A device string or function, for passing to `tf.device`. """ + super(DeviceWrapper, self).__init__() self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") @@ -1328,7 +1332,7 @@ class MultiRNNCell(RNNCell): return cur_inp, new_states -class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable): +class _SlimRNNCell(RNNCell, checkpointable_tracking.NotCheckpointable): """A simple wrapper for slim.rnn_cells.""" def __init__(self, cell_fn): diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index f87c5dc5e39b7b3a58ca3cf2e9cd943cb081a020..1e3f662ff34f67d2b5f226427c8a03d82b9f2a7c 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Script Language Operators. See the @{$python/script_ops} guide.""" # pylint: disable=g-bad-name @@ -24,35 +23,61 @@ import threading # Used by py_util.cc to get tracebacks. import traceback # pylint: disable=unused-import +import weakref import numpy as np import six from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_script_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export +# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs); +# used for differentiation. +tape_cache = {} + class EagerFunc(object): """A wrapper for a function owned by an EagerPyFunc.""" - def __init__(self, func, Tout): + def __init__(self, func, Tout, is_grad_func): """Constructs an EagerFunc. Args: func: The function to wrap. Tout: A list of datatypes for the output; an empty list if the output is None. + is_grad_func: Whether this EagerFunc is the gradient of another + EagerPyFunc. """ self._func = func self._out_dtypes = Tout + self._is_grad_func = is_grad_func def _convert(self, value, dtype): + """Converts `value` to a tensor of type `dtype`, with error checking. + + Args: + value: The tensor to convert. + dtype: The desired dtype. + + Returns: + A tensor of type `dtype`, or a zeros tensor if value is None and + this function is in fact a grdient function. + + Raises: + RuntimeError: if `value` is a variable. + """ + if isinstance(value, resource_variable_ops.ResourceVariable): raise RuntimeError( "Attempting to return a variable from an eagerly executed py_func. " @@ -60,22 +85,39 @@ class EagerFunc(object): "be returned; to return the value of a variable, make sure to obtain " "the Tensor backing it by calling `.read_value()` on the variable in " "question: %s" % value) + if value is None and self._is_grad_func: + # Gradient functions may legitimately return a list that contains + # both Tensors and Python Nones. Unfortuantely this breaks the + # OpKernel, so for now we replace None objects with zeros, which is + # mathematically correct but will prevent short-circuiting gradient + # computations. + # + # TODO(akshayka): Make it possible to return a list of both Tensors and + # Nones from an EagerPyFunc. + return constant_op.constant(0.0, dtype=dtype) return ops.convert_to_tensor(value, dtype=dtype) - def __call__(self, on_gpu, args): + def __call__(self, device, token, args): """Passes `args` to `self._func`, which is executed eagerly.""" - with context.eager_mode(): + + with context.eager_mode(), backprop.GradientTape() as tape: + for tensor in args: + tape.watch(tensor) ret = self._func(*args) - maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() - if isinstance(ret, (tuple, list)): - return [ - maybe_copy_to_gpu(self._convert(x, dtype=dtype)) - for (x, dtype) in zip(ret, self._out_dtypes) - ] - elif ret is None: - return ret - else: - return maybe_copy_to_gpu(self._convert(ret, dtype=self._out_dtypes[0])) + # Use tf.identity to copy the returned tensors to device if neccesary. + with ops.device(device): + if isinstance(ret, (tuple, list)): + outputs = [ + array_ops.identity(self._convert(x, dtype=dtype)) + for (x, dtype) in zip(ret, self._out_dtypes) + ] + elif ret is None: + outputs = None + else: + outputs = array_ops.identity( + self._convert(ret, dtype=self._out_dtypes[0])) + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + return outputs class FuncRegistry(object): @@ -88,11 +130,14 @@ class FuncRegistry(object): def __init__(self): self._lock = threading.Lock() self._unique_id = 0 # GUARDED_BY(self._lock) - self._funcs = {} + # Only store weakrefs to the funtions. The strong reference is stored in + # the graph. + self._funcs = weakref.WeakValueDictionary() def insert(self, func): """Registers `func` and returns a unique token for this entry.""" token = self._next_unique_token() + # Store a weakref to the function self._funcs[token] = func return token @@ -129,14 +174,14 @@ class FuncRegistry(object): else: return result - def __call__(self, token, on_gpu, args): + def __call__(self, token, device, args): """Calls the registered function for `token` with args. Args: token: A key into this `FuncRegistry` identifying which function to call. - on_gpu: A boolean indicating whether or not `token`'s corresponding - operation was placed on GPU; only used if the function registered for - `token` is an `EagerPyFunc`. + device: Name of the device on which outputs of `token`'s corresponding + operation should be placed. Used iff the function registered for `token` + is an EagerPyFunc. args: The arguments to pass to the function registered for `token`. Returns: @@ -145,11 +190,18 @@ class FuncRegistry(object): Raises: ValueError: if no function is registered for `token`. """ - func = self._funcs[token] + func = self._funcs.get(token, None) if func is None: raise ValueError("callback %s is not found" % token) if isinstance(func, EagerFunc): - return func(on_gpu, args) + # NB: Different invocations of the same py_func will share the same + # token, and the entries they stash in the tape_cache will collide. + # In practice, when executing a graph, this should only happen if + # the py_func is in a while_loop whose iterations are run in parallel + # or if the graph is being driven by concurrent session.run() calls. + # + # TODO(akshayka): Key the tape cache in a thread-safe way. + return func(device, token, args) else: ret = func(*args) # Strings seem to lead to a memory leak here if they're not wrapped in a @@ -180,20 +232,13 @@ _py_funcs = FuncRegistry() pywrap_tensorflow.InitializePyTrampoline(_py_funcs) -class CleanupFunc(object): - """A helper class to remove a registered function from _py_funcs.""" - - def __init__(self, token): - self._token = token - - def __del__(self): - if _py_funcs is not None: - # If _py_funcs is None, the program is most likely in shutdown, and the - # _py_funcs object has been destroyed already. - _py_funcs.remove(self._token) - - -def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): +def _internal_py_func(func, + inp, + Tout, + stateful=None, + eager=False, + is_grad_func=False, + name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False @@ -203,7 +248,7 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): Tout = [Tout] if eager: - func = EagerFunc(func, Tout) + func = EagerFunc(func, Tout, is_grad_func) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, @@ -216,17 +261,15 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): # bound to that of the outer graph instead. graph = graph._outer_graph - cleanup = CleanupFunc(token) - # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. - if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"): - graph._cleanup_py_funcs_used_in_graph = [] + if not hasattr(graph, "_py_funcs_used_in_graph"): + graph._py_funcs_used_in_graph = [] - # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph - # will be destroyed and their __del__ will remove the 'token' from - # the funcs registry. - graph._cleanup_py_funcs_used_in_graph.append(cleanup) + # Store a reference to the function in the graph to ensure it stays alive + # as long as the graph lives. When the graph is destroyed, the function + # is left to the garbage collector for destruction as well. + graph._py_funcs_used_in_graph.append(func) # pylint: enable=protected-access if eager: @@ -242,34 +285,56 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): return result if is_list_or_tuple else result[0] +# TODO(akshayka): Implement higher-order derivatives. +@ops.RegisterGradient("EagerPyFunc") +def _EagerPyFuncGrad(op, dy): + """Computes the gradient of an EagerPyFunc.""" + + token = op.get_attr("token") + + def eagerly_executed_grad(dy): + tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token)) + return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy) + + with ops.control_dependencies(op.outputs): + return _internal_py_func( + func=eagerly_executed_grad, + inp=[dy] if isinstance(dy, ops.Tensor) else dy, + Tout=[tensor.dtype for tensor in op.inputs], + eager=True, + is_grad_func=True) + + def eager_py_func(func, inp, Tout, name=None): """Wraps a python function into a TensorFlow op that executes it eagerly. This function allows expressing computations in a TensorFlow graph as Python functions. In particular, it wraps a Python function `func` - in a TensorFlow operation that executes it with eager exeuction enabled. As a - consequence, `tf.contrib.eager.py_func` makes it possible to express control - flow using Python constructs (`if`, `while`, `for`, etc.), instead of - TensorFlow control flow constructs (@{tf.cond}, @{tf.while_loop}). For - example, you might use `tf.contrib.eager.py_func` to implement the log huber - function: + in a once-differentiable TensorFlow operation that executes it with eager + exeuction enabled. As a consequence, `tf.contrib.eager.py_func` makes it + possible to express control flow using Python constructs (`if`, `while`, + `for`, etc.), instead of TensorFlow control flow constructs (@{tf.cond}, + @{tf.while_loop}). For example, you might use `tf.contrib.eager.py_func` to + implement the log huber function: ```python def log_huber(x, m): if tf.abs(x) <= m: - return x ** 2 + return x**2 else: - return m ** 2 * (1 - 2 * tf.log(m) + tf.log(x ** 2)) + return m**2 * (1 - 2 * tf.log(m) + tf.log(x**2)) x = tf.placeholder(tf.float32) m = tf.placeholder(tf.float32) y = tf.contrib.eager.py_func(func=log_huber, inp=[x, m], Tout=tf.float32) + dy_dx = tf.gradients(y, x)[0] with tf.Session() as sess: # The session executes `log_huber` eagerly. Given the feed values below, - # it will take the second branch, so `output` evaluates to 7.24372. - output = sess.run(y, feed_dict={x: 3.0, m: 2.0}) + # it will take the first branch, so `y` evaluates to 1.0 and + # `dy_dx` evaluates to 2.0. + y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) ``` You can also use `tf.contrib.eager.py_func` to debug your models at runtime @@ -278,7 +343,7 @@ def eager_py_func(func, inp, Tout, name=None): or print statements as desired, and wrap those functions in `tf.contrib.eager.py_func`. - For more information on eager execution, see @{$programmers_guide/eager}. + For more information on eager execution, see @{$guide/eager}. `tf.contrib.eager.py_func` is similar in spirit to @{tf.py_func}, but unlike the latter, the former lets you use TensorFlow operations in the wrapped @@ -288,10 +353,6 @@ def eager_py_func(func, inp, Tout, name=None): that take Tensors as inputs, execute TensorFlow operations in their bodies, and return Tensors as outputs. - `tf.contrib.eager.py_func` is not differentiable, though a gradient may be - implemented in the future; if you would like to differentiate through it, - please file an issue on Github. - Like @{tf.py_func}, `tf.contrib.eager.py_func` has the following limitations with respect to serialization and distribution: diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 97353d6c747cb7e4d3c1fa92ad61af24fb17de91..1223b290ff6cfcfba27f40c05556c85b59e77148 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad): None, None) +@ops.RegisterGradient("SparseSlice") +def _SparseSliceGrad(op, *grads): + """The backward operator for the SparseSlice op. + + This op takes in the upstream gradient w.r.t. non-empty values of + the sliced `SparseTensor`, and outputs the gradients w.r.t. + the non-empty values of input `SparseTensor`. + + Args: + op: the SparseSlice op + *grads: the incoming gradients, one element per output of `op` + + Returns: + Gradient for each of the 5 input tensors of SparseSlice: + (indices, values, shape, start, size) + The gradients for the indices, shape, start and the size are None. + """ + backprop_val_grad = grads[1] + input_indices = op.inputs[0] + input_start = op.inputs[3] + output_indices = op.outputs[0] + + val_grad = gen_sparse_ops.sparse_slice_grad( + backprop_val_grad, input_indices, input_start, output_indices) + val_grad.set_shape(op.inputs[1].get_shape()) + # (indices, values, shape, start, size) + return (None, val_grad, None, None, None) + + @ops.RegisterGradient("SparseTensorDenseMatMul") def _SparseTensorDenseMatMulGrad(op, grad): """Gradients for the dense tensor in the SparseTensorDenseMatMul op. diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 6204adef3bb5dc96dab4a16bf05824d32627fccc..9a10abfcf736be783bfcd7907ec6f357912828ab 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -34,7 +34,7 @@ from tensorflow.python.util.tf_export import tf_export # TODO(b/27419586) Change docstring for required dtype of x once int allowed @tf_export('lbeta') -def lbeta(x, name='lbeta'): +def lbeta(x, name=None): r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define @@ -64,7 +64,7 @@ def lbeta(x, name='lbeta'): # This is consistent with a convention that the sum over the empty set 0, and # the product is 1. # This is standard. See https://en.wikipedia.org/wiki/Empty_set. - with ops.name_scope(name, values=[x]): + with ops.name_scope(name, 'lbeta', [x]): x = ops.convert_to_tensor(x, name='x') # Note reduce_sum([]) = 0. @@ -82,6 +82,54 @@ def lbeta(x, name='lbeta'): return result +@tf_export('math.bessel_i0') +def bessel_i0(x, name=None): + """Computes the Bessel i0 function of `x` element-wise. + + Modified Bessel function of order 0. + + It is preferable to use the numerically stabler function `i0e(x)` instead. + + Args: + x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + + @compatibility(scipy) + Equivalent to scipy.special.i0 + @end_compatibility + """ + with ops.name_scope(name, 'bessel_i0', [x]): + return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x) + + +@tf_export('math.bessel_i1') +def bessel_i1(x, name=None): + """Computes the Bessel i1 function of `x` element-wise. + + Modified Bessel function of order 1. + + It is preferable to use the numerically stabler function `i1e(x)` instead. + + Args: + x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + + @compatibility(scipy) + Equivalent to scipy.special.i1 + @end_compatibility + """ + with ops.name_scope(name, 'bessel_i1', [x]): + return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x) + + @tf_export('einsum', 'linalg.einsum') def einsum(equation, *inputs, **kwargs): """A generalized contraction between tensors of arbitrary dimension. @@ -153,6 +201,8 @@ def einsum(equation, *inputs, **kwargs): indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ + equation = equation.replace(' ', '') + name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index d7c3a7e8dc7c2ad611cf47718dddcf38700ce304..9bc4098d5b63c3e8ee4f9c14332e65b3d2875d8b 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -25,23 +25,25 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.platform import test - +from tensorflow.python.platform import tf_logging class LBetaTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def test_one_dimensional_arg(self): # Should evaluate to 1 and 1/2. x_one = [1, 1.] x_one_half = [2, 1.] with self.test_session(use_gpu=True): - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval()) - self.assertAllClose(0.5, - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one)))) + self.assertAllClose( + 0.5, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual([], special_math_ops.lbeta(x_one).get_shape()) def test_one_dimensional_arg_dynamic(self): @@ -52,7 +54,8 @@ class LBetaTest(test.TestCase): ph = array_ops.placeholder(dtypes.float32) beta_ph = math_ops.exp(special_math_ops.lbeta(ph)) self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one})) - self.assertAllClose(0.5, beta_ph.eval(feed_dict={ph: x_one_half})) + self.assertAllClose(0.5, + beta_ph.eval(feed_dict={ph: x_one_half})) def test_four_dimensional_arg_with_partial_shape_dynamic(self): x_ = np.ones((3, 2, 3, 4)) @@ -65,15 +68,17 @@ class LBetaTest(test.TestCase): with self.test_session(use_gpu=True): x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None]) beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph)) - self.assertAllClose(expected_beta_x, beta_ph.eval(feed_dict={x_ph: x_})) + self.assertAllClose(expected_beta_x, + beta_ph.eval(feed_dict={x_ph: x_})) + @test_util.run_in_graph_and_eager_modes def test_two_dimensional_arg(self): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose([0.5, 0.5], - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + [0.5, 0.5], + self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape()) def test_two_dimensional_arg_dynamic(self): @@ -82,50 +87,59 @@ class LBetaTest(test.TestCase): with self.test_session(use_gpu=True): ph = array_ops.placeholder(dtypes.float32) beta_ph = math_ops.exp(special_math_ops.lbeta(ph)) - self.assertAllClose([0.5, 0.5], beta_ph.eval(feed_dict={ph: x_one_half})) + self.assertAllClose([0.5, 0.5], + beta_ph.eval(feed_dict={ph: x_one_half})) + @test_util.run_in_graph_and_eager_modes def test_two_dimensional_proper_shape(self): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose([0.5, 0.5], - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + [0.5, 0.5], + self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual( (2,), - array_ops.shape(special_math_ops.lbeta(x_one_half)).eval()) + self.evaluate(array_ops.shape(special_math_ops.lbeta(x_one_half)))) self.assertEqual( tensor_shape.TensorShape([2]), special_math_ops.lbeta(x_one_half).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_complicated_shape(self): with self.test_session(use_gpu=True): x = ops.convert_to_tensor(np.random.rand(3, 2, 2)) - self.assertAllEqual((3, 2), - array_ops.shape(special_math_ops.lbeta(x)).eval()) + self.assertAllEqual( + (3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x)))) self.assertEqual( tensor_shape.TensorShape([3, 2]), special_math_ops.lbeta(x).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_length_1_last_dimension_results_in_one(self): # If there is only one coefficient, the formula still works, and we get one # as the answer, always. x_a = [5.5] x_b = [0.1] with self.test_session(use_gpu=True): - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_a)).eval()) - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_b)).eval()) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a)))) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_b)))) self.assertEqual((), special_math_ops.lbeta(x_a).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank1_returns_negative_infinity(self): with self.test_session(use_gpu=True): x = constant_op.constant([], shape=[0]) lbeta_x = special_math_ops.lbeta(x) expected_result = constant_op.constant(-np.inf, shape=()) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self): with self.test_session(use_gpu=True): event_size = 0 @@ -134,9 +148,11 @@ class LBetaTest(test.TestCase): lbeta_x = special_math_ops.lbeta(x) expected_result = constant_op.constant(-np.inf, shape=[batch_size]) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank2_with_zero_batch_dim_returns_empty(self): with self.test_session(use_gpu=True): batch_size = 0 @@ -146,10 +162,40 @@ class LBetaTest(test.TestCase): expected_result = constant_op.constant([], shape=[batch_size]) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) +class BesselTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_bessel_i0(self): + x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) + x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) + try: + from scipy import special # pylint: disable=g-import-not-at-top + self.assertAllClose(special.i0(x_single), + self.evaluate(special_math_ops.bessel_i0(x_single))) + self.assertAllClose(special.i0(x_double), + self.evaluate(special_math_ops.bessel_i0(x_double))) + except ImportError as e: + tf_logging.warn('Cannot test special functions: %s' % str(e)) + + @test_util.run_in_graph_and_eager_modes + def test_bessel_i1(self): + x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) + x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) + try: + from scipy import special # pylint: disable=g-import-not-at-top + self.assertAllClose(special.i1(x_single), + self.evaluate(special_math_ops.bessel_i1(x_single))) + self.assertAllClose(special.i1(x_double), + self.evaluate(special_math_ops.bessel_i1(x_double))) + except ImportError as e: + tf_logging.warn('Cannot test special functions: %s' % str(e)) + + class EinsumTest(test.TestCase): simple_cases = [ @@ -195,6 +241,12 @@ class EinsumTest(test.TestCase): 'iJ,Jk->ik', 'iJ,Ki->JK', 'iJk,Jklm->Jk' + 'ij, jk, kl -> il', + 'a, ab, abc -> abc', + 'ab, ab, cd, cd, ef, ef -> ', + 'abc, bac', + 'iJ, Ki -> JK', + 'iJk, Jklm -> Jk' ] long_cases = [ @@ -203,6 +255,8 @@ class EinsumTest(test.TestCase): 'ea,fb,gc,hd,abcd->efgh', 'ea,fb,abcd,gc,hd->efgh', 'abhe,hidj,jgba,hiab,gab', + 'efc, dbc, acf, fd -> abe', + 'abhe, hidj, jgba, hiab, gab', ] invalid_cases = [ @@ -273,20 +327,20 @@ class EinsumTest(test.TestCase): input_axes, _, _ = axes.partition('->') for idx in input_axes.split(','): - shape = [all_axes[ax] for ax in idx] + shape = [all_axes[ax] for ax in idx if ax.isalpha()] input_vals.append(np.random.random(shape)) input_tensors = [constant_op.constant(val) for val in input_vals] output_tensor = special_math_ops.einsum(axes, *input_tensors) with self.test_session(use_gpu=True): - output_value = output_tensor.eval() + output_value = self.evaluate(output_tensor) correct_value = np.einsum(axes, *input_vals) err = np.abs(correct_value - output_value).max() - print(axes, err) - assert err < 1e-8 + # print(axes, err) + self.assertLess(err, 1e-8) def test_input_is_placeholder(self): with ops.Graph().as_default(): @@ -298,8 +352,7 @@ class EinsumTest(test.TestCase): m0: [[1, 2, 3]], m1: [[2], [1], [1]], } - np.testing.assert_almost_equal([[7]], sess.run( - out, feed_dict=feed_dict)) + self.assertAllClose([[7]], sess.run(out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(None, 3)) @@ -310,7 +363,7 @@ class EinsumTest(test.TestCase): m0: [[1, 2, 3]], m1: [2, 1, 1], } - np.testing.assert_almost_equal([7], sess.run(out, feed_dict=feed_dict)) + self.assertAllClose([7], sess.run(out, feed_dict=feed_dict)) # Tests for placeholders which have two or more None values with ops.Graph().as_default(): @@ -322,8 +375,7 @@ class EinsumTest(test.TestCase): m0: [[[1, 2]]], m1: [[3], [2]], } - np.testing.assert_almost_equal([[[7]]], - sess.run(out, feed_dict=feed_dict)) + self.assertAllClose([[[7]]], sess.run(out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1)) @@ -334,8 +386,7 @@ class EinsumTest(test.TestCase): m0: [[3], [2]], m1: [[[1, 2]]], } - np.testing.assert_almost_equal([[[7]]], - sess.run(out, feed_dict=feed_dict)) + self.assertAllClose([[[7]]], sess.run(out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2)) @@ -346,8 +397,7 @@ class EinsumTest(test.TestCase): m0: [[[1, 2]]], m1: [3, 2], } - np.testing.assert_almost_equal([[7]], sess.run( - out, feed_dict=feed_dict)) + self.assertAllClose([[7]], sess.run(out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2)) @@ -358,8 +408,7 @@ class EinsumTest(test.TestCase): m0: [[[[1, 2]], [[2, 1]]]], m1: [[3, 2]], } - np.testing.assert_almost_equal([[[7, 8]]], - sess.run(out, feed_dict=feed_dict)) + self.assertAllClose([[[7, 8]]], sess.run(out, feed_dict=feed_dict)) if __name__ == '__main__': diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 28054f50ef3b1227f12376b4b3700a7618270d65..293aace7282eb0f8dde9da75b0d353a560c0ecb9 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm): raise NotImplementedError("The DCT length argument is not implemented.") if axis != -1: raise NotImplementedError("axis must be -1. Got: %s" % axis) - if dct_type != 2: - raise ValueError("Only the Type II DCT is supported.") + if dct_type not in (2, 3): + raise ValueError("Only Types II and III (I)DCT are supported.") if norm not in (None, "ortho"): raise ValueError( "Unknown normalization. Expected None or 'ortho', got: %s" % norm) @@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm): def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. - Currently only Type II is supported. Implemented using a length `2N` padded - @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606 + Currently only Types II and III are supported. Type II is implemented using a + length `2N` padded @{tf.spectral.rfft}, as described here: + https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward + inverse of Type II (i.e. using a length `2N` padded @{tf.spectral.irfft}). @compatibility(scipy) - Equivalent to scipy.fftpack.dct for the Type-II DCT. + Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT. https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html @end_compatibility Args: input: A `[..., samples]` `float32` `Tensor` containing the signals to take the DCT of. - type: The DCT type to perform. Must be 2. + type: The DCT type to perform. Must be 2 or 3. n: For future expansion. The length of the transform. Must be `None`. axis: For future expansion. The axis to compute the DCT along. Must be `-1`. norm: The normalization to apply. `None` for no normalization or `'ortho'` @@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl A `[..., samples]` `float32` `Tensor` containing the DCT of `input`. Raises: - ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or - `norm` is not `None` or `'ortho'`. + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform """ @@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1] axis_dim_float = _math_ops.to_float(axis_dim) - scale = 2.0 * _math_ops.exp(_math_ops.complex( - 0.0, -_math.pi * _math_ops.range(axis_dim_float) / - (2.0 * axis_dim_float))) - - # TODO(rjryan): Benchmark performance and memory usage of the various - # approaches to computing a DCT via the RFFT. - dct2 = _math_ops.real( - rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) - - if norm == "ortho": - n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) - n2 = n1 * _math_ops.sqrt(2.0) - # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. - weights = _array_ops.pad( - _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], - constant_values=n2) - dct2 *= weights - - return dct2 + if type == 2: + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + + # TODO(rjryan): Benchmark performance and memory usage of the various + # approaches to computing a DCT via the RFFT. + dct2 = _math_ops.real( + rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) + + if norm == "ortho": + n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(2.0) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + dct2 *= weights + + return dct2 + + elif type == 3: + if norm == "ortho": + n1 = _math_ops.sqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(0.5) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + input *= weights + else: + input *= axis_dim_float + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, + _math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + dct3 = _math_ops.real( + irfft( + scale * _math_ops.complex(input, 0.0), + fft_length=[2 * axis_dim]))[..., :axis_dim] + + return dct3 + + +# TODO(rjryan): Implement `type`, `n` and `axis` parameters. +@tf_export("spectral.idct") +def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin + """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. + + Currently only Types II and III are supported. Type III is the inverse of + Type II, and vice versa. + + Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is + not `'ortho'`. That is: + `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. + When `norm='ortho'`, we have: + `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. + + @compatibility(scipy) + Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT. + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html + @end_compatibility + + Args: + input: A `[..., samples]` `float32` `Tensor` containing the signals to take + the DCT of. + type: The IDCT type to perform. Must be 2 or 3. + n: For future expansion. The length of the transform. Must be `None`. + axis: For future expansion. The axis to compute the DCT along. Must be `-1`. + norm: The normalization to apply. `None` for no normalization or `'ortho'` + for orthonormal normalization. + name: An optional name for the operation. + + Returns: + A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`. + + Raises: + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. + + [idct]: + https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms + """ + _validate_dct_arguments(type, n, axis, norm) + inverse_type = {2: 3, 3: 2}[type] + return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index a2d24711e2291bafcf5736c6206ceb09ac210453..d0e5f700254fa5273cb707e59ac0d141fdc13627 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import cudnn_rnn_grad from tensorflow.python.ops import data_flow_grad from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad +from tensorflow.python.ops import random_grad from tensorflow.python.ops import sparse_grad from tensorflow.python.ops import spectral_grad from tensorflow.python.ops import state_grad diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 94d7458ec8735836566033faae95a3aed3af1824..8cb6a0537e928effbcf4c475bcc4e974182da2a7 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -338,7 +338,6 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): Args: ref: A Variable. indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. - A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. updates: A `Tensor`. Must have the same type as `ref`. A Tensor. Must have the same type as ref. A tensor of updated @@ -355,10 +354,9 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.scatter_nd_update( ref, indices, updates, use_locking, name) - with ops.control_dependencies([gen_state_ops.resource_scatter_nd_update( - ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype), - use_locking, name)]): - return ref.read_value() + return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) @tf_export("scatter_add") @@ -396,7 +394,7 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None): A tensor of indices into the first dimension of `ref`. updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated values to store in `ref`. - use_locking: An optional `bool`. Defaults to `True`. + use_locking: An optional `bool`. Defaults to `False`. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. name: A name for the operation (optional). @@ -411,3 +409,67 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None): return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), name=name)) + + +@tf_export("scatter_nd_add") +def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): + r"""Applies sparse addition to individual values or slices in a Variable. + + `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + + `indices` must be integer tensor, containing indices into `ref`. + It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + + The innermost dimension of `indices` (with length `K`) corresponds to + indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th + dimension of `ref`. + + `updates` is `Tensor` of rank `Q-1+P-K` with shape: + + ``` + [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. + ``` + + For example, say we want to add 4 scattered elements to a rank-1 tensor to + 8 elements. In Python, that update would look like this: + + ```python + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + add = tf.scatter_nd_add(ref, indices, updates) + with tf.Session() as sess: + print sess.run(add) + ``` + + The resulting update to ref would look like this: + + [1, 13, 3, 14, 14, 6, 7, 20] + + See @{tf.scatter_nd} for more details about how to make updates to + slices. + + Args: + ref: A mutable `Tensor`. Must be one of the following types: `float32`, + `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, + `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, + `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A tensor of indices into ref. + updates: A `Tensor`. Must have the same type as `ref`. + A tensor of updated values to add to ref. + use_locking: An optional `bool`. Defaults to `False`. + An optional bool. Defaults to True. If True, the assignment will + be protected by a lock; otherwise the behavior is undefined, + but may exhibit less contention. + name: A name for the operation (optional). + + Returns: + A mutable `Tensor`. Has the same type as `ref`. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.scatter_nd_add( + ref, indices, updates, use_locking, name) + return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index ae79c0194954a052db799d7a00ce1ddc584ea6ed..0280c89c10f264dcc37c89598599d377e8ac9e07 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -91,6 +91,59 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv shape.set_shape([2]) return sparse_tensor.SparseTensor(indices, values, shape) +@tf_export("strings.split") +def string_split_v2(source, sep=None, maxsplit=-1): + """Split elements of `source` based on `sep` into a `SparseTensor`. + + Let N be the size of source (typically N will be the batch size). Split each + element of `source` based on `sep` and return a `SparseTensor` + containing the split tokens. Empty tokens are ignored. + + For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', + then the output will be + + st.indices = [0, 0; + 0, 1; + 1, 0; + 1, 1; + 1, 2] + st.shape = [2, 3] + st.values = ['hello', 'world', 'a', 'b', 'c'] + + If `sep` is given, consecutive delimiters are not grouped together and are + deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and + sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty + string, consecutive whitespace are regarded as a single separator, and the + result will contain no empty strings at the startor end if the string has + leading or trailing whitespace. + + Note that the above mentioned behavior matches python's str.split. + + Args: + source: `1-D` string `Tensor`, the strings to split. + sep: `0-D` string `Tensor`, the delimiter character. + maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. + + Raises: + ValueError: If sep is not a string. + + Returns: + A `SparseTensor` of rank `2`, the strings split according to the delimiter. + The first column of the indices corresponds to the row in `source` and the + second column corresponds to the index of the split component in this row. + """ + if sep is None: + sep = '' + sep = ops.convert_to_tensor(sep, dtype=dtypes.string) + source = ops.convert_to_tensor(source, dtype=dtypes.string) + + indices, values, shape = gen_string_ops.string_split_v2( + source, sep=sep, maxsplit=maxsplit) + indices.set_shape([None, 2]) + values.set_shape([None]) + shape.set_shape([2]) + return sparse_tensor.SparseTensor(indices, values, shape) + def _reduce_join_reduction_dims(x, axis, reduction_indices): """Returns range(rank(x) - 1, 0, -1) if reduction_indices is None.""" diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index b80f84eb7cde264c5a7c83eafacc344adb50b80a..00150fe68820da711c76f642baced45163a8727c 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -306,10 +306,11 @@ def create_db_writer(db_uri, def _make_summary_writer(name, factory, **kwargs): resource = gen_summary_ops.summary_writer(shared_name=name) init_op_fn = lambda: factory(resource, **kwargs) - # TODO(apassos): Consider doing this instead. - # if not context.executing_eagerly(): - # ops.get_default_session().run(init_op) - ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn()) + init_op = init_op_fn() + if not context.executing_eagerly(): + # TODO(apassos): Consider doing this instead. + # ops.get_default_session().run(init_op) + ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op) return SummaryWriter(resource, init_op_fn) @@ -380,7 +381,8 @@ def summary_writer_function(name, tensor, function, family=None): with ops.device("cpu:0"): op = smart_cond.smart_cond( should_record_summaries(), record, _nothing, name="") - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 355b0d961e2105bf19105dbc6f8a9ddfc41c0d30..161d9687d6b0af58a3e8aef5518d70432e70691c 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_util from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util.deprecation import deprecated @@ -295,66 +296,6 @@ class Template(checkpointable.CheckpointableBase): # which is not the same as whether the scope has been created. self._variables_created = False - def _checkpointable_custom_creator(self, next_creator, name, initial_value, - checkpointable_parent=None, **kwargs): - """A variable creation hook which adds Checkpointable dependencies. - - Set during the `Template`'s first wrapped function execution. Ensures that - (a) `Template` objects depend on `Template`s created inside them which - create variables, and (b) that any variables not in a more deeply nested - `Template` are added as dependencies directly. - - The `checkpointable_parent` argument is passed between `Template` custom - creators but ignored when the variable object itself is created. This - argument indicates (if not `None`) that a more deeply nested `Template` has - already added the variable as a dependency, and that parent `Template`s - should add a dependency on that `Template` rather than on the variable - directly. - - Args: - next_creator: See `variable_scope.variable_creator_scope`; the next - creator in the chain. - name: The (full, scope-influenced) name of the variable. The scope name - for the Template itself is stripped for the purposes of object-based - dependency tracking, but scopes within Templates are respected. - initial_value: See `variable_scope.variable_creator_scope`. Taken - explicitly so the argument can be re-named and used with - `Checkpointable._add_variable_with_custom_getter`. - checkpointable_parent: If not None, a more deeply nested Template object - to add a dependency on (rather than depending on the variable directly). - **kwargs: Passed through to the next creator. - Returns: - The output of `next_creator`: the fetched/created variable object. - """ - def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): - inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which - # we don't want to propagate. - return next_creator( - initial_value=initializer, - name=name, - **inner_kwargs) - if name.startswith(self._variable_scope.name): - scope_stripped_name = name[len(self._variable_scope.name) + 1:] - if not checkpointable_parent: - return self._add_variable_with_custom_getter( - initializer=initial_value, - name=scope_stripped_name, - getter=_call_next_creator_renaming_initializer, - # Disable error checking for Checkpointable. Exceptions are instead - # raised if necessary when the object-based saver tries to - # save/restore the object. - overwrite=True, - checkpointable_parent=self, - **kwargs) - else: - self._track_checkpointable( - checkpointable_parent, - name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access - len(self._variable_scope.name) + 1:], - overwrite=True) - return next_creator(name=name, initial_value=initial_value, - checkpointable_parent=self, **kwargs) - def _call_func(self, args, kwargs): try: vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) @@ -365,8 +306,7 @@ class Template(checkpointable.CheckpointableBase): else: # The first time we run, restore variables if necessary (via # Checkpointable). - with variable_scope.variable_creator_scope( - self._checkpointable_custom_creator): + with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: @@ -634,8 +574,7 @@ class EagerTemplate(Template): else: # The first time we run, restore variables if necessary (via # Checkpointable). - with variable_scope.variable_creator_scope( - self._checkpointable_custom_creator): + with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py index 1f70d695485ca0aab22c532099caad1b361d3637..d34134980400999ee2b0de9362423b2ec495868f 100644 --- a/tensorflow/python/ops/tensor_array_grad.py +++ b/tensorflow/python/ops/tensor_array_grad.py @@ -34,6 +34,7 @@ ops.NotDifferentiable("TensorArrayCloseV2") ops.NotDifferentiable("TensorArrayV3") ops.NotDifferentiable("TensorArrayGradV3") +ops.NotDifferentiable("TensorArrayGradWithShape") ops.NotDifferentiable("TensorArraySizeV3") ops.NotDifferentiable("TensorArrayCloseV3") diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8d93d24b149a2fd27b956e73d9e866b61ca97287..1e06bf07d5aaa88a4a30760450cffc32a20f4ca5 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope", - "get_variable", "get_local_variable", "variable_scope", - "variable_op_scope", "no_regularizer"] +__all__ = [ + "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", + "get_local_variable", "variable_scope", "variable_op_scope", + "no_regularizer", "VariableSynchronization", "VariableAggregation" +] class _PartitionInfo(object): @@ -188,6 +190,38 @@ class _ReuseMode(enum.Enum): # REUSE_FALSE = 2 # REUSE_TRUE = 3 + +@tf_export("VariableSynchronization") +class VariableSynchronization(enum.Enum): + """Indicates when a distributed variable will be synced.""" + + # Indicates that the synchronization will be determined by the current + # `DistributionStrategy` (eg. With `MirroredStrategy` this would be + # `ON_WRITE`). + AUTO = 0 + + # Indicates that there will only be one copy of the variable, so there is no + # need to sync. + NONE = 1 + + # Indicates that the variable will be aggregated across devices + # every time it is updated. + ON_WRITE = 2 + + # Indicates that the variable will be aggregated across devices + # when it is read (eg. when checkpointing or when evaluating an op that uses + # the variable). + ON_READ = 3 + + +@tf_export("VariableAggregation") +class VariableAggregation(enum.Enum): + """Indicates how a distributed variable will be aggregated.""" + NONE = 0 + SUM = 1 + MEAN = 2 + + AUTO_REUSE = _ReuseMode.AUTO_REUSE tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") AUTO_REUSE.__doc__ = """ @@ -214,11 +248,23 @@ class _VariableStore(object): self._partitioned_vars = {} # A dict of the stored PartitionedVariables. self._store_eager_variables = False - def get_variable(self, name, shape=None, dtype=dtypes.float32, - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, - partitioner=None, validate_shape=True, use_resource=None, - custom_getter=None, constraint=None): + def get_variable(self, + name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + custom_getter=None, + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Gets an existing variable with these parameters or create a new one. If a variable with the given name is already stored, we return the stored @@ -291,6 +337,14 @@ class _VariableStore(object): variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. Returns: The created or existing `Variable` (or `PartitionedVariable`, if a @@ -343,11 +397,22 @@ class _VariableStore(object): # it to custom_getter. # Note: the parameters of _true_getter, and their documentation, match # *exactly* item-for-item with the docstring of this method. - def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, - partitioner=None, validate_shape=True, use_resource=None, - constraint=None): + def _true_getter( # pylint: disable=missing-docstring + name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): is_scalar = (shape is not None and isinstance(shape, collections_lib.Sequence) and not shape) @@ -397,11 +462,20 @@ class _VariableStore(object): "name was already created with partitioning?" % name) return self._get_single_variable( - name=name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, reuse=reuse, - trainable=trainable, collections=collections, - caching_device=caching_device, validate_shape=validate_shape, - use_resource=use_resource, constraint=constraint) + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + reuse=reuse, + trainable=trainable, + collections=collections, + caching_device=caching_device, + validate_shape=validate_shape, + use_resource=use_resource, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) if custom_getter is not None: # Handle backwards compatibility with getter arguments that were added @@ -420,6 +494,8 @@ class _VariableStore(object): "partitioner": partitioner, "validate_shape": validate_shape, "use_resource": use_resource, + "synchronization": synchronization, + "aggregation": aggregation, } # `fn_args` can handle functions, `functools.partial`, `lambda`. if "constraint" in function_utils.fn_args(custom_getter): @@ -427,12 +503,21 @@ class _VariableStore(object): return custom_getter(**custom_getter_kwargs) else: return _true_getter( - name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, - reuse=reuse, trainable=trainable, collections=collections, - caching_device=caching_device, partitioner=partitioner, - validate_shape=validate_shape, use_resource=use_resource, - constraint=constraint) + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + reuse=reuse, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) def _get_partitioned_variable( self, name, partitioner, shape=None, dtype=dtypes.float32, @@ -693,7 +778,9 @@ class _VariableStore(object): caching_device=None, validate_shape=True, use_resource=None, - constraint=None): + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Get or create a single Variable (e.g. a shard or entire variable). See the documentation of get_variable above (ignore partitioning components) @@ -713,6 +800,8 @@ class _VariableStore(object): validate_shape: see get_variable. use_resource: see get_variable. constraint: see get_variable. + synchronization: see get_variable. + aggregation: see get_variable. Returns: A Variable. See documentation of get_variable above. @@ -793,7 +882,17 @@ class _VariableStore(object): dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) + if context.executing_eagerly() and self._store_eager_variables: + if collections: + ops.add_to_collections(collections, v) + else: + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) + if trainable: + ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) + if not context.executing_eagerly() or self._store_eager_variables: # In eager mode we do not want to keep default references to Variable # objects as this will prevent their memory from being released. @@ -1044,7 +1143,9 @@ class VariableScope(object): validate_shape=True, use_resource=None, custom_getter=None, - constraint=None): + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Gets an existing variable with this name or create a new one.""" if regularizer is None: regularizer = self._regularizer @@ -1082,12 +1183,22 @@ class VariableScope(object): if dtype is None: dtype = self._dtype return var_store.get_variable( - full_name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, reuse=reuse, trainable=trainable, - collections=collections, caching_device=caching_device, - partitioner=partitioner, validate_shape=validate_shape, - use_resource=use_resource, custom_getter=custom_getter, - constraint=constraint) + full_name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + reuse=reuse, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + custom_getter=custom_getter, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) def _get_partitioned_variable(self, var_store, @@ -1261,13 +1372,13 @@ class EagerVariableStore(object): def trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if x._trainable], + return sorted([x for x in self._store._vars.values() if x.trainable], key=lambda x: x.name) # pylint: enable=protected-access def non_trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if not x._trainable], + return sorted([x for x in self._store._vars.values() if not x.trainable], key=lambda x: x.name) # pylint: enable=protected-access @@ -1296,7 +1407,7 @@ class EagerVariableStore(object): new_var = resource_variable_ops.ResourceVariable( var.read_value(), name=stripped_var_name, - trainable=var._trainable) + trainable=var.trainable) new_store._store._vars[key] = new_var return new_store # pylint: enable=protected-access @@ -1318,14 +1429,28 @@ def get_variable(name, validate_shape=True, use_resource=None, custom_getter=None, - constraint=None): + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): return get_variable_scope().get_variable( - _get_default_variable_store(), name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, trainable=trainable, - collections=collections, caching_device=caching_device, - partitioner=partitioner, validate_shape=validate_shape, - use_resource=use_resource, custom_getter=custom_getter, - constraint=constraint) + _get_default_variable_store(), + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + custom_getter=custom_getter, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) + + get_variable_or_local_docstring = ( """%s @@ -1422,29 +1547,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % ( # The argument list for get_local_variable must match arguments to get_variable. # So, if you are updating the arguments, also update arguments to get_variable. @tf_export("get_local_variable") -def get_local_variable(name, - shape=None, - dtype=None, - initializer=None, - regularizer=None, - trainable=False, # pylint: disable=unused-argument - collections=None, - caching_device=None, - partitioner=None, - validate_shape=True, - use_resource=None, - custom_getter=None, - constraint=None): +def get_local_variable( # pylint: disable=missing-docstring + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=False, # pylint: disable=unused-argument + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE, + custom_getter=None, + constraint=None): if collections: collections += [ops.GraphKeys.LOCAL_VARIABLES] else: collections = [ops.GraphKeys.LOCAL_VARIABLES] return get_variable( - name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, trainable=False, collections=collections, - caching_device=caching_device, partitioner=partitioner, - validate_shape=validate_shape, use_resource=use_resource, - custom_getter=custom_getter, constraint=constraint) + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=False, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation, + custom_getter=custom_getter, + constraint=constraint) + + get_local_variable.__doc__ = get_variable_or_local_docstring % ( "Gets an existing *local* variable or creates a new one.", "Behavior is the same as in `get_variable`, except that variables are\n" @@ -1778,6 +1918,23 @@ class variable_scope(object): assert v.name == "foo/bar/v:0" ``` + Simple example of how to reenter a premade variable scope safely: + + ```python + with tf.variable_scope("foo") as vs: + pass + + # Re-enter the variable scope. + with tf.variable_scope(vs, + auxiliary_name_scope=False) as vs1: + # Restore the original name_scope. + with tf.name_scope(vs1.original_name_scope): + v = tf.get_variable("v", [1]) + assert v.name == "foo/v:0" + c = tf.constant([1], name="c") + assert c.name == "foo/c:0" + ``` + Basic example of sharing a variable AUTO_REUSE: ```python @@ -1900,7 +2057,8 @@ class variable_scope(object): for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create variables if they do not exist, and return them otherwise; if None, we inherit the parent scope's reuse flag. When eager execution is enabled, - this argument is always forced to be tf.AUTO_REUSE. + new variables are always created unless an EagerVariableStore or + template is currently active. dtype: type of variables created in this scope (defaults to the type in the passed scope, or inherited from parent scope). use_resource: If False, all variables will be regular Variables. If True, @@ -1915,7 +2073,9 @@ class variable_scope(object): (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. auxiliary_name_scope: If `True`, we create an auxiliary name scope with - the scope. If `False`, we don't touch name scope. + the scope. If `False`, we don't create it. Note that the argument is + not inherited, and it only takes effect for once when creating. You + should only use it for re-entering a premade variable scope. Returns: A scope that can be captured and reused. @@ -2186,6 +2346,12 @@ def default_variable_creator(next_creator=None, **kwargs): dtype = kwargs.get("dtype", None) constraint = kwargs.get("constraint", None) use_resource = kwargs.get("use_resource", None) + + # Enforce `ON_READ` variables to be not trainable. + synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO) + if synchronization == VariableSynchronization.ON_READ: + trainable = False + if use_resource is None: use_resource = get_variable_scope().use_resource if use_resource or (use_resource is None and context.executing_eagerly()): @@ -2220,18 +2386,28 @@ def variable(initial_value=None, name=None, dtype=None, constraint=None, - use_resource=None): + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access previous_getter = _make_getter(getter, previous_getter) - return previous_getter(initial_value=initial_value, - trainable=trainable, - collections=collections, - validate_shape=validate_shape, - caching_device=caching_device, - name=name, dtype=dtype, - constraint=constraint, - use_resource=use_resource) + + # Reset `aggregation` that is explicitly set as `None` to the enum None value. + if aggregation is None: + aggregation = VariableAggregation.NONE + return previous_getter( + initial_value=initial_value, + trainable=trainable, + collections=collections, + validate_shape=validate_shape, + caching_device=caching_device, + name=name, + dtype=dtype, + constraint=constraint, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) @tf_contextlib.contextmanager @@ -2283,6 +2459,14 @@ def variable_creator_scope(variable_creator): constraint: A constraint function to be applied to the variable after updates by some algorithms. use_resource: if True, a ResourceVariable is always created. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. This set may grow over time, so it's important the signature of creators is as mentioned above. diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 294ee0e32831928d34d6bae5668898609a37dc59..9a09cdaa52425713cf18362dd8726fe7207c604f 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -123,6 +123,30 @@ class Variable(checkpointable.CheckpointableBase): various `Optimizer` classes use this collection as the default list of variables to optimize. + WARNING: tf.Variable objects have a non-intuitive memory model. A Variable is + represented internally as a mutable Tensor which can non-deterministically + alias other Tensors in a graph. The set of operations which consume a Variable + and can lead to aliasing is undetermined and can change across TensorFlow + versions. Avoid writing code which relies on the value of a Variable either + changing or not changing as other operations happen. For example, using + Variable objects or simple functions thereof as predicates in a `tf.cond` is + dangerous and error-prone: + + ``` + v = tf.Variable(True) + tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. + ``` + + Here replacing tf.Variable with tf.contrib.eager.Variable will fix any + nondeterminism issues. + + To use the replacement for variables which does + not have these issues: + + * Replace `tf.Variable` with `tf.contrib.eager.Variable`; + * Call `tf.get_variable_scope().set_use_resource(True)` inside a + `tf.variable_scope` before the `tf.get_variable()` call. + @compatibility(eager) `tf.Variable` is not compatible with eager execution. Use `tf.contrib.eager.Variable` instead which is compatible with both eager @@ -235,7 +259,7 @@ class Variable(checkpointable.CheckpointableBase): constraint=constraint) def __repr__(self): - if context.executing_eagerly(): + if context.executing_eagerly() and not self._in_graph_mode: return "" % ( self.name, self.get_shape(), self.dtype.name, ops.numpy_text(self.read_value(), is_repr=True)) @@ -317,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase): self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value + self._trainable = trainable if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): @@ -426,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) self._snapshot = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) @@ -519,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase): self._ref().set_shape(shape) self.value().set_shape(shape) + @property + def trainable(self): + return self._trainable + def eval(self, session=None): """In a session, computes and returns the value of this variable. @@ -1026,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase): # For backwards compatibility. var_def.initial_value_name = ops.strip_name_scope( self._initial_value.name, export_scope) + var_def.trainable = self.trainable var_def.initializer_name = ops.strip_name_scope( self.initializer.name, export_scope) var_def.snapshot_name = ops.strip_name_scope( @@ -1062,39 +1093,40 @@ class Variable(checkpointable.CheckpointableBase): def __imul__(self, other): logging.log_first_n( logging.WARN, - "Variable *= will be deprecated. Use variable.assign_mul" - " if you want assignment to the variable value or 'x = x * y'" + "Variable *= will be deprecated. Use `var.assign(var * other)`" + " if you want assignment to the variable value or `x = x * y`" " if you want a new python Tensor object.", 1) return self * other def __idiv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __itruediv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __irealdiv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __ipow__(self, other): logging.log_first_n( logging.WARN, - "Variable **= will be deprecated. Use 'x = x ** y'" + "Variable **= will be deprecated. Use `var.assign(var ** other)`" + " if you want assignment to the variable value or `x = x ** y`" " if you want a new python Tensor object.", 1) return self ** other @@ -1691,6 +1723,8 @@ def report_uninitialized_variables(var_list=None, var_list.append(op.outputs[0]) with ops.name_scope(name): # Run all operations on CPU + if var_list: + init_vars = [state_ops.is_variable_initialized(v) for v in var_list] with ops.device("/cpu:0"): if not var_list: # Return an empty tensor so we only need to check for returned tensor @@ -1698,9 +1732,7 @@ def report_uninitialized_variables(var_list=None, return array_ops.constant([], dtype=dtypes.string) else: # Get a 1-D boolean tensor listing whether each variable is initialized. - variables_mask = math_ops.logical_not( - array_ops.stack( - [state_ops.is_variable_initialized(v) for v in var_list])) + variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) # Get a 1-D string tensor containing all the variable names. variable_names_tensor = array_ops.constant( [s.op.name for s in var_list]) diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index 478dd46f7e6965f8727e5741f2ccdfdc69247980..2e06f26fa4c43b934a1c7ff08eba6c9b755f8085 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -233,7 +233,7 @@ _COPY_TYPEMAPS(unsigned int, mode_t); // Typemaps to automatically raise a Python exception from bad output TF_Status. // TODO(b/77295559): expand this to all TF_Status* output params and deprecate // raise_exception_on_not_ok_status (currently it only affects the C API). -%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) { +%typemap(in, numinputs=0) TF_Status* status { $1 = TF_NewStatus(); } diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 9e49188c1ef353d345c97ea0295aa1a68283605e..f9891f3b1e2e94f61329babd1409e3efacc7f5b3 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -707,8 +707,10 @@ class PrintModelAnalysisTest(test.TestCase): a = array_ops.constant(np.ones((100, 100))) b = array_ops.constant(np.ones((100, 100))) c = a * b + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.min_graph_nodes = -1 - with session.Session() as sess: + with session.Session(config=config) as sess: run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5ee55301df986998b22b8b57b5f01b1f6b4918ac..5d7535cf34f7396b7ff6aebd3984046e51c98347 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -33,8 +33,9 @@ limitations under the License. %rename("%s") TFE_ContextAsyncClearError; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; +%rename("%s") TFE_Py_SetEagerTensorProfiler; %rename("%s") TFE_Py_RegisterExceptionClass; -%rename("%s") TFE_Py_RegisterBackwardFunctionGetter; +%rename("%s") TFE_Py_RegisterGradientFunction; %rename("%s") TFE_Py_RegisterFallbackExceptionClass; %rename("%s") TFE_Py_RegisterResourceVariableType; %rename("%s") TFE_Py_Execute; @@ -42,6 +43,7 @@ limitations under the License. %rename("%s") TFE_Py_RecordGradient; %rename("%s") TFE_Py_UID; %rename("%s") TFE_Py_TapeSetNew; +%rename("%s") TFE_Py_TapeSetAdd; %rename("%s") TFE_Py_TapeSetRemove; %rename("%s") TFE_Py_TapeSetStopOnThread; %rename("%s") TFE_Py_TapeSetRestartOnThread; @@ -57,8 +59,10 @@ limitations under the License. %rename("%s") TFE_ContextOptionsSetConfig; %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; %rename("%s") TFE_ContextOptionsSetAsync; +%rename("%s") TFE_ContextOptionsSetServerDef; %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; +%rename("%s") TFE_Py_TensorShapeOnDevice; %{ #include "tensorflow/python/eager/pywrap_tfe.h" @@ -152,9 +156,12 @@ limitations under the License. if (EagerTensor_CheckExact(elem)) { (*$1)[i] = EagerTensor_Handle(elem); } else { - SWIG_exception_fail(SWIG_TypeError, - "provided list of inputs contains objects other " - "than 'EagerTensor'"); + SWIG_exception_fail( + SWIG_TypeError, + tensorflow::strings::StrCat( + "provided list of inputs contains objects other " + "than 'EagerTensor'. Item ", + i, " is ", elem->ob_type->tp_name).c_str()); } } } diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 2609a5d222659f6ebf775d6baa48bd7bc39fd7f6..076f2d8760fe00035ef5830a02d22e82c54dd768 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -87,6 +87,30 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "loader_test", + size = "small", + srcs = ["loader_test.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":builder", + ":loader", + ":signature_def_utils", + ":utils", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) @@ -149,6 +173,7 @@ py_test( "//tensorflow/python:saver_test_utils", "//tensorflow/python:state_ops", "//tensorflow/python:test_ops", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", ], diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 8f1d5a099f8dce262ae902230a010e532c279e3e..e58be804c2738dbad0e2f90c21d6eff3832a8148 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -104,10 +104,10 @@ class SavedModelBuilder(object): Args: assets_collection_to_add: The collection where the asset paths are setup. """ - asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add) + asset_filename_map = _maybe_save_assets(assets_collection_to_add) # Return if there are no assets to write. - if len(asset_source_filepath_list) is 0: + if not asset_filename_map: tf_logging.info("No assets to write.") return @@ -119,12 +119,10 @@ class SavedModelBuilder(object): file_io.recursive_create_dir(assets_destination_dir) # Copy each asset from source path to destination path. - for asset_source_filepath in asset_source_filepath_list: - asset_source_filename = os.path.basename(asset_source_filepath) - + for asset_basename, asset_source_filepath in asset_filename_map.items(): asset_destination_filepath = os.path.join( compat.as_bytes(assets_destination_dir), - compat.as_bytes(asset_source_filename)) + compat.as_bytes(asset_basename)) # Only copy the asset file to the destination if it does not already # exist. This is to ensure that an asset with the same name defined as @@ -272,6 +270,18 @@ class SavedModelBuilder(object): self._add_train_op(train_op) + def _maybe_create_saver(self, saver=None): + """Creates a sharded saver if one does not already exist.""" + if not saver: + # Initialize a saver to generate a sharded output for all saveables in the + # current scope. + saver = tf_saver.Saver( + variables._all_saveable_objects(), # pylint: disable=protected-access + sharded=True, + write_version=saver_pb2.SaverDef.V2, + allow_empty=True) + return saver + def add_meta_graph(self, tags, signature_def_map=None, @@ -279,7 +289,8 @@ class SavedModelBuilder(object): legacy_init_op=None, clear_devices=False, main_op=None, - strip_default_attrs=False): + strip_default_attrs=False, + saver=None): # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel. @@ -304,6 +315,9 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + saver: An instance of tf.train.Saver that will be used to export the + metagraph. If None, a sharded Saver that restores all variables will + be used. Raises: AssertionError: If the variables for the SavedModel have not been saved @@ -322,18 +336,11 @@ class SavedModelBuilder(object): # Add assets and ops self._add_collections(assets_collection, legacy_init_op, main_op, None) - # Initialize a saver to generate a sharded output for all saveables in the - # current scope. - saver = tf_saver.Saver( - variables._all_saveable_objects(), # pylint: disable=protected-access - sharded=True, - write_version=saver_pb2.SaverDef.V2, - allow_empty=True) + saver = self._maybe_create_saver(saver) # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another - # for the model weights). However, a *new* Saver was just created that - # includes all of the variables. Removing the preexisting ones was the + # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. @@ -352,7 +359,8 @@ class SavedModelBuilder(object): legacy_init_op=None, clear_devices=False, main_op=None, - strip_default_attrs=False): + strip_default_attrs=False, + saver=None): # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel and saves variables. @@ -379,6 +387,9 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + saver: An instance of tf.train.Saver that will be used to export the + metagraph and save variables. If None, a sharded Saver that restores + all variables will be used. """ # pylint: enable=line-too-long @@ -405,13 +416,7 @@ class SavedModelBuilder(object): compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) - # Initialize a saver to generate a sharded output for all saveables in the - # current scope. - saver = tf_saver.Saver( - variables._all_saveable_objects(), # pylint: disable=protected-access - sharded=True, - write_version=saver_pb2.SaverDef.V2, - allow_empty=True) + saver = self._maybe_create_saver(saver) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a @@ -423,8 +428,7 @@ class SavedModelBuilder(object): # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another - # for the model weights). However, a *new* Saver was just created that - # includes all of the variables. Removing the preexisting ones was the + # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. @@ -476,16 +480,17 @@ def _maybe_save_assets(assets_collection_to_add=None): assets_collection_to_add: The collection where the asset paths are setup. Returns: - The list of filepaths to the assets in the assets collection. + A dict of asset basenames for saving to the original full path to the asset. Raises: ValueError: Indicating an invalid filepath tensor. """ - asset_source_filepath_list = [] + # Map of target file names to original filenames + asset_filename_map = {} if assets_collection_to_add is None: tf_logging.info("No assets to save.") - return asset_source_filepath_list + return asset_filename_map # Iterate over the supplied asset collection, build the `AssetFile` proto # and add them to the collection with key `constants.ASSETS_KEY`, in the @@ -495,15 +500,71 @@ def _maybe_save_assets(assets_collection_to_add=None): if not asset_source_filepath: raise ValueError("Invalid asset filepath tensor %s" % asset_tensor) - asset_source_filename = os.path.basename(asset_source_filepath) + asset_filename = _get_asset_filename_to_add( + asset_source_filepath, asset_filename_map) # Build `AssetFile` proto and add it to the asset collection in the graph. - _add_asset_to_collection(asset_source_filename, asset_tensor) + # Note that this should be done even when the file is a duplicate of an + # already-added file, as the tensor reference should still exist. + _add_asset_to_collection(asset_filename, asset_tensor) - asset_source_filepath_list.append(asset_source_filepath) + # In the cases where we are adding a duplicate, this will result in the + # last of the filepaths being the one used for copying the file to the + # SavedModel. Since the files in question are the same, it doesn't matter + # either way. + asset_filename_map[asset_filename] = asset_source_filepath tf_logging.info("Assets added to graph.") - return asset_source_filepath_list + return asset_filename_map + + +def _get_asset_filename_to_add(asset_filepath, asset_filename_map): + """Get a unique basename to add to the SavedModel if this file is unseen. + + Assets come from users as full paths, and we save them out to the + SavedModel as basenames. In some cases, the basenames collide. Here, + we dedupe asset basenames by first checking if the file is the same, + and, if different, generate and return an index-suffixed basename + that can be used to add the asset to the SavedModel. + + Args: + asset_filepath: the full path to the asset that is being saved + asset_filename_map: a dict of filenames used for saving the asset in + the SavedModel to full paths from which the filenames were derived. + + Returns: + Uniquified filename string if the file is not a duplicate, or the original + filename if the file has already been seen and saved. + """ + asset_filename = os.path.basename(asset_filepath) + + if asset_filename not in asset_filename_map: + # This is an unseen asset. Safe to add. + return asset_filename + + other_asset_filepath = asset_filename_map[asset_filename] + if other_asset_filepath == asset_filepath: + # This is the same file, stored twice in the collection list. No need + # to make unique. + return asset_filename + + # Else, asset_filename is in the map, and the filepath is different. Dedupe. + if not file_io.filecmp(asset_filepath, other_asset_filepath): + # Files are different; dedupe filenames. + return _get_unique_asset_filename(asset_filename, asset_filename_map) + + # Files are the same; don't make unique. + return asset_filename + + +def _get_unique_asset_filename(asset_filename, asset_filename_map): + i = 1 + unique_filename = asset_filename + while unique_filename in asset_filename_map: + unique_filename = compat.as_bytes("_").join( + [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) + i += 1 + return unique_filename def _asset_path_from_tensor(path_tensor): diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index bebf1d5e0d3cc6ac0e431230577704365d37a437..e5f649fdabb5cc2600a6fdd0e5ed9950d6bb23c2 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -28,6 +28,7 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants from tensorflow.python.training import saver as tf_saver @@ -79,12 +80,14 @@ def _parse_saved_model(export_dir): constants.SAVED_MODEL_FILENAME_PB)) -def _get_asset_tensors(export_dir, meta_graph_def_to_load): +def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): """Gets the asset tensors, if defined in the meta graph def to load. Args: export_dir: Directory where the SavedModel is located. meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. + import_scope: Optional `string` -- if specified, prepend this followed by + '/' to all returned asset tensor names. Returns: A dictionary of asset tensors, keyed by the name of the asset tensor. The @@ -104,7 +107,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load): for asset_any_proto in assets_any_proto: asset_proto = meta_graph_pb2.AssetFileDef() asset_any_proto.Unpack(asset_proto) - asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join( + tensor_name = asset_proto.tensor_info.name + if import_scope: + tensor_name = "%s/%s" % (import_scope, tensor_name) + asset_tensor_dict[tensor_name] = os.path.join( compat.as_bytes(assets_directory), compat.as_bytes(asset_proto.filename)) return asset_tensor_dict @@ -179,7 +185,7 @@ def maybe_saved_model_directory(export_dir): @tf_export("saved_model.loader.load") -def load(sess, tags, export_dir, **saver_kwargs): +def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): """Loads the model from a SavedModel as specified by tags. Args: @@ -189,6 +195,10 @@ def load(sess, tags, export_dir, **saver_kwargs): SavedModel `save()` API. export_dir: Directory in which the SavedModel protocol buffer and variables to be loaded are located. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. **saver_kwargs: Optional keyword arguments passed through to Saver. Returns: @@ -198,11 +208,56 @@ def load(sess, tags, export_dir, **saver_kwargs): Raises: RuntimeError: MetaGraphDef associated with the tags cannot be found. """ - with sess.graph.as_default(): - # Build the SavedModel protocol buffer and find requested meta graph def. - saved_model = _parse_saved_model(export_dir) + loader = SavedModelLoader(export_dir) + return loader.load(sess, tags, import_scope, **saver_kwargs) + + +class SavedModelLoader(object): + """Load graphs and restore variable values from a `SavedModel`.""" + + def __init__(self, export_dir): + """Creates a `SavedModelLoader`. + + Args: + export_dir: Directory in which the SavedModel protocol buffer and + variables to be loaded are located. + """ + self._export_dir = export_dir + self._variables_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.VARIABLES_DIRECTORY), + compat.as_bytes(constants.VARIABLES_FILENAME)) + self._saved_model = _parse_saved_model(export_dir) + + @property + def export_dir(self): + """Directory containing the SavedModel.""" + return self._export_dir + + @property + def variables_path(self): + """Path to variable checkpoint files.""" + return self._variables_path + + @property + def saved_model(self): + """SavedModel object parsed from the export directory.""" + return self._saved_model + + def get_meta_graph_def_from_tags(self, tags): + """Return MetaGraphDef with the exact specified tags. + + Args: + tags: A list or set of string tags that identify the MetaGraphDef. + + Returns: + MetaGraphDef with the same tags. + + Raises: + RuntimeError: if no metagraphs were found with the associated tags. + """ found_match = False - for meta_graph_def in saved_model.meta_graphs: + for meta_graph_def in self._saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set(tags): meta_graph_def_to_load = meta_graph_def found_match = True @@ -214,31 +269,100 @@ def load(sess, tags, export_dir, **saver_kwargs): " could not be found in SavedModel. To inspect available tag-sets in" " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" ) - - # Build a saver by importing the meta graph def to load. - saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs) - - if saver: - # Build the checkpoint path where the variables are located. - variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) - - # Restore the variables using the built saver in the provided session. - saver.restore(sess, variables_path) - else: - tf_logging.info("The specified SavedModel has no variables; no " - "checkpoints were restored.") - - # Get asset tensors, if any. - asset_tensors_dictionary = _get_asset_tensors(export_dir, - meta_graph_def_to_load) - - main_op_tensor = ( - _get_main_op_tensor(meta_graph_def_to_load) or - (_get_legacy_init_op_tensor(meta_graph_def_to_load))) - if main_op_tensor is not None: - sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) - return meta_graph_def_to_load + + def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): + """Load ops and nodes from SavedModel MetaGraph into graph. + + Args: + graph: tf.Graph object. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. + + Returns: + Saver defined by the MetaGraph, which can be used to restore the variable + values. + """ + meta_graph_def = self.get_meta_graph_def_from_tags(tags) + with graph.as_default(): + return tf_saver.import_meta_graph( + meta_graph_def, import_scope=import_scope, **saver_kwargs) + + def restore_variables(self, sess, saver, import_scope=None): + """Restore SavedModel variable values into the session. + + Args: + sess: tf.Session to restore variable values. + saver: a tf.train.Saver object. Can be None if there are no variables in + graph. This may be the saver returned by the load_graph() function, or a + default `tf.train.Saver()`. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + + Raises: + ValueError: if no saver was passed to the saver argument, and there are + variables in the graph. + """ + with sess.graph.as_default(): + if (saver is None and + not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access + tf_logging.info("The specified SavedModel has no variables; no " + "checkpoints were restored.") + elif isinstance(saver, tf_saver.Saver): + saver.restore(sess, self._variables_path) + else: + raise ValueError( + "No tf.train.Saver object was passed to the function " + "SavedModelLoader.restore_variables. Since there are variables in " + "the graph, a saver is required.") + + def run_init_ops(self, sess, tags, import_scope=None): + """Run initialization ops defined in the `MetaGraphDef`. + + Args: + sess: tf.Session to restore variable values. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + """ + meta_graph_def = self.get_meta_graph_def_from_tags(tags) + with sess.graph.as_default(): + # Get asset tensors, if any. + asset_tensors_dictionary = _get_asset_tensors( + self._export_dir, meta_graph_def, import_scope=import_scope) + + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def) or + (_get_legacy_init_op_tensor(meta_graph_def))) + if main_op_tensor is not None: + sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) + + def load(self, sess, tags, import_scope=None, **saver_kwargs): + """Load the MetaGraphDef graph and restore variable values into the session. + + Args: + sess: tf.Session to restore variable values. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. + + Returns: + `MetagraphDef` proto of the graph that was loaded. + """ + with sess.graph.as_default(): + saver = self.load_graph(sess.graph, tags, import_scope, + **saver_kwargs) + self.restore_variables(sess, saver, import_scope) + self.run_init_ops(sess, tags, import_scope) + return self.get_meta_graph_def_from_tags(tags) diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ce18859f6b9e4c141c4b27f3643c8d4004eb56f6 --- /dev/null +++ b/tensorflow/python/saved_model/loader_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModelLoader class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils +from tensorflow.python.training import saver as tf_saver + + +def _get_export_dir(label): + return os.path.join(test.get_temp_dir(), label) + +SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model") +SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op") + + +class SavedModelLoaderTest(test.TestCase): + + def setUp(self): + """Write test SavedModels to a temp directory.""" + with session.Session(graph=ops.Graph()) as sess: + x = variables.Variable(5, name="x") + y = variables.Variable(11, name="y") + z = x + y + sess.run(variables.global_variables_initializer()) + + foo_sig_def = signature_def_utils.build_signature_def( + {"foo_input": utils.build_tensor_info(x)}, + {"foo_output": utils.build_tensor_info(z)}) + bar_sig_def = signature_def_utils.build_signature_def( + {"bar_x": utils.build_tensor_info(x), + "bar_y": utils.build_tensor_info(y)}, + {"bar_z": utils.build_tensor_info(z)}) + + builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}) + builder.save() + + # Write SavedModel with a main_op + assign_op = control_flow_ops.group(state_ops.assign(y, 7)) + + builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}, + main_op=assign_op) + builder.save() + + def tearDown(self): + file_io.delete_recursively(test.get_temp_dir()) + + def test_load_function(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + loader2.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_load_graph(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + graph = ops.Graph() + loader.load_graph(graph, ["foo_graph"]) + + x = graph.get_tensor_by_name("x:0") + y = graph.get_tensor_by_name("y:0") + + with self.assertRaises(KeyError): + graph.get_tensor_by_name("z:0") + + with self.test_session(graph=graph) as sess: + # Check that x and y are not initialized + with self.assertRaises(errors.FailedPreconditionError): + sess.run(x) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(y) + + def test_load_with_import_scope(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz") + + # The default saver should not work when the import scope is set. + with self.assertRaises(errors.NotFoundError): + loader.restore_variables(sess, tf_saver.Saver()) + + loader.restore_variables(sess, saver) + loader.run_init_ops(sess, ["foo_graph"]) + + self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval()) + + # Test combined load function. + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"], import_scope="baa") + self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval()) + + def test_restore_variables(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + x = variables.Variable(0, name="x") + y = variables.Variable(0, name="y") + z = x * y + + sess.run(variables.global_variables_initializer()) + + # There are variables to restore, so a saver must be created. + with self.assertRaises(ValueError): + loader.restore_variables(sess, None) + + loader.restore_variables(sess, tf_saver.Saver()) + self.assertEqual(55, z.eval()) + + def test_run_init_op(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + graph = ops.Graph() + saver = loader.load_graph(graph, ["foo_graph"]) + with self.test_session(graph=graph) as sess: + loader.restore_variables(sess, saver) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + loader.run_init_ops(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_parse_saved_model(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"]) + self.assertIsNotNone(meta_graph) + self.assertIn("foo", meta_graph.signature_def) + self.assertIn("bar", meta_graph.signature_def) + + def test_load_invalid_meta_graph(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags([]) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags([""]) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags(["not_a_graph"]) + + def test_load_saved_model_with_no_variables(self): + """Test that SavedModel runs saver when there appear to be no variables. + + When no variables are detected, this may mean that the variables were saved + to different collections, or the collections weren't saved to the + SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still + run in either of these cases. + """ + path = _get_export_dir("no_variable_saved_model") + with session.Session(graph=ops.Graph()) as sess: + x = variables.Variable(5, name="x", collections=["not_global_variable"]) + y = variables.Variable(11, name="y", collections=["not_global_variable"]) + self.assertFalse(variables._all_saveable_objects()) + z = x + y + sess.run(variables.variables_initializer([x, y])) + + foo_sig_def = signature_def_utils.build_signature_def( + {"foo_input": utils.build_tensor_info(x)}, + {"foo_output": utils.build_tensor_info(z)}) + + builder = saved_model_builder.SavedModelBuilder(path) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def}, + saver=tf_saver.Saver([x, y])) + builder.save() + + loader = loader_impl.SavedModelLoader(path) + with self.test_session(graph=ops.Graph()) as sess: + saver = loader.load_graph(sess.graph, ["foo_graph"]) + self.assertFalse(variables._all_saveable_objects()) + self.assertIsNotNone(saver) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 1b83d60df926f727be2469a2ea1409bc0c59d6c5..fb4732aca21d4661aaea21a472475690687a42be 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -44,6 +44,7 @@ from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver_test_utils +from tensorflow.python.training import training from tensorflow.python.util import compat SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123") @@ -64,9 +65,12 @@ class SavedModelTest(test.TestCase): self.assertEqual(variable_value, v.eval()) def _build_asset_collection(self, asset_file_name, asset_file_contents, - asset_file_tensor_name): + asset_file_tensor_name, asset_subdir=""): + parent_dir = os.path.join( + compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_subdir)) + file_io.recursive_create_dir(parent_dir) asset_filepath = os.path.join( - compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name)) + compat.as_bytes(parent_dir), compat.as_bytes(asset_file_name)) file_io.write_string_to_file(asset_filepath, asset_file_contents) asset_file_tensor = constant_op.constant( asset_filepath, name=asset_file_tensor_name) @@ -77,10 +81,11 @@ class SavedModelTest(test.TestCase): def _validate_asset_collection(self, export_dir, graph_collection_def, expected_asset_file_name, expected_asset_file_contents, - expected_asset_tensor_name): + expected_asset_tensor_name, + asset_id=0): assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value asset = meta_graph_pb2.AssetFileDef() - assets_any[0].Unpack(asset) + assets_any[asset_id].Unpack(asset) assets_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY), @@ -634,6 +639,141 @@ class SavedModelTest(test.TestCase): compat.as_bytes("ignored.txt")) self.assertFalse(file_io.file_exists(ignored_asset_path)) + def testAssetsNameCollisionDiffFile(self): + export_dir = self._get_export_dir("test_assets_name_collision_diff_file") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar bak", "asset_file_tensor", + asset_subdir="1") + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor_1", + asset_subdir="2") + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar bak", + "asset_file_tensor:0") + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt_1", "foo bar baz", + "asset_file_tensor_1:0", + asset_id=1) + + def testAssetsNameCollisionSameFilepath(self): + export_dir = self._get_export_dir("test_assets_name_collision_same_path") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor") + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor_1") + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor:0") + # The second tensor should be recorded, but the same. + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor_1:0", + asset_id=1) + ignored_asset_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes("hello42.txt_1")) + self.assertFalse(file_io.file_exists(ignored_asset_path)) + + def testAssetsNameCollisionSameFile(self): + export_dir = self._get_export_dir("test_assets_name_collision_same_file") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor", + asset_subdir="1") + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor_1", + asset_subdir="2") + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor:0") + # The second tensor should be recorded, but the same. + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor_1:0", + asset_id=1) + ignored_asset_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes("hello42.txt_1")) + self.assertFalse(file_io.file_exists(ignored_asset_path)) + + def testAssetsNameCollisionManyFiles(self): + export_dir = self._get_export_dir("test_assets_name_collision_many_files") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + for i in range(5): + idx = str(i) + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx, + asset_subdir=idx) + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + for i in range(1, 5): + idx = str(i) + self._validate_asset_collection( + export_dir, foo_graph.collection_def, "hello42.txt_" + idx, + "foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx), + asset_id=i) + + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz 0", + "asset_file_tensor_0:0") + def testCustomMainOp(self): export_dir = self._get_export_dir("test_main_op") builder = saved_model_builder.SavedModelBuilder(export_dir) @@ -983,6 +1123,133 @@ class SavedModelTest(test.TestCase): self.assertEqual(b"k1", v1.keys().eval()) self.assertEqual(3.0, v1.values().eval()) + def testCustomSaver(self): + export_dir = self._get_export_dir("test_custom_saver") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + variables.Variable(1, name="v1") + sess.run(variables.global_variables_initializer()) + custom_saver = training.Saver(name="my_saver") + builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver) + + # Save the SavedModel to disk. + builder.save() + + with ops.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + saved_graph = loader.load(sess, ["tag"], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue("my_saver/restore_all" in graph_ops) + self.assertFalse("save/restore_all" in graph_ops) + self.assertEqual( + saved_graph.saver_def.restore_op_name, "my_saver/restore_all") + + def testNoCustomSaver(self): + export_dir = self._get_export_dir("test_no_custom_saver") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + variables.Variable(1, name="v1") + sess.run(variables.global_variables_initializer()) + training.Saver(name="my_saver") + builder.add_meta_graph_and_variables(sess, ["tag"]) + + # Save the SavedModel to disk. + builder.save() + + with ops.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + saved_graph = loader.load(sess, ["tag"], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue("my_saver/restore_all" in graph_ops) + self.assertTrue("save/restore_all" in graph_ops) + self.assertEqual( + saved_graph.saver_def.restore_op_name, "save/restore_all") + + def testMultipleCustomSavers(self): + export_dir = self._get_export_dir("test_multiple_custom_savers") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + variables.Variable(1, name="v1") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables(sess, ["tag_0"]) + + saver_1 = training.Saver() + builder.add_meta_graph(["tag_1"], saver=saver_1) + + saver_2 = training.Saver() + builder.add_meta_graph(["tag_2"], saver=saver_2) + + # Save the SavedModel to disk. + builder.save() + + def _validate_custom_saver(tag_name, saver_name): + with ops.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + saved_graph = loader.load(sess, [tag_name], export_dir) + self.assertEqual( + saved_graph.saver_def.restore_op_name, + saver_name) + + _validate_custom_saver("tag_0", "save/restore_all") + _validate_custom_saver("tag_1", "save_1/restore_all") + _validate_custom_saver("tag_2", "save_2/restore_all") + + def testImportScope(self): + export_dir = self._get_export_dir("test_scoped_assets") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Build a SavedModel with a variable, an asset, and a constant tensor. + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + asset_collection = self._build_asset_collection("foo.txt", "content_foo", + "asset_file_tensor") + constant_op.constant("constant value", name="constant_tensor_name") + builder.add_meta_graph_and_variables( + sess, ["tag_name"], assets_collection=asset_collection) + + # Save the asset file path for later comparison. + asset_file_path = asset_collection[0].eval() + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + # Restore the SavedModel under an import_scope in a new graph/session. + graph_proto = loader.load( + sess, ["tag_name"], export_dir, import_scope="scope_name") + + # The loaded variable tensor should be scoped, but its contents should be + # unchanged. + self.assertEqual( + "scope_name/v:0", + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name) + self.assertEqual( + 42, + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + + # The loaded asset tensor should be scoped, but the asset file path and + # contents should be unchanged. + asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) + self.assertEqual(1, len(asset_collection)) + self.assertEqual(asset_file_path, asset_collection[0].eval()) + self.assertEqual("scope_name/asset_file_tensor:0", + asset_collection[0].name) + # The static asset data inside graph_proto.collection_def should not be + # scoped. + self._validate_asset_collection(export_dir, graph_proto.collection_def, + "foo.txt", "content_foo", + "asset_file_tensor:0") + + # The constant tensor should be scoped, but its contents should be + # unchanged. + self.assertEqual( + compat.as_bytes("constant value"), + ops.get_default_graph().get_tensor_by_name( + "scope_name/constant_tensor_name:0").eval()) + def testClearDevices(self): export_dir = self._get_export_dir("test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py old mode 100755 new mode 100644 diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 5b9d25d449d43d8420e0f30fa8b907d41171d5e5..38fed5335ef39e9832c8b47e3c872ada453aa645 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -15,7 +15,7 @@ """Command-line interface to inspect and execute a graph in a SavedModel. For detailed usages and examples, please refer to: -https://www.tensorflow.org/programmers_guide/saved_model_cli +https://www.tensorflow.org/guide/saved_model_cli """ @@ -720,7 +720,7 @@ def create_parser(): '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' \\\n' ' --outdir=/out\n\n' 'For more information about input file format, please see:\n' - 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') + 'https://www.tensorflow.org/guide/saved_model_cli\n') parser_run = subparsers.add_parser( 'run', description=run_msg, formatter_class=argparse.RawTextHelpFormatter) parser_run.add_argument( diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py index c08e3cca007dc17f1112d53bf729c1accf61b5df..95eca76496992f7ac66643a4c94d7e9e812cecf8 100644 --- a/tensorflow/python/training/adadelta.py +++ b/tensorflow/python/training/adadelta.py @@ -46,6 +46,13 @@ class AdadeltaOptimizer(optimizer.Optimizer): use_locking: If `True` use locks for update operations. name: Optional name prefix for the operations created when applying gradients. Defaults to "Adadelta". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can + each be a callable that takes no arguments and returns the actual value to + use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility """ super(AdadeltaOptimizer, self).__init__(use_locking, name) self._lr = learning_rate @@ -63,9 +70,13 @@ class AdadeltaOptimizer(optimizer.Optimizer): self._zeros_slot(v, "accum_update", self._name) def _prepare(self): - self._lr_t = ops.convert_to_tensor(self._lr, name="lr") - self._rho_t = ops.convert_to_tensor(self._rho, name="rho") - self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon") + lr = self._call_if_callable(self._lr) + rho = self._call_if_callable(self._rho) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="lr") + self._rho_t = ops.convert_to_tensor(rho, name="rho") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") def _apply_dense(self, grad, var): accum = self.get_slot(var, "accum") diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py index 50f435236b41fcda7ab5ea37a4e96b72dd1043e7..2678016d24b99b30cbf7021d67e33910051e2561 100644 --- a/tensorflow/python/training/adadelta_test.py +++ b/tensorflow/python/training/adadelta_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -32,44 +34,52 @@ from tensorflow.python.training import adadelta class AdadeltaOptimizerTest(test.TestCase): - def doTestBasic(self, use_resource=False): + def doTestBasic(self, use_resource=False, use_callable_params=False): num_updates = 4 # number of ADADELTA steps to perform for dtype in [dtypes.half, dtypes.float32]: for grad in [0.2, 0.1, 0.01]: for lr in [1.0, 0.5, 0.1]: - with self.test_session(): - var0_init = [1.0, 2.0] - var1_init = [3.0, 4.0] - if use_resource: - var0 = resource_variable_ops.ResourceVariable( - var0_init, dtype=dtype) - var1 = resource_variable_ops.ResourceVariable( - var1_init, dtype=dtype) - else: - var0 = variables.Variable(var0_init, dtype=dtype) - var1 = variables.Variable(var1_init, dtype=dtype) - - grads = constant_op.constant([grad, grad], dtype=dtype) - - accum = 0.0 - accum_update = 0.0 - - # ADADELTA gradient optimizer - rho = 0.95 - epsilon = 1e-8 - adadelta_opt = adadelta.AdadeltaOptimizer(lr, rho, epsilon) + var0_init = [1.0, 2.0] + var1_init = [3.0, 4.0] + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_init, dtype=dtype) + var1 = resource_variable_ops.ResourceVariable( + var1_init, dtype=dtype) + else: + var0 = variables.Variable(var0_init, dtype=dtype) + var1 = variables.Variable(var1_init, dtype=dtype) + + grads = constant_op.constant([grad, grad], dtype=dtype) + + accum = 0.0 + accum_update = 0.0 + + # ADADELTA gradient optimizer + rho = 0.95 + epsilon = 1e-8 + if use_callable_params: + adadelta_opt = adadelta.AdadeltaOptimizer( + learning_rate=lambda: lr, # pylint: disable=cell-var-from-loop + rho=lambda: rho, # pylint: disable=cell-var-from-loop + epsilon=lambda: epsilon) # pylint: disable=cell-var-from-loop + else: + adadelta_opt = adadelta.AdadeltaOptimizer( + learning_rate=lr, rho=rho, epsilon=epsilon) + if not context.executing_eagerly(): adadelta_update = adadelta_opt.apply_gradients( zip([grads, grads], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + # TODO(lxuechen): This is hard to test in eager mode, + # since the optimizer is not fully initialized until the first + # call to `apply_gradients` opt_vars = adadelta_opt.variables() self.assertStartsWith(opt_vars[0].name, var0._shared_name) self.assertStartsWith(opt_vars[1].name, var0._shared_name) self.assertStartsWith(opt_vars[2].name, var1._shared_name) self.assertStartsWith(opt_vars[3].name, var1._shared_name) self.assertEqual(4, len(opt_vars)) - - variables.global_variables_initializer().run() - # Assign slots slot = [None] * 2 slot_update = [None] * 2 @@ -91,36 +101,42 @@ class AdadeltaOptimizerTest(test.TestCase): self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) self.assertFalse(slot_update[1] in variables.trainable_variables()) - # Fetch params to validate initial values - self.assertAllClose(var0_init, var0.eval()) - self.assertAllClose(var1_init, var1.eval()) - - update = [None] * num_updates - tot_update = 0 - for step in range(num_updates): - # Run adadelta update for comparison - adadelta_update.run() - - # Perform initial update without previous accum values - accum = accum * rho + (grad**2) * (1 - rho) - update[step] = (np.sqrt(accum_update + epsilon) * - (1. / np.sqrt(accum + epsilon)) * grad) - accum_update = (accum_update * rho + (update[step]**2) * - (1.0 - rho)) - tot_update += update[step] * lr + # Fetch params to validate initial values + self.assertAllClose(var0_init, self.evaluate(var0)) + self.assertAllClose(var1_init, self.evaluate(var1)) + update = [None] * num_updates + tot_update = 0 + for step in range(num_updates): + # Run adadelta update for comparison + if not context.executing_eagerly(): + self.evaluate(adadelta_update) + else: + adadelta_opt.apply_gradients(zip([grads, grads], [var0, var1])) + + # Perform initial update without previous accum values + accum = accum * rho + (grad**2) * (1 - rho) + update[step] = ( + np.sqrt(accum_update + epsilon) * + (1. / np.sqrt(accum + epsilon)) * grad) + accum_update = ( + accum_update * rho + (update[step]**2) * (1.0 - rho)) + tot_update += update[step] * lr + + if not context.executing_eagerly(): # Check that the accumulators have been updated + # TODO(lxuechen): This is hard to test in eager mode for slot_idx in range(2): self.assertAllCloseAccordingToType( np.array([accum, accum], dtype=dtype.as_numpy_dtype()), - slot[slot_idx].eval(), + self.evaluate(slot[slot_idx]), rtol=1e-5) self.assertAllCloseAccordingToType( np.array( [accum_update, accum_update], dtype=dtype.as_numpy_dtype()), - slot_update[slot_idx].eval(), + self.evaluate(slot_update[slot_idx]), rtol=1e-5) # Check that the parameters have been updated @@ -128,22 +144,28 @@ class AdadeltaOptimizerTest(test.TestCase): np.array( [var0_init[0] - tot_update, var0_init[1] - tot_update], dtype=dtype.as_numpy_dtype()), - var0.eval(), + self.evaluate(var0), rtol=1e-5) self.assertAllCloseAccordingToType( np.array( [var1_init[0] - tot_update, var1_init[1] - tot_update], dtype=dtype.as_numpy_dtype()), - var1.eval(), + self.evaluate(var1), rtol=1e-5) def testBasic(self): - self.doTestBasic(use_resource=False) + with self.test_session(): + self.doTestBasic(use_resource=False) + @test_util.run_in_graph_and_eager_modes(reset_test=True) def testResourceBasic(self): self.doTestBasic(use_resource=True) + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.test_session(): diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py index deb4e6f546379eff330235dbc302a30c44193830..6778f3c735a70fc32ed299bc9d800b270f06cc66 100644 --- a/tensorflow/python/training/adagrad.py +++ b/tensorflow/python/training/adagrad.py @@ -51,6 +51,13 @@ class AdagradOptimizer(optimizer.Optimizer): Raises: ValueError: If the `initial_accumulator_value` is invalid. + + @compatibility(eager) + When eager execution is enabled, `learning_rate` can be a callable that + takes no arguments and returns the actual value to use. This can be useful + for changing these values across different invocations of optimizer + functions. + @end_compatibility """ if initial_accumulator_value <= 0.0: raise ValueError("initial_accumulator_value must be positive: %s" % @@ -78,8 +85,9 @@ class AdagradOptimizer(optimizer.Optimizer): "accumulator", self._name) def _prepare(self): - self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate, - name="learning_rate") + learning_rate = self._call_if_callable(self._learning_rate) + self._learning_rate_tensor = ops.convert_to_tensor( + learning_rate, name="learning_rate") def _apply_dense(self, grad, var): acc = self.get_slot(var, "accumulator") diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py index 15b007b46dea6b3125c5f7bffe8782594bb23692..c9aec33d0916781e3d1a41b996083da92a4ae839 100644 --- a/tensorflow/python/training/adagrad_test.py +++ b/tensorflow/python/training/adagrad_test.py @@ -20,9 +20,11 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -34,40 +36,63 @@ from tensorflow.python.training import adagrad class AdagradOptimizerTest(test.TestCase): - def doTestBasic(self, use_locking=False, use_resource=False): + def doTestBasic(self, + use_locking=False, + use_resource=False, + use_callable_params=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - if use_resource: - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - else: - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - ada_opt = adagrad.AdagradOptimizer( - 3.0, initial_accumulator_value=0.1, use_locking=use_locking) + if use_resource: + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + else: + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + + learning_rate = lambda: 3.0 + if not use_callable_params: + learning_rate = learning_rate() + + ada_opt = adagrad.AdagradOptimizer( + learning_rate, initial_accumulator_value=0.1, use_locking=use_locking) + + if not context.executing_eagerly(): ada_update = ada_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Run 3 steps of adagrad - for _ in range(3): - ada_update.run() - # Validate updated params - self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllClose([1.0, 2.0], v0_val) + self.assertAllClose([3.0, 4.0], v1_val) + + # Run 3 steps of adagrad + for _ in range(3): + if not context.executing_eagerly(): + self.evaluate(ada_update) + else: + ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + # Validate updated params + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), v0_val) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), v1_val) def testBasic(self): self.doTestBasic(use_locking=False) + @test_util.run_in_graph_and_eager_modes(reset_test=True) def testBasicResource(self): self.doTestBasic(use_locking=False, use_resource=True) + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic( + use_locking=False, use_resource=True, use_callable_params=True) + def testBasicLocked(self): self.doTestBasic(use_locking=True) diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 6fa3ff66583ce07a6ee7b0d8158c851ea578637c..b65c88e972454da14dc5161a19cd26280d51d28f 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -85,6 +85,13 @@ class AdamOptimizer(optimizer.Optimizer): use_locking: If True use locks for update operations. name: Optional name for the operations created when applying gradients. Defaults to "Adam". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and + `epsilon` can each be a callable that takes no arguments and returns the + actual value to use. This can be useful for changing these values across + different invocations of optimizer functions. + @end_compatibility """ super(AdamOptimizer, self).__init__(use_locking, name) self._lr = learning_rate @@ -128,10 +135,15 @@ class AdamOptimizer(optimizer.Optimizer): self._zeros_slot(v, "v", self._name) def _prepare(self): - self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate") - self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1") - self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2") - self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon") + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") def _apply_dense(self, grad, var): m = self.get_slot(var, "m") diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index 9be8b6aafefa33977511cde24dd2e87dd6c3b81a..ccdc7e384da2ae792a681298c7076fc582d362df 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -150,7 +150,7 @@ class AdamOptimizerTest(test.TestCase): self.assertAllClose(aggregated_update_var.eval(), repeated_index_update_var.eval()) - def doTestBasic(self, use_resource=False): + def doTestBasic(self, use_resource=False, use_callable_params=False): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): with self.test_session(graph=ops.Graph()): # Initialize variables for numpy implementation. @@ -171,7 +171,17 @@ class AdamOptimizerTest(test.TestCase): grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) - opt = adam.AdamOptimizer() + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam.AdamOptimizer(learning_rate=learning_rate) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) opt_variables = opt.variables() beta1_power, beta2_power = opt._get_beta_accumulators() @@ -180,11 +190,10 @@ class AdamOptimizerTest(test.TestCase): self.assertIn(beta1_power, opt_variables) self.assertIn(beta2_power, opt_variables) - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -222,6 +231,10 @@ class AdamOptimizerTest(test.TestCase): def testResourceBasic(self): self.doTestBasic(use_resource=True) + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.test_session(): diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index df528d54d6503ef1626d0b3bc4b5afe9e0616c31..b0dd188db14a46aae44f8150095cf9ed337ee8a7 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -28,6 +28,7 @@ from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.client import timeline +from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import gfile @@ -336,6 +337,8 @@ class CheckpointSaverListener(object): def after_save(self, session, global_step_value): print('Done writing checkpoint.') + if decided_to_stop_training(): + return True def end(self, session, global_step_value): print('Done with the session.') @@ -354,6 +357,11 @@ class CheckpointSaverListener(object): implementors should implement the `end()` method to handle actions related to the last checkpoint save. But the listener should not act twice if `after_save()` already handled this last checkpoint save. + + A `CheckpointSaverListener` can request training to be stopped, by returning + True in `after_save`. Please note that, in replicated distributed training + setting, only `chief` should use this behavior. Otherwise each worker will do + their own evaluation, which may be wasteful of resources. """ def begin(self): @@ -453,7 +461,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): global_step = run_context.session.run(self._global_step_tensor) if self._timer.should_trigger_for_step(global_step): self._timer.update_last_triggered_step(global_step) - self._save(run_context.session, global_step) + if self._save(run_context.session, global_step): + run_context.request_stop() def end(self, session): last_step = session.run(self._global_step_tensor) @@ -463,7 +472,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): l.end(session, last_step) def _save(self, session, step): - """Saves the latest checkpoint.""" + """Saves the latest checkpoint, returns should_stop.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) for l in self._listeners: @@ -475,8 +484,14 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) + should_stop = False for l in self._listeners: - l.after_save(session, step) + if l.after_save(session, step): + logging.info( + "A CheckpointSaverListener requested that training be stopped. " + "listener: {}".format(l)) + should_stop = True + return should_stop def _get_saver(self): if self._saver is not None: @@ -804,8 +819,25 @@ class FinalOpsHook(session_run_hook.SessionRunHook): def end(self, session): if self._final_ops is not None: - self._final_ops_values = session.run(self._final_ops, - feed_dict=self._final_ops_feed_dict) + try: + self._final_ops_values = session.run( + self._final_ops, feed_dict=self._final_ops_feed_dict) + except (errors.OutOfRangeError, StopIteration) as e: + logging.warning( + "An OutOfRangeError or StopIteration exception is raised by the " + "code in FinalOpsHook. This typically means the Ops running by the " + "FinalOpsHook have a dependency back to some input source, which " + "should not happen. For example, for metrics in " + "tf.estimator.Estimator, all metrics functions return two Ops: " + "`value_op` and `update_op`. Estimator.evaluate calls the " + "`update_op` for each batch of the data in input source and, once " + "it is exhausted, it call the `value_op` to get the metric values. " + "The `value_op` here should have dependency back to variables " + "reading only, rather than reading another batch from input. " + "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " + "another data reading, which ends OutOfRangeError/StopIteration. " + "Please fix that.") + raise e @tf_export("train.FeedFnHook") diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 7344ce2758658e3a39add365dd7f089e693c4c4a..b49a871a56a3402fc46a0185d02e6099b6e69f79 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -29,8 +29,10 @@ from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.testing.python.framework import fake_summary_writer from tensorflow.python.client import session as session_lib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -58,6 +60,7 @@ class MockCheckpointSaverListener( self.before_save_count = 0 self.after_save_count = 0 self.end_count = 0 + self.ask_for_stop = False def begin(self): self.begin_count += 1 @@ -67,6 +70,8 @@ class MockCheckpointSaverListener( def after_save(self, session, global_step): self.after_save_count += 1 + if self.ask_for_stop: + return True def end(self, session, global_step): self.end_count += 1 @@ -471,6 +476,25 @@ class CheckpointSaverHookTest(test.TestCase): 'end': 1 }, listener_counts) + def test_listener_stops_training_in_after_save(self): + with ops.Graph().as_default(): + scaffold = monitored_session.Scaffold() + variables.get_or_create_global_step() + train_op = training_util._increment_global_step(1) + listener = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], scaffold=scaffold, + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + self.assertFalse(sess.should_stop()) + sess.run(train_op) + self.assertFalse(sess.should_stop()) + listener.ask_for_stop = True + sess.run(train_op) + self.assertTrue(sess.should_stop()) + def test_listener_with_default_saver(self): with ops.Graph().as_default(): global_step = variables.get_or_create_global_step() @@ -1306,6 +1330,26 @@ class FinalOpsHookTest(test.TestCase): self.assertListEqual(expected_values, hook.final_ops_values.tolist()) + def test_final_ops_triggers_out_of_range_error(self): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.range(1) + iterator = dataset.make_one_shot_iterator() + read_ops = iterator.get_next() + final_ops = read_ops + + hook = basic_session_run_hooks.FinalOpsHook(final_ops) + hook.begin() + + with session_lib.Session() as session: + session.run(read_ops) + with test.mock.patch.object(tf_logging, 'warning') as mock_log: + with self.assertRaisesRegexp(errors.OutOfRangeError, + 'End of sequence'): + hook.end(session) + self.assertRegexpMatches( + str(mock_log.call_args), + 'dependency back to some input source') + def test_final_ops_with_dictionary(self): with ops.Graph().as_default(): expected_values = [4, -3] diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index e7f88de1d2290a49f3b7bdf47417016d7e7c9cea..5b372e82b3f637b78db4388b58b8d04a838fbe60 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -147,7 +147,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. - init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/', 'new_scope_1'}) + init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', @@ -219,8 +219,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): else: var_name = ",".join([v.name for v in var]) _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) - logging.info("Initialize variable %s from checkpoint %s with %s", - var_name, ckpt_dir_or_file, tensor_name_in_ckpt) + logging.debug("Initialize variable %s from checkpoint %s with %s", + var_name, ckpt_dir_or_file, tensor_name_in_ckpt) else: scopes = "" # TODO(vihanjain): Support list of 'current_var_or_name' here. @@ -261,8 +261,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): if var is None: var = _collect_partitioned_variable(var_name, store_vars) _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) - logging.info("Initialize variable %s from checkpoint %s with %s", - var_name, ckpt_dir_or_file, full_tensor_name) + logging.debug("Initialize variable %s from checkpoint %s with %s", + var_name, ckpt_dir_or_file, full_tensor_name) def _get_checkpoint_filename(ckpt_dir_or_file): diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index a7ae6e50a9975d8d164d5c7455aedb8b39ed802e..35007653a09f4b4990be19ef6b14bf6084a7f14c 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -22,8 +22,9 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops_gen", - "//tensorflow/python:ops", + "//tensorflow/python:platform", "//tensorflow/python:saveable_object", "//tensorflow/python:util", "//tensorflow/python/eager:context", @@ -40,12 +41,67 @@ py_test( ], ) +py_library( + name = "tracking", + srcs = ["tracking.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ":data_structures", + ], +) + +py_test( + name = "tracking_test", + srcs = ["tracking_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ":tracking", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "layer_utils", + srcs = ["layer_utils.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "data_structures", + srcs = ["data_structures.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ":layer_utils", + ], +) + +py_test( + name = "data_structures_test", + srcs = ["data_structures_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data_structures", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + ], +) + py_library( name = "util", srcs = ["util.py"], srcs_version = "PY2AND3", deps = [ ":base", + ":tracking", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index e378f0e898c88d6f3e4fa22e93ca2165f2325660..ee35b01328436911fd7926b25b14433377ec4188 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saveable_object from tensorflow.python.util import nest from tensorflow.python.util import serialization +from tensorflow.python.util import tf_decorator # Key where the object graph proto is saved in a TensorBundle @@ -340,6 +341,34 @@ _SlotVariableRestoration = collections.namedtuple( ]) +def no_automatic_dependency_tracking(method): + """Disables automatic dependency tracking on attribute assignment. + + Use to decorate any method of a Checkpointable object. Attribute assignment in + that method will not add dependencies (also respected in Model). Harmless if + used in a class which does not do automatic dependency tracking (which means + it's safe to use in base classes which may have subclasses which also inherit + from Checkpointable). + + Args: + method: The method to decorate. + Returns: + A decorated method which sets and un-sets automatic dependency tracking for + the object the method is called on (not thread safe). + """ + + def _method_wrapper(self, *args, **kwargs): + previous_value = getattr(self, "_setattr_tracking", True) + self._setattr_tracking = False # pylint: disable=protected-access + try: + method(self, *args, **kwargs) + finally: + self._setattr_tracking = previous_value # pylint: disable=protected-access + + return tf_decorator.make_decorator( + target=method, decorator_func=_method_wrapper) + + class CheckpointableBase(object): """Base class for `Checkpointable` objects without automatic dependencies. @@ -349,6 +378,11 @@ class CheckpointableBase(object): checks. """ + # CheckpointableBase does not do automatic dependency tracking, but uses the + # no_automatic_dependency_tracking decorator so it can avoid adding + # dependencies if a subclass is Checkpointable / inherits from Model (both of + # which have __setattr__ overrides). + @no_automatic_dependency_tracking def _maybe_initialize_checkpointable(self): """Initialize dependency management. @@ -386,6 +420,10 @@ class CheckpointableBase(object): # building. self._name_based_restores = set() + def _no_dependency(self, value): + """If automatic dependency tracking is enabled, ignores `value`.""" + return value + def _name_based_attribute_restore(self, checkpoint): """Restore the object's attributes from a name-based checkpoint.""" self._name_based_restores.add(checkpoint) @@ -463,7 +501,7 @@ class CheckpointableBase(object): ValueError: If the variable name is not unique. """ self._maybe_initialize_checkpointable() - if not overwrite and self._lookup_dependency(name) is not None: + if overwrite and self._lookup_dependency(name) is not None: raise ValueError( ("A variable named '%s' already exists in this Checkpointable, but " "Checkpointable._add_variable called to create another with " @@ -591,11 +629,11 @@ class CheckpointableBase(object): self._unconditional_checkpoint_dependencies): if name == old_name: self._unconditional_checkpoint_dependencies[index] = new_reference - else: + elif current_object is None: self._unconditional_checkpoint_dependencies.append(new_reference) - + self._handle_deferred_dependencies( + name=name, checkpointable=checkpointable) self._unconditional_dependency_names[name] = checkpointable - self._handle_deferred_dependencies(name=name, checkpointable=checkpointable) return checkpointable def _handle_deferred_dependencies(self, name, checkpointable): @@ -733,86 +771,3 @@ class CheckpointableBase(object): return {OBJECT_CONFIG_JSON_KEY: functools.partial( PythonStringStateSaveable, state_callback=_state_callback)} - - -class NoDependency(object): - """Allows attribute assignment to `Checkpointable` objects with no dependency. - - Example usage: - ```python - obj = Checkpointable() - obj.has_dependency = tf.Variable(0., name="dep") - obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) - assert obj.no_dependency.name == "nodep:0" - ``` - - `obj` in this example has a dependency on the variable "dep", and both - attributes contain un-wrapped `Variable` objects. - - `NoDependency` also works with `tf.keras.Model`, but only for checkpoint - dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) - `Layer` to the attribute without a checkpoint dependency, but the `Model` will - still track the `Layer` (so it will appear in `Model.layers`, and its - variables will appear in `Model.variables`). - """ - - def __init__(self, value): - self.value = value - - -class NotCheckpointable(object): - """Marks instances of child classes as unsaveable using an object-based API. - - Useful for marking objects which would otherwise look checkpointable because - of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from - `NotCheckpointable` does not prevent an object from being assigned to any - attributes, but will throw an error on save/restore. - """ - pass - - -class Checkpointable(CheckpointableBase): - """Manages dependencies on other objects. - - `Checkpointable` objects may have dependencies: other `Checkpointable` objects - which should be saved if the object declaring the dependency is saved. A - correctly saveable program has a dependency graph such that if changing a - global variable affects an object (e.g. changes the behavior of any of its - methods) then there is a chain of dependencies from the influenced object to - the variable. - - Dependency edges have names, and are created implicitly when a - `Checkpointable` object is assigned to an attribute of another - `Checkpointable` object. For example: - - ``` - obj = Checkpointable() - obj.v = ResourceVariable(0.) - ``` - - The `Checkpointable` object `obj` now has a dependency named "v" on a - variable. - - `Checkpointable` objects may specify `Tensor`s to be saved and restored - directly (e.g. a `Variable` indicating how to save itself) rather than through - dependencies on other objects. See - `Checkpointable._gather_saveables_for_checkpoint` for details. - """ - - def __setattr__(self, name, value): - """Support self.foo = checkpointable syntax.""" - # Perform the attribute assignment, and potentially call other __setattr__ - # overrides such as that for tf.keras.Model. - no_dependency = isinstance(value, NoDependency) - if no_dependency: - value = value.value - super(Checkpointable, self).__setattr__(name, value) - if not no_dependency and isinstance(value, CheckpointableBase): - self._track_checkpointable( - value, name=name, - # Allow the user to switch the Checkpointable which is tracked by this - # name, since assigning a new variable to an attribute has - # historically been fine (e.g. Adam did this). - # TODO(allenl): Should this be a warning once Checkpointable save/load - # is usable? - overwrite=True) diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py index 0a274cdfed5af83a69513e9b26bf427f284a4df7..950e9c5b535a8314e1068b772f48a14b572df691 100644 --- a/tensorflow/python/training/checkpointable/base_test.py +++ b/tensorflow/python/training/checkpointable/base_test.py @@ -17,33 +17,25 @@ from __future__ import division from __future__ import print_function from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base class InterfaceTests(test.TestCase): - def testMultipleAssignment(self): - root = checkpointable.Checkpointable() - root.leaf = checkpointable.Checkpointable() - root.leaf = root.leaf - duplicate_name_dep = checkpointable.Checkpointable() + def testOverwrite(self): + root = base.CheckpointableBase() + leaf = base.CheckpointableBase() + root._track_checkpointable(leaf, name="leaf") + (current_name, current_dependency), = root._checkpoint_dependencies + self.assertIs(leaf, current_dependency) + self.assertEqual("leaf", current_name) + duplicate_name_dep = base.CheckpointableBase() with self.assertRaises(ValueError): root._track_checkpointable(duplicate_name_dep, name="leaf") - # No error; we're overriding __setattr__, so we can't really stop people - # from doing this while maintaining backward compatibility. - root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) - - def testNoDependency(self): - root = checkpointable.Checkpointable() - hasdep = checkpointable.Checkpointable() - root.hasdep = hasdep - nodep = checkpointable.Checkpointable() - root.nodep = checkpointable.NoDependency(nodep) - self.assertEqual(1, len(root._checkpoint_dependencies)) - self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) - self.assertIs(root.hasdep, hasdep) - self.assertIs(root.nodep, nodep) + (current_name, current_dependency), = root._checkpoint_dependencies + self.assertIs(duplicate_name_dep, current_dependency) + self.assertEqual("leaf", current_name) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..019d43f09c10a4975a9b483593af30b5bbe06089 --- /dev/null +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -0,0 +1,478 @@ +"""Checkpointable data structures.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import six + +from tensorflow.python.ops import variables +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import layer_utils + + +class NoDependency(object): + """Allows attribute assignment to `Checkpointable` objects with no dependency. + + Example usage: + ```python + obj = Checkpointable() + obj.has_dependency = tf.Variable(0., name="dep") + obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) + assert obj.no_dependency.name == "nodep:0" + ``` + + `obj` in this example has a dependency on the variable "dep", and both + attributes contain un-wrapped `Variable` objects. + + `NoDependency` also works with `tf.keras.Model`, but only for checkpoint + dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) + `Layer` to the attribute without a checkpoint dependency, but the `Model` will + still track the `Layer` (so it will appear in `Model.layers`, and its + variables will appear in `Model.variables`). + """ + + def __init__(self, value): + self.value = value + + +def _wrap_or_unwrap(value): + """Wraps basic data structures, unwraps NoDependency objects.""" + if isinstance(value, NoDependency): + return value.value + if isinstance(value, base.CheckpointableBase): + return value # Skip conversion for already checkpointable objects. + elif isinstance(value, list): + return _ListWrapper(value) + else: + return value + # TODO(allenl): Handle other common data structures. Tuples will require + # special casing (tuple subclasses are not weak referenceable, so replacement + # with a wrapper that subclasses tuple on attribute assignment works poorly, + # and replacement with a wrapper that isn't a tuple is also problematic), + # probably a tree traversal where the leaves are non-tuples(/namedtuples) to + # come up with names. Dictionaries should look like lists. + + +def sticky_attribute_assignment(checkpointable, name, value): + """Adds dependencies, generally called from __setattr__. + + This behavior is shared between Checkpointable and Model. + + Respects NoDependency indicators, but otherwise makes checkpointable objects + out of common data structures and tracks objects by their attribute names. + + Args: + checkpointable: The object to add dependencies to (generally the one having + an attribute assigned). + name: The attribute name being assigned. + value: The value being assigned. Not necessarily a checkpointable object. + + Returns: + The value which should be stored in the attribute (unwrapped from a + NoDependency object if necessary). + """ + if isinstance(value, NoDependency): + add_dependency = False + else: + add_dependency = True + value = _wrap_or_unwrap(value) + if not add_dependency: + return value + if isinstance(value, base.CheckpointableBase): + checkpointable._track_checkpointable( # pylint: disable=protected-access + value, name=name, + # Allow the user to switch the Checkpointable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + overwrite=True) + return value + + +class CheckpointableDataStructure(base.CheckpointableBase): + """Base class for data structures which contain checkpointable objects.""" + + def __init__(self): + # An append-only ordered set + self._layers = [] + + self.trainable = True + self._extra_variables = [] + + def _track_value(self, value, name): + """Add a dependency on `value`.""" + value = sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + if isinstance(value, variables.Variable): + self._extra_variables.append(value) + if not isinstance(value, base.CheckpointableBase): + raise ValueError( + ("Only checkpointable objects (such as Layers or Optimizers) may be " + "stored in a List object. Got %s, which does not inherit from " + "CheckpointableBase.") % (value,)) + if (isinstance(value, CheckpointableDataStructure) + or layer_utils.is_layer(value)): + # Check for object-identity rather than with __eq__ to avoid + # de-duplicating empty container types. Automatically generated list + # wrappers keep things like "[] == []" true, which means "[] in [[]]" is + # also true. This becomes not true once one of the lists is mutated. + if not any((layer is value for layer in self._layers)): + self._layers.append(value) + if hasattr(value, "_use_resource_variables"): + # In subclassed models, legacy layers (tf.layers) must always use + # resource variables. + value._use_resource_variables = True # pylint: disable=protected-access + return value + + @property + def layers(self): + return layer_utils.filter_empty_layer_containers(self._layers) + + @property + def trainable_weights(self): + return layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) + + @property + def non_trainable_weights(self): + return layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) + + @property + def weights(self): + return self.trainable_weights + self.non_trainable_weights + + @property + def trainable_variables(self): + return self.trainable_weights + + @property + def non_trainable_variables(self): + return self.non_trainable_weights + + @property + def variables(self): + return self.weights + + @property + def updates(self): + """Aggregate updates from any `Layer` instances.""" + # Updates and conditional losses are forwarded as-is rather than being + # filtered based on inputs, since this is just a container and won't ever + # have any inputs. + aggregated = [] + for layer in self.layers: + aggregated += layer.updates + return aggregated + + @property + def losses(self): + """Aggregate losses from any `Layer` instances.""" + aggregated = [] + for layer in self.layers: + aggregated += layer.losses + return aggregated + + def __hash__(self): + # Support object-identity hashing, so these structures can be used as keys + # in sets/dicts. + return id(self) + + def __eq__(self, other): + # Similar to Tensors, checkpointable data structures use object-identity + # equality to support set/dict membership. + return self is other + + +class List(CheckpointableDataStructure, collections.Sequence): + """An append-only sequence type which is checkpointable. + + Maintains checkpoint dependencies on its contents (which must also be + checkpointable), and forwards any `Layer` metadata such as updates and losses. + + Note that `List` is purely a container. It lets a `tf.keras.Model` or + other checkpointable object know about its contents, but does not call any + `Layer` instances which are added to it. To indicate a sequence of `Layer` + instances which should be called sequentially, use `tf.keras.Sequential`. + + Example usage: + ```python + class HasList(tf.keras.Model): + + def __init__(self): + super(HasList, self).__init__() + self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)]) + self.layer_list.append(layers.Dense(4)) + + def call(self, x): + aggregation = 0. + for l in self.layer_list: + x = l(x) + aggregation += tf.reduce_sum(x) + return aggregation + ``` + + This kind of wrapping is necessary because `Checkpointable` objects do not + (yet) deeply inspect regular Python data structures, so for example assigning + a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a + checkpoint dependency and does not add the `Layer` instance's weights to its + parent `Model`. + """ + + def __init__(self, *args, **kwargs): + """Construct a new sequence. Arguments are passed to `list()`.""" + super(List, self).__init__() + self._storage = self._make_storage(*args, **kwargs) + for index, element in enumerate(self._storage): + self._storage[index] = self._track_value( + element, name=self._name_element(index)) + + def _make_storage(self, *args, **kwargs): + """Determines the backing storage (overridden in subclasses).""" + return list(*args, **kwargs) + + def _name_element(self, index): + return "%d" % (index,) + + def append(self, value): + """Add a new checkpointable value.""" + value = self._track_value(value, self._name_element(len(self._storage))) + self._storage.append(value) + + def extend(self, values): + """Add a sequence of checkpointable values.""" + for value in values: + self._storage.append(self._track_value( + value, name=self._name_element(len(self._storage)))) + + def __iadd__(self, values): + self.extend(values) + return self + + def __add__(self, other): + if isinstance(other, List): + return self.__class__(self._storage + other._storage) # pylint: disable=protected-access + else: + return self.__class__(self._storage + other) + + def __radd__(self, other): + return self + other + + def __getitem__(self, key): + return self._storage[key] + + def __len__(self): + return len(self._storage) + + def __repr__(self): + return "List(%s)" % (repr(self._storage),) + + +class _ListWrapper(List, collections.MutableSequence, + # Shadowed, but there for isinstance checks. + list): + """Wraps the built-in `list` to support restore-on-create for variables. + + Unlike `List`, this sequence type is mutable in the same ways built-in lists + are. Instead of throwing an error immediately like `List`, it records + problematic mutations (e.g. assigning a new element to a position already + occupied, meaning both elements get the same names at different times) and + refuses to save. + + On assignment to an attribute of a Model or Checkpointable object, Python + lists are replaced with _ListWrapper. Wrapping a list in a + `tf.contrib.checkpoint.NoDependency` object prevents this. + """ + + def __init__(self, wrapped_list): + """Construct a new list wrapper. + + Args: + wrapped_list: The initial value of the data structure. A shallow copy may + be maintained for error checking. `wrapped_list` itself should not be + modified directly after constructing the `_ListWrapper`, and if changes + are detected the `_ListWrapper` will throw an exception on save. + """ + # Monotonic flags which indicate this object would not be restored properly, + # and therefore should throw an error on save to avoid giving the impression + # that restoring it will work. + self._non_append_mutation = False + self._external_modification = False + super(_ListWrapper, self).__init__(wrapped_list) + self._last_wrapped_list_snapshot = list(self._storage) + + def _make_storage(self, wrapped_list): + """Use the user's original list for storage.""" + return wrapped_list + + def _check_external_modification(self): + """Checks for any changes to the wrapped list not through the wrapper.""" + if self._external_modification or self._non_append_mutation: + return + if self._storage != self._last_wrapped_list_snapshot: + self._external_modification = True + self._last_wrapped_list_snapshot = None + + def _update_snapshot(self): + """Acknowledges tracked changes to the wrapped list.""" + if self._external_modification or self._non_append_mutation: + return + self._last_wrapped_list_snapshot = list(self._storage) + + @property + def _checkpoint_dependencies(self): + self._check_external_modification() + if self._non_append_mutation: + raise ValueError( + ("Unable to save the object %s (a list wrapper constructed to track " + "checkpointable TensorFlow objects). A list element was replaced " + "(__setitem__), deleted, or inserted. In order to support " + "restoration on object creation, tracking is exclusively for " + "append-only data structures.\n\nIf you don't need this list " + "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " + "object; it will be automatically un-wrapped and subsequently " + "ignored." % (self,))) + if self._external_modification: + raise ValueError( + ("Unable to save the object %s (a list wrapper constructed to track " + "checkpointable TensorFlow objects). The wrapped list was modified " + "outside the wrapper (its final value was %s, its value when a " + "checkpoint dependency was added was %s), which breaks restoration " + "on object creation.\n\nIf you don't need this list checkpointed, " + "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be " + "automatically un-wrapped and subsequently ignored." % ( + self, self._storage, self._last_wrapped_list_snapshot))) + return super(_ListWrapper, self)._checkpoint_dependencies + + def __delitem__(self, key): + self._non_append_mutation = True + del self._storage[key] + + def __setitem__(self, key, value): + self._non_append_mutation = True + self._storage[key] = value + + def append(self, value): + """Add a new checkpointable value.""" + self._check_external_modification() + super(_ListWrapper, self).append(value) + self._update_snapshot() + + def extend(self, values): + """Add a sequence of checkpointable values.""" + self._check_external_modification() + super(_ListWrapper, self).extend(values) + self._update_snapshot() + + def __eq__(self, other): + return self._storage == getattr(other, "_storage", other) + + def __ne__(self, other): + return self._storage != getattr(other, "_storage", other) + + def __lt__(self, other): + return self._storage < getattr(other, "_storage", other) + + def __le__(self, other): + return self._storage <= getattr(other, "_storage", other) + + def __gt__(self, other): + return self._storage > getattr(other, "_storage", other) + + def __ge__(self, other): + return self._storage >= getattr(other, "_storage", other) + + def __hash__(self): + # List wrappers need to compare like regular lists, and so like regular + # lists they don't belong in hash tables. + raise TypeError("unhashable type: 'ListWrapper'") + + def insert(self, index, obj): + self._non_append_mutation = True + self._storage.insert(index, obj) + + def _track_value(self, value, name): + """Allows storage of non-checkpointable objects.""" + try: + value = super(_ListWrapper, self)._track_value(value=value, name=name) + except ValueError: + # Even if this value isn't checkpointable, we need to make sure + # NoDependency objects get unwrapped. + value = sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + return value + + def __repr__(self): + return "ListWrapper(%s)" % (repr(self._storage),) + + +class Mapping(CheckpointableDataStructure, collections.Mapping): + """An append-only checkpointable mapping data structure with string keys. + + Maintains checkpoint dependencies on its contents (which must also be + checkpointable), named based on its keys. + + Note that once a key has been added, it may not be deleted or replaced. If + names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`. + """ + + def __init__(self, *args, **kwargs): + """Construct a new sequence. Arguments are passed to `dict()`.""" + super(Mapping, self).__init__() + self._storage = dict(*args, **kwargs) + self._storage.update( + {key: self._track_value( + value, name=self._name_element(key)) + for key, value in self._storage.items()}) + + def _name_element(self, key): + if not isinstance(key, six.string_types): + raise TypeError( + "Mapping accepts only string keys, but got a key %s." + % repr(key)) + return str(key) + + def __setitem__(self, key, value): + name = self._name_element(key) + value = self._track_value(value, name=name) + current_value = self._storage.setdefault(key, value) + if current_value is not value: + raise ValueError( + ("Mappings are an append-only data structure. Tried to overwrite the " + "key '%s' with value %s, but it already contains %s") + % (key, value, current_value)) + + def update(self, *args, **kwargs): + for key, value in dict(*args, **kwargs).items(): + self[key] = value + + def __getitem__(self, key): + return self._storage[key] + + def __len__(self): + return len(self._storage) + + def __repr__(self): + return "Mapping(%s)" % (repr(self._storage),) + + def __iter__(self): + return iter(self._storage) diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8c9da8090c968e8931f96949f5b982dd94f215 --- /dev/null +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -0,0 +1,303 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy + +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import normalization +from tensorflow.python.layers import core as non_keras_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking + + +class HasList(training.Model): + + def __init__(self): + super(HasList, self).__init__() + self.layer_list = data_structures.List([core.Dense(3)]) + self.layer_list.append(core.Dense(4)) + self.layer_list.extend( + [core.Dense(5), + core.Dense(6, kernel_regularizer=math_ops.reduce_sum)]) + self.layer_list += [ + core.Dense(7, bias_regularizer=math_ops.reduce_sum), + core.Dense(8) + ] + self.layer_list += ( + data_structures.List([core.Dense(9)]) + data_structures.List( + [core.Dense(10)])) + self.layer_list.extend( + data_structures.List( + list(sequence=[core.Dense(11)]) + [core.Dense(12)])) + self.layers_with_updates = data_structures.List( + sequence=(normalization.BatchNormalization(),)) + + def call(self, x): + aggregation = 0. + for l in self.layer_list: + x = l(x) + aggregation += math_ops.reduce_sum(x) + bn, = self.layers_with_updates + return bn(x) / aggregation + + +class ListTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testTracking(self): + model = HasList() + output = model(array_ops.ones([32, 2])) + self.assertAllEqual([32, 12], output.shape) + self.assertEqual(2, len(model.layers)) + self.assertIs(model.layer_list, model.layers[0]) + self.assertEqual(10, len(model.layers[0].layers)) + for index in range(10): + self.assertEqual(3 + index, model.layers[0].layers[index].units) + self.assertEqual(2, len(model._checkpoint_dependencies)) + self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref) + self.assertIs(model.layers_with_updates, + model._checkpoint_dependencies[1].ref) + self.assertEqual( + 10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies)) + self.evaluate([v.initializer for v in model.variables]) + self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]])) + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3]))) + model.load_weights(save_path) + self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]], + self.evaluate(model.variables[0])) + + def testUpdatesForwarded(self): + with context.graph_mode(): + model = HasList() + model_input = array_ops.ones([32, 2]) + model(model_input) + self.assertGreater(len(model.layers_with_updates[0].updates), 0) + self.assertEqual(set(model.layers_with_updates[0].updates), + set(model.updates)) + + with context.eager_mode(): + model = HasList() + model_input = array_ops.ones([32, 2]) + model(model_input) + self.assertEqual(0, len(model.updates)) + + @test_util.run_in_graph_and_eager_modes + def testLossesForwarded(self): + model = HasList() + model_input = array_ops.ones([32, 2]) + model(model_input) + self.assertEqual(2, len(model.losses)) + + def testModelContainersCompareEqual(self): + class HasEqualContainers(training.Model): + + def __init__(self): + super(HasEqualContainers, self).__init__() + self.l1 = [] + self.l2 = [] + + model = HasEqualContainers() + model.l1.append(HasEqualContainers()) + model.l2.append(HasEqualContainers()) + self.assertEqual([model.l1, model.l2], model.layers) + + def testNotCheckpointable(self): + class NotCheckpointable(object): + pass + + with self.assertRaises(ValueError): + data_structures.List([NotCheckpointable()]) + + def testCallNotImplemented(self): + with self.assertRaisesRegexp(TypeError, "not callable"): + data_structures.List()(1.) + + def testNoPop(self): + with self.assertRaises(AttributeError): + data_structures.List().pop() + + def testNesting(self): + with context.graph_mode(): + inner = data_structures.List() + outer = data_structures.List([inner]) + inner.append(non_keras_core.Dense(1)) + inner[0](array_ops.ones([2, 3])) + self.assertEqual(2, len(outer.variables)) + self.assertIsInstance( + outer.variables[0], + resource_variable_ops.ResourceVariable) + + def testNonLayerVariables(self): + v = resource_variable_ops.ResourceVariable([1.]) + l = data_structures.List([v]) + self.assertTrue(l.trainable) + self.assertEqual([], l.layers) + self.assertEqual([v], l.variables) + self.assertEqual([v], l.trainable_weights) + self.assertEqual([], l.non_trainable_variables) + l.trainable = False + self.assertEqual([v], l.variables) + self.assertEqual([], l.trainable_variables) + self.assertEqual([v], l.non_trainable_variables) + l.trainable = True + v2 = resource_variable_ops.ResourceVariable(1., trainable=False) + l.append(v2) + self.assertEqual([v, v2], l.weights) + self.assertEqual([v], l.trainable_weights) + self.assertEqual([v2], l.non_trainable_weights) + + def testListWrapperBasic(self): + # _ListWrapper, unlike List, compares like the built-in list type (since it + # is used to automatically replace lists). + a = tracking.Checkpointable() + b = tracking.Checkpointable() + self.assertEqual([a, a], + [a, a]) + self.assertEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([a, a])) + self.assertEqual([a, a], + data_structures._ListWrapper([a, a])) + self.assertEqual(data_structures._ListWrapper([a, a]), + [a, a]) + self.assertNotEqual([a, a], + [b, a]) + self.assertNotEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([b, a])) + self.assertNotEqual([a, a], + data_structures._ListWrapper([b, a])) + self.assertLess([a], [a, b]) + self.assertLess(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertLessEqual([a], [a, b]) + self.assertLessEqual(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertGreater([a, b], [a]) + self.assertGreater(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertGreaterEqual([a, b], [a]) + self.assertGreaterEqual(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertEqual([a], data_structures._ListWrapper([a])) + self.assertEqual([a], list(data_structures.List([a]))) + self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) + self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) + self.assertIsInstance(data_structures._ListWrapper([a]), list) + + def testWrapperChangesList(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l_wrapper.append(1) + self.assertEqual([1], l) + + def testListChangesWrapper(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l.append(1) + self.assertEqual([1], l_wrapper) + + def testHashing(self): + has_sequences = set([data_structures.List(), + data_structures.List()]) + self.assertEqual(2, len(has_sequences)) + self.assertNotIn(data_structures.List(), has_sequences) + with self.assertRaises(TypeError): + has_sequences.add(data_structures._ListWrapper([])) + + +class HasMapping(training.Model): + + def __init__(self): + super(HasMapping, self).__init__() + self.layer_dict = data_structures.Mapping(output=core.Dense(7)) + self.layer_dict["norm"] = data_structures.List() + self.layer_dict["dense"] = data_structures.List() + self.layer_dict["dense"].extend( + [core.Dense(5), + core.Dense(6, kernel_regularizer=math_ops.reduce_sum)]) + self.layer_dict["norm"].append( + normalization.BatchNormalization()) + self.layer_dict["norm"].append( + normalization.BatchNormalization()) + + def call(self, x): + aggregation = 0. + for norm, dense in zip(self.layer_dict["norm"], self.layer_dict["dense"]): + x = norm(dense(x)) + aggregation += math_ops.reduce_sum(x) + return self.layer_dict["output"](x) / aggregation + + +class MappingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testTracking(self): + model = HasMapping() + output = model(array_ops.ones([32, 2])) + self.assertAllEqual([32, 7], output.shape) + self.assertEqual(1, len(model.layers)) + self.assertIs(model.layer_dict, model.layers[0]) + self.assertEqual(3, len(model.layers[0].layers)) + self.assertEqual(1, len(model._checkpoint_dependencies)) + self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref) + self.evaluate([v.initializer for v in model.variables]) + test_var = model.layer_dict["output"].kernel + self.evaluate(test_var.assign(array_ops.ones([6, 7]))) + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + self.evaluate(test_var.assign(array_ops.zeros([6, 7]))) + model.load_weights(save_path) + self.assertAllEqual(numpy.ones([6, 7]), + self.evaluate(test_var)) + + def testNoOverwrite(self): + mapping = data_structures.Mapping() + original = data_structures.List() + mapping["a"] = original + with self.assertRaises(ValueError): + mapping["a"] = data_structures.List() + self.assertIs(original, mapping["a"]) + with self.assertRaises(AttributeError): + del mapping["a"] + mapping.update(b=data_structures.Mapping()) + with self.assertRaises(ValueError): + mapping.update({"b": data_structures.Mapping()}) + + def testNonStringKeys(self): + mapping = data_structures.Mapping() + with self.assertRaises(TypeError): + mapping[1] = data_structures.List() + + def testHashing(self): + has_mappings = set([data_structures.Mapping(), + data_structures.Mapping()]) + self.assertEqual(2, len(has_mappings)) + self.assertNotIn(data_structures.Mapping(), has_mappings) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..978fcb2252cd4481b8286bdf3afd58b30ce6d665 --- /dev/null +++ b/tensorflow/python/training/checkpointable/layer_utils.py @@ -0,0 +1,93 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities related to layer/model functionality.""" + +# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils +# once __init__ files no longer require all of tf.keras to be imported together. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def is_layer(obj): + """Implicit check for Layer-like objects.""" + # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). + return (hasattr(obj, "call") + and hasattr(obj, "build") + and hasattr(obj, "variables")) + + +def filter_empty_layer_containers(layer_list): + """Filter out empty Layer-like containers.""" + return [layer for layer in layer_list + # Filter out only empty Checkpointable data structures. Empty Networks + # will still show up in Model.layers. + if is_layer(layer) or getattr(layer, "layers", True)] + + +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0bed9d46f2e75633e3bf1230eded3708ec1c8b --- /dev/null +++ b/tensorflow/python/training/checkpointable/tracking.py @@ -0,0 +1,72 @@ +"""Dependency tracking for checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures + + +class NotCheckpointable(object): + """Marks instances of child classes as unsaveable using an object-based API. + + Useful for marking objects which would otherwise look checkpointable because + of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from + `NotCheckpointable` does not prevent an object from being assigned to any + attributes, but will throw an error on save/restore. + """ + pass + + +class Checkpointable(base.CheckpointableBase): + """Manages dependencies on other objects. + + `Checkpointable` objects may have dependencies: other `Checkpointable` objects + which should be saved if the object declaring the dependency is saved. A + correctly saveable program has a dependency graph such that if changing a + global variable affects an object (e.g. changes the behavior of any of its + methods) then there is a chain of dependencies from the influenced object to + the variable. + + Dependency edges have names, and are created implicitly when a + `Checkpointable` object is assigned to an attribute of another + `Checkpointable` object. For example: + + ``` + obj = Checkpointable() + obj.v = ResourceVariable(0.) + ``` + + The `Checkpointable` object `obj` now has a dependency named "v" on a + variable. + + `Checkpointable` objects may specify `Tensor`s to be saved and restored + directly (e.g. a `Variable` indicating how to save itself) rather than through + dependencies on other objects. See + `Checkpointable._gather_saveables_for_checkpoint` for details. + """ + + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax.""" + if getattr(self, "_setattr_tracking", True): + value = data_structures.sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + super(Checkpointable, self).__setattr__(name, value) + + def _no_dependency(self, value): + """Override to allow CheckpointableBase to disable dependency tracking.""" + return data_structures.NoDependency(value) diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py new file mode 100644 index 0000000000000000000000000000000000000000..96da0d6e4720b44815de137c0efdd74645bae0fc --- /dev/null +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy + +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util +from tensorflow.python.util import nest + + +class InterfaceTests(test.TestCase): + + def testMultipleAssignment(self): + root = tracking.Checkpointable() + root.leaf = tracking.Checkpointable() + root.leaf = root.leaf + duplicate_name_dep = tracking.Checkpointable() + with self.assertRaisesRegexp(ValueError, "already declared"): + root._track_checkpointable(duplicate_name_dep, name="leaf") + # No error; we're overriding __setattr__, so we can't really stop people + # from doing this while maintaining backward compatibility. + root.leaf = duplicate_name_dep + root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf")) + (_, dep_object), = root._checkpoint_dependencies + self.assertIs(duplicate_name_dep, dep_object) + + def testNoDependency(self): + root = tracking.Checkpointable() + hasdep = tracking.Checkpointable() + root.hasdep = hasdep + nodep = tracking.Checkpointable() + root.nodep = data_structures.NoDependency(nodep) + self.assertEqual(1, len(root._checkpoint_dependencies)) + self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) + self.assertIs(root.hasdep, hasdep) + self.assertIs(root.nodep, nodep) + + class NoDependencyModel(training.Model): + + @base.no_automatic_dependency_tracking + def __init__(self): + super(NoDependencyModel, self).__init__() + self.a = [] + self.b = tracking.Checkpointable() + + nodeps = NoDependencyModel() + self.assertEqual([nodeps], util.list_objects(nodeps)) + + def testListBasic(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + a.l = [b] + c = tracking.Checkpointable() + a.l.append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + direct_a_dep, = a._checkpoint_dependencies + self.assertEqual("l", direct_a_dep.name) + self.assertIn(b, direct_a_dep.ref) + self.assertIn(c, direct_a_dep.ref) + + @test_util.run_in_graph_and_eager_modes + def testMutationDirtiesList(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + a.l = [b] + c = tracking.Checkpointable() + a.l.insert(0, c) + checkpoint = util.Checkpoint(a=a) + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testOutOfBandEditDirtiesList(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + held_reference = [b] + a.l = held_reference + c = tracking.Checkpointable() + held_reference.append(c) + checkpoint = util.Checkpoint(a=a) + with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testNestedLists(self): + a = tracking.Checkpointable() + a.l = [] + b = tracking.Checkpointable() + a.l.append([b]) + c = tracking.Checkpointable() + a.l[0].append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + a.l[0].append(1) + d = tracking.Checkpointable() + a.l[0].append(d) + a_deps = util.list_objects(a) + self.assertIn(d, a_deps) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + self.assertNotIn(1, a_deps) + e = tracking.Checkpointable() + f = tracking.Checkpointable() + a.l1 = [[], [e]] + a.l1[0].append(f) + a_deps = util.list_objects(a) + self.assertIn(e, a_deps) + self.assertIn(f, a_deps) + checkpoint = util.Checkpoint(a=a) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + a.l[0].append(data_structures.NoDependency([])) + a.l[0][-1].append(5) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + # Dirtying the inner list means the root object is unsaveable. + a.l[0][1] = 2 + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testNoDepList(self): + a = training.Model() + a.l1 = data_structures.NoDependency([]) + a.l1.insert(1, 0) + self.assertTrue(isinstance(a.l1, list)) + checkpoint = util.Checkpoint(a=a) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + a.l2 = [] + a.l2.insert(1, 0) + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testAssertions(self): + a = tracking.Checkpointable() + a.l = [numpy.zeros([2, 2])] + self.assertAllEqual([numpy.zeros([2, 2])], a.l) + self.assertAllClose([numpy.zeros([2, 2])], a.l) + nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])]) + a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])] + self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])], + self.evaluate(a.tensors)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 96e6d10791f396ad7f9f73cce9356dd4cbe3ce9d..6ae5765b133cc72b67f3d9864d0f67abf33f0648 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -39,8 +39,11 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saveable_object as saveable_object_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking from tensorflow.python.util import deprecation +from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -91,7 +94,7 @@ class _CheckpointRestoreCoordinator(object): # use them (for example because of inconsistent references when # loading). Used to make status assertions fail when loading checkpoints # that don't quite match. - self.all_python_objects = weakref.WeakSet() + self.all_python_objects = _ObjectIdentityWeakSet() self.save_path = save_path self.dtype_map = dtype_map # When graph building, contains a list of ops to run to restore objects from @@ -113,7 +116,7 @@ class _CheckpointRestoreCoordinator(object): # `node` refers to an `Optimizer`, since only these have slot variables. self.slot_restorations.setdefault( slot_reference.original_variable_node_id, []).append( - checkpointable_lib._SlotVariableRestoration( # pylint: disable=protected-access + base._SlotVariableRestoration( # pylint: disable=protected-access optimizer_id=node_index, slot_variable_id=slot_reference.slot_variable_node_id, slot_name=slot_reference.slot_name)) @@ -257,27 +260,145 @@ def object_metadata(save_path): reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: raise ValueError( ('The specified checkpoint "%s" does not appear to be object-based (it ' 'is missing the key "%s"). Likely it was created with a name-based ' 'saver and does not contain an object dependency graph.') % ( - save_path, checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + save_path, base.OBJECT_GRAPH_PROTO_KEY)) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) return object_graph_proto +class _ObjectIdentityWrapper(object): + """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. + + Since __eq__ is based on object identity, it's safe to also define __hash__ + based on object ids. This lets us add unhashable types like checkpointable + _ListWrapper objects to object-identity collections. + """ + + def __init__(self, wrapped): + self._wrapped = wrapped + + @property + def unwrapped(self): + return self._wrapped + + def __eq__(self, other): + if isinstance(other, _ObjectIdentityWrapper): + return self._wrapped is other._wrapped # pylint: disable=protected-access + return self._wrapped is other + + def __hash__(self): + # Wrapper id() is also fine for weakrefs. In fact, we rely on + # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is + # weakref.ref(a) in _WeakObjectIdentityWrapper. + return id(self._wrapped) + + +class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): + + def __init__(self, wrapped): + super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) + + @property + def unwrapped(self): + return self._wrapped() + + +class _ObjectIdentityDictionary(collections.MutableMapping): + """A mutable mapping data structure which compares using "is". + + This is necessary because we have checkpointable objects (_ListWrapper) which + have behavior identical to built-in Python lists (including being unhashable + and comparing based on the equality of their contents by default). + """ + + def __init__(self): + self._storage = {} + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) + + def __getitem__(self, key): + return self._storage[self._wrap_key(key)] + + def __setitem__(self, key, value): + self._storage[self._wrap_key(key)] = value + + def __delitem__(self, key): + del self._storage[self._wrap_key(key)] + + def __len__(self): + return len(self._storage) + + def __iter__(self): + for key in self._storage: + yield key.unwrapped + + +class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary): + """Like weakref.WeakKeyDictionary, but compares objects with "is".""" + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) + + def __len__(self): + # Iterate, discarding old weak refs + return len(list(self._storage)) + + def __iter__(self): + keys = self._storage.keys() + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + del self[key] + else: + yield unwrapped + + +class _ObjectIdentityWeakSet(collections.MutableSet): + """Like weakref.WeakSet, but compares objects with "is".""" + + def __init__(self): + self._storage = set() + + def __contains__(self, key): + return _WeakObjectIdentityWrapper(key) in self._storage + + def discard(self, key): + self._storage.discard(_WeakObjectIdentityWrapper(key)) + + def add(self, key): + self._storage.add(_WeakObjectIdentityWrapper(key)) + + def __len__(self): + # Iterate, discarding old weak refs + return len(list(self)) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + self.discard(key) + else: + yield unwrapped + + def _breadth_first_checkpointable_traversal(root_checkpointable): """Find shortest paths to all variables owned by dependencies of root.""" bfs_sorted = [] to_visit = collections.deque([root_checkpointable]) - path_to_root = {root_checkpointable: ()} + path_to_root = _ObjectIdentityDictionary() + path_to_root[root_checkpointable] = () while to_visit: current_checkpointable = to_visit.popleft() - if isinstance(current_checkpointable, checkpointable_lib.NotCheckpointable): + if isinstance(current_checkpointable, tracking.NotCheckpointable): raise NotImplementedError( ("The object %s does not support object-based saving. File a feature " "request if this limitation bothers you. In the meantime, you can " @@ -335,7 +456,7 @@ def _slot_variable_naming_for_optimizer(optimizer_path): def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): """Gather and name slot variables.""" non_slot_objects = list(checkpointable_objects) - slot_variables = {} + slot_variables = _ObjectIdentityDictionary() for checkpointable in non_slot_objects: if isinstance(checkpointable, optimizer_lib.Optimizer): naming_scheme = _slot_variable_naming_for_optimizer( @@ -498,11 +619,12 @@ def _serialize_object_graph(root_checkpointable, saveables_cache): """ checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id slot_variables = _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, @@ -533,11 +655,12 @@ def list_objects(root_checkpointable): # to run. checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, @@ -564,6 +687,93 @@ def gather_initializers(root_checkpointable): if hasattr(c, "initializer") and c.initializer is not None] +@tf_contextlib.contextmanager +def capture_dependencies(template): + """Capture variables created within this scope as `Template` dependencies. + + Requires that `template.variable_scope` is active. + + This scope is intended as a compatibility measure, allowing a checkpointable + object to add dependencies on variables created in a block of code which is + not aware of object-based saving (and instead uses variable names + heavily). This is how `Template` objects add dependencies on variables and + sub-`Template`s. Where possible, use `tf.make_template` directly. + + Args: + template: The `Template` object to register dependencies with. + + Yields: + None (when used as a context manager). + """ + name_prefix = template.variable_scope.name + + def _checkpointable_custom_creator(next_creator, name, initial_value, + checkpointable_parent=None, **kwargs): + """A variable creation hook which adds Checkpointable dependencies. + + Set for example during a `Template`'s first wrapped function + execution. Ensures that (a) `template` depends on any checkpointable + objects using their own `capture_dependencies` scope inside this scope which + create variables, and (b) that any variables not in a more deeply nested + scope are added as dependencies directly. + + The `checkpointable_parent` argument is passed between custom creators but + ignored when the variable object itself is created. This argument indicates + (if not `None`) that a more deeply nested scope has already added the + variable as a dependency, and that parent scopes should add a dependency on + that object rather than on the variable directly. + + Args: + next_creator: See `variable_scope.variable_creator_scope`; the next + creator in the chain. + name: The (full, scope-influenced) name of the variable. The `name_prefix` + itself is stripped for the purposes of object-based dependency tracking, + but scopes opened within this scope are respected. + initial_value: See `variable_scope.variable_creator_scope`. Taken + explicitly so the argument can be re-named and used with + `Checkpointable._add_variable_with_custom_getter`. + checkpointable_parent: If not None, a more deeply nested checkpointable + object and its name prefix which were passed to `capture_dependencies` + to add a dependency on (rather than depending on the variable directly). + **kwargs: Passed through to the next creator. + + Returns: + The output of `next_creator`: the fetched/created variable object. + """ + def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): + inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which + # we don't want to propagate. + return next_creator( + initial_value=initializer, + name=name, + **inner_kwargs) + if name.startswith(name_prefix): + scope_stripped_name = name[len(name_prefix) + 1:] + if not checkpointable_parent: + return template._add_variable_with_custom_getter( # pylint: disable=protected-access + initializer=initial_value, + name=scope_stripped_name, + getter=_call_next_creator_renaming_initializer, + # Disable error checking for Checkpointable. Exceptions are instead + # raised if necessary when the object-based saver tries to + # save/restore the object. + overwrite=True, + checkpointable_parent=(template, name_prefix), + **kwargs) + else: + parent_object, parent_name_prefix = checkpointable_parent + template._track_checkpointable( # pylint: disable=protected-access + parent_object, + name=parent_name_prefix[len(name_prefix) + 1:], + overwrite=True) + return next_creator( + name=name, initial_value=initial_value, + checkpointable_parent=(template, name_prefix), **kwargs) + + with variable_scope.variable_creator_scope(_checkpointable_custom_creator): + yield + + class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): def __init__(self, tensor, name): @@ -899,7 +1109,7 @@ class CheckpointableSaver(object): else: # Maps Checkpointable objects -> attribute names -> SaveableObjects, to # avoid re-creating SaveableObjects when graph building. - self._saveable_object_cache = weakref.WeakKeyDictionary() + self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary() @property def _root_checkpointable(self): @@ -950,11 +1160,11 @@ class CheckpointableSaver(object): with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) - assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables + assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables named_variables.append( _NoRestoreSaveable( tensor=object_graph_tensor, - name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + name=base.OBJECT_GRAPH_PROTO_KEY)) if (self._last_save_object_graph != graph_proto # When executing eagerly, we need to re-create SaveableObjects each time # save() is called so they pick up new Tensors passed to their @@ -1044,7 +1254,7 @@ class CheckpointableSaver(object): dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. @@ -1090,7 +1300,7 @@ class CheckpointableSaver(object): "file a feature request if this limitation bothers you.") self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto - checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access + base._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus( checkpoint, @@ -1100,7 +1310,7 @@ class CheckpointableSaver(object): @tf_export("train.Checkpoint") -class Checkpoint(checkpointable_lib.Checkpointable): +class Checkpoint(tracking.Checkpointable): """Groups checkpointable objects, saving and restoring them. `Checkpoint`'s constructor accepts keyword arguments whose values are types @@ -1202,7 +1412,7 @@ class Checkpoint(checkpointable_lib.Checkpointable): """ super(Checkpoint, self).__init__() for k, v in sorted(kwargs.items(), key=lambda item: item[0]): - if not isinstance(v, checkpointable_lib.CheckpointableBase): + if not isinstance(v, base.CheckpointableBase): raise ValueError( ("`Checkpoint` was expecting a checkpointable object (an object " "derived from `CheckpointableBase`), got %s. If you believe this " @@ -1221,7 +1431,7 @@ class Checkpoint(checkpointable_lib.Checkpointable): with ops.device("/cpu:0"): # add_variable creates a dependency named "save_counter"; NoDependency # prevents creating a second dependency named "_save_counter". - self._save_counter = checkpointable_lib.NoDependency( + self._save_counter = data_structures.NoDependency( add_variable(self, name="save_counter", initializer=0, dtype=dtypes.int64)) diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 8968aad283b69923f5d10afb7769e89f41c197bf..3c1a4a6f83c20a74961bf3e1263b2a33d3e36f05 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -32,9 +32,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl.keras.engine import sequential -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops @@ -44,11 +44,12 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import adam from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -101,7 +102,7 @@ class InterfaceTests(test.TestCase): name="duplicate", initial_value=1.) duplicate = checkpointable_utils.add_variable( obj, name="duplicate", shape=[]) - with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): + with self.assertRaisesRegexp(ValueError, "'duplicate'.*already declared"): checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) self.evaluate(checkpointable_utils.gather_initializers(obj)) @@ -136,7 +137,7 @@ class InterfaceTests(test.TestCase): def testInitNotCalled(self): - class NoInit(checkpointable.Checkpointable): + class NoInit(tracking.Checkpointable): def __init__(self): pass @@ -145,7 +146,7 @@ class InterfaceTests(test.TestCase): checkpointable_utils.add_variable(NoInit(), "var", shape=[]) def testShapeDtype(self): - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() v1 = checkpointable_utils.add_variable( root, name="v1", initializer=3., dtype=dtypes.float64) self.assertEqual(dtypes.float64, v1.dtype) @@ -177,7 +178,7 @@ class InterfaceTests(test.TestCase): def testNotCheckpointable(self): class CallsFunctionalStuff( - checkpointable.NotCheckpointable, checkpointable.Checkpointable): + tracking.NotCheckpointable, tracking.Checkpointable): pass test_dir = self.get_temp_dir() @@ -187,7 +188,7 @@ class InterfaceTests(test.TestCase): checkpoint.save(prefix) class CallsFunctionalStuffOtherMRO( - checkpointable.Checkpointable, checkpointable.NotCheckpointable): + tracking.Checkpointable, tracking.NotCheckpointable): pass checkpoint_reversed = checkpointable_utils.Checkpoint( @@ -217,7 +218,7 @@ class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject): self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): +class _OwnsMirroredVariables(base.CheckpointableBase): """A Checkpointable object which returns a more complex SaveableObject.""" def __init__(self): @@ -232,7 +233,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): primary_variable=self.non_dep_variable, mirrored_variable=self.mirrored, name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + return {base.VARIABLE_VALUE_KEY: _saveable_factory} # The Saver sorts by name before parsing, so we need a name property. @property @@ -355,7 +356,7 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() checkpoint = checkpointable_utils.Checkpoint(v=v) @@ -375,7 +376,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(44., self.evaluate(v.non_dep_variable)) self.assertEqual(44., self.evaluate(v.mirrored)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturnedWithGlobalName(self): # The same object can also be saved using the name-based saver. v = _OwnsMirroredVariables() @@ -391,7 +392,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(42., self.evaluate(v.non_dep_variable)) self.assertEqual(42., self.evaluate(v.mirrored)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) @@ -512,7 +513,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, session.run(root.save_counter)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. @@ -546,7 +547,7 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: disable=cell-var-from-loop - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() @@ -590,7 +591,7 @@ class CheckpointingTests(test.TestCase): # pylint: enable=cell-var-from-loop def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( @@ -611,18 +612,18 @@ class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() + root = tracking.Checkpointable() + leaf = tracking.Checkpointable() root.leaf = leaf checkpointable_utils.add_variable(leaf, name="v", shape=[]) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( root, saveables_cache=None) self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() + root = tracking.Checkpointable() + leaf = tracking.Checkpointable() # Dots are escaped, which avoids conflicts with reserved names. root._track_checkpointable(leaf, name=".ATTRIBUTES") checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) @@ -660,16 +661,16 @@ class CheckpointingTests(test.TestCase): optimizer.apply_gradients( [(g, v) for g, v in zip(grad, model.vars)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLateDependencyTracking(self): - class Dependency(checkpointable.Checkpointable): + class Dependency(tracking.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) - class LateDependencies(checkpointable.Checkpointable): + class LateDependencies(tracking.Checkpointable): def add_dep(self): self.dep = Dependency() @@ -692,16 +693,16 @@ class CheckpointingTests(test.TestCase): status.run_restore_ops() self.assertEqual(123., self.evaluate(load_into.dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDepAfterVar(self): - class Dependency(checkpointable.Checkpointable): + class Dependency(tracking.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) - class DepAfterVar(checkpointable.Checkpointable): + class DepAfterVar(tracking.Checkpointable): def add_dep(self): dep = Dependency() @@ -724,11 +725,11 @@ class CheckpointingTests(test.TestCase): status.run_restore_ops() self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -751,7 +752,7 @@ class CheckpointingTests(test.TestCase): 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( @@ -789,11 +790,11 @@ class CheckpointingTests(test.TestCase): self.evaluate(train_op) slot_status.assert_consumed() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep = tracking.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) @@ -802,13 +803,13 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = saver.save(os.path.join(checkpoint_directory, "second")) - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() + first_root = tracking.Checkpointable() + second_root = tracking.Checkpointable() first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) - load_dep = checkpointable.Checkpointable() + load_dep = tracking.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep @@ -822,13 +823,13 @@ class CheckpointingTests(test.TestCase): # Try again with the order of the restore() reversed. The last restore # determines the final value. - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() + first_root = tracking.Checkpointable() + second_root = tracking.Checkpointable() second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) - load_dep = checkpointable.Checkpointable() + load_dep = tracking.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep @@ -840,39 +841,39 @@ class CheckpointingTests(test.TestCase): second_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAmbiguousLoad(self): # Not OK to split one checkpoint object into two checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() - dep_three = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep_one = tracking.Checkpointable() + save_root.dep_two = tracking.Checkpointable() + dep_three = tracking.Checkpointable() save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) - load_root = checkpointable.Checkpointable() + load_root = tracking.Checkpointable() status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path) - load_root.dep_one = checkpointable.Checkpointable() - load_root.dep_two = checkpointable.Checkpointable() - load_root.dep_one.dep_three = checkpointable.Checkpointable() - load_root.dep_two.dep_three = checkpointable.Checkpointable() + load_root.dep_one = tracking.Checkpointable() + load_root.dep_two = tracking.Checkpointable() + load_root.dep_one.dep_three = tracking.Checkpointable() + load_root.dep_two.dep_three = tracking.Checkpointable() checkpointable_utils.add_variable( load_root.dep_one.dep_three, name="var", initializer=0.) with self.assertRaises(AssertionError): status.assert_consumed() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testObjectsCombined(self): # Currently fine to load two checkpoint objects into one Python object checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep_one = tracking.Checkpointable() + save_root.dep_two = tracking.Checkpointable() checkpointable_utils.add_variable( save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( @@ -880,8 +881,8 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) - load_root = checkpointable.Checkpointable() - load_root.dep_one = checkpointable.Checkpointable() + load_root = tracking.Checkpointable() + load_root.dep_one = tracking.Checkpointable() load_root.dep_two = load_root.dep_one v1 = checkpointable_utils.add_variable( load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) @@ -893,12 +894,12 @@ class CheckpointingTests(test.TestCase): self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDependencyLoop(self): # Note: this test creates garbage during eager execution because it # purposefully creates a reference cycle. - first = checkpointable.Checkpointable() - second = checkpointable.Checkpointable() + first = tracking.Checkpointable() + second = tracking.Checkpointable() first.second = second second.first = first first.v = checkpointable_utils.add_variable( @@ -911,10 +912,10 @@ class CheckpointingTests(test.TestCase): os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading - first_load = checkpointable.Checkpointable() + first_load = tracking.Checkpointable() status = checkpointable_utils.CheckpointableSaver( first_load).restore(save_path) - second_load = checkpointable.Checkpointable() + second_load = tracking.Checkpointable() first_load.second = second_load second_load.first = first_load with self.assertRaises(AssertionError): @@ -939,13 +940,13 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRestoreOnAssign(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_graph = ops.Graph() with save_graph.as_default(), self.test_session(save_graph): - first = checkpointable.Checkpointable() + first = tracking.Checkpointable() first.var1 = variable_scope.get_variable( name="outside_var", initializer=0.) first.var2 = variable_scope.get_variable( @@ -956,7 +957,7 @@ class CheckpointingTests(test.TestCase): checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): - second = checkpointable.Checkpointable() + second = tracking.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) status = checkpointable_utils.CheckpointableSaver( @@ -978,7 +979,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -989,11 +990,11 @@ class CheckpointingTests(test.TestCase): saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCheckpointCleanup(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.Checkpoint(obj=obj) @@ -1009,11 +1010,11 @@ class CheckpointingTests(test.TestCase): expected_filenames, os.listdir(checkpoint_directory)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCheckpointCleanupChangingVarList(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(obj)) checkpoint = checkpointable_utils.Checkpoint(obj=obj) @@ -1062,7 +1063,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -1132,7 +1133,7 @@ class CheckpointingTests(test.TestCase): beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_sequential(self): model = sequential.Sequential() checkpoint = checkpointable_utils.Checkpoint(model=model) @@ -1164,7 +1165,7 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(deferred_second_dense.bias)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_initialize_if_not_restoring(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -1243,9 +1244,21 @@ class CheckpointingTests(test.TestCase): self.assertEqual(42., self.evaluate(optimizer.variables()[0])) +class _ManualScope(tracking.Checkpointable): + + def __call__(self): + with variable_scope.variable_scope("ManualScope") as vs: + self.variable_scope = vs + with checkpointable_utils.capture_dependencies(template=self): + return self._build() + + def _build(self): + return variable_scope.get_variable(name="in_manual_scope", shape=[]) + + class TemplateTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore(self): def _templated(): @@ -1255,14 +1268,23 @@ class TemplateTests(test.TestCase): v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) - return v, v + 1., v2 + manual = _ManualScope() + return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) - v1_save, _, v2_save = save_template() + v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() + six.assertCountEqual( + self, + [v1_save, v2_save, manual_scope, manual_scope_v, save_template], + checkpointable_utils.list_objects(save_template)) + manual_dep, = manual_scope._checkpoint_dependencies + self.assertEqual("in_manual_scope", manual_dep.name) + self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.AdamOptimizer(0.0) save_root = checkpointable_utils.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) + self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) @@ -1275,17 +1297,19 @@ class TemplateTests(test.TestCase): load_root = checkpointable_utils.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) - var, var_plus_one, var2 = load_template() + var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) - self.assertEqual(2, len(load_template._checkpoint_dependencies)) + self.assertEqual(3, len(load_template._checkpoint_dependencies)) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) + self.assertEqual("ManualScope", + load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore_nested(self): def _inner_template(): @@ -1386,7 +1410,7 @@ class CheckpointCompatibilityTests(test.TestCase): sess=session, save_path=checkpoint_prefix, global_step=root.optimizer_step) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): @@ -1448,7 +1472,7 @@ class CheckpointCompatibilityTests(test.TestCase): class PythonMetadataTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveLoad(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py index e31fa02d60679d218a62f4e2affc16f0f5bc51c3..70e1ca4b5d77e5e7529cb0d06a9ffb4657dc74fe 100644 --- a/tensorflow/python/training/device_util.py +++ b/tensorflow/python/training/device_util.py @@ -27,13 +27,15 @@ def canonicalize(d, default=None): """Canonicalize device string. If d has missing components, the rest would be deduced from the `default` - argument or from '/job:localhost/replica:0/task:0/device:CPU:0'. For example: + argument or from '/replica:0/task:0/device:CPU:0'. For example: If d = '/cpu:0', default='/job:worker/task:1', it returns '/job:worker/replica:0/task:1/device:CPU:0'. If d = '/cpu:0', default='/job:worker', it returns '/job:worker/replica:0/task:0/device:CPU:0'. If d = '/gpu:0', default=None, it returns - '/job:localhost/replica:0/task:0/device:GPU:0'. + '/replica:0/task:0/device:GPU:0'. + + Note: This uses "job:localhost" as the default if executing eagerly. Args: d: a device string. @@ -47,7 +49,9 @@ def canonicalize(d, default=None): "Device type '%s' must be all-caps." % (d.device_type,)) # Fill in missing device fields using defaults. result = tf_device.DeviceSpec( - job="localhost", replica=0, task=0, device_type="CPU", device_index=0) + replica=0, task=0, device_type="CPU", device_index=0) + if context.executing_eagerly(): + result.job = "localhost" if default: result.merge_from(tf_device.DeviceSpec.from_string(default)) result.merge_from(d) diff --git a/tensorflow/python/training/device_util_test.py b/tensorflow/python/training/device_util_test.py index 61525e21f508bcef5b61fd077d288b93803f1aa8..cdbb08229d2f06c2cfeeb855b32665f7c03ea969 100644 --- a/tensorflow/python/training/device_util_test.py +++ b/tensorflow/python/training/device_util_test.py @@ -52,7 +52,7 @@ class DeviceUtilTest(test.TestCase): def testCanonicalizeWithoutDefaultDevice(self): self.assertEqual( device_util.canonicalize("/cpu:0"), - "/job:localhost/replica:0/task:0/device:CPU:0") + "/replica:0/task:0/device:CPU:0") self.assertEqual( device_util.canonicalize("/job:worker/cpu:0"), "/job:worker/replica:0/task:0/device:CPU:0") diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index ab8b37bb655bfc3c222ed661b6d48f0ecdc3a858..d33fd7376a7244535f7a0f393dd6125b125b8018 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import threading -import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops @@ -222,11 +221,11 @@ def has_distribution_strategy(): def get_loss_reduction(): - """Reduce `method_string` corresponding to the last loss reduction.""" + """Reduce `aggregation` corresponding to the last loss reduction.""" loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access if loss_reduction == losses_impl.Reduction.SUM: - return "sum" - return "mean" + return variable_scope.VariableAggregation.SUM + return variable_scope.VariableAggregation.MEAN # ------------------------------------------------------------------------------ @@ -527,15 +526,21 @@ class DistributionStrategy(object): V(`v`), output will have locality V(`v`) as well. * `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower context, like `d.update()` except with locality N. - * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device. + * `d.read_var(v)`: Gets the (read-only) value of the variable `v` (on + the device determined by the current device scope), aggregating + across towers for tower-local variables. Frequently, this will be + done automatically when using `v` in an expression or fetching it in + a cross-tower context, but this function can be used to force that + conversion happens at a particular point in time (for example, to + add the result of the conversion to a graph collection). The standard pattern for updating variables is to: 1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator. 2. Define each tower `d.call_for_each_tower()` up to the point of getting a list of gradient, variable pairs. - 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the - gradients (with locality T) into values with locality V(`v`). + 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum + the gradients (with locality T) into values with locality V(`v`). 4. Call `d.update(v)` for each variable to update its value. Steps 3 and 4 are done automatically by class `Optimizer` if you call @@ -609,18 +614,18 @@ class DistributionStrategy(object): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") - def tower_local_var_scope(self, reduce_method): + def tower_local_var_scope(self, aggregation): """Inside this scope, new variables will not be mirrored. There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead, when saving them - or calling `fetch()`, we use the value that results when calling - `reduce()` on all the towers' variables. + or calling `read_var()`, we use the value that results when + calling `reduce()` on all the towers' variables. Note: tower-local implies not trainable. Instead, it is expected that each tower will directly update (using `assign_add()` or whatever) its local variable instance but only the aggregated - value (accessible using `fetch()`) will be exported from the + value (accessible using `read_var()`) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce communication overhead by using tower-local variables. @@ -631,21 +636,41 @@ class DistributionStrategy(object): random numbers. Args: - reduce_method: String used as a `method_string` to `reduce()` - to get the value to save when checkpointing. + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. Returns: A context manager. """ + # TODO(psv): Remove this after adding support for synchronization and + # aggregation parameters in get_variable() and mirrored strategy. def create_tower_local_variable(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) kwargs["use_resource"] = True - kwargs["tower_local_reduce_method"] = reduce_method + + # Set synchronization to be ON_READ for tower local variables. + kwargs["synchronization"] = variable_scope.VariableSynchronization.ON_READ + kwargs["aggregation"] = aggregation return next_creator(*args, **kwargs) _require_distribution_strategy_scope(self) return variable_scope.variable_creator_scope(create_tower_local_variable) + def read_var(self, v): + """Reads the value of a variable. + + Returns the aggregate value of a tower-local variable, or the + (read-only) value of any other variable. + + Args: + v: A variable allocated within the scope of this `DistributionStrategy`. + + Returns: + A tensor representing the value of `v`, aggregated across towers if + necessary. + """ + raise NotImplementedError("must be implemented in descendants") + def colocate_vars_with(self, colocate_with_variable): """Scope that controls which devices variables will be created on. @@ -796,12 +821,12 @@ class DistributionStrategy(object): def _call_for_each_tower(self, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def reduce(self, method_string, value, destinations=None): + def reduce(self, aggregation, value, destinations=None): """Combine (via e.g. sum or mean) values across towers. Args: - method_string: A string indicating how to combine values, either - "sum" or "mean". + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. value: A per-device value with one value per tower. destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all @@ -816,18 +841,21 @@ class DistributionStrategy(object): # TODO(josh11b): Return an unwrapped value if colocate_with is a # single device. _require_cross_tower_context(self) - assert method_string in ("sum", "mean") - return self._reduce(method_string, value, destinations) + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ] + return self._reduce(aggregation, value, destinations) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): raise NotImplementedError("must be implemented in descendants") - def batch_reduce(self, method_string, value_destination_pairs): + def batch_reduce(self, aggregation, value_destination_pairs): """Combine multiple `reduce` calls into one for faster execution. Args: - method_string: A string indicating how to combine values, either - "sum" or "mean". + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -836,12 +864,17 @@ class DistributionStrategy(object): """ # TODO(josh11b): More docstring _require_cross_tower_context(self) - assert method_string in ("sum", "mean") - return self._batch_reduce(method_string, value_destination_pairs) - - def _batch_reduce(self, method_string, value_destination_pairs): - return [self.reduce(method_string, t, destinations=v) - for t, v in value_destination_pairs] + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ] + return self._batch_reduce(aggregation, value_destination_pairs) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self.reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] def update(self, var, fn, *args, **kwargs): """Run `fn` to update `var` using inputs mirrored to the same devices. @@ -897,30 +930,6 @@ class DistributionStrategy(object): def _update_non_slot(self, colocate_with, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x): - """Return a copy of `val` or `fn(val)` on `destination`. - - This is useful for getting a mirrored value onto a device. It - will attempt to avoid a copy by checking if the value is already - on the destination device. - - Args: - val: Value (which may be mirrored) to copy. - destination: A device string to copy the value to. - fn: An optional function to apply to the value on the source - device, before copying. - - Returns: - A `Tensor` on `destination`. - """ - _require_cross_tower_context(self) - assert isinstance(destination, six.string_types) - destination = device_util.resolve(destination) - return self._fetch(val, destination, fn) - - def _fetch(self, val, destination, fn): - raise NotImplementedError("must be implemented in descendants") - def unwrap(self, value): """Returns the list of all per-device values contained in `value`. @@ -946,7 +955,7 @@ class DistributionStrategy(object): return control_flow_ops.group(value, name=name) # Special handling for the common case of one op. v, = value - if isinstance(v, ops.Tensor): + if hasattr(v, "op"): v = v.op return v @@ -1094,9 +1103,9 @@ class TowerContext(object): finally: _pop_per_thread_mode() - def tower_local_var_scope(self, reduce_method): + def tower_local_var_scope(self, aggregation): """Alias for distribution_strategy.tower_local_var_scope().""" - return self._distribution_strategy.tower_local_var_scope(reduce_method) + return self._distribution_strategy.tower_local_var_scope(aggregation) @property def is_single_tower(self): @@ -1144,13 +1153,12 @@ class _DefaultDistributionStrategy(DistributionStrategy): def creator(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) - kwargs.pop("tower_local_reduce_method", None) return next_creator(*args, **kwargs) return _CurrentDistributionContext( self, variable_scope.variable_creator_scope(creator)) - def tower_local_var_scope(self, reduce_method): + def tower_local_var_scope(self, aggregation): """Does not set to resource variables.""" def create_tower_local_variable(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) @@ -1180,9 +1188,9 @@ class _DefaultDistributionStrategy(DistributionStrategy): with TowerContext(self, tower_id=0): return fn(*args, **kwargs) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): # TODO(josh11b): Use destinations? - del method_string, destinations + del aggregation, destinations return value def _update(self, var, fn, *args, **kwargs): @@ -1197,11 +1205,8 @@ class _DefaultDistributionStrategy(DistributionStrategy): with ops.colocate_with(colocate_with), UpdateContext(colocate_with): return fn(*args, **kwargs) - def _fetch(self, var, destination, fn): - with ops.colocate_with(var): - var = fn(var) - with ops.device(destination): - return array_ops.identity(var) + def read_var(self, tower_local_var): + return array_ops.identity(tower_local_var) def _unwrap(self, distributed_value): return [distributed_value] diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py index 0a4f19c31f6714e1211f9deed9703c02192cc2c0..694145ede73c1c9121cbc4c4e2d6f61e93165d09 100644 --- a/tensorflow/python/training/distribute_test.py +++ b/tensorflow/python/training/distribute_test.py @@ -29,6 +29,14 @@ class _TestTowerContext(distribute.TowerContext): return kwargs["test_arg"] +def _get_test_variable(name, synchronization, aggregation): + return { + "name": name, + "synchronization": synchronization, + "aggregation": aggregation + } + + class _TestStrategy(distribute.DistributionStrategy): def _call_for_each_tower(self, fn, *args, **kwargs): @@ -36,7 +44,8 @@ class _TestStrategy(distribute.DistributionStrategy): return fn(*args, **kwargs) def _create_variable(self, next_creator, *args, **kwargs): - return kwargs["name"] + return _get_test_variable(kwargs["name"], kwargs["synchronization"], + kwargs["aggregation"]) def _assert_in_default_state(t): @@ -61,7 +70,11 @@ class TestStrategyTest(test.TestCase): self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) - self.assertEqual("bar", variable_scope.variable(1.0, name="bar")) + expected_value = _get_test_variable( + "bar", variable_scope.VariableSynchronization.AUTO, + variable_scope.VariableAggregation.NONE) + self.assertDictEqual(expected_value, + variable_scope.variable(1.0, name="bar")) with self.assertRaises(RuntimeError): dist.call_for_each_tower(run_fn) @@ -77,7 +90,27 @@ class TestStrategyTest(test.TestCase): self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) - self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) + expected_value = _get_test_variable( + "baz", variable_scope.VariableSynchronization.AUTO, + variable_scope.VariableAggregation.NONE) + self.assertDictEqual(expected_value, + variable_scope.variable(1.0, name="baz")) + _assert_in_default_state(self) + + def testSettingSynchronizationAndAggregation(self): + _assert_in_default_state(self) + dist = _TestStrategy() + with dist.scope(): + expected_value = _get_test_variable( + "baz", variable_scope.VariableSynchronization.ON_WRITE, + variable_scope.VariableAggregation.MEAN) + self.assertDictEqual( + expected_value, + variable_scope.variable( + 1.0, + name="baz", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN)) _assert_in_default_state(self) diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py index 6caf29d83af546f821314179e17f7bf1a693ff1a..ef50f6315dd623647e000b9b713d3ae557c31427 100644 --- a/tensorflow/python/training/gradient_descent.py +++ b/tensorflow/python/training/gradient_descent.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -41,6 +40,13 @@ class GradientDescentOptimizer(optimizer.Optimizer): use_locking: If True use locks for update operations. name: Optional name prefix for the operations created when applying gradients. Defaults to "GradientDescent". + + @compatibility(eager) + When eager execution is enabled, `learning_rate` can be a callable that + takes no arguments and returns the actual value to use. This can be useful + for changing these values across different invocations of optimizer + functions. + @end_compatibility """ super(GradientDescentOptimizer, self).__init__(use_locking, name) self._learning_rate = learning_rate @@ -71,6 +77,6 @@ class GradientDescentOptimizer(optimizer.Optimizer): return var.scatter_sub(delta, use_locking=self._use_locking) def _prepare(self): - if not context.executing_eagerly() or self._learning_rate_tensor is None: - self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate, - name="learning_rate") + learning_rate = self._call_if_callable(self._learning_rate) + self._learning_rate_tensor = ops.convert_to_tensor( + learning_rate, name="learning_rate") diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py index 5370cafbcfab6e5ea46685db997989bf6f218a1a..b304e924212c49d84b7c85e01869603b47fc1222 100644 --- a/tensorflow/python/training/gradient_descent_test.py +++ b/tensorflow/python/training/gradient_descent_test.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -80,6 +83,32 @@ class GradientDescentOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval()) + def testBasicCallableParams(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lr = lambda: 3.0 + sgd_op = gradient_descent.GradientDescentOptimizer(lr).apply_gradients( + zip([grads0, grads1], [var0, var1])) + # TODO(apassos) calling initialize_resources on all resources here + # doesn't work because the sessions and graph are reused across unit + # tests and this would mean trying to reinitialize variables. Figure out + # a long-term solution for this. + resources.initialize_resources([var0, var1]).run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + def testMinimizeResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.test_session(): @@ -218,6 +247,26 @@ class GradientDescentOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]], var1.eval()) + def testCapturingInDefunWhileExecutingEagerly(self): + with context.eager_mode(): + optimizer = gradient_descent.GradientDescentOptimizer(1.0) + + def step(): + v = resource_variable_ops.ResourceVariable(1.0) + with backprop.GradientTape() as tape: + loss = v ** 2 + grad = tape.gradient(loss, v) + optimizer.apply_gradients([(grad, v)]) + return v.read_value() + + compiled_step = function.defun(step) + + self.assertEqual(float(step()), -1.0) + self.assertEqual(float(compiled_step()), -1.0) + # This shouldn't fail; in particular, the learning rate tensor should + # be an EagerTensor once again, not a graph Tensor. + self.assertEqual(float(step()), -1.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 10ab4c1137ff226d88902143d4f2281ad77de531..51190264e81ad177c56a6864b616aee52d954c43 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -19,6 +19,7 @@ from __future__ import print_function import math +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -87,6 +88,12 @@ def exponential_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for exponential_decay.") @@ -95,14 +102,22 @@ def exponential_decay(learning_rate, [learning_rate, global_step, decay_steps, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - return math_ops.multiply( - learning_rate, math_ops.pow(decay_rate, p), name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + return math_ops.multiply( + learning_rate, math_ops.pow(decay_rate, p), name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.piecewise_constant") @@ -141,48 +156,62 @@ def piecewise_constant(x, boundaries, values, name=None): ValueError: if types of `x` and `boundaries` do not match, or types of all `values` do not match or the number of elements in the lists does not match. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if len(boundaries) != len(values) - 1: raise ValueError( "The length of boundaries should be 1 less than the length of values") with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: - x = ops.convert_to_tensor(x) - # Avoid explicit conversion to x's dtype. This could result in faulty - # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) - for i, b in enumerate(boundaries): - if b.dtype.base_dtype != x.dtype.base_dtype: - # We can promote int32 boundaries to int64 without loss of precision. - # This covers the most common case where the user passes in boundaries - # as an array of Python integers. - if (b.dtype.base_dtype == dtypes.int32 and - x.dtype.base_dtype == dtypes.int64): - b = math_ops.cast(b, x.dtype.base_dtype) - boundaries[i] = b - else: - raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % - (b.dtype.base_dtype, x.dtype.base_dtype)) - # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) - for v in values[1:]: - if v.dtype.base_dtype != values[0].dtype.base_dtype: - raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % - (values[0].dtype.base_dtype, v.dtype.base_dtype)) - pred_fn_pairs = [] - pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) - pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) - for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): - # Need to bind v here; can do this with lambda v=v: ... - pred = (x > low) & (x <= high) - pred_fn_pairs.append((pred, lambda v=v: v)) - - # The default isn't needed here because our conditions are mutually - # exclusive and exhaustive, but tf.case requires it. - default = lambda: values[0] - return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + x_recomp = ops.convert_to_tensor(x) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + for i, b in enumerate(boundaries): + if b.dtype.base_dtype != x_recomp.dtype.base_dtype: + # We can promote int32 boundaries to int64 without loss of precision. + # This covers the most common case where the user passes in boundaries + # as an array of Python integers. + if (b.dtype.base_dtype == dtypes.int32 and + x_recomp.dtype.base_dtype == dtypes.int64): + b = math_ops.cast(b, x_recomp.dtype.base_dtype) + boundaries[i] = b + else: + raise ValueError( + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) + # TODO(rdipietro): Ensure that boundaries' elements strictly increases. + for v in values[1:]: + if v.dtype.base_dtype != values[0].dtype.base_dtype: + raise ValueError( + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) + pred_fn_pairs = [] + pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) + pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) + for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): + # Need to bind v here; can do this with lambda v=v: ... + pred = (x_recomp > low) & (x_recomp <= high) + pred_fn_pairs.append((pred, lambda v=v: v)) + + # The default isn't needed here because our conditions are mutually + # exclusive and exhaustive, but tf.case requires it. + default = lambda: values[0] + return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.polynomial_decay") @@ -263,6 +292,12 @@ def polynomial_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for polynomial_decay.") @@ -272,27 +307,35 @@ def polynomial_decay(learning_rate, ]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) - decay_steps = math_ops.cast(decay_steps, dtype) end_learning_rate = math_ops.cast(end_learning_rate, dtype) power = math_ops.cast(power, dtype) - if cycle: - # Find the first multiple of decay_steps that is bigger than global_step. - # If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond( - math_ops.equal(global_step, 0), lambda: 1.0, - lambda: math_ops.ceil(global_step / decay_steps)) - decay_steps = math_ops.multiply(decay_steps, multiplier) - else: - # Make sure that the global_step used is not bigger than decay_steps. - global_step = math_ops.minimum(global_step, decay_steps) - - p = math_ops.div(global_step, decay_steps) - return math_ops.add( - math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, - name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + decay_steps_recomp = math_ops.cast(decay_steps, dtype) + if cycle: + # Find the first multiple of decay_steps that is bigger than + # global_step. If global_step is zero set the multiplier to 1 + multiplier = control_flow_ops.cond( + math_ops.equal(global_step_recomp, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step_recomp / decay_steps)) + decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) + else: + # Make sure that the global_step used is not bigger than decay_steps. + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + + p = math_ops.div(global_step_recomp, decay_steps_recomp) + return math_ops.add( + math_ops.multiply(learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.natural_exp_decay") @@ -350,6 +393,12 @@ def natural_exp_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for natural_exp_decay.") @@ -357,14 +406,23 @@ def natural_exp_decay(learning_rate, [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - exponent = math_ops.exp(math_ops.multiply(math_ops.negative(decay_rate), p)) - return math_ops.multiply(learning_rate, exponent, name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + exponent = math_ops.exp( + math_ops.multiply(math_ops.negative(decay_rate), p)) + return math_ops.multiply(learning_rate, exponent, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.inverse_time_decay") @@ -432,6 +490,12 @@ def inverse_time_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for inverse_time_decay.") @@ -439,15 +503,23 @@ def inverse_time_decay(learning_rate, [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - const = math_ops.cast(constant_op.constant(1), learning_rate.dtype) - denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) - return math_ops.div(learning_rate, denom, name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + const = math_ops.cast(constant_op.constant(1), dtype) + denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) + return math_ops.div(learning_rate, denom, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.cosine_decay") @@ -492,6 +564,12 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("cosine decay requires global_step") @@ -499,15 +577,23 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) - global_step = math_ops.minimum(global_step, decay_steps) - completed_fraction = global_step / decay_steps - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha - return math_ops.multiply(learning_rate, decayed) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + completed_fraction = global_step_recomp / decay_steps + cosine_decayed = 0.5 * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + + decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.cosine_decay_restarts") @@ -561,6 +647,12 @@ def cosine_decay_restarts(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("cosine decay restarts requires global_step") @@ -568,40 +660,48 @@ def cosine_decay_restarts(learning_rate, learning_rate = ops.convert_to_tensor( learning_rate, name="initial_learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) first_decay_steps = math_ops.cast(first_decay_steps, dtype) alpha = math_ops.cast(alpha, dtype) t_mul = math_ops.cast(t_mul, dtype) m_mul = math_ops.cast(m_mul, dtype) - completed_fraction = global_step / first_decay_steps + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + completed_fraction = global_step_recomp / first_decay_steps - def compute_step(completed_fraction, geometric=False): - if geometric: - i_restart = math_ops.floor( - math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / - math_ops.log(t_mul)) + def compute_step(completed_fraction, geometric=False): + """Helper for `cond` operation.""" + if geometric: + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) - sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart - else: - i_restart = math_ops.floor(completed_fraction) - completed_fraction = completed_fraction - i_restart + else: + i_restart = math_ops.floor(completed_fraction) + completed_fraction -= i_restart + + return i_restart, completed_fraction - return i_restart, completed_fraction + i_restart, completed_fraction = control_flow_ops.cond( + math_ops.equal(t_mul, 1.0), + lambda: compute_step(completed_fraction, geometric=False), + lambda: compute_step(completed_fraction, geometric=True)) - i_restart, completed_fraction = control_flow_ops.cond( - math_ops.equal(t_mul, 1.0), - lambda: compute_step(completed_fraction, geometric=False), - lambda: compute_step(completed_fraction, geometric=True)) + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + decayed = (1 - alpha) * cosine_decayed + alpha - m_fac = m_mul**i_restart - cosine_decayed = 0.5 * m_fac * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed, name=name) - return math_ops.multiply(learning_rate, decayed, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.linear_cosine_decay") @@ -664,6 +764,12 @@ def linear_cosine_decay(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("linear cosine decay requires global_step") @@ -671,21 +777,28 @@ def linear_cosine_decay(learning_rate, [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) num_periods = math_ops.cast(num_periods, dtype) - global_step = math_ops.minimum(global_step, decay_steps) alpha = math_ops.cast(alpha, dtype) beta = math_ops.cast(beta, dtype) - linear_decayed = (decay_steps - global_step) / decay_steps - completed_fraction = global_step / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + + linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta + return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) - linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta - return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.noisy_linear_cosine_decay") @@ -756,6 +869,12 @@ def noisy_linear_cosine_decay(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("noisy linear cosine decay requires global_step") @@ -763,29 +882,36 @@ def noisy_linear_cosine_decay(learning_rate, [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) - global_step = math_ops.minimum(global_step, decay_steps) initial_variance = math_ops.cast(initial_variance, dtype) variance_decay = math_ops.cast(variance_decay, dtype) num_periods = math_ops.cast(num_periods, dtype) alpha = math_ops.cast(alpha, dtype) beta = math_ops.cast(beta, dtype) - linear_decayed = (decay_steps - global_step) / decay_steps - variance = initial_variance / ( - math_ops.pow(1.0 + global_step, variance_decay)) - std = math_ops.sqrt(variance) - noisy_linear_decayed = ( - linear_decayed + - random_ops.random_normal(linear_decayed.shape, stddev=std)) - - completed_fraction = global_step / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - noisy_linear_cosine_decayed = ( - (alpha + noisy_linear_decayed) * cosine_decayed + beta) - - return math_ops.multiply( - learning_rate, noisy_linear_cosine_decayed, name=name) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + variance = initial_variance / ( + math_ops.pow(1.0 + global_step_recomp, variance_decay)) + std = math_ops.sqrt(variance) + noisy_linear_decayed = ( + linear_decayed + random_ops.random_normal( + linear_decayed.shape, stddev=std)) + + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + noisy_linear_cosine_decayed = ( + (alpha + noisy_linear_decayed) * cosine_decayed + beta) + + return math_ops.multiply( + learning_rate, noisy_linear_cosine_decayed, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index 60306e4f1239a759ea1f68492a1211d5f0858997..4f3cf01822c5b56c8fd05f859c3a1db302a57625 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -21,12 +21,9 @@ from __future__ import print_function import math from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.ops import gen_state_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import learning_rate_decay @@ -34,31 +31,35 @@ from tensorflow.python.training import learning_rate_decay class LRDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testContinuous(self): - with self.test_session(): - step = 5 - decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96) - expected = .05 * 0.96 ** (5.0 / 10.0) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + self.evaluate(variables.global_variables_initializer()) + step = 5 + decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96) + expected = .05 * 0.96**(5.0 / 10.0) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testStaircase(self): - with self.test_session(): - step = gen_state_ops.variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") - assign_100 = state_ops.assign(step, 100) - assign_1 = state_ops.assign(step, 1) - assign_2 = state_ops.assign(step, 2) - decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96, - staircase=True) - # No change to learning rate - assign_1.op.run() - self.assertAllClose(decayed_lr.eval(), .1, 1e-6) - assign_2.op.run() - self.assertAllClose(decayed_lr.eval(), .1, 1e-6) + if context.executing_eagerly(): + step = resource_variable_ops.ResourceVariable(0) + self.evaluate(variables.global_variables_initializer()) + decayed_lr = learning_rate_decay.exponential_decay( + .1, step, 3, 0.96, staircase=True) + + # No change to learning rate due to staircase + expected = .1 + self.evaluate(step.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + expected = .1 + self.evaluate(step.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6) + # Decayed learning rate - assign_100.op.run() expected = .1 * 0.96 ** (100 // 3) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + self.evaluate(step.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) def testVariables(self): with self.test_session(): @@ -79,38 +80,44 @@ class LRDecayTest(test_util.TensorFlowTestCase): expected = .1 * 0.96 ** (100 // 3) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPiecewiseConstant(self): x = resource_variable_ops.ResourceVariable(-999) - def pc(): - return learning_rate_decay.piecewise_constant(x, [100, 110, 120], - [1.0, 0.1, 0.01, 0.001]) + decayed_lr = learning_rate_decay.piecewise_constant( + x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001]) self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6) self.evaluate(x.assign(100)) - self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6) self.evaluate(x.assign(105)) - self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6) self.evaluate(x.assign(110)) - self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6) self.evaluate(x.assign(120)) - self.assertAllClose(self.evaluate(pc()), 0.01, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.01, 1e-6) self.evaluate(x.assign(999)) - self.assertAllClose(self.evaluate(pc()), 0.001, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.001, 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPiecewiseConstantEdgeCases(self): x_int = resource_variable_ops.ResourceVariable( 0, dtype=variables.dtypes.int32) boundaries, values = [-1.0, 1.0], [1, 2, 3] with self.assertRaises(ValueError): - learning_rate_decay.piecewise_constant(x_int, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x_int, boundaries, values) + if context.executing_eagerly(): + decayed_lr() + x = resource_variable_ops.ResourceVariable(0.0) boundaries, values = [-1.0, 1.0], [1.0, 2, 3] with self.assertRaises(ValueError): - learning_rate_decay.piecewise_constant(x, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x, boundaries, values) + if context.executing_eagerly(): + decayed_lr() # Test that ref types are valid. if not context.executing_eagerly(): @@ -123,221 +130,205 @@ class LRDecayTest(test_util.TensorFlowTestCase): x_int64 = resource_variable_ops.ResourceVariable( 0, dtype=variables.dtypes.int64) boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] - def pc(): - return learning_rate_decay.piecewise_constant(x_int64, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x_int64, boundaries, values) self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6) self.evaluate(x_int64.assign(1)) - self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6) self.evaluate(x_int64.assign(2)) - self.assertAllClose(self.evaluate(pc()), 0.5, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.5, 1e-6) self.evaluate(x_int64.assign(3)) - self.assertAllClose(self.evaluate(pc()), 0.6, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.6, 1e-6) self.evaluate(x_int64.assign(4)) - self.assertAllClose(self.evaluate(pc()), 0.7, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.7, 1e-6) class LinearDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testHalfWay(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.0 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = lr * 0.5 - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.0 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = lr * 0.5 + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testEnd(self): - with self.test_session(): - step = 10 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 10 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testHalfWayWithEnd(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = (lr + end_lr) * 0.5 - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = (lr + end_lr) * 0.5 + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEnd(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEndWithCycle(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - cycle=True) - expected = (lr - end_lr) * 0.25 + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, cycle=True) + expected = (lr - end_lr) * 0.25 + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class SqrtDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testHalfWay(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.0 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = lr * 0.5 ** power - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.0 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = lr * 0.5**power + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testEnd(self): - with self.test_session(): - step = 10 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 10 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testHalfWayWithEnd(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = (lr - end_lr) * 0.5 ** power + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = (lr - end_lr) * 0.5**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEnd(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEndWithCycle(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power, cycle=True) - expected = (lr - end_lr) * 0.25 ** power + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power, cycle=True) + expected = (lr - end_lr) * 0.25**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class PolynomialDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testBeginWithCycle(self): - with self.test_session(): - lr = 0.001 - decay_steps = 10 - step = 0 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, - decay_steps, cycle=True) - expected = lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + lr = 0.001 + decay_steps = 10 + step = 0 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, decay_steps, cycle=True) + expected = lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class ExponentialDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, - k, decay_rate) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr * math.exp(-i / k * decay_rate) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, k, + decay_rate) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) + @test_util.run_in_graph_and_eager_modes def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, - step, - k, - decay_rate, - staircase=True) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr * math.exp(-decay_rate * (i // k)) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.natural_exp_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) class InverseDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, - step, - k, + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, step, k, decay_rate) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr / (1 + i / k * decay_rate) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_util.run_in_graph_and_eager_modes def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, - step, - k, - decay_rate, - staircase=True) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr / (1 + decay_rate * (i // k)) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.inverse_time_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) class CosineDecayTest(test_util.TensorFlowTestCase): @@ -348,34 +339,35 @@ class CosineDecayTest(test_util.TensorFlowTestCase): decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) return (1.0 - alpha) * decay + alpha + @test_util.run_in_graph_and_eager_modes def testDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step, + num_training_steps) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testAlpha(self): num_training_steps = 1000 initial_lr = 1.0 alpha = 0.1 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay( - initial_lr, step, num_training_steps, alpha) - expected = self.np_cosine_decay(step, num_training_steps, alpha) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step, + num_training_steps, alpha) + expected = self.np_cosine_decay(step, num_training_steps, alpha) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class CosineDecayRestartsTest(test_util.TensorFlowTestCase): + def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0): fac = 1.0 while step >= decay_steps: - step = step - decay_steps + step -= decay_steps decay_steps *= t_mul fac *= m_mul @@ -383,51 +375,51 @@ class CosineDecayRestartsTest(test_util.TensorFlowTestCase): decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) return (1.0 - alpha) * decay + alpha + @test_util.run_in_graph_and_eager_modes def testDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay_restarts(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testAlpha(self): num_training_steps = 1000 initial_lr = 1.0 alpha = 0.1 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, alpha=alpha) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - alpha=alpha) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, alpha=alpha) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, alpha=alpha) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testMMul(self): num_training_steps = 1000 initial_lr = 1.0 m_mul = 0.9 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, m_mul=m_mul) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - m_mul=m_mul) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, m_mul=m_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, m_mul=m_mul) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testTMul(self): num_training_steps = 1000 initial_lr = 1.0 t_mul = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, t_mul=t_mul) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - t_mul=t_mul) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, t_mul=t_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, t_mul=t_mul) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class LinearCosineDecayTest(test_util.TensorFlowTestCase): @@ -444,65 +436,63 @@ class LinearCosineDecayTest(test_util.TensorFlowTestCase): cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction)) return (alpha + linear_decayed) * cosine_decayed + beta + @test_util.run_in_graph_and_eager_modes def testDefaultDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.linear_cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_linear_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.linear_cosine_decay( + initial_lr, step, num_training_steps) + expected = self.np_linear_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testNonDefaultDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.linear_cosine_decay( - initial_lr, - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - expected = self.np_linear_cosine_decay( - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.linear_cosine_decay( + initial_lr, + step, + num_training_steps, + alpha=0.1, + beta=1e-4, + num_periods=5) + expected = self.np_linear_cosine_decay( + step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class NoisyLinearCosineDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDefaultNoisyLinearCosine(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - # No numerical check because of noise - decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( - initial_lr, step, num_training_steps) - decayed_lr.eval() + # No numerical check because of noise + decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( + initial_lr, step, num_training_steps) + # Cannot be deterministically tested + self.evaluate(decayed_lr) + @test_util.run_in_graph_and_eager_modes def testNonDefaultNoisyLinearCosine(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - # No numerical check because of noise - decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( - initial_lr, - step, - num_training_steps, - initial_variance=0.5, - variance_decay=0.1, - alpha=0.1, - beta=1e-4, - num_periods=5) - decayed_lr.eval() + # No numerical check because of noise + decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( + initial_lr, + step, + num_training_steps, + initial_variance=0.5, + variance_decay=0.1, + alpha=0.1, + beta=1e-4, + num_periods=5) + # Cannot be deterministically tested + self.evaluate(decayed_lr) if __name__ == "__main__": diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py index bd9fa79d8feac68c149f787ee8501bdddb173d33..cb3ec6f053e2e7f5aa80152ed233c8fbb6920be0 100644 --- a/tensorflow/python/training/momentum.py +++ b/tensorflow/python/training/momentum.py @@ -61,8 +61,8 @@ class MomentumOptimizer(optimizer.Optimizer): variable(s) track the values called `theta_t + mu*v_t` in the paper. @compatibility(eager) - When eager execution is enabled, learning_rate and momentum can each be a - callable that takes no arguments and returns the actual value to use. This + When eager execution is enabled, `learning_rate` and `momentum` can each be + a callable that takes no arguments and returns the actual value to use. This can be useful for changing these values across different invocations of optimizer functions. @end_compatibility diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 7bd57ad3d854534e196fa7b72bebbd7195e6bca8..f7e78071d8b15500ba607a5cd16aefcaf9d2abfe 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase): with context.eager_mode(): self.doTestBasic(use_resource=True, use_callable_params=True) - @test_util.run_in_graph_and_eager_modes(reset_test=True) def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): @@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var0 + var1) - else: - loss = math_ops.reduce_sum(var0 + var1) + loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") @@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var2 + var3) - else: - loss = math_ops.reduce_sum(var2 + var3) + loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index fece3370f343173de46bc447c478264864708dca..7b06bffa4b29b92dd8d3df5d8eaa6ebec1ea44b1 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -298,7 +298,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, - save_checkpoint_steps=USE_DEFAULT): + save_checkpoint_steps=USE_DEFAULT, + summary_dir=None): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -348,6 +349,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then the default checkpoint saver isn't used. If both are provided, then only `save_checkpoint_secs` is used. Default not enabled. + summary_dir: A string. Optional path to a directory where to + save summaries. If None, checkpoint_dir is used instead. Returns: A `MonitoredSession` object. @@ -388,11 +391,12 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name master=master, config=config) - if checkpoint_dir: + summary_dir = summary_dir or checkpoint_dir + if summary_dir: if log_step_count_steps and log_step_count_steps > 0: all_hooks.append( basic_session_run_hooks.StepCounterHook( - output_dir=checkpoint_dir, every_n_steps=log_step_count_steps)) + output_dir=summary_dir, every_n_steps=log_step_count_steps)) if (save_summaries_steps and save_summaries_steps > 0) or ( save_summaries_secs and save_summaries_secs > 0): @@ -400,7 +404,9 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name scaffold=scaffold, save_steps=save_summaries_steps, save_secs=save_summaries_secs, - output_dir=checkpoint_dir)) + output_dir=summary_dir)) + + if checkpoint_dir: if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( save_checkpoint_steps and save_checkpoint_steps > 0): all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 61fc828a840c490b0f787119134a0941f60f947a..60cc54c2645a0f44195bbb86013e0306387aa8aa 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -344,6 +344,11 @@ class ExponentialMovingAverage(object): self._name = name self._averages = {} + @property + def name(self): + """The name of this ExponentialMovingAverage object.""" + return self._name + def apply(self, var_list=None): """Maintains moving averages of variables. @@ -394,7 +399,7 @@ class ExponentialMovingAverage(object): if isinstance(var, variables.Variable): avg = slot_creator.create_slot(var, var.initialized_value(), - self._name, + self.name, colocate_with_primary=True) # NOTE(mrry): We only add `tf.Variable` objects to the # `MOVING_AVERAGE_VARIABLES` collection. @@ -402,7 +407,7 @@ class ExponentialMovingAverage(object): else: avg = slot_creator.create_zeros_slot( var, - self._name, + self.name, colocate_with_primary=(var.op.type in ["Variable", "VariableV2", "VarHandleOp"])) @@ -410,7 +415,7 @@ class ExponentialMovingAverage(object): zero_debias_true.add(avg) self._averages[var] = avg - with ops.name_scope(self._name) as scope: + with ops.name_scope(self.name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") if self._num_updates is not None: num_updates = math_ops.cast(self._num_updates, @@ -462,7 +467,7 @@ class ExponentialMovingAverage(object): if var in self._averages: return self._averages[var].op.name return ops.get_default_graph().unique_name( - var.op.name + "/" + self._name, mark_as_used=False) + var.op.name + "/" + self.name, mark_as_used=False) def variables_to_restore(self, moving_avg_variables=None): """Returns a map of names to `Variables` to restore. diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 6717811bbb0f05723a5ad0fbcbfba75249d0d43b..3e85e6bfa7b20da061d36ab0fba4913e402f8f0c 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -263,6 +263,7 @@ class ExponentialMovingAverageTest(test.TestCase): tensor2 = v0 + v1 ema = moving_averages.ExponentialMovingAverage( 0.25, zero_debias=zero_debias, name="foo") + self.assertEqual("foo", ema.name) self.assertEqual("v0/foo", ema.average_name(v0)) self.assertEqual("v1/foo", ema.average_name(v1)) self.assertEqual("add/foo", ema.average_name(tensor2)) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index a9287a0f0d0391cc6e0b297cce18eebaf9f64291..971ed5c8b5ed3bd78b0d467e5c3fa4b7a72c96a1 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -461,7 +461,8 @@ class Optimizer( # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. - if distribute_lib.get_loss_reduction() == "mean": + if (distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss_value *= (1. / num_towers) @@ -478,7 +479,8 @@ class Optimizer( "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. - if distribute_lib.get_loss_reduction() == "mean": + if (distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss *= (1. / num_towers) @@ -649,7 +651,8 @@ class Optimizer( towers. If `global_step` was not None, that operation also increments `global_step`. """ - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) # Note that this is called in a cross-tower context. @@ -730,15 +733,15 @@ class Optimizer( if not named_slots: return None - if hasattr(var, "_mirrored_container"): + if hasattr(var, "_distributed_container"): # NOTE: If this isn't patched, then there is no `handle` in # `_resource_apply_dense`. - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - key = mirrored_container._unique_id + key = distributed_container._unique_id else: - key = (mirrored_container.graph, mirrored_container._shared_name) + key = (distributed_container.graph, distributed_container._shared_name) # pylint: enable=protected-access mirrored_slot = named_slots.get(key, None) if mirrored_slot is None: return None @@ -839,7 +842,7 @@ class Optimizer( def _get_non_slot_variable(self, name, graph=None): non_slot = self._non_slot_dict.get((name, graph), None) - if hasattr(non_slot, "_mirrored_container"): + if hasattr(non_slot, "_distributed_container"): # This is a mirrored non-slot. In order to enable code like `_finish` # to assign to a non-slot, return the current context replica. return non_slot.get() @@ -1211,3 +1214,7 @@ class Optimizer( self._deferred_slot_restorations.setdefault( slot_name, {}).setdefault(variable_key, []).append( slot_variable_position) + + def _call_if_callable(self, param): + """Call the function if param is callable.""" + return param() if callable(param) else param diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 0cab6410e83ca1880a0a4a80d2cfa5c17517af95..dfe9176beaf27f3cfa945eee8693ba7c5e9551fa 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -34,7 +34,7 @@ from tensorflow.python.training import gradient_descent class OptimizerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -112,7 +112,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoVariables(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: # pylint: disable=cell-var-from-loop @@ -127,7 +127,7 @@ class OptimizerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'No.*variables'): sgd_op.minimize(loss) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -145,7 +145,7 @@ class OptimizerTest(test.TestCase): # var1 has no gradient sgd_op.minimize(loss, var_list=[var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -161,7 +161,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.minimize(loss, var_list=[var0, var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -175,7 +175,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.apply_gradients([(None, var0), (None, var1)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -215,7 +215,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): x = ops.convert_to_tensor(1.0) def f(): diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py index 341b970c92e42b4fe392d91f57219d713d2513e5..f38c9861d64aa258cde07ccd3041d3c50932c33b 100644 --- a/tensorflow/python/training/rmsprop.py +++ b/tensorflow/python/training/rmsprop.py @@ -92,6 +92,13 @@ class RMSPropOptimizer(optimizer.Optimizer): computation and memory. Defaults to False. name: Optional name prefix for the operations created when applying gradients. Defaults to "RMSProp". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and + `epsilon` can each be a callable that takes no arguments and returns the + actual value to use. This can be useful for changing these values across + different invocations of optimizer functions. + @end_compatibility """ super(RMSPropOptimizer, self).__init__(use_locking, name) self._learning_rate = learning_rate @@ -120,12 +127,15 @@ class RMSPropOptimizer(optimizer.Optimizer): self._zeros_slot(v, "momentum", self._name) def _prepare(self): - self._learning_rate_tensor = ops.convert_to_tensor( - self._learning_rate, name="learning_rate") - self._decay_tensor = ops.convert_to_tensor(self._decay, name="decay") - self._momentum_tensor = ops.convert_to_tensor( - self._momentum, name="momentum") - self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon") + lr = self._call_if_callable(self._learning_rate) + decay = self._call_if_callable(self._decay) + momentum = self._call_if_callable(self._momentum) + epsilon = self._call_if_callable(self._epsilon) + + self._learning_rate_tensor = ops.convert_to_tensor(lr, name="learning_rate") + self._decay_tensor = ops.convert_to_tensor(decay, name="decay") + self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum") + self._epsilon_tensor = ops.convert_to_tensor(epsilon, name="epsilon") def _apply_dense(self, grad, var): rms = self.get_slot(var, "rms") diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py index ee5385596c8b11e607969f94153f7e4f5d2d4cdd..604332738456bfc8b3ff24242f6032bf95273072 100644 --- a/tensorflow/python/training/rmsprop_test.py +++ b/tensorflow/python/training/rmsprop_test.py @@ -24,6 +24,7 @@ import math import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -141,7 +142,7 @@ class RMSPropOptimizerTest(test.TestCase): self.assertAllClose([3.0, 4.0], var1.eval()) # Run 4 steps of RMSProp - for t in range(1, 5): + for _ in range(1, 5): update.run() var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( @@ -261,7 +262,7 @@ class RMSPropOptimizerTest(test.TestCase): self.assertAllClose([3.0, 4.0], var1.eval()) # Run 4 steps of RMSProp - for t in range(1, 5): + for _ in range(1, 5): update.run() var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( @@ -444,6 +445,55 @@ class RMSPropOptimizerTest(test.TestCase): (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) ]), var1.eval()) + def testCallableParams(self): + with context.eager_mode(): + for dtype in [dtypes.half, dtypes.float32]: + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + + learning_rate = lambda: 2.0 + decay = lambda: 0.9 + momentum = lambda: 0.0 + epsilon = lambda: 1.0 + opt = rmsprop.RMSPropOptimizer(learning_rate, decay, momentum, epsilon) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Step 1: the rms accumulators where 1. So we should see a normal + # update: v -= grad * learning_rate + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) + ]), self.evaluate(var0)) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) + ]), self.evaluate(var1)) + # Step 2: the root mean square accumulators contain the previous update. + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) + ]), self.evaluate(var0)) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) + ]), self.evaluate(var1)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index e46c7f141397c6fd038896bff14605ad047cf57d..1ee975fbe48e8ba724d8f40040b122c5c02aa352 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -22,7 +22,6 @@ from __future__ import print_function import collections import os.path import re -import sys import time import uuid @@ -206,21 +205,19 @@ class BaseSaverBuilder(object): filename_tensor: String Tensor. saveables: List of BaseSaverBuilder.SaveableObject objects. preferred_shard: Int. Shard to open first when loading a sharded file. - restore_sequentially: Bool. If true, each restore is sequential. + restore_sequentially: Unused. Bool. If true, each restore is sequential. Returns: A list of Tensors resulting from reading 'saveable' from 'filename'. """ + del restore_sequentially all_tensors = [] - assign_ops = [] for saveable in saveables: - restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] with ops.device(_set_cpu0(saveable.device) if saveable.device else None): - with ops.control_dependencies(restore_control_inputs): - all_tensors.extend( - self.restore_op(filename_tensor, saveable, preferred_shard)) + all_tensors.extend( + self.restore_op(filename_tensor, saveable, preferred_shard)) return all_tensors # pylint: disable=unused-argument @@ -1045,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): @@ -1373,23 +1370,6 @@ class Saver(object): name, _ = p return name - def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"): - """Returns the meta graph filename. - - Args: - checkpoint_filename: Name of the checkpoint file. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. - - Returns: - MetaGraph file name. - """ - # If the checkpoint_filename is sharded, the checkpoint_filename could - # be of format model.ckpt-step#-?????-of-shard#. For example, - # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. - basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) - meta_graph_filename = ".".join([basename, meta_graph_suffix]) - return meta_graph_filename - def _RecordLastCheckpoint(self, latest_save_path): """Manages the list of the latest checkpoints.""" if not self.saver_def.max_to_keep: @@ -1430,24 +1410,12 @@ class Saver(object): # Otherwise delete the files. try: - checkpoint_prefix = self._CheckpointFilename(p) - self._delete_file_if_exists( - self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix)) - if self.saver_def.version == saver_pb2.SaverDef.V2: - # V2 has a metadata file and some data files. - self._delete_file_if_exists(checkpoint_prefix + ".index") - self._delete_file_if_exists(checkpoint_prefix + - ".data-?????-of-?????") - else: - # V1, Legacy. Exact match on the data file. - self._delete_file_if_exists(checkpoint_prefix) + remove_checkpoint( + self._CheckpointFilename(p), self.saver_def.version, + meta_graph_suffix) except Exception as e: # pylint: disable=broad-except logging.warning("Ignoring: %s", str(e)) - def _delete_file_if_exists(self, filespec): - for pathname in file_io.get_matching_files(filespec): - file_io.delete_file(pathname) - def as_saver_def(self): """Generates a `SaverDef` representation of this saver. @@ -1669,7 +1637,7 @@ class Saver(object): raise exc if write_meta_graph: - meta_graph_filename = self._MetaGraphFilename( + meta_graph_filename = _meta_graph_filename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) if not context.executing_eagerly(): with sess.graph.as_default(): @@ -1737,12 +1705,17 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If save_path is None. + ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") + + if not checkpoint_exists(compat.as_text(save_path)): + raise ValueError("The passed save_path is not a valid checkpoint: " + + compat.as_text(save_path)) + logging.info("Restoring parameters from %s", compat.as_text(save_path)) try: if context.executing_eagerly(): @@ -1750,19 +1723,24 @@ class Saver(object): else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - except errors.NotFoundError: - exception_type, exception_value, exception_traceback = sys.exc_info() - # The checkpoint would not be loaded successfully as is. Try to parse it - # as an object-based checkpoint. + except errors.NotFoundError as err: + # There are three common conditions that might cause this error: + # 0. The file is missing. We ignore here, as this is checked above. + # 1. This is an object-based checkpoint trying name-based loading. + # 2. The graph has been altered and a variable or other name is missing. + + # 1. The checkpoint would not be loaded successfully as is. Try to parse + # it as an object-based checkpoint. try: reader = pywrap_tensorflow.NewCheckpointReader(save_path) object_graph_string = reader.get_tensor( checkpointable.OBJECT_GRAPH_PROTO_KEY) except errors.NotFoundError: - # This is not an object-based checkpoint, or the checkpoint doesn't - # exist. Re-raise the original exception. - six.reraise(exception_type, exception_value, exception_traceback) - del exception_traceback # avoid reference cycles + # 2. This is not an object-based checkpoint, which likely means there + # is a graph mismatch. Re-raise the original error with + # a helpful message (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. @@ -1774,6 +1752,11 @@ class Saver(object): self._restore_from_object_based_checkpoint( sess=sess, save_path=save_path, object_graph_string=object_graph_string) + except errors.InvalidArgumentError as err: + # There is a mismatch between the graph and the checkpoint being loaded. + # We add a more reasonable error message here to help users (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a mismatch between the current graph and the graph") def _restore_from_object_based_checkpoint(self, sess, save_path, object_graph_string): @@ -1966,7 +1949,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, return Saver(saver_def=meta_graph_def.saver_def, name=scope) else: - if variables._all_saveable_objects(): # pylint: disable=protected-access + if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access # Return the default saver instance for all graph variables. return Saver() else: @@ -2117,6 +2100,63 @@ def get_checkpoint_mtimes(checkpoint_prefixes): return mtimes +@tf_export("train.remove_checkpoint") +def remove_checkpoint(checkpoint_prefix, + checkpoint_format_version=saver_pb2.SaverDef.V2, + meta_graph_suffix="meta"): + """Removes a checkpoint given by `checkpoint_prefix`. + + Args: + checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result + of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of + sharded/non-sharded or V1/V2. + checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to + `SaverDef.V2`. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + """ + _delete_file_if_exists( + _meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) + if checkpoint_format_version == saver_pb2.SaverDef.V2: + # V2 has a metadata file and some data files. + _delete_file_if_exists(checkpoint_prefix + ".index") + _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") + else: + # V1, Legacy. Exact match on the data file. + _delete_file_if_exists(checkpoint_prefix) + + +def _delete_file_if_exists(filespec): + """Deletes files matching `filespec`.""" + for pathname in file_io.get_matching_files(filespec): + file_io.delete_file(pathname) + + +def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): + """Returns the meta graph filename. + + Args: + checkpoint_filename: Name of the checkpoint file. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + + Returns: + MetaGraph file name. + """ + # If the checkpoint_filename is sharded, the checkpoint_filename could + # be of format model.ckpt-step#-?????-of-shard#. For example, + # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. + basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) + meta_graph_filename = ".".join([basename, meta_graph_suffix]) + return meta_graph_filename + + +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed. This is most likely " + "due to {} from the checkpoint. Please ensure that you " + "have not altered the graph expected based on the checkpoint. " + "Original error:\n\n{}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) + + ops.register_proto_function( ops.GraphKeys.SAVERS, proto_type=saver_pb2.SaverDef, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index dd5174f17c27b0d4d7ae5ae6d3b365b1719a49fa..ae9c244aaf372dcbcf365cf3e6a21ae77d9ae7d0 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -51,8 +51,8 @@ from tensorflow.python.framework import graph_io from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -77,7 +77,8 @@ from tensorflow.python.training import saver as saver_module from tensorflow.python.training import saver_test_utils from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable_base +from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat @@ -169,7 +170,7 @@ class SaverTest(test.TestCase): def testBasic(self): self.basicSaveRestore(variables.Variable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) @@ -250,7 +251,7 @@ class SaverTest(test.TestCase): self.assertAllEqual(w3.eval(), 3.0) self.assertAllEqual(w4.eval(), 4.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceSaveRestoreCachingDevice(self): save_path = os.path.join(self.get_temp_dir(), "resource_cache") with self.test_session(graph=ops_lib.Graph()) as sess: @@ -366,8 +367,8 @@ class SaverTest(test.TestCase): for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): with self.test_session() as sess: save = saver_module.Saver({"v0": v0}, write_version=ver) - with self.assertRaisesRegexp(errors.NotFoundError, - "Failed to find any matching files for"): + with self.assertRaisesRegexp( + ValueError, "The passed save_path is not a valid checkpoint:"): save.restore(sess, "invalid path") def testInt64(self): @@ -669,7 +670,7 @@ class SaverTest(test.TestCase): save.restore(sess, save_path) self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveWithGlobalStep(self, pad_step_number=False): save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") global_step_int = 5 @@ -807,7 +808,7 @@ class SaveRestoreShardedTest(test.TestCase): self.assertEqual(save_path + "-?????-of-00002", val) else: self.assertEqual(save_path, val) - meta_graph_filename = save._MetaGraphFilename(val) + meta_graph_filename = saver_module._meta_graph_filename(val) self.assertEqual(save_path + ".meta", meta_graph_filename) if save._write_version is saver_pb2.SaverDef.V1: @@ -1183,13 +1184,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s3, s2], save.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s1)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertTrue(saver_module.checkpoint_exists(s3)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1200,13 +1201,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s2, s1], save.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1220,14 +1221,14 @@ class MaxToKeepTest(test.TestCase): # Created by the first helper. self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) # Deleted by the first helper. self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1238,13 +1239,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s2, s1], save2.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1258,14 +1259,14 @@ class MaxToKeepTest(test.TestCase): # Created by the first helper. self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) # Deleted by the first helper. self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) # Even though the file for s1 exists, this saver isn't aware of it, which # is why it doesn't end up in the checkpoint state. self.assertCheckpointState( @@ -1278,13 +1279,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s2, s1], save3.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1315,7 +1316,7 @@ class MaxToKeepTest(test.TestCase): else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) @@ -1323,27 +1324,27 @@ class MaxToKeepTest(test.TestCase): self.assertEqual(2, len(gfile.Glob(s1))) else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertEqual(0, len(gfile.Glob(s1 + "*"))) - self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s3))) else: self.assertEqual(4, len(gfile.Glob(s3 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3))) def testNoMaxToKeep(self): save_dir = self._get_test_dir("no_max_to_keep") @@ -1383,7 +1384,7 @@ class MaxToKeepTest(test.TestCase): s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) class KeepCheckpointEveryNHoursTest(test.TestCase): @@ -1393,7 +1394,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): gfile.MakeDirs(test_dir) return test_dir - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes @test.mock.patch.object(saver_module, "time") def testNonSharded(self, mock_time): save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") @@ -1513,7 +1514,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase): self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNonReshapeResourceVariable(self): self._testNonReshape(resource_variable_ops.ResourceVariable) @@ -2337,6 +2338,46 @@ class MetaGraphTest(test.TestCase): 10, size=[1, 10]) }) + def testImportIntoNamescopeWithoutVariables(self): + # Save a simple graph that contains no variables into a checkpoint. + test_dir = self._get_test_dir("no_vars_graph") + filename = os.path.join(test_dir, "ckpt") + graph_1 = ops_lib.Graph() + with session.Session(graph=graph_1) as sess: + constant_op.constant([1, 2, 3], name="x") + constant_op.constant([1, 2, 3], name="y") + saver = saver_module.Saver(allow_empty=True) + saver.save(sess, filename) + + # Create a fresh graph. + graph_2 = ops_lib.Graph() + with session.Session(graph=graph_2) as sess: + # Restore the above checkpoint under scope "subgraph_1". + new_saver_1 = saver_module.import_meta_graph( + filename + ".meta", graph=graph_2, import_scope="subgraph_1") + # There are no variables to restore, so import_meta_graph should not + # return a Saver. + self.assertIsNone(new_saver_1) + + # Create a variable in graph_2 under scope "my_scope". + variables.Variable(array_ops.zeros([10]), name="my_scope/my_var") + sess.run(variables.global_variables_initializer()) + # Restore the checkpoint into a different scope "subgraph_2". + new_saver_2 = saver_module.import_meta_graph( + filename + ".meta", graph=graph_2, import_scope="subgraph_2") + # Because the variable does not live in scope "subgraph_2", + # import_meta_graph should not attempt to restore the variable. So, + # import_meta_graph still won't return a Saver instance. + self.assertIsNone(new_saver_2) + + # However, if we restore the checkpoint under scope "my_scope", + # import_meta_graph will detect the variable and return a Saver for + # restoring it. This should happen even when the variable does not + # originate from graph_1. + new_saver_3 = saver_module.import_meta_graph( + filename + ".meta", graph=graph_2, import_scope="my_scope") + self.assertIsInstance(new_saver_3, saver_module.Saver) + def testImportIntoImplicitNamescope(self): # Test that we can import a meta graph into an implicit namescope. test_dir = self._get_test_dir("import_into_namescope") @@ -2579,6 +2620,20 @@ class SaverUtilsTest(test.TestCase): self.assertEqual(2, len(mtimes)) self.assertTrue(mtimes[1] >= mtimes[0]) + def testRemoveCheckpoint(self): + for sharded in (False, True): + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(sharded=sharded, write_version=version) + + path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) + ckpt_prefix = saver.save(sess, path) + self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) + saver_module.remove_checkpoint(ckpt_prefix, version) + self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix)) + class ScopedGraphTest(test.TestCase): @@ -2883,7 +2938,7 @@ class ScopedGraphTest(test.TestCase): self.assertEqual(2.0, var_dict2["variable2:0"].eval()) -class _OwnsAVariableSimple(checkpointable.CheckpointableBase): +class _OwnsAVariableSimple(checkpointable_base.CheckpointableBase): """A Checkpointable object which can be saved using a tf.train.Saver.""" def __init__(self): @@ -2891,7 +2946,7 @@ class _OwnsAVariableSimple(checkpointable.CheckpointableBase): name="non_dep_variable", initializer=6., use_resource=True) def _gather_saveables_for_checkpoint(self): - return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable} + return {checkpointable_base.VARIABLE_VALUE_KEY: self.non_dep_variable} # The Saver sorts by name before parsing, so we need a name property. @property @@ -2916,7 +2971,7 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): +class _OwnsMirroredVariables(checkpointable_base.CheckpointableBase): """A Checkpointable object which returns a more complex SaveableObject.""" def __init__(self): @@ -2931,7 +2986,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): primary_variable=self.non_dep_variable, mirrored_variable=self.mirrored, name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + return {checkpointable_base.VARIABLE_VALUE_KEY: _saveable_factory} # The Saver sorts by name before parsing, so we need a name property. @property @@ -2939,7 +2994,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): return self.non_dep_variable.name -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(checkpointable_tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -2965,7 +3020,7 @@ class MyModel(training.Model): class CheckpointableCompatibilityTests(test.TestCase): # TODO(allenl): Track down python3 reference cycles in these tests. - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNotSaveableButIsCheckpointable(self): v = _OwnsAVariableSimple() saver = saver_module.Saver(var_list=[v]) @@ -2978,7 +3033,7 @@ class CheckpointableCompatibilityTests(test.TestCase): saver.restore(sess, save_path) self.assertEqual(42., self.evaluate(v.non_dep_variable)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() saver = saver_module.Saver(var_list=[v]) @@ -3082,17 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase): errors.NotFoundError, "Key b not found in checkpoint"): b_saver.restore(sess=sess, save_path=save_path) - def testCheckpointNotFoundErrorRaised(self): - # Restore does some tricky exception handling to figure out if it should - # load an object-based checkpoint. Tests that the exception handling isn't - # too broad. - a = resource_variable_ops.ResourceVariable(1., name="a") - saver = saver_module.Saver([a]) - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.NotFoundError, - "Failed to find any matching files for path_which_does_not_exist"): - saver.restore(sess=sess, save_path="path_which_does_not_exist") + with self.assertRaises(errors.NotFoundError) as cs: + b_saver.restore(sess=sess, save_path=save_path) + + # Make sure we don't have a confusing "During handling of the above + # exception" block in Python 3. + self.assertNotIn("NewCheckpointReader", cs.exception.message) + + def testGraphChangedForRestoreErrorRaised(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + with ops_lib.Graph().as_default() as g: + a = variables.Variable(1., name="a") + a_saver = saver_module.Saver([a]) + + with self.test_session(graph=g) as sess: + sess.run(a.initializer) + save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) + + with ops_lib.Graph().as_default() as g: + a = variables.Variable([1.], name="a") + a_saver = saver_module.Saver([a]) + with self.test_session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "a mismatch between the current graph and the graph"): + a_saver.restore(sess=sess, save_path=save_path) def testLoadFromObjectBasedGraph(self): checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index 3cb3877cc2fb846477640058b5bcdfea97ee6828..974f75777f43ab4ef3be2edea564d1ad902e4fd5 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -95,7 +95,8 @@ class SessionManager(object): ready_op=None, ready_for_local_init_op=None, graph=None, - recovery_wait_secs=30): + recovery_wait_secs=30, + local_init_run_options=None): """Creates a SessionManager. The `local_init_op` is an `Operation` that is run always after a new session @@ -127,6 +128,8 @@ class SessionManager(object): to run local_init_op. graph: The `Graph` that the model will use. recovery_wait_secs: Seconds between checks for the model to be ready. + local_init_run_options: RunOptions to be passed to session.run when + executing the local_init_op. Raises: ValueError: If ready_for_local_init_op is not None but local_init_op is @@ -141,6 +144,7 @@ class SessionManager(object): self._graph = graph self._recovery_wait_secs = recovery_wait_secs self._target = None + self._local_init_run_options = local_init_run_options if ready_for_local_init_op is not None and local_init_op is None: raise ValueError("If you pass a ready_for_local_init_op " "you must also pass a local_init_op " @@ -485,7 +489,7 @@ class SessionManager(object): is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) if is_ready_for_local_init: logging.info("Running local_init_op.") - sess.run(self._local_init_op) + sess.run(self._local_init_op, options=self._local_init_run_options) logging.info("Done running local_init_op.") return True, None else: diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 7389e344c7d8eef8e26c4d24c0985ff66276deea..372ea415df0ee299ebb51b2369c1027eb2db4865 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -225,7 +225,8 @@ class Supervisor(object): checkpoint_basename="model.ckpt", session_manager=None, summary_writer=USE_DEFAULT, - init_fn=None): + init_fn=None, + local_init_run_options=None): """Create a `Supervisor`. Args: @@ -294,6 +295,8 @@ class Supervisor(object): init_fn: Optional callable used to initialize the model. Called after the optional `init_op` is called. The callable must accept one argument, the session being initialized. + local_init_run_options: RunOptions to be passed as the SessionManager + local_init_run_options parameter. Returns: A `Supervisor`. @@ -327,6 +330,7 @@ class Supervisor(object): self._recovery_wait_secs = recovery_wait_secs self._stop_grace_secs = stop_grace_secs self._init_fn = init_fn + self._local_init_run_options = local_init_run_options # Set all attributes related to checkpointing and writing events to None. # Afterwards, set them appropriately for chief supervisors, as these are @@ -362,7 +366,8 @@ class Supervisor(object): ready_op=self._ready_op, ready_for_local_init_op=self._ready_for_local_init_op, graph=self._graph, - recovery_wait_secs=self._recovery_wait_secs) + recovery_wait_secs=self._recovery_wait_secs, + local_init_run_options=self._local_init_run_options) else: self._session_manager = session_manager diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index d05e1d2c830b2aa7008c9cba9f28eb6230d8bc82..0877b2a8a2fc7d59c4075c7d37c52ab691ec0361 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -119,18 +119,18 @@ def create_global_step(graph=None): graph = graph or ops.get_default_graph() if get_global_step(graph) is not None: raise ValueError('"global_step" already exists.') + if context.executing_eagerly(): + with ops.device('cpu:0'): + return variable_scope.get_variable( + ops.GraphKeys.GLOBAL_STEP, + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.GLOBAL_STEP]) # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): - if context.executing_eagerly(): - with ops.device('cpu:0'): - return variable_scope.get_variable( - ops.GraphKeys.GLOBAL_STEP, - shape=[], - dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, - ops.GraphKeys.GLOBAL_STEP]) return variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, shape=[], diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index 4d4fb394c1272d2bf510bb594d70b9aa2edb3df2..ec740abdd15ae2904f79246429deaa5fc831dad5 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -33,7 +33,7 @@ from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export -@tf_export("train.VocabInfo", "estimator.VocabInfo") +@tf_export("train.VocabInfo", allow_multiple_exports=True) class VocabInfo( collections.namedtuple("VocabInfo", [ "new_vocab", @@ -237,6 +237,62 @@ def _warm_start_var_with_vocab(var, # pylint: enable=protected-access +def _get_grouped_variables(vars_to_warm_start): + """Collects and groups (possibly partitioned) variables into a dictionary. + + The variables can be provided explicitly through vars_to_warm_start, or they + are retrieved from collections (see below). + + Args: + vars_to_warm_start: One of the following: + + - A regular expression (string) that captures which variables to + warm-start (see tf.get_collection). This expression will only consider + variables in the TRAINABLE_VARIABLES collection. + - A list of Variables to warm-start. + - A list of strings, each representing a full variable name to warm-start. + - `None`, in which case only variables specified in + `var_name_to_vocab_info` will be warm-started. + Returns: + A dictionary mapping variable names (strings) to lists of Variables. + Raises: + ValueError: If vars_to_warm_start is not a string, `None`, a list of + `Variables`, or a list of strings. + """ + if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: + # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match + # everything (in TRAINABLE_VARIABLES) here. + list_of_vars = ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=vars_to_warm_start) + elif isinstance(vars_to_warm_start, list): + if all([isinstance(v, str) for v in vars_to_warm_start]): + list_of_vars = [] + for v in vars_to_warm_start: + list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + scope=v) + elif all([_is_variable(v) for v in vars_to_warm_start]): + list_of_vars = vars_to_warm_start + else: + raise ValueError("If `vars_to_warm_start` is a list, it must be all " + "`Variable` or all `str`. Given types are {}".format( + [type(v) for v in vars_to_warm_start])) + else: + raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " + "type is {}".format(type(vars_to_warm_start))) + # We have to deal with partitioned variables, since get_collection flattens + # out the list. + grouped_variables = {} + for v in list_of_vars: + if not isinstance(v, list): + var_name = _infer_var_name([v]) + else: + var_name = _infer_var_name(v) + grouped_variables.setdefault(var_name, []).append(v) + + return grouped_variables + + @tf_export("train.warm_start") def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", @@ -251,10 +307,19 @@ def warm_start(ckpt_to_initialize_from, ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. - vars_to_warm_start: [Optional] A regular expression that captures which - variables to warm-start (see tf.get_collection). Defaults to `'.*'`, - which warm-starts all variables. If `None` is explicitly given, only - variables specified in `var_name_to_vocab_info` will be warm-started. + vars_to_warm_start: [Optional] One of the following: + + - A regular expression (string) that captures which variables to + warm-start (see tf.get_collection). This expression will only consider + variables in the TRAINABLE_VARIABLES collection. + - A list of Variables to warm-start. + - A list of strings, each representing a full variable name to warm-start. + - `None`, in which case only variables specified in + `var_name_to_vocab_info` will be warm-started. + + Defaults to `'.*'`, which warm-starts all variables in the + TRAINABLE_VARIABLES collection. Note that this excludes variables such as + accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to VocabInfo. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to @@ -274,21 +339,7 @@ def warm_start(ckpt_to_initialize_from, if var_name_to_prev_var_name is None: var_name_to_prev_var_name = {} logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,)) - # We have to deal with partitioned variables, since get_collection flattens - # out the list. - grouped_variables = {} - # Both vars_to_warm_start = '.*' and - # vars_to_warm_start = None will match everything here. - for v in ops.get_collection( - # TODO(eddz): Allow for different collections here (to support - # warm-starting accumulators). - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=vars_to_warm_start): - if not isinstance(v, list): - var_name = _infer_var_name([v]) - else: - var_name = _infer_var_name(v) - grouped_variables.setdefault(var_name, []).append(v) + grouped_variables = _get_grouped_variables(vars_to_warm_start) # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py index 7e8cbd6baeea160075b61d1191c8f1da5fe2163c..6a4c207d79edf22d635c38fe98589396e781e84e 100644 --- a/tensorflow/python/training/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -36,6 +36,7 @@ from tensorflow.python.training import warm_starting_util as ws_util ones = init_ops.ones_initializer norms = init_ops.truncated_normal_initializer rand = init_ops.random_uniform_initializer +zeros = init_ops.zeros_initializer class WarmStartingUtilTest(test.TestCase): @@ -305,6 +306,46 @@ class WarmStartingUtilTest(test.TestCase): self.assertAllEqual([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStart_ListOfVariables(self): + # Save checkpoint from which to warm-start. + _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], + initializer=ones()) + # Verify we initialized the values correctly. + self.assertAllEqual(np.ones([10, 1]), prev_int_val) + + # New graph, new session with warm-starting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Initialize with zeros. + var = variable_scope.get_variable( + "v1", + shape=[10, 1], + initializer=zeros()) + ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var]) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warm-started (init overridden to ones). + self.assertAllEqual(var.eval(), prev_int_val) + + def testWarmStart_ListOfStrings(self): + # Save checkpoint from which to warm-start. + _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], + initializer=ones()) + # Verify we initialized the values correctly. + self.assertAllEqual(np.ones([10, 1]), prev_int_val) + + # New graph, new session with warm-starting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Initialize with zeros. + var = variable_scope.get_variable( + "v1", + shape=[10, 1], + initializer=zeros()) + ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"]) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warm-started (init overridden to ones). + self.assertAllEqual(var.eval(), prev_int_val) + def testWarmStart_SparseColumnIntegerized(self): # Create feature column. sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10) diff --git a/tensorflow/python/util/lock_util.py b/tensorflow/python/util/lock_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0424960666323870fb1db83804857dd838cfe9ae --- /dev/null +++ b/tensorflow/python/util/lock_util.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Locking related utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + + +class GroupLock(object): + """A lock to allow many members of a group to access a resource exclusively. + + This lock provides a way to allow access to a resource by multiple threads + belonging to a logical group at the same time, while restricting access to + threads from all other groups. You can think of this as an extension of a + reader-writer lock, where you allow multiple writers at the same time. We + made it generic to support multiple groups instead of just two - readers and + writers. + + Simple usage example with two groups accessing the same resource: + + ```python + lock = GroupLock(num_groups=2) + + # In a member of group 0: + with lock.group(0): + # do stuff, access the resource + # ... + + # In a member of group 1: + with lock.group(1): + # do stuff, access the resource + # ... + ``` + + Using as a context manager with `.group(group_id)` is the easiest way. You + can also use the `acquire` and `release` method directly. + """ + + def __init__(self, num_groups=2): + """Initialize a group lock. + + Args: + num_groups: The number of groups that will be accessing the resource under + consideration. Should be a positive number. + + Returns: + A group lock that can then be used to synchronize code. + + Raises: + ValueError: If num_groups is less than 1. + """ + if num_groups < 1: + raise ValueError("num_groups must be a positive integer, got {}".format( + num_groups)) + self._ready = threading.Condition(threading.Lock()) + self._num_groups = num_groups + self._group_member_counts = [0] * self._num_groups + + def group(self, group_id): + """Enter a context where the lock is with group `group_id`. + + Args: + group_id: The group for which to acquire and release the lock. + + Returns: + A context manager which will acquire the lock for `group_id`. + """ + self._validate_group_id(group_id) + return self._Context(self, group_id) + + def acquire(self, group_id): + """Acquire the group lock for a specific group `group_id`.""" + self._validate_group_id(group_id) + + self._ready.acquire() + while self._another_group_active(group_id): + self._ready.wait() + self._group_member_counts[group_id] += 1 + self._ready.release() + + def release(self, group_id): + """Release the group lock for a specific group `group_id`.""" + self._validate_group_id(group_id) + + self._ready.acquire() + self._group_member_counts[group_id] -= 1 + if self._group_member_counts[group_id] == 0: + self._ready.notifyAll() + self._ready.release() + + def _another_group_active(self, group_id): + return any( + c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id) + + def _validate_group_id(self, group_id): + if group_id < 0 or group_id >= self._num_groups: + raise ValueError( + "group_id={} should be between 0 and num_groups={}".format( + group_id, self._num_groups)) + + class _Context(object): + """Context manager helper for `GroupLock`.""" + + def __init__(self, lock, group_id): + self._lock = lock + self._group_id = group_id + + def __enter__(self): + self._lock.acquire(self._group_id) + + def __exit__(self, type_arg, value_arg, traceback_arg): + del type_arg, value_arg, traceback_arg + self._lock.release(self._group_id) diff --git a/tensorflow/python/util/lock_util_test.py b/tensorflow/python/util/lock_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cda8f952259c9e117e0bd7ff3cac35e764856f43 --- /dev/null +++ b/tensorflow/python/util/lock_util_test.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lock_util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +from absl.testing import parameterized + +from tensorflow.python.platform import test +from tensorflow.python.util import lock_util + + +class GroupLockTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters(1, 2, 3, 5, 10) + def testGroups(self, num_groups): + lock = lock_util.GroupLock(num_groups) + num_threads = 10 + finished = set() + + def thread_fn(thread_id): + time.sleep(random.random() * 0.1) + group_id = thread_id % num_groups + with lock.group(group_id): + time.sleep(random.random() * 0.1) + self.assertGreater(lock._group_member_counts[group_id], 0) + for g, c in enumerate(lock._group_member_counts): + if g != group_id: + self.assertEqual(0, c) + finished.add(thread_id) + + threads = [ + self.checkedThread(target=thread_fn, args=(i,)) + for i in range(num_threads) + ] + + for i in range(num_threads): + threads[i].start() + for i in range(num_threads): + threads[i].join() + + self.assertEqual(set(range(num_threads)), finished) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 1104768ae8f69598f686eb2ffee8b69e43051011..d63f59a8c8e836d3f8ad3686da0b0b3f010a9225 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True): Args: nest1: an arbitrarily nested structure. nest2: an arbitrarily nested structure. - check_types: if `True` (default) types of sequences are checked as - well, including the keys of dictionaries. If set to `False`, for example - a list and a tuple of objects will look the same if they have the same + check_types: if `True` (default) types of sequences are checked as well, + including the keys of dictionaries. If set to `False`, for example a + list and a tuple of objects will look the same if they have the same size. Note that namedtuples with identical name and fields are always - considered to have the same shallow structure. + considered to have the same shallow structure. Two types will also be + considered the same if they are both list subtypes (which allows "list" + and "_ListWrapper" from checkpointable dependency tracking to compare + equal). Raises: ValueError: If the two structures do not have the same number of elements or diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py index f16fa5377b5dcdcf2a1aff5994fb6005a22e2e70..9d9cac272592f6b73b4c78f38310d7b89a89e05d 100644 --- a/tensorflow/python/util/serialization_test.py +++ b/tensorflow/python/util/serialization_test.py @@ -23,10 +23,10 @@ import json from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl.keras.engine import input_layer -from tensorflow.python.keras._impl.keras.engine import sequential -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import input_layer +from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.platform import test from tensorflow.python.util import serialization @@ -47,7 +47,7 @@ class SerializationTests(test.TestCase): self.assertIs(round_trip[0], None) self.assertEqual(round_trip[1], 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_serialize_sequential(self): model = sequential.Sequential() model.add(core.Dense(4)) @@ -61,7 +61,7 @@ class SerializationTests(test.TestCase): self.assertAllEqual([1, 1], input_round_trip[0]["config"]["batch_input_shape"]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_serialize_model(self): x = input_layer.Input(shape=[3]) y = core.Dense(10)(x) diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i index 6aeaa0e31b9b48f7e6705ab7146828cc0e0e5e08..73fa85494b72d920d00577c826b76c3381d963a4 100644 --- a/tensorflow/python/util/stat_summarizer.i +++ b/tensorflow/python/util/stat_summarizer.i @@ -73,7 +73,7 @@ void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss); return ss; } } - +%include "tensorflow/core/util/stat_summarizer_options.h" %include "tensorflow/core/util/stat_summarizer.h" %unignoreall @@ -88,9 +88,4 @@ def NewStatSummarizer(unused): def DeleteStatSummarizer(stat_summarizer): _DeleteStatSummarizer(stat_summarizer) - -NewStatSummarizer._tf_api_names = ["contrib.stat_summarizer.NewStatSummarizer"] -DeleteStatSummarizer._tf_api_names = [ - "contrib.stat_summarizer.DeleteStatSummarizer"] -StatSummarizer._tf_api_names = ["contrib.stat_summarizer.StatSummarizer"] %} diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index a30b8b1336358372bed3b6fa0b853e79c1535036..e154ffb68a4f0ccdebf5320cad7d3da056117197 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -41,17 +41,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import functools import sys from tensorflow.python.util import tf_decorator +ESTIMATOR_API_NAME = 'estimator' +TENSORFLOW_API_NAME = 'tensorflow' + +_Attributes = collections.namedtuple( + 'ExportedApiAttributes', ['names', 'constants']) + +# Attribute values must be unique to each API. +API_ATTRS = { + TENSORFLOW_API_NAME: _Attributes( + '_tf_api_names', + '_tf_api_constants'), + ESTIMATOR_API_NAME: _Attributes( + '_estimator_api_names', + '_estimator_api_constants') +} + class SymbolAlreadyExposedError(Exception): """Raised when adding API names to symbol that already has API names.""" pass -class tf_export(object): # pylint: disable=invalid-name +class api_export(object): # pylint: disable=invalid-name """Provides ways to export symbols to the TensorFlow API.""" def __init__(self, *args, **kwargs): @@ -59,12 +77,15 @@ class tf_export(object): # pylint: disable=invalid-name Args: *args: API names in dot delimited format. - **kwargs: Optional keyed arguments. Currently only supports 'overrides' - argument. overrides: List of symbols that this is overriding - (those overrided api exports will be removed). Note: passing overrides - has no effect on exporting a constant. + **kwargs: Optional keyed arguments. + overrides: List of symbols that this is overriding + (those overrided api exports will be removed). Note: passing overrides + has no effect on exporting a constant. + api_name: Name of the API you want to generate (e.g. `tensorflow` or + `estimator`). Default is `tensorflow`. """ self._names = args + self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) def __call__(self, func): @@ -77,29 +98,27 @@ class tf_export(object): # pylint: disable=invalid-name The input function with _tf_api_names attribute set. Raises: - SymbolAlreadyExposedError: Raised when a symbol already has API names. + SymbolAlreadyExposedError: Raised when a symbol already has API names + and kwarg `allow_multiple_exports` not set. """ + api_names_attr = API_ATTRS[self._api_name].names + # Undecorate overridden names for f in self._overrides: _, undecorated_f = tf_decorator.unwrap(f) - del undecorated_f._tf_api_names # pylint: disable=protected-access + delattr(undecorated_f, api_names_attr) _, undecorated_func = tf_decorator.unwrap(func) # Check for an existing api. We check if attribute name is in # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. - if '_tf_api_names' in undecorated_func.__dict__: - # pylint: disable=protected-access + if api_names_attr in undecorated_func.__dict__: raise SymbolAlreadyExposedError( 'Symbol %s is already exposed as %s.' % - (undecorated_func.__name__, undecorated_func._tf_api_names)) - # pylint: enable=protected-access - - # Complete the export by creating/overriding attribute - # pylint: disable=protected-access - undecorated_func._tf_api_names = self._names - # pylint: enable=protected-access + (undecorated_func.__name__, getattr( + undecorated_func, api_names_attr))) # pylint: disable=protected-access + setattr(undecorated_func, api_names_attr, self._names) return func def export_constant(self, module_name, name): @@ -121,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name name: (string) Current constant name. """ module = sys.modules[module_name] - if not hasattr(module, '_tf_api_constants'): - module._tf_api_constants = [] # pylint: disable=protected-access + if not hasattr(module, API_ATTRS[self._api_name].constants): + setattr(module, API_ATTRS[self._api_name].constants, []) # pylint: disable=protected-access - module._tf_api_constants.append((self._names, name)) + getattr(module, API_ATTRS[self._api_name].constants).append( + (self._names, name)) + +tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) +estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME) diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index 9bad4a24814e9d61e17fbb98b5ee861f48558c66..fbd65617670b15bfc69506bab1e83369081502af 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -18,8 +18,11 @@ from __future__ import division from __future__ import print_function from collections import namedtuple +import functools import inspect as _inspect +import six + from tensorflow.python.util import tf_decorator ArgSpec = _inspect.ArgSpec @@ -39,27 +42,164 @@ def currentframe(): return _inspect.stack()[1][0] -def getargspec(object): # pylint: disable=redefined-builtin +def getargspec(obj): """TFDecorator-aware replacement for inspect.getargspec. Args: - object: A callable, possibly decorated. + obj: A function, partial function, or callable object, possibly + decorated. Returns: The `ArgSpec` that describes the signature of the outermost decorator that changes the callable's signature. If the callable is not decorated, - `inspect.getargspec()` will be called directly on the callable. + `inspect.getargspec()` will be called directly on the object. + + Raises: + ValueError: When callable's signature can not be expressed with + ArgSpec. + TypeError: For objects of unsupported types. """ - decorators, target = tf_decorator.unwrap(object) - return next((d.decorator_argspec for d in decorators - if d.decorator_argspec is not None), _inspect.getargspec(target)) + if isinstance(obj, functools.partial): + return _get_argspec_for_partial(obj) + + decorators, target = tf_decorator.unwrap(obj) + + spec = next((d.decorator_argspec + for d in decorators + if d.decorator_argspec is not None), None) + if spec: + return spec + try: + # Python3 will handle most callables here (not partial). + return _inspect.getargspec(target) + except TypeError: + pass -def getfullargspec(obj): # pylint: disable=redefined-builtin - """TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`. + if isinstance(target, type): + try: + return _inspect.getargspec(target.__init__) + except TypeError: + pass - This wrapper uses `inspect.getfullargspec` if available and falls back to - `inspect.getargspec` in Python 2. + try: + return _inspect.getargspec(target.__new__) + except TypeError: + pass + + # The `type(target)` ensures that if a class is received we don't return + # the signature of it's __call__ method. + return _inspect.getargspec(type(target).__call__) + + +def _get_argspec_for_partial(obj): + """Implements `getargspec` for `functools.partial` objects. + + Args: + obj: The `functools.partial` obeject + Returns: + An `inspect.ArgSpec` + Raises: + ValueError: When callable's signature can not be expressed with + ArgSpec. + """ + # When callable is a functools.partial object, we construct its ArgSpec with + # following strategy: + # - If callable partial contains default value for positional arguments (ie. + # object.args), then final ArgSpec doesn't contain those positional arguments. + # - If callable partial contains default value for keyword arguments (ie. + # object.keywords), then we merge them with wrapped target. Default values + # from callable partial takes precedence over those from wrapped target. + # + # However, there is a case where it is impossible to construct a valid + # ArgSpec. Python requires arguments that have no default values must be + # defined before those with default values. ArgSpec structure is only valid + # when this presumption holds true because default values are expressed as a + # tuple of values without keywords and they are always assumed to belong to + # last K arguments where K is number of default values present. + # + # Since functools.partial can give default value to any argument, this + # presumption may no longer hold in some cases. For example: + # + # def func(m, n): + # return 2 * m + n + # partialed = functools.partial(func, m=1) + # + # This example will result in m having a default value but n doesn't. This is + # usually not allowed in Python and can not be expressed in ArgSpec correctly. + # + # Thus, we must detect cases like this by finding first argument with default + # value and ensures all following arguments also have default values. When + # this is not true, a ValueError is raised. + + n_prune_args = len(obj.args) + partial_keywords = obj.keywords or {} + + args, varargs, keywords, defaults = getargspec(obj.func) + + # Pruning first n_prune_args arguments. + args = args[n_prune_args:] + + # Partial function may give default value to any argument, therefore length + # of default value list must be len(args) to allow each argument to + # potentially be given a default value. + all_defaults = [None] * len(args) + if defaults: + all_defaults[-len(defaults):] = defaults + + # Fill in default values provided by partial function in all_defaults. + for kw, default in six.iteritems(partial_keywords): + idx = args.index(kw) + all_defaults[idx] = default + + # Find first argument with default value set. + first_default = next((idx for idx, x in enumerate(all_defaults) if x), None) + + # If no default values are found, return ArgSpec with defaults=None. + if first_default is None: + return ArgSpec(args, varargs, keywords, None) + + # Checks if all arguments have default value set after first one. + invalid_default_values = [ + args[i] for i, j in enumerate(all_defaults) if not j and i > first_default + ] + + if invalid_default_values: + raise ValueError('Some arguments %s do not have default value, but they ' + 'are positioned after those with default values. This can ' + 'not be expressed with ArgSpec.' % invalid_default_values) + + return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:])) + + +if hasattr(_inspect, 'getfullargspec'): + _getfullargspec = _inspect.getfullargspec +else: + + def _getfullargspec(target): + """A python2 version of getfullargspec. + + Args: + target: the target object to inspect. + Returns: + A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations. + """ + argspecs = _inspect.getargspec(target) + fullargspecs = FullArgSpec( + args=argspecs.args, + varargs=argspecs.varargs, + varkw=argspecs.keywords, + defaults=argspecs.defaults, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + return fullargspecs + + +def getfullargspec(obj): + """TFDecorator-aware replacement for `inspect.getfullargspec`. + + This wrapper emulates `inspect.getfullargspec` in[^)]* Python2. Args: obj: A callable, possibly decorated. @@ -70,34 +210,10 @@ def getfullargspec(obj): # pylint: disable=redefined-builtin callable is not decorated, `inspect.getfullargspec()` will be called directly on the callable. """ - if hasattr(_inspect, 'getfullargspec'): - spec_fn = _inspect.getfullargspec - else: - def spec_fn(target): - """Spec function that adding default value from FullArgSpec. - - It is used when getfullargspec is not available (eg in PY2). - - Args: - target: the target object to inspect. - Returns: - The full argument specs with empty kwonlyargs, kwonlydefaults and - annotations. - """ - argspecs = _inspect.getargspec(target) - fullargspecs = FullArgSpec( - args=argspecs.args, - varargs=argspecs.varargs, - varkw=argspecs.keywords, - defaults=argspecs.defaults, - kwonlyargs=[], - kwonlydefaults=None, - annotations={}) - return fullargspecs - decorators, target = tf_decorator.unwrap(obj) - return next((d.decorator_argspec for d in decorators - if d.decorator_argspec is not None), spec_fn(target)) + return next((d.decorator_argspec + for d in decorators + if d.decorator_argspec is not None), _getfullargspec(target)) def getcallargs(func, *positional, **named): diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index 129408449ebb45ac3a322f163a13b705cbb31f0c..beaf350de1e469a7675a4b55ff341419262b79b2 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import inspect from tensorflow.python.platform import test @@ -109,6 +110,187 @@ class TfInspectTest(test.TestCase): outer_argspec) self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator)) + def testGetArgSpecOnPartialPositionalArgumentOnly(self): + """Tests getargspec on partial function with only positional arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, 7) + argspec = tf_inspect.ArgSpec( + args=['n'], varargs=None, keywords=None, defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialInvalidArgspec(self): + """Tests getargspec on partial function that doesn't have valid argspec.""" + + def func(m, n, l, k=4): + return 2 * m + l + n * k + + partial_func = functools.partial(func, n=7) + + exception_message = (r"Some arguments \['l'\] do not have default value, " + "but they are positioned after those with default " + "values. This can not be expressed with ArgSpec.") + with self.assertRaisesRegexp(ValueError, exception_message): + tf_inspect.getargspec(partial_func) + + def testGetArgSpecOnPartialValidArgspec(self): + """Tests getargspec on partial function with valid argspec.""" + + def func(m, n, l, k=4): + return 2 * m + l + n * k + + partial_func = functools.partial(func, n=7, l=2) + argspec = tf_inspect.ArgSpec( + args=['m', 'n', 'l', 'k'], + varargs=None, + keywords=None, + defaults=(7, 2, 4)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialNoArgumentsLeft(self): + """Tests getargspec on partial function that prunes all arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, 7, 10) + argspec = tf_inspect.ArgSpec( + args=[], varargs=None, keywords=None, defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialKeywordArgument(self): + """Tests getargspec on partial function that prunes some arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, n=7) + argspec = tf_inspect.ArgSpec( + args=['m', 'n'], varargs=None, keywords=None, defaults=(7,)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self): + """Tests getargspec on partial function that prunes argument by keyword.""" + + def func(m=1, n=2): + return 2 * m + n + + partial_func = functools.partial(func, n=7) + argspec = tf_inspect.ArgSpec( + args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithVarargs(self): + """Tests getargspec on partial function with variable arguments.""" + + def func(m, *arg): + return m + len(arg) + + partial_func = functools.partial(func, 7, 8) + argspec = tf_inspect.ArgSpec( + args=[], varargs='arg', keywords=None, defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithVarkwargs(self): + """Tests getargspec on partial function with variable keyword arguments.""" + + def func(m, n, **kwarg): + return m * n + len(kwarg) + + partial_func = functools.partial(func, 7) + argspec = tf_inspect.ArgSpec( + args=['n'], varargs=None, keywords='kwarg', defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithDecorator(self): + """Tests getargspec on decorated partial function.""" + + @test_decorator('decorator') + def func(m=1, n=2): + return 2 * m + n + + partial_func = functools.partial(func, n=7) + argspec = tf_inspect.ArgSpec( + args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self): + """Tests getargspec on partial function with decorated argspec.""" + + argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + decorator = tf_decorator.TFDecorator('', test_undecorated_function, '', + argspec) + partial_argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(2, 1, 'hello')) + partial_with_decorator = functools.partial(decorator, a=2) + + self.assertEqual(argspec, tf_inspect.getargspec(decorator)) + self.assertEqual(partial_argspec, + tf_inspect.getargspec(partial_with_decorator)) + + def testGetArgSpecOnCallableObject(self): + + class Callable(object): + + def __call__(self, a, b=1, c='hello'): + pass + + argspec = tf_inspect.ArgSpec( + args=['self', 'a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + + test_obj = Callable() + self.assertEqual(argspec, tf_inspect.getargspec(test_obj)) + + def testGetArgSpecOnInitClass(self): + + class InitClass(object): + + def __init__(self, a, b=1, c='hello'): + pass + + argspec = tf_inspect.ArgSpec( + args=['self', 'a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + + self.assertEqual(argspec, tf_inspect.getargspec(InitClass)) + + def testGetArgSpecOnNewClass(self): + + class NewClass(object): + + def __new__(cls, a, b=1, c='hello'): + pass + + argspec = tf_inspect.ArgSpec( + args=['cls', 'a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + + self.assertEqual(argspec, tf_inspect.getargspec(NewClass)) + def testGetDoc(self): self.assertEqual('Test Decorated Function With Defaults Docstring.', tf_inspect.getdoc(test_decorated_function_with_defaults)) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 9c8d50da7351d91d435719565d82a1dd3f19c043..366f8a0deb533c3ee258ea618136d44a28160f8f 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -14,8 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/python/util/util.h" +#include +#include +#include + +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { @@ -25,6 +31,9 @@ namespace { // Type object for collections.Sequence. This is set by RegisterSequenceClass. PyObject* CollectionsSequenceType = nullptr; +PyTypeObject* SparseTensorValueType = nullptr; + +const int kMaxItemsInCache = 1024; bool WarnedThatSetIsNotSequence = false; @@ -135,6 +144,12 @@ class ValIterator { Py_ssize_t index_; }; +mutex g_type_to_sequence_map(LINKER_INITIALIZED); +std::unordered_map* IsTypeSequenceMap() { + static auto* const m = new std::unordered_map; + return m; +} + // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. @@ -155,64 +170,149 @@ int IsSequenceHelper(PyObject* o) { .c_str()); return -1; } + + // Try not to return to Python - see if the type has already been seen + // before. + + auto* type_to_sequence_map = IsTypeSequenceMap(); + auto* type = Py_TYPE(o); + + { + mutex_lock l(g_type_to_sequence_map); + auto it = type_to_sequence_map->find(type); + if (it != type_to_sequence_map->end()) { + return it->second; + } + } + + // NOTE: We explicitly release the g_type_to_sequence_map mutex, + // because PyObject_IsInstance() may release the GIL, allowing another thread + // concurrent entry to this function. int is_instance = PyObject_IsInstance(o, CollectionsSequenceType); + + // Don't cache a failed is_instance check. if (is_instance == -1) return -1; - return static_cast(is_instance != 0 && !IsString(o)); + + bool is_sequence = static_cast(is_instance != 0 && !IsString(o)); + + // NOTE: This is never decref'd, but we don't want the type to get deleted + // as long as it is in the map. This should not be too much of a + // leak, as there should only be a relatively small number of types in the + // map, and an even smaller number that are eligible for decref. As a + // precaution, we limit the size of the map to 1024. + { + mutex_lock l(g_type_to_sequence_map); + if (type_to_sequence_map->size() < kMaxItemsInCache) { + Py_INCREF(type); + type_to_sequence_map->insert({type, is_sequence}); + } + } + + return is_sequence; } -bool FlattenHelper(PyObject* nested, PyObject* list) { - // if nested is not a sequence, append itself and exit - int is_seq = IsSequenceHelper(nested); - if (is_seq == -1) return false; - if (!is_seq) { - return PyList_Append(list, nested) != -1; +bool IsSparseTensorValueType(PyObject* o) { + if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { + return false; } - // if nested if dictionary, sort it by key and recurse on each value - if (PyDict_Check(nested)) { - PyObject* keys = PyDict_Keys(nested); - if (PyList_Sort(keys) == -1) return false; - Py_ssize_t size = PyList_Size(keys); - for (Py_ssize_t i = 0; i < size; ++i) { - // We know that key and val will not be deleted because nested owns - // a reference to them and callers of flatten must not modify nested - // while the method is running. - PyObject* key = PyList_GET_ITEM(keys, i); - PyObject* val = PyDict_GetItem(nested, key); - if (Py_EnterRecursiveCall(" in flatten")) { - Py_DECREF(keys); - return false; - } - const bool success = FlattenHelper(val, list); - Py_LeaveRecursiveCall(); - if (!success) { - Py_DECREF(keys); - return false; - } - } - Py_DECREF(keys); - return true; + return PyObject_TypeCheck(o, SparseTensorValueType) == 1; +} + +int IsSequenceForDataHelper(PyObject* o) { + return IsSequenceHelper(o) == 1 && !PyList_Check(o) && + !IsSparseTensorValueType(o); +} + +bool GetNextValuesForDict(PyObject* nested, + std::vector* next_values) { + std::vector result; + + PyObject* keys = PyDict_Keys(nested); + if (PyList_Sort(keys) == -1) return false; + Py_ssize_t size = PyList_Size(keys); + for (Py_ssize_t i = 0; i < size; ++i) { + // We know that key and item will not be deleted because nested owns + // a reference to them and callers of flatten must not modify nested + // while the method is running. + PyObject* key = PyList_GET_ITEM(keys, i); + PyObject* item = PyDict_GetItem(nested, key); + Py_INCREF(item); + next_values->emplace_back(item); } + Py_DECREF(keys); + return true; +} - // iterate and recurse +bool GetNextValuesForIterable(PyObject* nested, + std::vector* next_values) { PyObject* item; PyObject* iterator = PyObject_GetIter(nested); + if (iterator == nullptr || PyErr_Occurred()) { + return false; + } while ((item = PyIter_Next(iterator)) != nullptr) { + next_values->emplace_back(item); + } + Py_DECREF(iterator); + return true; +} + +// GetNextValues returns the values that the FlattenHelper function will recurse +// over next. +bool GetNextValues(PyObject* nested, + std::vector* next_values) { + if (PyDict_Check(nested)) { + // if nested is dictionary, sort it by key and recurse on each value + return GetNextValuesForDict(nested, next_values); + } + // iterate and recurse + return GetNextValuesForIterable(nested, next_values); +} + +// Similar to above, just specialized for the functions in the data pacakage. +bool GetNextValuesForData(PyObject* nested, + std::vector* next_values) { + if (PyDict_Check(nested)) { + // if nested is dictionary, sort it by key and recurse on each value + return GetNextValuesForDict(nested, next_values); + } else if (IsSparseTensorValueType(nested)) { + // if nested is a SparseTensorValue, just return itself as a single item + Py_INCREF(nested); + next_values->emplace_back(nested); + return true; + } + // iterate and recurse + return GetNextValuesForIterable(nested, next_values); +} + +bool FlattenHelper( + PyObject* nested, PyObject* list, + const std::function& is_sequence_helper, + const std::function*)>& + next_values_getter) { + // if nested is not a sequence, append itself and exit + int is_seq = is_sequence_helper(nested); + if (is_seq == -1) return false; + if (!is_seq) { + return PyList_Append(list, nested) != -1; + } + + std::vector next_values; + // Get the next values to recurse over. + if (!next_values_getter(nested, &next_values)) return false; + + for (const auto& item : next_values) { if (Py_EnterRecursiveCall(" in flatten")) { - Py_DECREF(iterator); - Py_DECREF(item); return false; } - bool success = FlattenHelper(item, list); + const bool success = + FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter); Py_LeaveRecursiveCall(); if (!success) { - Py_DECREF(iterator); - Py_DECREF(item); return false; } - Py_DECREF(item); } - Py_DECREF(iterator); return true; } @@ -294,7 +394,11 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, type2->tp_name); return true; } - } else if (type1 != type2) { + } else if (type1 != type2 + /* If both sequences are list types, don't complain. This allows + one to be a list subclass (e.g. _ListWrapper used for automatic + dependency tracking.) */ + && !(PyList_Check(o1) && PyList_Check(o2))) { *is_type_error = true; *error_msg = tensorflow::strings::StrCat( "The two namedtuples don't have the same sequence type. " @@ -351,7 +455,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, } } -} // anonymous namespace +} // namespace void RegisterSequenceClass(PyObject* sequence_class) { if (!PyType_Check(sequence_class)) { @@ -366,11 +470,38 @@ void RegisterSequenceClass(PyObject* sequence_class) { CollectionsSequenceType = sequence_class; } +void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { + if (!PyType_Check(sparse_tensor_value_class)) { + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting a class definition for `SparseTensorValue`. Got ", + Py_TYPE(sparse_tensor_value_class)->tp_name) + .c_str()); + return; + } + SparseTensorValueType = + reinterpret_cast(sparse_tensor_value_class); +} + bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); - if (FlattenHelper(nested, list)) { + if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) { + return list; + } else { + Py_DECREF(list); + return nullptr; + } +} + +bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; } + +PyObject* FlattenForData(PyObject* nested) { + PyObject* list = PyList_New(0); + if (FlattenHelper(nested, list, IsSequenceForDataHelper, + GetNextValuesForData)) { return list; } else { Py_DECREF(list); diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 4bb80d8289e958074fa7beb133e95908283a5b7b..70efc10c9abe7c57da61311bb2eb7ae362a48e3d 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -118,6 +118,30 @@ PyObject* Flatten(PyObject* nested); // the type from the module. This approach also requires some trigger from // Python so that we know that Python interpreter had been initialzied. void RegisterSequenceClass(PyObject* sequence_class); +// Similar to the above function, except for the +// sparse_tensor.SparseTensorValue class. +void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class); + +// The tensorflow.python.data package has its own nest utility that follows very +// slightly different semantics for its functions than the tensorflow.python +// nest utility. Returns a true if its input is a collections.Sequence (except +// strings). +// +// Main differences are (this is copied from nest.py in the +// tensorflow.data.util): +// +// 1. It removes support for lists as a level of nesting in nested structures. +// 2. It adds support for `SparseTensorValue` as an atomic element. + +// IsSequence specialized for the data package. Additional comments about +// difference in functionality can be found in nest.py in tensorflow.data.util +// and in the comments for Flatten above. +bool IsSequenceForData(PyObject* o); + +// IsSequence specialized for the data package. Additional comments about +// difference in functionality can be found in nest.py in tensorflow.data.util +// and in the comments for Flatten above. +PyObject* FlattenForData(PyObject* nested); } // namespace swig } // namespace tensorflow diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index b7f201b6fe6fd18af2bb833df2d08bfedb23a185..9f3b11b982bb0d52f903b09975cc7029fa8cb013 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -31,6 +31,9 @@ limitations under the License. %unignore tensorflow::swig::RegisterSequenceClass; %noexception tensorflow::swig::RegisterSequenceClass; +%unignore tensorflow::swig::RegisterSparseTensorValueClass; +%noexception tensorflow::swig::RegisterSparseTensorValueClass; + %unignore tensorflow::swig::IsSequence; %noexception tensorflow::swig::IsSequence; @@ -46,6 +49,12 @@ limitations under the License. %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; +%unignore tensorflow::swig::IsSequenceForData; +%noexception tensorflow::swig::IsSequenceForData; + +%unignore tensorflow::swig::FlattenForData; +%noexception tensorflow::swig::FlattenForData; + %include "tensorflow/python/util/util.h" %unignoreall diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md new file mode 100644 index 0000000000000000000000000000000000000000..bb97543a21988b4370ddac912102add6a10e2b35 --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2018-001.md @@ -0,0 +1,34 @@ +## TFSA-2018-001: BMP File Parser Out-of-bounds Read. + +### CVE Number + +CVE-2018-7574 + +### Issue Description + +The BMP (bitmap image file graphics format) decoder had an out-of-bounds read +due to insufficient checking of header sizes and signed integer values. + +### Impact + +The most likely consequence of this vulnerability would be that an invalid BMP +file could lead to an unhandled process crash, but may permit read access to +unintended regions of the TensorFlow process memory. + +### Vulnerable Versions + +TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0 + +### Mitigation + +We have patched the vulnerability in GitHub commit +[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55). +If users are running TensorFlow in production or on untrusted data, they are +encouraged to apply this patch. + +Additionally, this patch has already been integrated into TensorFlow 1.7.0 and +newer. + +### Credits + +This issue was discovered by the Blade Team of Tencent. diff --git a/tensorflow/security/advisory/tfsa-2018-002.md b/tensorflow/security/advisory/tfsa-2018-002.md new file mode 100644 index 0000000000000000000000000000000000000000..fad7fdd40f6dcc651ee72e0496f99377ebe24dbc --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2018-002.md @@ -0,0 +1,33 @@ +## TFSA-2018-002: GIF File Parsing Null Pointer Dereference Error + +### CVE Number + +CVE-2018-7576 + +### Issue Description + +When parsing certain invalid GIF files, an internal function in the GIF decoder +returned a null pointer, which was subsequently used as an argument to strcat. + +### Impact + +A maliciously crafted GIF could be used to cause the TensorFlow process to +crash. + +### Vulnerable Versions + +TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1 1.4.1, 1.5.0, 1.5.1 + +### Mitigation + +We have patched the vulnerability in GitHub commit +[c4843158](https://github.com/tensorflow/tensorflow/commit/c48431588e7cf8aff61d4c299231e3e925144df8). +If users are running TensorFlow in production or on untrusted data, they are +encouraged to apply this patch. + +Additionally, this patch has already been integrated into TensorFlow 1.6.0 and +newer. + +### Credits + +This issue was discovered by the Blade Team of Tencent. diff --git a/tensorflow/security/advisory/tfsa-2018-003.md b/tensorflow/security/advisory/tfsa-2018-003.md new file mode 100644 index 0000000000000000000000000000000000000000..747d37064c02db84b92e669512b5ca4e40c431a2 --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2018-003.md @@ -0,0 +1,48 @@ +## TFSA-2018-003: TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability + +### CVE Number + +CVE-2018-8825 + +### Issue Description + +The TensorFlow Lite TOCO compiler does not perform correct boundary checks when +reading from some fields within TFLite files. + +As background, TFLite files are based on the FlatBuffers serialization format, +which does not have bounds checking built-in, rather it relies on the clients to +handle the appropriate security checks by themselves. + +In particular, TOCO is not performing correct bounds checks in the following places: +* Out of bounds read in TOCO in import.cc:42 +* Null dereference in TOCO in import.cc:135 +* Out of bounds read in TOCO in import.cc:104 +* Null dereference in TOCO in import.cc:121 +* Out of bounds read in TOCO in import.cc:62 +* Out of bounds read in TOCO in operator.cc:48 +* Out of bounds read in TOCO graph_transformations (propagate_fixed_sizes.cc:93) + + +### Impact + +Users passing a malformed or malicious version of a TFLite graph into TOCO will +cause TOCO to crash or cause a buffer overflow, potentially allowing malicious +code to be executed. + +### Vulnerable Versions + +TensorFlow 1.5.0, 1.5.1, 1.6.0, 1.7.0 + +### Mitigation + +We have patched the vulnerability in GitHub commits [41335abb](https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476) and +[8badd11d](https://github.com/tensorflow/tensorflow/commit/8badd11d875a826bd318ed439909d5c47a7fb811). +If users are running the TensorFlow TFLite TOCO compiler in production or on +untrusted data, they are encouraged to apply this patch. + +Additionally, we have released TensorFlow version 1.7.1 to mitigate this +vulnerability. + +### Credits + +This issue was discovered by the Blade Team of Tencent. diff --git a/tensorflow/security/advisory/tfsa-2018-004.md b/tensorflow/security/advisory/tfsa-2018-004.md new file mode 100644 index 0000000000000000000000000000000000000000..3af28defa1387fc8ff99c9f07ae2ff2bcda9b268 --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2018-004.md @@ -0,0 +1,35 @@ +## TFSA-2018-004: Checkpoint Meta File Out-of-Bounds Read + +### CVE Number + +CVE-2018-7575 + +### Issue Description + +The block size in meta file might contain a large int64 value which causes +an integer overflow upon addition. Subsequent code using n as index may cause +an out-of-bounds read. + +### Impact + +A maliciously crafted meta checkpoint could be used to cause the TensorFlow +process to perform an out of bounds read on in process memory. + +### Vulnerable Versions + +TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.7.0 + +### Mitigation + +We have patched the vulnerability in GitHub commit +[d107fee1](https://github.com/tensorflow/tensorflow/commit/d107fee1e4a9a4462f01564798d345802acc2aef). +If users are running TensorFlow on untrusted meta checkpoints, such as those +downloaded from the Internet, in production or on untrusted data, they are +encouraged to apply this patch. + +Additionally, we have released TensorFlow version 1.7.1 to mitigate this +vulnerability. + +### Credits + +This issue was discovered by the Blade Team of Tencent. diff --git a/tensorflow/security/advisory/tfsa-2018-005.md b/tensorflow/security/advisory/tfsa-2018-005.md new file mode 100644 index 0000000000000000000000000000000000000000..c0f339fd976f5635fe774141e671a31d27523a0b --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2018-005.md @@ -0,0 +1,36 @@ +## TFSA-2018-005: Old Snappy Library Usage Resulting in Memcpy Parameter Overlap + +### CVE Number + +CVE-2018-7577 + +### Issue Description + +TensorFlow checkpoint meta file uses Google's [https://github.com/google/snappy](snappy) +compression/decompression library. There is a memcpy-param-overlap issue in the +version of snappy currently used by TensorFlow. + +### Impact + +A maliciously crafted checkpoint meta file could cause TensorFlow to crash or +read from other parts of its process memory. + +### Vulnerable Versions + +TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.7.0 + +### Mitigation + +We have patched the vulnerability in GitHub commit +[dfa9921e](https://github.com/tensorflow/tensorflow/commit/dfa9921e6343727b05f42f8d4a918b19528ff994) +by upgrading the version of the snappy library used by TensorFlow to v1.1.7. + +If users are loading untrusted checkpoints in TensorFlow, we encourage users to +apply the patch to upgrade snappy. + +Additionally, we have released TensorFlow version 1.7.1 to mitigate this +vulnerability. + +### Credits + +This issue was discovered by the Blade Team of Tencent. diff --git a/tensorflow/security/advisory/tfsa-2018-006.md b/tensorflow/security/advisory/tfsa-2018-006.md new file mode 100644 index 0000000000000000000000000000000000000000..17f514d8d2b5435d3325cc2e30bb4e48fe3284cf --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2018-006.md @@ -0,0 +1,35 @@ +## TFSA-2018-006: Crafted Configuration File results in Invalid Memory Access + +### CVE Number + +CVE-2018-10055 + +### Issue Description + +A maliciously crafted configuration file passed into the TensorFlow XLA compiler +could cause an invalid memory access and/or a heap buffer overflow. + +### Impact + +A maliciously crafted configuration file could cause TensorFlow to crash or +read from other parts of its process memory. + +### Vulnerable Versions + +TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.7.0 + +### Mitigation + +We have patched the vulnerability in GitHub commit +[c89ab82a](https://github.com/tensorflow/tensorflow/commit/c89ab82a82585cdaa90bf4911980e9e845909e78). + +If users are loading untrusted configurations in TensorFlow, we encourage users +to apply the patch to upgrade snappy or upgrade the version of TensorFlow they +are currently using. + +Additionally, we have released TensorFlow version 1.7.1 to mitigate this +vulnerability. + +### Credits + +This issue was discovered by the Blade Team of Tencent. diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md new file mode 100644 index 0000000000000000000000000000000000000000..ea39e17ab2bb417bba1ebe4a589833309fc2c626 --- /dev/null +++ b/tensorflow/security/index.md @@ -0,0 +1,18 @@ +# TensorFlow Security Advisories + +We regularly publish security advisories about using TensorFlow. + +*Note*: In conjunction with these security advisories, we strongly encourage +TensorFlow users to read and understand TensorFlow's security model as outlined +in (https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)[SECURITY.md]. + +| Advisory Number | Type | Versions affected | Reported by | Additional Information | +|-----------------|--------------------|:-----------------:|-----------------------|-----------------------------| +| [TFSA-2018-006](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-006.md) | Crafted Configuration File results in Invalid Memory Access | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-005](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-005.md) | Old Snappy Library Usage Resulting in Memcpy Parameter Overlap | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-004](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-004.md) | Checkpoint Meta File Out-of-Bounds Read | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-003](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-003.md) | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-002](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-002.md) | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | | +| [TFSA-2018-001](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-001.md) | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | | +| - | Out Of Bounds Read | <= 1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | + diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index c68cda01002b1c5bbc2facb95b1eba214fbad7cb..e742f8e8d51d0217b631ebdc23ee65263c1ce0f0 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -2,6 +2,7 @@ licenses(["restricted"]) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") STREAM_EXECUTOR_HEADERS = glob([ "*.h", @@ -33,7 +34,6 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@local_config_cuda//cuda:cuda_headers", @@ -48,11 +48,18 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/compiler/xla:statusor", "@local_config_cuda//cuda:cuda_headers", ] + if_static([":stream_executor_impl"]), ) +cc_header_only_library( + name = "stream_executor_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":stream_executor", + ], +) + cc_library( name = "cuda_platform", srcs = if_cuda_is_configured( diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index b8ec4248442b9fb7b54dbc00b314746034d9eaed..874bf0e8cb481bf9e506e6d9b71c19afbe89d644 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -16,11 +16,7 @@ limitations under the License. #include "cuda/include/cublas_v2.h" #include "cuda/include/cuda.h" -#if CUDA_VERSION >= 8000 #define SE_CUDA_DATA_HALF CUDA_R_16F -#else -#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF -#endif #include "tensorflow/stream_executor/cuda/cuda_blas.h" @@ -45,10 +41,8 @@ limitations under the License. // approach when the issue is fixed. #if CUDA_VERSION < 9000 #include "cuda/include/cuda_fp16.h" -#if CUDA_VERSION >= 7050 #define EIGEN_HAS_CUDA_FP16 #endif -#endif #include "third_party/eigen3/Eigen/Core" @@ -547,9 +541,7 @@ cublasSideMode_t CUDABlasSide(blas::Side side) { // blas::ComputationType to a cudaDataType_t. // // These are used to build the argument type and computation type args to -// cublasGemmEx. cublasGemmEx and cudaDataType_t are available only on -// CUDA >= 8.0. -#if CUDA_VERSION >= 8000 +// cublasGemmEx. template struct CUDADataType; @@ -624,8 +616,6 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) { return CUDA_C_64F; } } -#endif - } // namespace template @@ -2165,10 +2155,7 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( const HostOrDeviceScalar &beta, DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { -// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx. -#if CUDA_VERSION < 8000 - return false; -#else + // GPUs < sm_50 don't support cublasGemmEx. int cc_major, cc_minor; if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && @@ -2194,6 +2181,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( } } + // Return false if we might be hitting a cuBLAS bug that produces the wrong + // result. See nvbugs/2156201, b/79126339. +#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 + if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) && + std::max({m, n, k}) >= 2097153 && cc_major < 7) { + return false; + } +#endif + cudaDataType_t cuda_in_type = CUDADataType::type; // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, // we do the following compile-time check on the default value: @@ -2223,7 +2219,6 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( timer->GetElapsedMilliseconds()); } return result; -#endif } bool CUDABlas::GetBlasGemmAlgorithms( @@ -2233,7 +2228,6 @@ bool CUDABlas::GetBlasGemmAlgorithms( // Note that when CUDA version and compute capability is not sufficient, we // still return the out_algorithms. Caller needs to make sure that in this case, // the returned vector is empty. -#if CUDA_VERSION >= 8000 for (cublasGemmAlgo_t algo : { CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, @@ -2249,7 +2243,6 @@ bool CUDABlas::GetBlasGemmAlgorithms( }) { out_algorithms->push_back(algo); } -#endif return true; } diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index 46e5deed8474dfa0c0ce6402bd6e5e2675491b31..124d5905b91cbf839437e763728cc76ad0d671dc 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -124,15 +124,20 @@ void Diagnostician::LogDiagnosticInformation() { #ifdef __APPLE__ CFStringRef kext_ids[1]; kext_ids[0] = kDriverKextIdentifier; - CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks); - CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); + CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1, + &kCFTypeArrayCallBacks); + CFDictionaryRef kext_infos = + KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); CFRelease(kext_id_query); CFDictionaryRef cuda_driver_info = nullptr; - if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) { - bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue(cuda_driver_info, CFSTR("OSBundleStarted"))); + if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, + (const void **)&cuda_driver_info)) { + bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue( + cuda_driver_info, CFSTR("OSBundleStarted"))); if (!started) { - LOG(INFO) << "kernel driver is installed, but does not appear to be running on this host " + LOG(INFO) << "kernel driver is installed, but does not appear to be " + "running on this host " << "(" << port::Hostname() << ")"; } } else { @@ -210,27 +215,27 @@ port::StatusOr Diagnostician::FindDsoVersion() { "was unable to find libcuda.so DSO loaded into this program")); #if defined(__APPLE__) - // OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib - const string prefix("libcuda_"); - const string suffix("_mercury.dylib"); - for (uint32_t image_index = 0; image_index < _dyld_image_count(); ++image_index) { - const string path(_dyld_get_image_name(image_index)); - const size_t suffix_pos = path.rfind(suffix); - const size_t prefix_pos = path.rfind(prefix, suffix_pos); - if (prefix_pos == string::npos || - suffix_pos == string::npos) { - // no match - continue; - } - const size_t start = prefix_pos + prefix.size(); - if (start >= suffix_pos) { - // version not included - continue; - } - const size_t length = suffix_pos - start; - const string version = path.substr(start, length); - result = StringToDriverVersion(version); + // OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib + const string prefix("libcuda_"); + const string suffix("_mercury.dylib"); + for (uint32_t image_index = 0; image_index < _dyld_image_count(); + ++image_index) { + const string path(_dyld_get_image_name(image_index)); + const size_t suffix_pos = path.rfind(suffix); + const size_t prefix_pos = path.rfind(prefix, suffix_pos); + if (prefix_pos == string::npos || suffix_pos == string::npos) { + // no match + continue; + } + const size_t start = prefix_pos + prefix.size(); + if (start >= suffix_pos) { + // version not included + continue; } + const size_t length = suffix_pos - start; + const string version = path.substr(start, length); + result = StringToDriverVersion(version); + } #else #if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA) // Callback used when iterating through DSOs. Looks for the driver-interfacing @@ -313,12 +318,15 @@ port::StatusOr Diagnostician::FindKernelDriverVersion() { #if defined(__APPLE__) CFStringRef kext_ids[1]; kext_ids[0] = kDriverKextIdentifier; - CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks); - CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); + CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1, + &kCFTypeArrayCallBacks); + CFDictionaryRef kext_infos = + KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); CFRelease(kext_id_query); CFDictionaryRef cuda_driver_info = nullptr; - if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) { + if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, + (const void **)&cuda_driver_info)) { // NOTE: OSX CUDA driver does not currently store the same driver version // in kCFBundleVersionKey as is returned by cuDriverGetVersion CFRelease(kext_infos); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 7ace7fd3031e935dd11cfff170f64a47bda0c8d1..84916385a89b6e2bafb8a3c0a8f435ec9626e816 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" @@ -53,6 +54,35 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin); namespace { +static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher"); + +// Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS. +#define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS) + +// If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current +// function with a non-successful port::Status. +#define RETURN_IF_CUDNN_ERROR(expr) \ + do { \ + cudnnStatus_t _status = expr; \ + if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \ + std::ostringstream oss; \ + oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \ + << "): '" << #expr << "'"; \ + return port::Status(port::error::UNKNOWN, oss.str().c_str()); \ + } \ + } while (false) + +// Returns whether status is 'ok', and potentially logs the error. +bool IsStatusOk(const port::Status& status, bool report_error) { + if (status.ok()) { + return true; + } + if (report_error) { + LOG(ERROR) << status.error_message(); + } + return false; +} + // Converts (via narrowing) a type T value to a type U, and checks that the // value has no value change due to the conversion. template @@ -87,28 +117,20 @@ string ToString(cudnnStatus_t status) { return "CUDNN_STATUS_NOT_SUPPORTED"; case CUDNN_STATUS_LICENSE_ERROR: return "CUDNN_STATUS_LICENSE_ERROR"; + case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING: + return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING"; +#if CUDNN_VERSION >= 7000 + case CUDNN_STATUS_RUNTIME_IN_PROGRESS: + return "CUDNN_STATUS_RUNTIME_IN_PROGRESS"; + case CUDNN_STATUS_RUNTIME_FP_OVERFLOW: + return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW"; +#endif default: return port::StrCat("(status), ">"); } } -#if CUDNN_VERSION >= 6000 -string ToString(libraryPropertyType type) { - switch (type) { - case MAJOR_VERSION: - return "MAJOR_VERSION"; - case MINOR_VERSION: - return "MINOR_VERSION"; - case PATCH_LEVEL: - return "PATCH_LEVEL"; - default: - return port::StrCat( - "(type), ">"); - } -} -#endif - template cudnnDataType_t GetCudnnDataType(); @@ -150,9 +172,9 @@ class CudnnHandle { } // namespace -// Wraps a cuDNN handle and provides access to it through CudnnHandle instances, -// which also locks a mutex, acquires the CUDA context, and sets the stream -// that cuDNN should use to enqueue any work. +// Wraps a cuDNN handle and provides access to it through CudnnHandle +// instances, which also locks a mutex, acquires the CUDA context, and sets +// the stream that cuDNN should use to enqueue any work. // // Note: CudnnSupport::cudnn_ should be the only instantiation of this class. class CudnnAccess { @@ -167,13 +189,13 @@ class CudnnAccess { // Creates a CudnnHandle instance for stream. // - // cuDNN API calls using the same handle instance need to be serialized across - // threads. This is guaranteed by CudnnHandle instances locking the mutex - // owned by this class. + // cuDNN API calls using the same handle instance need to be serialized + // across threads. This is guaranteed by CudnnHandle instances locking the + // mutex owned by this class. // // Most cuDNN APIs taking a handle perform work on a CUDA stream. The - // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN to - // use the provided stream. + // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN + // to use the provided stream. // // The stream argument may be null, which translates to the legacy default // stream. See @@ -187,7 +209,6 @@ class CudnnAccess { CUstream cu_stream = stream ? AsCUDAStreamValue(stream) : cudaStreamLegacy; auto status = cudnnSetStream(handle_, cu_stream); CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream."; - using my_mutex_lock = mutex_lock; return CudnnHandle(std::move(context), std::move(lock), handle_); } @@ -201,6 +222,8 @@ class CudnnAccess { namespace { +// A helper function to return the internal compute type for +// RNNs in cudnn. cudnnDataType_t GetRnnComputeType(dnn::DataType data_type); cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) { @@ -213,12 +236,8 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) { case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT: case CUDNN_CONVOLUTION_FWD_ALGO_FFT: case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING: -#if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD: -#endif -#if CUDNN_VERSION >= 5100 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED: -#endif return algo; default: LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: " @@ -235,12 +254,8 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo( case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1: case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT: case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING: -#if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD: -#endif -#if CUDNN_VERSION >= 5100 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED: -#endif return algo; default: LOG(FATAL) @@ -258,12 +273,13 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3: -#if CUDNN_VERSION >= 5100 // Based on cudnn.h, the following is not implemented. // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED: -#endif return algo; + // Produces incorrect results for some shapes. Disabled for now, see + // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes. + // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING: default: LOG(FATAL) << "Unsupported Cudnn convolution backward algorithm for filter: " @@ -271,17 +287,10 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( } } -#if CUDNN_VERSION >= 6000 -port::Status GetCudnnProperty(libraryPropertyType type, int* value) { - cudnnStatus_t status = cudnnGetProperty(type, value); - if (status != CUDNN_STATUS_SUCCESS) { - const string error = - port::StrCat("cudnnGetProperty failed for type: ", ToString(type), - " with status: ", ToString(status)); - LOG(ERROR) << error; - return port::Status(port::error::INTERNAL, error); - } - return port::Status::OK(); +port::StatusOr GetCudnnProperty(libraryPropertyType type) { + int value; + RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value)); + return value; } cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) { @@ -300,19 +309,11 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) { } } } -#endif port::Status GetLoadedCudnnVersion(CudnnVersion* version) { -#if CUDNN_VERSION >= 6000 - TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version)); - TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version)); - TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level)); -#else - size_t loaded_version = ::cudnnGetVersion(); - version->major_version = loaded_version / 1000; - version->minor_version = (loaded_version / 100) % 10; - version->patch_level = loaded_version % 100; -#endif + SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION)); + SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION)); + SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL)); return port::Status::OK(); } @@ -335,9 +336,11 @@ port::Status CudnnSupport::Init() { ". CuDNN library major and minor version needs to match or have " "higher minor version in case of CuDNN 7.0 or later version. If " "using a binary install, upgrade your CuDNN library. If building " - "from sources, make sure the library loaded at runtime is compatible " + "from sources, make sure the library loaded at runtime is " + "compatible " "with the version specified during compile configuration."); LOG(ERROR) << error; + cudnnDestroy(cudnn_handle); return port::Status(port::error::INTERNAL, error); } @@ -345,23 +348,17 @@ port::Status CudnnSupport::Init() { return port::Status::OK(); } - LOG(ERROR) << "could not create cudnn handle: " << ToString(status); + CHECK_EQ(cudnn_handle, nullptr); + LOG(ERROR) << "Could not create cudnn handle: " << ToString(status); if (status == CUDNN_STATUS_NOT_INITIALIZED) { auto result = cuda::Diagnostician::FindKernelDriverVersion(); if (!result.ok()) { - LOG(ERROR) << "error retrieving driver version: " + LOG(ERROR) << "Error retrieving driver version: " << DriverVersionStatusToString(result); } else { const auto& version = result.ValueOrDie(); - LOG(ERROR) << "possibly insufficient driver version: " + LOG(ERROR) << "Possibly insufficient driver version: " << DriverVersionToString(version); - // OS X kernel driver does not report version accurately -#if !defined(__APPLE__) - if (std::get<0>(version) < 340) { - LOG(ERROR) - << "cudnn library is only supported on 340.XX+ driver versions"; - } -#endif } } @@ -380,18 +377,129 @@ CudnnSupport::GetVersion() { namespace { -// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope. -class ScopedTensorDescriptor { - public: - ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor, - cudnnDataType_t elem_type) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn tensor descriptor: " - << ToString(status); - } +// Deleter functors for cuDNN types that need to be deleted. +struct TensorDescriptorDeleter { + void operator()(cudnnTensorDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor)); + } +}; +struct FilterDescriptorDeleter { + void operator()(cudnnFilterDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor)); + } +}; +struct ConvolutionDescriptorDeleter { + void operator()(cudnnConvolutionDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor)); + } +}; +struct PoolingDescriptorDeleter { + void operator()(cudnnPoolingDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor)); + } +}; +struct LrnDescriptorDeleter { + void operator()(cudnnLRNDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor)); + } +}; + +struct ActivationDescriptorDeleter { + void operator()(cudnnActivationDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor)); + } +}; +struct DropoutDescriptorDeleter { + void operator()(cudnnDropoutDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor)); + } +}; +struct RnnDescriptorDeleter { + void operator()(cudnnRNNDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor)); + } +}; +struct PersistentRnnPlanDeleter { + void operator()(cudnnPersistentRNNPlan_t plan) const { + CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan)); + } +}; +// RAII wrappers for cuDNN types. +using TensorDescriptor = + std::unique_ptr; +using FilterDescriptor = + std::unique_ptr; +using ConvolutionDescriptor = + std::unique_ptr; +using PoolingDescriptor = + std::unique_ptr; +using LrnDescriptor = std::unique_ptr; +using ActivationDescriptor = + std::unique_ptr; +using DropoutDescriptor = + std::unique_ptr; +using RnnDescriptor = std::unique_ptr; +using PersistentRnnPlan = + std::unique_ptr; + +// Factory methods for cuDNN types. +TensorDescriptor CreateTensorDescriptor() { + cudnnTensorDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result)); + return TensorDescriptor(result); +} +FilterDescriptor CreateFilterDescriptor() { + cudnnFilterDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result)); + return FilterDescriptor(result); +} +ConvolutionDescriptor CreateConvolutionDescriptor() { + cudnnConvolutionDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result)); + return ConvolutionDescriptor(result); +} +PoolingDescriptor CreatePoolingDescriptor() { + cudnnPoolingDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result)); + return PoolingDescriptor(result); +} +LrnDescriptor CreateLrnDescriptor() { + cudnnLRNDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result)); + return LrnDescriptor(result); +} +ActivationDescriptor CreateActivationDescriptor() { + cudnnActivationDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result)); + return ActivationDescriptor(result); +} +DropoutDescriptor CreateDropoutDescriptor() { + cudnnDropoutDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result)); + return DropoutDescriptor(result); +} +RnnDescriptor CreateRnnDescriptor() { + cudnnRNNDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result)); + return RnnDescriptor(result); +} +PersistentRnnPlan CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc, + int batch_size, + cudnnDataType_t data_type) { + cudnnPersistentRNNPlan_t result; + CHECK_CUDNN_OK( + cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result)); + return PersistentRnnPlan(result); +} + +// Turns a BatchDescriptor structure into a cudnn tensor handle within a +// scope. +class CudnnTensorDescriptor { + public: + CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor, + cudnnDataType_t elem_type) + : handle_(CreateTensorDescriptor()) { switch (batch_descriptor.layout()) { case dnn::DataLayout::kBatchYXDepth: case dnn::DataLayout::kBatchDepthYX: { @@ -409,28 +517,17 @@ class ScopedTensorDescriptor { &CheckedNarrowing); std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), &CheckedNarrowing); - status = cudnnSetTensorNdDescriptor(handle_, elem_type, nd, dims.data(), - strides.data()); - - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not convert BatchDescriptor " - << batch_descriptor.ToString() - << " to cudnn tensor descriptor: " << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd, + dims.data(), strides.data())) + << "batch_descriptor: " << batch_descriptor.ToString(); } break; -#if CUDNN_VERSION >= 6000 case dnn::DataLayout::kBatchDepthYX4: { - status = cudnnSetTensor4dDescriptor( - handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type, + CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor( + handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type, batch_descriptor.count(), batch_descriptor.feature_map_count(), - batch_descriptor.height(), batch_descriptor.width()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not convert BatchDescriptor " - << batch_descriptor.ToString() - << " to cudnn tensor descriptor: " << ToString(status); - } + batch_descriptor.height(), batch_descriptor.width())) + << "batch_descriptor: " << batch_descriptor.ToString(); } break; -#endif default: LOG(FATAL) << "Unsupported tensor format " << DataLayoutString(batch_descriptor.layout()); @@ -438,55 +535,41 @@ class ScopedTensorDescriptor { } } - ~ScopedTensorDescriptor() { - cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn tensor descriptor: " - << ToString(status); - } - } - - cudnnTensorDescriptor_t handle() const { return handle_; } + cudnnTensorDescriptor_t handle() const { return handle_.get(); } private: - cudnnTensorDescriptor_t handle_; // Owned. + TensorDescriptor handle_; - SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor); }; -// Turns a FilterDescriptor structure into a cudnn filter handle within a scope. -class ScopedFilterDescriptor { +// Turns a FilterDescriptor structure into a cudnn filter handle within a +// scope. +class CudnnFilterDescriptor { public: - ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor, - cudnnDataType_t elem_type) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn filter descriptor: " - << ToString(status); - } - -#if CUDNN_VERSION >= 5000 + CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor, + cudnnDataType_t elem_type) + : handle_(CreateFilterDescriptor()) { // TODO(b/23032134): Even if the filter layout is not supported, - // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it - // does not take layout as an input. Maybe force cuDNN by giving wrong + // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because + // it does not take layout as an input. Maybe force cuDNN by giving wrong // inputs intentionally? cudnnTensorFormat_t format; switch (filter_descriptor.layout()) { case dnn::FilterLayout::kOutputInputYX: format = CUDNN_TENSOR_NCHW; break; -#if CUDNN_VERSION >= 6000 + case dnn::FilterLayout::kOutputYXInput: + format = CUDNN_TENSOR_NHWC; + break; case dnn::FilterLayout::kOutputInputYX4: format = CUDNN_TENSOR_NCHW_VECT_C; break; -#endif default: LOG(FATAL) << "Unsupported filter format " << FilterLayoutString(filter_descriptor.layout()); break; } -#endif std::vector dims(2 + filter_descriptor.ndims()); dims[0] = filter_descriptor.output_feature_map_count(); @@ -494,35 +577,20 @@ class ScopedFilterDescriptor { const auto& spatial_dims = filter_descriptor.input_filter_dims(); std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2); - status = cudnnSetFilterNdDescriptor(handle_, elem_type, -#if CUDNN_VERSION >= 5000 - format, -#endif - dims.size(), dims.data()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn filter descriptor: " - << ToString(status); - } - } - - ~ScopedFilterDescriptor() { - cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn filter descriptor: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format, + dims.size(), dims.data())); } - cudnnFilterDescriptor_t handle() const { return handle_; } + cudnnFilterDescriptor_t handle() const { return handle_.get(); } private: - cudnnFilterDescriptor_t handle_; // Owned. + FilterDescriptor handle_; // Owned. - SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor); }; // A helper function to decide whether to enable the TENSOR_OP_MATH math type -static bool TensorOpMathEnabled() { +bool TensorOpMathEnabled() { static bool is_enabled = [] { bool is_disabled = false; TF_CHECK_OK( @@ -535,7 +603,7 @@ static bool TensorOpMathEnabled() { // A helper function to decide whether to enable the TENSOR_OP_MATH math type // for RNNs. -static bool RnnTensorOpMathEnabled() { +bool RnnTensorOpMathEnabled() { static bool is_enabled = [] { bool is_disabled = false; TF_CHECK_OK( @@ -546,15 +614,16 @@ static bool RnnTensorOpMathEnabled() { return is_enabled; } -// A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT -// in batchnorm. This mode can be faster in some tasks because an optimized path -// may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute -// capability 6.0 or higher. The reason we set it to false by default is that -// this mode may use scaled atomic integer reduction that may cause a numerical -// overflow for certain input data range. +// A helper function to decide whether to use +// CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in +// some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT +// and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The +// reason we set it to false by default is that this mode may use scaled +// atomic integer reduction that may cause a numerical overflow for certain +// input data range. // TODO(yangzihao): Use autotune to choose between this mode and // CUDNN_BATCHNORM_SPATIAL mode. -static bool BatchnormSpatialPersistentEnabled() { +bool BatchnormSpatialPersistentEnabled() { static bool is_enabled = [] { bool is_enabled = false; TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar( @@ -567,24 +636,18 @@ static bool BatchnormSpatialPersistentEnabled() { // Turns a ConvolutionDescriptor structure into a cudnn convolution handle // within a scope. -class ScopedConvolutionDescriptor { +class CudnnConvolutionDescriptor { public: - ScopedConvolutionDescriptor( + CudnnConvolutionDescriptor( const dnn::ConvolutionDescriptor& convolution_descriptor, cudnnDataType_t data_type) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn convolution descriptor: " - << ToString(status); - } + : handle_(CreateConvolutionDescriptor()) { const auto& strides64 = convolution_descriptor.strides(); const auto& padding64 = convolution_descriptor.padding(); const auto& dilations64 = convolution_descriptor.dilations(); - if (convolution_descriptor.pad_alignment() == - dnn::PadAlignment::kTensorFlowPadding) { - LOG(ERROR) << "TensorFlow padding alignment is not supported."; - } + CHECK_NE(convolution_descriptor.pad_alignment(), + dnn::PadAlignment::kTensorFlowPadding) + << "TensorFlow padding alignment is not supported."; // cuDNN requires arrays of ints. std::vector strides(convolution_descriptor.ndims()); @@ -599,18 +662,14 @@ class ScopedConvolutionDescriptor { std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(), &CheckedNarrowing); - status = cudnnSetConvolutionNdDescriptor( - handle_, convolution_descriptor.ndims(), padding.data(), strides.data(), - dilations.data(), + CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor( + handle_.get(), convolution_descriptor.ndims(), padding.data(), + strides.data(), dilations.data(), // NOTE(keveman): cuDNN supports convolution and cross correlation. // However, almost all the use cases do cross correlation, so just // hard coding it here. - CUDNN_CROSS_CORRELATION, data_type); + CUDNN_CROSS_CORRELATION, data_type)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn convolution descriptor: " - << ToString(status); - } // NOTE(benbarsdell): This only applies if tensor op math is enabled // and algo selection is set to Default. this->set_use_tensor_op_math(true); @@ -618,60 +677,39 @@ class ScopedConvolutionDescriptor { #if CUDNN_MAJOR >= 7 VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); - status = cudnnSetConvolutionGroupCount( - handle_, convolution_descriptor.group_count()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn convolution group count: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount( + handle_.get(), convolution_descriptor.group_count())); #else CHECK_EQ(convolution_descriptor.group_count(), 1) << "Requested grouped convolution for cuDNN version < 7"; #endif } - void set_use_tensor_op_math(bool use_tensor_op_math) { + void set_use_tensor_op_math(bool use_tensor_op_math) const { #if CUDNN_VERSION >= 7000 cudnnMathType_t math_type = (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); if (TensorOpMathEnabled()) { - cudnnStatus_t status = cudnnSetConvolutionMathType(handle_, math_type); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn convolution math type: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type)); } #endif } - ~ScopedConvolutionDescriptor() { - cudnnStatus_t status = cudnnDestroyConvolutionDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn convolution descriptor: " - << ToString(status); - } - } - - cudnnConvolutionDescriptor_t handle() const { return handle_; } + cudnnConvolutionDescriptor_t handle() const { return handle_.get(); } private: - cudnnConvolutionDescriptor_t handle_; // Owned. + ConvolutionDescriptor handle_; // Owned. - SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor); }; // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle // within a scope. -class ScopedPoolingDescriptor { +class CudnnPoolingDescriptor { public: - explicit ScopedPoolingDescriptor( + explicit CudnnPoolingDescriptor( const dnn::PoolingDescriptor& pooling_descriptor) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn pooling descriptor: " - << ToString(status); - } + : handle_(CreatePoolingDescriptor()) { const std::vector strides64 = pooling_descriptor.strides(); const std::vector padding64 = pooling_descriptor.padding(); const std::vector shape64 = pooling_descriptor.window(); @@ -687,48 +725,29 @@ class ScopedPoolingDescriptor { std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing); bool propagate_nans = pooling_descriptor.propagate_nans(); - status = cudnnSetPoolingNdDescriptor( - handle_, + CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor( + handle_.get(), (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), -#if CUDNN_VERSION >= 5000 - propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, -#endif - nd, shape.data(), padding.data(), strides.data()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn pooling descriptor: " - << ToString(status); - } - } - ~ScopedPoolingDescriptor() { - cudnnStatus_t status = cudnnDestroyPoolingDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn pooling descriptor: " - << ToString(status); - } + propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd, + shape.data(), padding.data(), strides.data())); } - cudnnPoolingDescriptor_t handle() const { return handle_; } + cudnnPoolingDescriptor_t handle() const { return handle_.get(); } private: - cudnnPoolingDescriptor_t handle_; // Owned. + PoolingDescriptor handle_; // Owned. - SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor); }; // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle. -class ScopedNormalizeDescriptor { +class CudnnNormalizeDescriptor { public: - explicit ScopedNormalizeDescriptor( + explicit CudnnNormalizeDescriptor( const dnn::NormalizeDescriptor& normalize_descriptor) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn LRN descriptor: " - << ToString(status); - } - + : handle_(CreateLrnDescriptor()) { // The range specifies that the indices in the closed range // [i - range, i + range] should be included in the normalization for index // i. The lrnN value is the total number of elements in the range, so @@ -749,43 +768,26 @@ class ScopedNormalizeDescriptor { double lrnBeta = normalize_descriptor.beta(); double lrnK = normalize_descriptor.bias(); - status = cudnnSetLRNDescriptor(handle_, lrnN, lrnAlpha, lrnBeta, lrnK); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status); - } + CHECK_CUDNN_OK( + cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK)); } - ~ScopedNormalizeDescriptor() { - cudnnStatus_t status = cudnnDestroyLRNDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn LRN descriptor: " - << ToString(status); - } - } - - cudnnLRNDescriptor_t handle() const { return handle_; } + cudnnLRNDescriptor_t handle() const { return handle_.get(); } private: - cudnnLRNDescriptor_t handle_; // Owned. + LrnDescriptor handle_; // Owned. - SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor); }; -#if CUDNN_VERSION >= 5000 // Turns a ActivationDescriptor structure into a cudnn activation // descriptor handle within a scope. -class ScopedActivationDescriptor { +class CudnnActivationDescriptor { public: - ScopedActivationDescriptor(dnn::ActivationMode activation_mode, - cudnnNanPropagation_t nan_propagation, - double value_max) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateActivationDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn activation descriptor: " - << ToString(status); - } - + CudnnActivationDescriptor(dnn::ActivationMode activation_mode, + cudnnNanPropagation_t nan_propagation, + double value_max) + : handle_(CreateActivationDescriptor()) { double relu_ceiling = 0.0; cudnnActivationMode_t mode; switch (activation_mode) { @@ -811,30 +813,17 @@ class ScopedActivationDescriptor { << static_cast(activation_mode); } - status = cudnnSetActivationDescriptor(handle_, mode, nan_propagation, - relu_ceiling); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn activation descriptor: " - << ToString(status); - } - } - - ~ScopedActivationDescriptor() { - cudnnStatus_t status = cudnnDestroyActivationDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn activation descriptor: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode, + nan_propagation, relu_ceiling)); } - cudnnActivationDescriptor_t handle() const { return handle_; } + cudnnActivationDescriptor_t handle() const { return handle_.get(); } private: - cudnnActivationDescriptor_t handle_; // Owned. + ActivationDescriptor handle_; // Owned. - SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor); }; -#endif cudnnDataType_t ToCudnnDataType( dnn::DataType data_type, @@ -844,18 +833,14 @@ cudnnDataType_t ToCudnnDataType( case dnn::DataType::kDouble: case dnn::DataType::kHalf: return static_cast(data_type); -#if CUDNN_VERSION >= 6000 case dnn::DataType::kInt8: return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4 : CUDNN_DATA_INT8; -#endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); } } -#if CUDNN_VERSION >= 5000 - cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) { switch (input_mode) { case dnn::RnnInputMode::kRnnLinearSkip: @@ -903,121 +888,74 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) { } } -#endif // CUDNN_VERSION +class CudnnDropoutDescriptor { + explicit CudnnDropoutDescriptor(DropoutDescriptor handle) + : handle_(std::move(handle)) {} -template -class MixinBase : public Base {}; -template <> -class MixinBase {}; - -#if CUDNN_VERSION >= 5000 - -#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \ - if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \ - string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \ - SetFailure(port::Status(port::error::UNKNOWN, error_msg)); \ - LOG(ERROR) << error_msg; \ - return; \ - } - -// TODO(csigg): Remove inheritance for code reuse. -template -class CudnnDescriptorCommon : public MixinBase { public: - bool ok() const { return status_.ok(); } - port::Status Status() const { return status_; } + CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default; - protected: - void SetFailure(const port::Status& status) { status_.Update(status); } - port::Status status_; -}; + static port::StatusOr Create( + const CudnnHandle& cudnn, float dropout, uint64 seed, + ScratchAllocator* state_allocator) { + DropoutDescriptor handle = CreateDropoutDescriptor(); -class CudnnDropoutDescriptor : public CudnnDescriptorCommon { - public: - CudnnDropoutDescriptor(const CudnnHandle& cudnn, float dropout, uint64 seed, - ScratchAllocator* state_allocator) - : handle_(nullptr) { - cudnnStatus_t status; - status = cudnnCreateDropoutDescriptor(&handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor"); - - if (dropout == 0.f) { - return; + if (dropout == 0.0f) { + // Return 'empty' dropout descriptor. + return CudnnDropoutDescriptor(std::move(handle)); } DeviceMemory state_memory; if (state_allocator) { size_t state_sizes_in_bytes = 0; - status = cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes); - CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes"); - - auto allocated = - state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes); - if (!allocated.ok() || - (state_memory = allocated.ValueOrDie()) == nullptr) { - string error_msg = - port::StrCat("Failed to allocate Cudnn dropout state memory of ", - state_sizes_in_bytes, " bytes."); - status_ = port::Status(port::error::UNKNOWN, error_msg); - LOG(ERROR) << error_msg; - return; - } + RETURN_IF_CUDNN_ERROR( + cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes)); + SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes( + nullptr, state_sizes_in_bytes)); } - status = cudnnSetDropoutDescriptor(handle_, cudnn.handle(), dropout, - state_memory.opaque(), - state_memory.size(), seed); - CUDNN_RETURN_IF_FAIL( - status, port::StrCat( - "Failed to set dropout descriptor with state memory size: ", - state_memory.size(), " bytes.")); - } + RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor( + handle.get(), cudnn.handle(), dropout, state_memory.opaque(), + state_memory.size(), seed)); - ~CudnnDropoutDescriptor() { - cudnnStatus_t status = cudnnDestroyDropoutDescriptor(handle_); - // TODO(csigg): This is a no-op (error is not reported). Same below. - CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: "); + return CudnnDropoutDescriptor(std::move(handle)); } - cudnnDropoutDescriptor_t handle() const { - if (!ok()) return nullptr; - return handle_; - } + cudnnDropoutDescriptor_t handle() const { return handle_.get(); } private: - cudnnDropoutDescriptor_t handle_; // Owned. - float dropout_; - uint64 seed_; + DropoutDescriptor handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor); }; -class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon { - public: - typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion; +class CudnnRnnParamsDescriptor { typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions; - CudnnRnnParamsDescriptor(const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc); - ~CudnnRnnParamsDescriptor() { - cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor"); - } - cudnnFilterDescriptor_t handle() const { - if (!ok()) return nullptr; - return handle_; - } + + CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes, + ParamsRegions weights, ParamsRegions biases) + : handle_(std::move(handle)), + params_size_in_bytes_(params_size_in_bytes), + weights_(std::move(weights)), + biases_(std::move(biases)) {} + + public: + CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default; + + static port::StatusOr Create( + const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, + cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, + cudnnDirectionMode_t direction_mode, int num_layers); + + cudnnFilterDescriptor_t handle() const { return handle_.get(); } int64 params_size_in_bytes() const { return params_size_in_bytes_; } ParamsRegions params_weights() const { - if (!ok()) return ParamsRegions(); return weights_; } ParamsRegions params_biases() const { - if (!ok()) return ParamsRegions(); return biases_; } private: - int GetRegionCountPerLayer() const; - cudnnFilterDescriptor_t handle_; - const CudnnRnnDescriptor* rnn_desc_; + FilterDescriptor handle_; int64 params_size_in_bytes_; ParamsRegions weights_; ParamsRegions biases_; @@ -1026,112 +964,98 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon { } // namespace -class CudnnRnnDescriptor : public CudnnDescriptorCommon { - public: - CudnnRnnDescriptor(const CudnnHandle& cudnn, int num_layers, int hidden_size, - int input_size, int batch_size, +class CudnnRnnDescriptor : public dnn::RnnDescriptor { + CudnnRnnDescriptor(const CudnnHandle& cudnn, cuda::RnnDescriptor rnn_desc, + PersistentRnnPlan rnn_plan, int num_layers, + int hidden_size, int input_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64 seed, - ScratchAllocator* state_allocator) - : rnn_desc_(nullptr), + CudnnDropoutDescriptor dropout_desc, + CudnnRnnParamsDescriptor params_desc) + : rnn_desc_(std::move(rnn_desc)), + rnn_plan_(std::move(rnn_plan)), num_layers_(num_layers), hidden_size_(hidden_size), input_size_(input_size), batch_size_(batch_size), -#if CUDNN_VERSION >= 6000 - rnn_plan_(nullptr), -#endif + rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())), input_mode_(input_mode), direction_mode_(direction_mode), rnn_mode_(rnn_mode), data_type_(data_type), compute_type_(compute_type), - algorithm_config_(algorithm_config) { - // Create the dropout handle. - cudnn_dropout_desc_.reset( - new CudnnDropoutDescriptor(cudnn, dropout, seed, state_allocator)); - if (!cudnn_dropout_desc_->ok()) { - SetFailure(cudnn_dropout_desc_->Status()); - return; - } + algorithm_config_(algorithm_config), + dropout_desc_(std::move(dropout_desc)), + params_desc_(std::move(params_desc)) {} + + public: + CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default; + + static port::StatusOr Create( + const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size, + int batch_size, cudnnRNNInputMode_t input_mode, + cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, + cudnnDataType_t data_type, cudnnDataType_t compute_type, + const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, + ScratchAllocator* state_allocator) { + SE_ASSIGN_OR_RETURN( + CudnnDropoutDescriptor dropout_desc, + CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator)); + + cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor(); + cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm()); - // Create the RNN handle - cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_); - CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor"); -#if CUDNN_VERSION >= 6000 // TODO: allow the user to choose an algorithm. - rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm()); - status = cudnnSetRNNDescriptor_v6( - cudnn.handle(), /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, - /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(), + RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( + cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size, + /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, /*direction=*/direction_mode, - /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type); - CUDNN_RETURN_IF_FAIL(status, ::tensorflow::strings::Printf( - "Unable to update RNN descriptor with " - "algo_id: %d and compute_type: %d", - static_cast(rnn_algo_), - static_cast(compute_type))); - - if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { - CHECK_GE(batch_size_, 0); - status = cudnnCreatePersistentRNNPlan(rnn_desc_, batch_size_, data_type_, - &rnn_plan_); - CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan."); - status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_); - CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan."); + /*mode=*/rnn_mode, /*algo=*/rnn_algo, + /*dataType=*/compute_type)); + + PersistentRnnPlan rnn_plan; + if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { + CHECK_GE(batch_size, 0); + rnn_plan = CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type); + RETURN_IF_CUDNN_ERROR( + cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get())); } -#else - CHECK(algorithm_config_.is_default()) - << "Non-default algorithm not supported for CUDA version < 6.0"; - status = cudnnSetRNNDescriptor( - /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, - /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(), - /*inputMode=*/input_mode, /*direction=*/direction_mode, - /*mode=*/rnn_mode, /*dataType=*/compute_type); - CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor"); -#endif // Create the params handle. - cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this)); - if (!cudnn_params_desc_->ok()) { - SetFailure(cudnn_params_desc_->Status()); - return; - } - set_use_tensor_op_math(algorithm_config_.algorithm().tensor_ops_enabled()); - } - ~CudnnRnnDescriptor() override { - if (rnn_desc_) { - cudnnStatus_t status; -#if CUDNN_VERSION >= 6000 - if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) { - status = cudnnDestroyPersistentRNNPlan(rnn_plan_); - CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan."); - } -#endif - status = cudnnDestroyRNNDescriptor(rnn_desc_); - CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor"); - } - } - void set_use_tensor_op_math(bool use_tensor_op_math) { + SE_ASSIGN_OR_RETURN(auto params_desc, + CudnnRnnParamsDescriptor::Create( + cudnn, input_size, data_type, rnn_desc.get(), + rnn_mode, direction_mode, num_layers)); + #if CUDNN_VERSION >= 7000 - cudnnMathType_t math_type = - (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); - if (RnnTensorOpMathEnabled()) { - cudnnStatus_t status = cudnnSetRNNMatrixMathType(rnn_desc_, math_type); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn RNN math type: " << ToString(status); - } + // Require explicit algorithm config to enable tensor cores. Some configs + // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against + // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799). + // We can only reasonably expect the user to handle the subsequent failure + // in profile mode, which is run with algorithms returned from + // GetRnnAlgorithms() (which are non-default and explicitly set whether to + // use tensor ops). + if (RnnTensorOpMathEnabled() && + !algorithm_config.algorithm().is_default()) { + cudnnMathType_t math_type = + algorithm_config.algorithm().tensor_ops_enabled() + ? CUDNN_TENSOR_OP_MATH + : CUDNN_DEFAULT_MATH; + CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); } #endif + + return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), + num_layers, hidden_size, input_size, batch_size, + input_mode, direction_mode, rnn_mode, data_type, + compute_type, algorithm_config, + std::move(dropout_desc), std::move(params_desc)); } - cudnnRNNDescriptor_t handle() const { - if (!ok()) return nullptr; - return rnn_desc_; - } + + cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); } int num_layers() const { return num_layers_; } int hidden_size() const { return hidden_size_; } int input_size() const { return input_size_; } @@ -1145,210 +1069,164 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { return algorithm_config_; } int64 ParamsSizeInBytes() const override { - return cudnn_params_desc_->params_size_in_bytes(); - } - cudnnDropoutDescriptor_t dropout_handle() const { - if (!cudnn_dropout_desc_) return nullptr; - return cudnn_dropout_desc_->handle(); + return params_desc_.params_size_in_bytes(); } cudnnFilterDescriptor_t params_handle() const { - if (!cudnn_params_desc_) return nullptr; - return cudnn_params_desc_->handle(); + return params_desc_.handle(); } ParamsRegions ParamsWeightRegions() const override { - if (!ok()) return ParamsRegions(); - return cudnn_params_desc_->params_weights(); + return params_desc_.params_weights(); } ParamsRegions ParamsBiasRegions() const override { - if (!ok()) return ParamsRegions(); - return cudnn_params_desc_->params_biases(); + return params_desc_.params_biases(); } private: - cudnnRNNDescriptor_t rnn_desc_; + cuda::RnnDescriptor rnn_desc_; + PersistentRnnPlan rnn_plan_; int num_layers_; int hidden_size_; int input_size_; // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC // algorithm. int batch_size_; -#if CUDNN_VERSION >= 6000 cudnnRNNAlgo_t rnn_algo_; - cudnnPersistentRNNPlan_t rnn_plan_; -#endif cudnnRNNInputMode_t input_mode_; cudnnDirectionMode_t direction_mode_; cudnnRNNMode_t rnn_mode_; cudnnDataType_t data_type_; cudnnDataType_t compute_type_; dnn::AlgorithmConfig algorithm_config_; - std::unique_ptr cudnn_dropout_desc_; - std::unique_ptr cudnn_params_desc_; + CudnnDropoutDescriptor dropout_desc_; + CudnnRnnParamsDescriptor params_desc_; SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor); }; namespace { -CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( - const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc) - : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) { - cudnnTensorDescriptor_t input_desc = nullptr; - { - // Query the params size. - auto status = cudnnCreateTensorDescriptor(&input_desc); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor"); - int dims[] = {1, rnn_desc.input_size(), 1}; - int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = cudnnSetTensorNdDescriptor( - /*tensorDesc=*/input_desc, /*dataType=*/rnn_desc.data_type(), - /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, - /*strideA=*/strides); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor"); - - size_t params_size = 0; - status = cudnnGetRNNParamsSize( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc, /*sizeInBytes=*/¶ms_size, - /*dataType=*/rnn_desc.data_type()); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size"); - params_size_in_bytes_ = static_cast(params_size); - } - - { - // Create the params descriptor. - auto status = cudnnCreateFilterDescriptor(&handle_); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor"); - int dims[] = {static_cast(params_size_in_bytes_), 1, 1}; - status = cudnnSetFilterNdDescriptor( - /*filterDesc=*/handle_, /*dataType=*/rnn_desc.data_type(), - /*format=*/CUDNN_TENSOR_NCHW, /*nbDims=*/sizeof(dims) / sizeof(dims[0]), - /*filterDimA=*/dims); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor"); - } +port::StatusOr CudnnRnnParamsDescriptor::Create( + const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, + cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, + cudnnDirectionMode_t direction_mode, int num_layers) { + // Query the params size. + TensorDescriptor input_desc = CreateTensorDescriptor(); + int tensor_dims[] = {1, input_size, 1}; + int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1}; + RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor( + /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type, + /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]), + /*dimA=*/tensor_dims, + /*strideA=*/strides)); + + size_t params_size = 0; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size, + /*dataType=*/data_type)); + int64 params_size_in_bytes = static_cast(params_size); + + FilterDescriptor filter_desc = CreateFilterDescriptor(); + int filter_dims[] = {static_cast(params_size_in_bytes), 1, 1}; + RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor( + /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type, + /*format=*/CUDNN_TENSOR_NCHW, + /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]), + /*filterDimA=*/filter_dims)); + + // Create the weights and biases into the params buffer + int region_count_per_layer = [&] { + switch (rnn_mode) { + case CUDNN_RNN_RELU: + case CUDNN_RNN_TANH: + return 2; + case CUDNN_LSTM: + return 8; + case CUDNN_GRU: + return 6; + default: + LOG(FATAL) << "Invalid RNN Mode: " << static_cast(rnn_mode); + return 0; + } + }(); - { - // Create the weights and biases into the params buffer - int region_count_per_layer = GetRegionCountPerLayer(); - cudnnFilterDescriptor_t region_desc_handle = nullptr; - auto status = cudnnCreateFilterDescriptor(®ion_desc_handle); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor"); - const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL - ? rnn_desc.num_layers() - : 2 * rnn_desc.num_layers(); - for (int layer = 0; layer < layer_count; layer++) { - for (int region = 0; region < region_count_per_layer; region++) { - for (int type = 0; type < 2; type++) { - void* offset = nullptr; - if (type == 0) { - status = cudnnGetRNNLinLayerMatrixParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_, - /*w=*/nullptr, /*linLayerID=*/region, - /*linLayerMatDesc=*/region_desc_handle, - /*linLayerMat=*/&offset); - CUDNN_RETURN_IF_FAIL( - status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams"); - } else { - status = cudnnGetRNNLinLayerBiasParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_, - /*w=*/nullptr, /*linLayerID=*/region, - /*linLayerBiasDesc=*/region_desc_handle, - /*linLayerBias=*/&offset); - CUDNN_RETURN_IF_FAIL( - status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams"); - } - int dims[] = {1, 1, 1}; - cudnnDataType_t data_type; - cudnnTensorFormat_t tensor_format; - int n_dims; - status = cudnnGetFilterNdDescriptor( - /*filterDesc=*/region_desc_handle, - /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), - /*dataType=*/&data_type, /*format=*/&tensor_format, - /*nbDims=*/&n_dims, /*filterDimA=*/dims); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description"); - int64 size = dims[0] * dims[1] * dims[2] * - CudnnDataTypeToByteSize(rnn_desc.data_type()); - ParamsRegion region = {reinterpret_cast(offset), size}; - if (type == 0) { - weights_.push_back(region); - } else { - biases_.push_back(region); - } - } + FilterDescriptor region_desc_handle = CreateFilterDescriptor(); + const int layer_count = + direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers; + + ParamsRegions weights; + ParamsRegions biases; + + for (int layer = 0; layer < layer_count; layer++) { + for (int region = 0; region < region_count_per_layer; region++) { + for (int type = 0; type < 2; type++) { + void* offset = nullptr; + RETURN_IF_CUDNN_ERROR((type == 0 ? cudnnGetRNNLinLayerMatrixParams + : cudnnGetRNNLinLayerBiasParams)( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*layer=*/layer, /*xDesc=*/input_desc.get(), + /*wDesc=*/filter_desc.get(), + /*w=*/nullptr, /*linLayerID=*/region, + /*linLayerMatDesc=*/region_desc_handle.get(), + /*linLayerMat or linLayerBias=*/&offset)); + int dims[] = {1, 1, 1}; + cudnnDataType_t data_type; + cudnnTensorFormat_t tensor_format; + int n_dims; + RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( + /*filterDesc=*/region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, /*format=*/&tensor_format, + /*nbDims=*/&n_dims, /*filterDimA=*/dims)); + int64 size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + dnn::RnnDescriptor::ParamsRegion region = { + reinterpret_cast(offset), size}; + (type == 0 ? weights : biases).push_back(region); } } - status = cudnnDestroyFilterDescriptor(region_desc_handle); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor"); - } - - { - // Release the dummy input tensor descriptor. - auto status = cudnnDestroyTensorDescriptor(input_desc); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor"); } -} -int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const { - auto rnn_mode = rnn_desc_->rnn_mode(); - switch (rnn_mode) { - case CUDNN_RNN_RELU: - case CUDNN_RNN_TANH: - return 2; - case CUDNN_LSTM: - return 8; - case CUDNN_GRU: - return 6; - default: - LOG(FATAL) << "Invalid RNN Mode: " << static_cast(rnn_mode); - } + return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes, + weights, biases); } } // namespace class CudnnRnnSequenceTensorDescriptor - : public CudnnDescriptorCommon { - public: + : public dnn::RnnSequenceTensorDescriptor { CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length, int batch_size, int data_size, - cudnnDataType_t data_type) + cudnnDataType_t data_type, + TensorDescriptor handle) : parent_(parent), seq_length_(seq_length), batch_size_(batch_size), data_size_(data_size), - data_type_(data_type) { - cudnnTensorDescriptor_t handle = nullptr; - if (seq_length <= 0) { - string error_msg = - port::StrCat("sequence length must be positive: ", seq_length); - LOG(ERROR) << error_msg; - SetFailure(port::Status(port::error::UNKNOWN, error_msg)); - return; - } - cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle); - CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); + data_type_(data_type), + handle_(std::move(handle)), + handles_(seq_length, handle_.get()) {} + + public: + CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) = + default; + + static port::StatusOr Create( + CUDAExecutor* parent, int seq_length, int batch_size, int data_size, + cudnnDataType_t data_type) { + CHECK_GT(seq_length, 0); int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = cudnnSetTensorNdDescriptor( - /*tensorDesc=*/handle, /*dataType=*/data_type, + TensorDescriptor tensor_desc = CreateTensorDescriptor(); + RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor( + /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type, /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, - /*strideA=*/strides); - CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); - // Replicate handle across the number of steps. - handles_.assign(seq_length, handle); - } - - ~CudnnRnnSequenceTensorDescriptor() override { - // Only the first one needs to be destroyed. All others are the same. - cudnnStatus_t status = cudnnDestroyTensorDescriptor(handles_[0]); - CUDNN_RETURN_IF_FAIL(status, - "Failed to destroy sequence tensor descriptor"); + /*strideA=*/strides)); + return CudnnRnnSequenceTensorDescriptor(parent, seq_length, batch_size, + data_size, data_type, + std::move(tensor_desc)); } const cudnnTensorDescriptor_t* handles() const { - if (!ok()) return nullptr; - CHECK(!handles_.empty()) << "handles cannot be empty"; return handles_.data(); } @@ -1362,51 +1240,39 @@ class CudnnRnnSequenceTensorDescriptor int batch_size_; int data_size_; cudnnDataType_t data_type_; - std::vector handles_; + TensorDescriptor handle_; + std::vector handles_; // Copies of handle_. SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor); }; -class CudnnRnnStateTensorDescriptor - : public CudnnDescriptorCommon { +class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { public: CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers, int batch_size, int data_size, cudnnDataType_t data_type) : parent_(parent), - handle_(nullptr), + handle_(CreateTensorDescriptor()), num_layers_(num_layers), batch_size_(batch_size), data_size_(data_size), data_type_(data_type) { - cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); int dims[] = {num_layers, batch_size, data_size}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = cudnnSetTensorNdDescriptor( - /*tensorDesc=*/handle_, /*dataType=*/data_type, + CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor( + /*tensorDesc=*/handle_.get(), /*dataType=*/data_type, /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, - /*strideA=*/strides); - CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); + /*strideA=*/strides)); } - ~CudnnRnnStateTensorDescriptor() override { - if (!handle_) { - cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_); - CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor"); - } - } + cudnnTensorDescriptor_t handle() const { return handle_.get(); } - cudnnTensorDescriptor_t handle() const { - if (!ok()) return nullptr; - return handle_; - } int num_layers() const { return num_layers_; } int batch_size() const { return batch_size_; } int data_size() const { return data_size_; } private: CUDAExecutor* parent_; - cudnnTensorDescriptor_t handle_; + TensorDescriptor handle_; int num_layers_; int batch_size_; int data_size_; @@ -1426,7 +1292,7 @@ struct RnnModelDims { }; template -bool ExtractAndCheckRnnForward( +port::StatusOr ExtractAndCheckRnnForward( const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1439,103 +1305,89 @@ bool ExtractAndCheckRnnForward( const CudnnRnnStateTensorDescriptor& output_h_desc, const DeviceMemory& output_h_data, const CudnnRnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, RnnModelDims* model_dims) { + const DeviceMemory& output_c_data) { // extract model parameters - model_dims->num_layers = rnn_desc.num_layers(); - model_dims->batch_size = input_desc.batch_size(); - model_dims->seq_length = input_desc.seq_length(); - model_dims->hidden_size = rnn_desc.hidden_size(); - model_dims->input_size = input_desc.data_size(); - model_dims->dir_count = + RnnModelDims model_dims; + model_dims.num_layers = rnn_desc.num_layers(); + model_dims.batch_size = input_desc.batch_size(); + model_dims.seq_length = input_desc.seq_length(); + model_dims.hidden_size = rnn_desc.hidden_size(); + model_dims.input_size = input_desc.data_size(); + model_dims.dir_count = (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1; // check parameters if (!(input_h_desc.num_layers() == - model_dims->num_layers * model_dims->dir_count && - input_h_desc.batch_size() == model_dims->batch_size && - input_h_desc.data_size() == model_dims->hidden_size)) { - LOG(ERROR) << "Invalid input_h shape"; - return false; + model_dims.num_layers * model_dims.dir_count && + input_h_desc.batch_size() == model_dims.batch_size && + input_h_desc.data_size() == model_dims.hidden_size)) { + return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape"); } if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && input_h_desc.batch_size() == input_c_desc.batch_size() && input_h_desc.data_size() == input_c_desc.data_size())) { - LOG(ERROR) << "Invalid input_c shape"; - return false; + return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape"); } - if (!(output_desc.seq_length() == model_dims->seq_length && - output_desc.batch_size() == model_dims->batch_size && + if (!(output_desc.seq_length() == model_dims.seq_length && + output_desc.batch_size() == model_dims.batch_size && output_desc.data_size() == - model_dims->hidden_size * model_dims->dir_count)) { - LOG(ERROR) << "Invalid output shape"; - return false; + model_dims.hidden_size * model_dims.dir_count)) { + return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape"); } if (!(input_h_desc.num_layers() == output_h_desc.num_layers() && input_h_desc.batch_size() == output_h_desc.batch_size() && input_h_desc.data_size() == output_h_desc.data_size())) { - LOG(ERROR) << "Invalid output_h shape"; - return false; + return port::Status(port::error::INVALID_ARGUMENT, + "Invalid output_h shape"); } if (!(input_h_desc.num_layers() == output_c_desc.num_layers() && input_h_desc.batch_size() == output_c_desc.batch_size() && input_h_desc.data_size() == output_c_desc.data_size())) { - LOG(ERROR) << "Invalid output_h shape"; - return false; + return port::Status(port::error::INVALID_ARGUMENT, + "Invalid output_c shape"); } - return true; + return model_dims; } -bool CheckRNNParameterSize(const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc) { +port::Status CheckRNNParameterSize( + const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; - cudnnStatus_t status = cudnnGetRNNParamsSize( + RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes, - /*dataType=*/rnn_desc.data_type()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Unable to check RNN param size: " << ToString(status); - return false; + /*dataType=*/rnn_desc.data_type())); + if (static_cast(params_size_in_bytes) != + rnn_desc.ParamsSizeInBytes()) { + return port::Status(port::error::INVALID_ARGUMENT, + "Mismatching RNN parameter size"); } - return static_cast(params_size_in_bytes) == - rnn_desc.ParamsSizeInBytes(); + return port::Status::OK(); } -bool CreateRnnWorkspace(Stream* stream, const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc, - ScratchAllocator* workspace_allocator, - DeviceMemory* workspace) { +port::StatusOr> CreateRnnWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + ScratchAllocator* workspace_allocator) { // Query the workspace size. size_t workspace_size_in_bytes = 0; - cudnnStatus_t status = cudnnGetRNNWorkspaceSize( + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/input_desc.seq_length(), /*xDesc=*/input_desc.handles(), - /*sizeInBytes=*/&workspace_size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Unable to query workspace size: " << ToString(status); - return false; - } + /*sizeInBytes=*/&workspace_size_in_bytes)); // Allocate the workspace. - if (workspace_size_in_bytes > 0) { - auto allocated = - workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); - if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ", - workspace_size_in_bytes, " bytes."); - return false; - } - } else { - *workspace = DeviceMemory(); + if (workspace_size_in_bytes == 0) { + return DeviceMemory(); } - return true; + return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); } } // namespace template -bool CudnnSupport::DoRnnForwardImpl( +port::Status CudnnSupport::DoRnnForwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1552,57 +1404,34 @@ bool CudnnSupport::DoRnnForwardImpl( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - // extract model parameters - RnnModelDims model_dims; - bool res = ExtractAndCheckRnnForward( - rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, *output_data, - output_h_desc, *output_h_data, output_c_desc, *output_c_data, - &model_dims); - if (!res) { - LOG(ERROR) << "Invalid parameters for RNN Model"; - return false; - } + SE_ASSIGN_OR_RETURN( + RnnModelDims model_dims, + ExtractAndCheckRnnForward( + rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, *output_data, + output_h_desc, *output_h_data, output_c_desc, *output_c_data)); auto cudnn = cudnn_->GetHandle(parent_, stream); - // check params size - if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) { - LOG(ERROR) << "Invalid parameters"; - return false; - } - - // create the workspace - DeviceMemory workspace; - if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator, &workspace)) { - LOG(ERROR) << "Unable to create rnn workspace"; - return false; - } + SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, + workspace_allocator)) // query the reserve space size // allocate the reserve space DeviceMemory reserve_space; if (is_training) { size_t reserve_space_size_in_bytes = 0; - cudnnStatus_t status = cudnnGetRNNTrainingReserveSize( + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), - /*sizeInBytes=*/&reserve_space_size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Unable to query reserve space size: " << ToString(status); - return false; - } + /*sizeInBytes=*/&reserve_space_size_in_bytes)); if (reserve_space_size_in_bytes > 0) { - auto allocated = reserve_space_allocator->AllocateBytes( - stream, reserve_space_size_in_bytes); - if (!allocated.ok() || - (reserve_space = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << "Failed to allocate RNN reserve space of " - << reserve_space_size_in_bytes << " bytes."; - return false; - } + SE_ASSIGN_OR_RETURN(reserve_space, + reserve_space_allocator->AllocateBytes( + stream, reserve_space_size_in_bytes)); } } @@ -1610,20 +1439,16 @@ bool CudnnSupport::DoRnnForwardImpl( const bool is_profiling = output_profile_result != nullptr; if (is_profiling) { timer.reset(new CUDATimer(parent_)); - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } - // make the forward call - cudnnStatus_t status; + if (!is_training) { - status = cudnnRNNForwardInference( + RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), @@ -1633,9 +1458,9 @@ bool CudnnSupport::DoRnnForwardImpl( /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size()); + /*workSpaceSizeInBytes=*/workspace.size())); } else { - status = cudnnRNNForwardTraining( + RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), @@ -1647,35 +1472,24 @@ bool CudnnSupport::DoRnnForwardImpl( /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(), /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space.opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space.size()); + /*reserveSpaceSizeInBytes=*/reserve_space.size())); } + if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - return false; - } - if (status == CUDNN_STATUS_SUCCESS) { - auto algo_desc = rnn_desc.algorithm_config().algorithm(); - output_profile_result->set_algorithm(algo_desc); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - } - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "Failed to call " - << (is_training ? "cudnnRNNForwardTraining " - : "cudnnRNNForwardInference ") - << ToString(status); - return false; + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } + auto algo_desc = rnn_desc.algorithm_config().algorithm(); + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } template -bool CudnnSupport::DoRnnBackwardImpl( +port::Status CudnnSupport::DoRnnBackwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1699,53 +1513,38 @@ bool CudnnSupport::DoRnnBackwardImpl( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - // extract model parameters - RnnModelDims model_dims; - bool res = ExtractAndCheckRnnForward( - rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims); - if (!res) { - LOG(ERROR) << "Invalid parameters for RNN Model"; - return false; - } + SE_ASSIGN_OR_RETURN( + RnnModelDims model_dims, + ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, + params, output_desc, output_data, output_h_desc, + output_h_data, output_c_desc, output_c_data)); auto cudnn = cudnn_->GetHandle(parent_, stream); - // check params size - if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) { - LOG(ERROR) << "Invalid parameters"; - return false; - } - - // create the workspace - DeviceMemory workspace; - if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator, &workspace)) { - LOG(ERROR) << "Unable to create rnn workspace"; - return false; - } + SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, + workspace_allocator)); std::unique_ptr timer; const bool is_profiling = output_profile_result != nullptr; if (is_profiling) { timer.reset(new CUDATimer(parent_)); - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } - // make the backward data call - cudnnStatus_t status = cudnnRNNBackwardData( + + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*yDesc=*/output_desc.handles(), /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(), - /*dy=*/output_backprop_data.opaque(), /*dhyDesc=*/output_h_desc.handle(), + /*dy=*/output_backprop_data.opaque(), + /*dhyDesc=*/output_h_desc.handle(), /*dhy=*/output_h_backprop_data.opaque(), /*dcyDesc=*/output_c_desc.handle(), /*dcy=*/output_c_backprop_data.opaque(), @@ -1756,24 +1555,17 @@ bool CudnnSupport::DoRnnBackwardImpl( /*dhxDesc=*/input_h_desc.handle(), /*dhx=*/input_h_backprop_data->opaque(), /*dcxDesc=*/input_c_desc.handle(), - /*dcx=*/input_c_backprop_data->opaque(), /*workspace=*/workspace.opaque(), + /*dcx=*/input_c_backprop_data->opaque(), + /*workspace=*/workspace.opaque(), /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size()); - - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - } - LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status); - return false; - } + /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); if (params_backprop_data != nullptr) { // Clear the dw to zeros. stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); // make the backward weight call - status = cudnnRNNBackwardWeights( + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), @@ -1783,19 +1575,12 @@ bool CudnnSupport::DoRnnBackwardImpl( /*dwDesc=*/rnn_desc.params_handle(), /*dw=*/params_backprop_data->opaque(), /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size()); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - } - LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: " - << ToString(status); - return false; - } + /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); } + if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - return false; + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } auto algo_desc = rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); @@ -1803,11 +1588,9 @@ bool CudnnSupport::DoRnnBackwardImpl( timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } -#endif // CUDNN_VERSION - port::StatusOr> CudnnSupport::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int batch_size, @@ -1815,73 +1598,40 @@ CudnnSupport::createRnnDescriptor( dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, ScratchAllocator* state_allocator) { -#if CUDNN_VERSION >= 5000 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's // not enqueueing anything into a stream, we pass in the null stream. auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); - std::unique_ptr rnn_desc(new CudnnRnnDescriptor( - cudnn, num_layers, hidden_size, input_size, batch_size, - ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), - ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), - GetRnnComputeType(data_type), algorithm_config, dropout, seed, - state_allocator)); - if (!rnn_desc->ok()) { - return rnn_desc->Status(); - } - return port::StatusOr>( - std::move(rnn_desc)); -#else - string error_msg = - port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ", - "Current Cudnn version: ", CUDNN_VERSION, ". "); - LOG(ERROR) << error_msg; - return port::Status(port::error::UNIMPLEMENTED, error_msg); -#endif // CUDNN_VERSION + SE_ASSIGN_OR_RETURN( + CudnnRnnDescriptor rnn_desc, + CudnnRnnDescriptor::Create( + cudnn, num_layers, hidden_size, input_size, batch_size, + ToCudnnRnnInputMode(input_mode), + ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode), + ToCudnnDataType(data_type), GetRnnComputeType(data_type), + algorithm_config, dropout, seed, state_allocator)); + return std::unique_ptr( + new CudnnRnnDescriptor(std::move(rnn_desc))); } port::StatusOr> CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) { -#if CUDNN_VERSION >= 5000 - std::unique_ptr seq_desc( - new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size, - data_size, - ToCudnnDataType(data_type))); - if (!seq_desc->ok()) { - return seq_desc->Status(); - } - return port::StatusOr>( - std::move(seq_desc)); -#else - string error_msg = port::StrCat( - "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ", - "Current Cudnn version: ", CUDNN_VERSION, ". "); - LOG(ERROR) << error_msg; - return port::Status(port::error::UNIMPLEMENTED, error_msg); -#endif // CUDNN_VERSION + SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, + CudnnRnnSequenceTensorDescriptor::Create( + parent_, seq_length, batch_size, data_size, + ToCudnnDataType(data_type))); + return std::unique_ptr( + new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } port::StatusOr> CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { -#if CUDNN_VERSION >= 5000 - std::unique_ptr state_desc( + return std::unique_ptr( new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size, data_size, ToCudnnDataType(data_type))); - if (!state_desc->ok()) { - return state_desc->Status(); - } - return port::StatusOr>( - std::move(state_desc)); -#else - string error_msg = port::StrCat( - "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ", - "Current Cudnn version: ", CUDNN_VERSION, ". "); - LOG(ERROR) << error_msg; - return port::Status(port::error::UNIMPLEMENTED, error_msg); -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnForward( @@ -1902,7 +1652,6 @@ bool CudnnSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -1918,15 +1667,14 @@ bool CudnnSupport::DoRnnForward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); -#else - return false; -#endif // CUDNN_VERSION + return IsStatusOk( + DoRnnForwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnForward( @@ -1946,7 +1694,6 @@ bool CudnnSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -1962,15 +1709,14 @@ bool CudnnSupport::DoRnnForward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); -#else - return false; -#endif // CUDNN_VERSION + return IsStatusOk( + DoRnnForwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnForward( @@ -1991,7 +1737,6 @@ bool CudnnSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2007,15 +1752,14 @@ bool CudnnSupport::DoRnnForward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); -#else - return false; -#endif // CUDNN_VERSION + return IsStatusOk( + DoRnnForwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnBackward( @@ -2043,7 +1787,6 @@ bool CudnnSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2059,17 +1802,17 @@ bool CudnnSupport::DoRnnBackward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, input_backprop_data, input_h_backprop_data, - input_c_backprop_data, params_backprop_data, reserve_space_data, - workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION + return IsStatusOk( + DoRnnBackwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnBackward( @@ -2096,7 +1839,6 @@ bool CudnnSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2112,17 +1854,17 @@ bool CudnnSupport::DoRnnBackward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, input_backprop_data, input_h_backprop_data, - input_c_backprop_data, params_backprop_data, reserve_space_data, - workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION + return IsStatusOk( + DoRnnBackwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnBackward( @@ -2150,7 +1892,6 @@ bool CudnnSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2166,124 +1907,351 @@ bool CudnnSupport::DoRnnBackward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, input_backprop_data, input_h_backprop_data, - input_c_backprop_data, params_backprop_data, reserve_space_data, - workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION + return IsStatusOk( + DoRnnBackwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/!output_profile_result); } namespace { -inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo( - const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd, - const ScopedFilterDescriptor& filter, - const ScopedConvolutionDescriptor& conv, - const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit, +// TODO(csigg): Merge a lot of duplicate code below for forward, backward data, +// and backward filter. + +port::StatusOr GetCudnnConvolutionForwardAlgo( + const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, + const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit, size_t memory_limit_bytes) { cudnnConvolutionFwdPreference_t preference = specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - cudnnConvolutionFwdAlgo_t algo_to_use; - auto status = cudnnGetConvolutionForwardAlgorithm( + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm( cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), - output_nd.handle(), preference, memory_limit_bytes, &algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) - << "Unable to find a suitable algorithm for doing forward convolution"; + output_nd.handle(), preference, memory_limit_bytes, &algo_to_use)); return algo_to_use; } -dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm( - Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmConfig& algorithm_config, bool is_profiling, - const ScopedTensorDescriptor& input_nd, - const ScopedFilterDescriptor& filter, - const ScopedConvolutionDescriptor& conv, - const ScopedTensorDescriptor& output_nd, - ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { - cudnnConvolutionFwdAlgo_t algo; - bool use_tensor_ops; - if (algorithm_config.algorithm().is_default()) { - use_tensor_ops = true; +port::StatusOr +GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, + const CudnnTensorDescriptor& input_nd, + const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, + bool specify_workspace_limit, + size_t memory_limit_bytes) { + cudnnConvolutionBwdDataPreference_t preference = + specify_workspace_limit + ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; + cudnnConvolutionBwdDataAlgo_t algo_to_use; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm( + cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(), + input_nd.handle(), preference, memory_limit_bytes, &algo_to_use)); + return algo_to_use; +} - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } +port::StatusOr +GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, + const CudnnTensorDescriptor& input_nd, + const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, + bool specify_workspace_limit, + size_t memory_limit_bytes) { + cudnnConvolutionBwdFilterPreference_t preference = + specify_workspace_limit + ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; + cudnnConvolutionBwdFilterAlgo_t algo_to_use; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm( + cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(), + filter.handle(), preference, memory_limit_bytes, &algo_to_use)); + return algo_to_use; +} - algo = GetCudnnConvolutionForwardAlgo( - cudnn, input_nd, filter, conv, output_nd, - /*specify_workspace_limit=*/scratch_allocator != nullptr, - memory_limit_bytes); - } else { - use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - algo = ToConvForwardAlgo(algorithm_config.algorithm()); - } +port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmDesc& algorithm_desc, + const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator) { + // TODO(csigg): This has side effects on the convolution descriptor. It is + // functionally correct because the convolution is run with the algorithm of + // the last call to this function, but should be fixed anyway. + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + + // Query the size of the workspace and allocate it. size_t size_in_bytes; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize( cudnn.handle(), /*xDesc=*/input_nd.handle(), /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); + /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), + /*sizeInBytes=*/&size_in_bytes)); int64 size_in_bytes_int64 = size_in_bytes; - if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) { - CHECK(is_profiling) << "Cannot query the size of workspace needed " - "for the specified algorithm: " - << algorithm_config.algorithm().algo_id() << " " - << ToString(status); - // Silently return when we are profiling. - return dnn::AlgorithmDesc(); + + if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { + return port::Status( + port::error::INTERNAL, + "cudnnGetConvolutionForwardWorkspaceSize() returned " + "negative sizeInBytes value. This could be a cudnn bug."); + } + + if (size_in_bytes_int64 == 0) { + return DeviceMemory(); + } + + if (TF_PREDICT_FALSE(!scratch_allocator)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No scratch allocator provided"); } + + return scratch_allocator->AllocateBytes(stream, size_in_bytes); +} + +port::StatusOr> +AllocateCudnnConvolutionBackwardDataWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmDesc& algorithm_desc, + const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator) { + // TODO(csigg): This has side effects on the convolution descriptor. It is + // functionally correct because the convolution is run with the algorithm of + // the last call to this function, but should be fixed anyway. + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + + // Query the size of the workspace and allocate it. + size_t size_in_bytes; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn.handle(), + /*wDesc=*/filter.handle(), + /*dyDesc=*/output_nd.handle(), + /*convDesc=*/conv.handle(), + /*dxDesc=*/input_nd.handle(), + /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), + /*sizeInBytes=*/&size_in_bytes)); + int64 size_in_bytes_int64 = size_in_bytes; + if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { - LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - if (TF_PREDICT_TRUE(is_profiling)) { - return dnn::AlgorithmDesc(); - } - } else if (size_in_bytes_int64 > 0) { - port::StatusOr> allocated; - if (TF_PREDICT_TRUE(scratch_allocator)) { - allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (TF_PREDICT_TRUE(allocated.ok())) { - *scratch = allocated.ValueOrDie(); - } else { - if (TF_PREDICT_TRUE(is_profiling)) { - // Silently return when we are profiling. - return dnn::AlgorithmDesc(); - } - LOG(WARNING) << allocated.status().error_message(); - // For the int8 case, we fail at this point since the no_scratch - // algorithm should be set to dnn::kDefaultAlgorithm. - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - } - } - if (TF_PREDICT_FALSE(!allocated.ok())) { - if (algorithm_config.algorithm_no_scratch().is_default()) { - use_tensor_ops = true; - algo = GetCudnnConvolutionForwardAlgo( - cudnn, input_nd, filter, conv, output_nd, - /*specify_workspace_limit=*/false, 0); - } else { - use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch()); - } - } + return port::Status( + port::error::INTERNAL, + "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " + "negative sizeInBytes value. This could be a cudnn bug."); + } + + if (size_in_bytes_int64 == 0) { + return DeviceMemory(); + } + + if (TF_PREDICT_FALSE(!scratch_allocator)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No scratch allocator provided"); + } + + return scratch_allocator->AllocateBytes(stream, size_in_bytes); +} + +port::StatusOr> +AllocateCudnnConvolutionBackwardFilterWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmDesc& algorithm_desc, + const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator) { + // TODO(csigg): This has side effects on the convolution descriptor. It is + // functionally correct because the convolution is run with the algorithm of + // the last call to this function, but should be fixed anyway. + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + + // Query the size of the workspace and allocate it. + size_t size_in_bytes; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*dyDesc=*/output_nd.handle(), + /*convDesc=*/conv.handle(), + /*gradDesc=*/filter.handle(), + /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), + /*sizeInBytes=*/&size_in_bytes)); + int64 size_in_bytes_int64 = size_in_bytes; + + if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { + return port::Status( + port::error::INTERNAL, + "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " + "negative sizeInBytes value. This could be a cudnn bug."); + } + + if (size_in_bytes_int64 == 0) { + return DeviceMemory(); + } + + if (TF_PREDICT_FALSE(!scratch_allocator)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No scratch allocator provided"); + } + + return scratch_allocator->AllocateBytes(stream, size_in_bytes); +} + +port::StatusOr GetCudnnConvolutionForwardAlgorithm( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmConfig& algorithm_config, + const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, + DeviceMemory* scratch) { + dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); + if (algorithm_config.algorithm().is_default()) { + // Pick fastest algorithm within memory limit according to cuDNN's + // heuristics. + bool specify_workspace_limit = scratch_allocator != nullptr; + auto memory_limit_bytes = + specify_workspace_limit + ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + : 0ll; + SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo, + GetCudnnConvolutionForwardAlgo( + cudnn, input_nd, filter, conv, output_nd, + specify_workspace_limit, memory_limit_bytes)); + algo_desc = dnn::AlgorithmDesc( + algo, algorithm_config.algorithm().tensor_ops_enabled()); + } + + auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + scratch_allocator); + + if (scratch_or.ok()) { + *scratch = scratch_or.ValueOrDie(); + return algo_desc; + } + + // Failed to allocate workspace for the first algorithm, fall back to the + // no_scratch algorithm. + if (algorithm_config.algorithm_no_scratch().is_default()) { + return port::Status( + port::error::INVALID_ARGUMENT, + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided."); + } + + SE_ASSIGN_OR_RETURN( + *scratch, AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, algorithm_config.algorithm_no_scratch(), + input_nd, filter, conv, output_nd, scratch_allocator)); + return algorithm_config.algorithm_no_scratch(); +} + +port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmConfig& algorithm_config, + const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, + DeviceMemory* scratch) { + dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); + if (algorithm_config.algorithm().is_default()) { + // Pick fastest algorithm within memory limit according to cuDNN's + // heuristics. + bool specify_workspace_limit = scratch_allocator != nullptr; + auto memory_limit_bytes = + specify_workspace_limit + ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + : 0ll; + SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo, + GetCudnnConvolutionBackwardDataAlgo( + cudnn, input_nd, filter, conv, output_nd, + specify_workspace_limit, memory_limit_bytes)); + algo_desc = dnn::AlgorithmDesc( + algo, algorithm_config.algorithm().tensor_ops_enabled()); + } + + auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + scratch_allocator); + + if (scratch_or.ok()) { + *scratch = scratch_or.ValueOrDie(); + return algo_desc; + } + + // Failed to allocate workspace for the first algorithm, fall back to the + // no_scratch algorithm. + if (algorithm_config.algorithm_no_scratch().is_default()) { + return port::Status( + port::error::INVALID_ARGUMENT, + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided."); + } + + SE_ASSIGN_OR_RETURN( + *scratch, AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, algorithm_config.algorithm_no_scratch(), + input_nd, filter, conv, output_nd, scratch_allocator)); + return algorithm_config.algorithm_no_scratch(); +} + +port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmConfig& algorithm_config, + const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, + const CudnnConvolutionDescriptor& conv, + const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, + DeviceMemory* scratch) { + dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); + if (algorithm_config.algorithm().is_default()) { + // Pick fastest algorithm within memory limit according to cuDNN's + // heuristics. + bool specify_workspace_limit = scratch_allocator != nullptr; + auto memory_limit_bytes = + specify_workspace_limit + ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + : 0ll; + SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo, + GetCudnnConvolutionBackwardFilterAlgo( + cudnn, input_nd, filter, conv, output_nd, + specify_workspace_limit, memory_limit_bytes)); + algo_desc = dnn::AlgorithmDesc( + algo, algorithm_config.algorithm().tensor_ops_enabled()); + } + + auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + scratch_allocator); + + if (scratch_or.ok()) { + *scratch = scratch_or.ValueOrDie(); + return algo_desc; } - return dnn::AlgorithmDesc(algo, use_tensor_ops); + // Failed to allocate workspace for the first algorithm, fall back to the + // no_scratch algorithm. + if (algorithm_config.algorithm_no_scratch().is_default()) { + return port::Status( + port::error::INVALID_ARGUMENT, + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided."); + } + + SE_ASSIGN_OR_RETURN(*scratch, + AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, algorithm_config.algorithm(), input_nd, + filter, conv, output_nd, scratch_allocator)); + return algorithm_config.algorithm_no_scratch(); } // A helper class to set env-vars and choose options for cudnn-related @@ -2311,16 +2279,12 @@ class CudnnEnvVar { }; // A helper struct to decide whether to enable the FFT_TILING algorithms for -// forward convolution. Before cudnn v5.1 it works fine but since cudnn v5.1 -// it is turned off due to memory corruption caused by some shapes with this -// algorithm. -// Before NVIDIA fixes the memory corruption bug, users can explicitly -// enable the algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1". +// forward convolution. It is disabled for cuDNN < 7 due to memory corruption +// caused by some shapes with this algorithm. Users can explicitly enable the +// algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1". struct FftTilingForward { static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD"; - // TODO(yangzihao): turn the default to True when the memory corruption bug - // is fixed. - static constexpr bool kDefaultFlag = CUDNN_VERSION < 5100; + static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000; }; // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms. @@ -2329,10 +2293,9 @@ struct FftTilingForward { // https://github.com/tensorflow/tensorflow/pull/4901 struct WinogradNonfused { static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED"; - // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. - // For cudnn v>=5.1, we have a workaround and for any lower version, we - // disable it by default. - static constexpr bool kDefaultFlag = CUDNN_VERSION >= 5100; + // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions, + // we have a workaround. + static constexpr bool kDefaultFlag = true; }; // A helper struct to decide whether to use FP32 as the internal compute type @@ -2386,8 +2349,6 @@ struct RnnDoFP32ComputationFP16Input { static constexpr bool kDefaultFlag = false; }; -// A helper function to return the internal compute type for -// RNNs in cudnn. cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) { switch (data_type) { case dnn::DataType::kFloat: @@ -2408,7 +2369,7 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) { } // namespace template -bool CudnnSupport::DoConvolveImpl( +port::Status CudnnSupport::DoConvolveImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -2419,11 +2380,11 @@ bool CudnnSupport::DoConvolveImpl( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { cudnnDataType_t cudnn_type = GetCudnnDataType(); - ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type); - ScopedTensorDescriptor output_nd(output_descriptor, cudnn_type); - ScopedFilterDescriptor filter(filter_descriptor, cudnn_type); - ScopedConvolutionDescriptor conv(convolution_descriptor, - GetConvComputeType()); + CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type); + CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type); + CudnnFilterDescriptor filter(filter_descriptor, cudnn_type); + CudnnConvolutionDescriptor conv(convolution_descriptor, + GetConvComputeType()); auto cudnn = cudnn_->GetHandle(parent_, stream); // Alpha is the scaling factor for input. @@ -2438,177 +2399,75 @@ bool CudnnSupport::DoConvolveImpl( : static_cast(&fbeta); const bool is_profiling = output_profile_result != nullptr; - cudnnConvolutionFwdAlgo_t algo; - bool use_tensor_ops; - DeviceMemory scratch; - // TODO(pauldonnelly): Replace the following code with a call to - // GetCudnnConvolutionForwardAlgorithm(). - if (algorithm_config.algorithm().is_default()) { - // With the default algorithm, use Cudnn's heuristics. - auto get_algorithm = [&](bool specify_limit) { - cudnnConvolutionFwdPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - - cudnnConvolutionFwdAlgo_t algo_to_use; - auto status = cudnnGetConvolutionForwardAlgorithm( - cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), - output_nd.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " - "algorithm for doing forward " - "convolution"; - return algo_to_use; - }; - - algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); - use_tensor_ops = true; - if (scratch_allocator != nullptr) { - size_t size_in_bytes; - auto status = cudnnGetConvolutionForwardWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { - if (size_in_bytes_int64 > 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - } else { - LOG(WARNING) - << "cudnnGetConvolutionForwardWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionForwardAlgorithm( + stream, cudnn, algorithm_config, input_nd, filter, + conv, output_nd, scratch_allocator, &scratch)); - // If we didn't allocate any scratch space (perhaps because of failed - // allocation), we force a switch back to the "no workspace" algorithm. - if (scratch == nullptr) { - algo = get_algorithm(/*specify_limit=*/false); - } - } else { - // An algorithm has been specified. - dnn::AlgorithmDesc algotype = algorithm_config.algorithm(); - algo = ToConvForwardAlgo(algotype); - use_tensor_ops = algotype.tensor_ops_enabled(); - conv.set_use_tensor_op_math(use_tensor_ops); - size_t size_in_bytes; - auto status = cudnnGetConvolutionForwardWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - // Silently return when we are profiling. - return false; - } - LOG(FATAL) << "Cannot query the size of workspace needed for the given " - "algorithm: " - << algorithm_config.algorithm().algo_id(); - } - int64 size_in_bytes_int64 = size_in_bytes; - if (size_in_bytes_int64 > 0) { - if (scratch_allocator == nullptr) { - LOG(FATAL) << "An allocator must be specified when scratch memory is " - "needed"; - } - auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (is_profiling && !allocated.ok()) { - // Silently return when we are profiling. - return false; - } - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - if (scratch == nullptr) { - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch(); - algo = ToConvForwardAlgo(algotype); - use_tensor_ops = algotype.tensor_ops_enabled(); - conv.set_use_tensor_op_math(use_tensor_ops); - } - } else if (size_in_bytes_int64 < 0) { - LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - timer->Destroy(); - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } - auto status = cudnnConvolutionForward( + + // Report an error if we might be hitting a cuDNN bug that accesses illegal + // memory. See nvbugs/2138754, b/80018418. + SE_RETURN_IF_ERROR([&] { + if (algo_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) { + return port::Status::OK(); + } + if (input_descriptor.ndims() < 3) { + return port::Status::OK(); + } + // Checks that a*b is within the valid range (as provided by NVIDIA). + auto check_sizes = [](size_t a, size_t b) { + if ((a * b * 4608 - 1) >> 31 == 0) { + return port::Status::OK(); + } + return port::Status( + port::error::FAILED_PRECONDITION, + "This configuration potentially accesses illegal memory."); + }; + SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(), + output_descriptor.feature_map_count())); + SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(), + input_descriptor.feature_map_count())); + SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(), + output_descriptor.feature_map_count())); + return port::Status::OK(); + }()); + + RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward( cudnn.handle(), /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*algo=*/ToConvForwardAlgo(algo_desc), /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, - /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque())); if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - timer->Destroy(); - return false; - } - if (status == CUDNN_STATUS_SUCCESS) { - dnn::AlgorithmDesc algotype(algo, use_tensor_ops); - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - timer->Destroy(); - } - - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } template -bool CudnnSupport::DoFusedConvolveImpl( +port::Status CudnnSupport::DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, const DeviceMemory& conv_input_data, ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, @@ -2621,61 +2480,48 @@ bool CudnnSupport::DoFusedConvolveImpl( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION < 6000 - LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only " - "supported for cuDNN version >= 6"; - return false; -#else - ScopedTensorDescriptor conv_input_nd( + if (activation_mode != dnn::ActivationMode::kRelu) { + return port::Status(port::error::INVALID_ARGUMENT, + "cudnnConvolutionBiasActivationForward() only supports " + "Relu activation."); + } + + CudnnTensorDescriptor conv_input_nd( conv_input_descriptor, static_cast(cudnn_data_type)); - ScopedTensorDescriptor output_nd( + CudnnTensorDescriptor output_nd( output_descriptor, static_cast(cudnn_data_type)); - ScopedFilterDescriptor filter(filter_descriptor, - static_cast(cudnn_data_type)); - ScopedTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT); - ScopedConvolutionDescriptor conv( + CudnnFilterDescriptor filter(filter_descriptor, + static_cast(cudnn_data_type)); + CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT); + CudnnConvolutionDescriptor conv( convolution_descriptor, static_cast(cudnn_compute_type)); auto cudnn = cudnn_->GetHandle(parent_, stream); + const bool is_profiling = output_profile_result != nullptr; - DeviceMemory scratch; - dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm( - stream, cudnn, algorithm_config, is_profiling, conv_input_nd, filter, - conv, output_nd, scratch_allocator, &scratch); - if (algotype.is_default()) { - if (!is_profiling) { - LOG(ERROR) << "No suitable algorithm found"; - } - return false; - } - auto algo = static_cast(algotype.algo_id()); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - if (activation_mode != dnn::ActivationMode::kRelu) { - LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu " - "activation."; - return false; - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN( + dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionForwardAlgorithm( + stream, cudnn, algorithm_config, conv_input_nd, filter, conv, + output_nd, scratch_allocator, &scratch)); - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - timer->Destroy(); - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for // activation descriptor. Note that this will change the nan propagation // behavior from separate conv, bias, and relu (which by default is // CUDNN_PROPAGATE_NAN. - ScopedActivationDescriptor activation_desc( + CudnnActivationDescriptor activation_desc( activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max()); auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque() : side_input_data.opaque(); @@ -2685,7 +2531,8 @@ bool CudnnSupport::DoFusedConvolveImpl( << "\nconv_input_data.opaque() = " << conv_input_data.opaque() << "\nfilter.handle() = " << filter.handle() << "\nfilter_data.opaque() = " << filter_data.opaque() - << "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo + << "\nconv.handle() = " << conv.handle() + << "\nalgo = " << algo_desc.algo_id() << "\nscratch.opaque() = " << scratch.opaque() << "\nscratch.size() = " << scratch.size() << "\nside_input_scale = " << side_input_scale @@ -2697,42 +2544,29 @@ bool CudnnSupport::DoFusedConvolveImpl( << "\noutput_nd.handle() = " << output_nd.handle() << "\noutput_data->opaque() = " << output_data->opaque(); - auto status = cudnnConvolutionBiasActivationForward( + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward( cudnn.handle(), /*alpha1=*/&conv_input_scale, /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), - /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(), + /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc), + /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale, /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr, /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(), /*activationDesc=*/activation_desc.handle(), - /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque())); if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - timer->Destroy(); - return false; - } - if (status == CUDNN_STATUS_SUCCESS) { - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - timer->Destroy(); - } - - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; -#endif // CUDNN_VERSION < 6000 + return port::Status::OK(); } bool CudnnSupport::GetConvolveAlgorithms( @@ -2745,19 +2579,15 @@ bool CudnnSupport::GetConvolveAlgorithms( CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, CUDNN_CONVOLUTION_FWD_ALGO_FFT, -#if CUDNN_VERSION >= 5000 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, -#endif // clang-format on }; if (CudnnEnvVar::IsEnabled()) { algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING); } -#if CUDNN_VERSION >= 5100 if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); } -#endif out_algorithms->clear(); for (auto i : algo_types) { @@ -2772,13 +2602,11 @@ bool CudnnSupport::GetConvolveAlgorithms( bool CudnnSupport::GetRnnAlgorithms( std::vector* out_algorithms) { std::vector algo_types = { - // clang-format off -#if CUDNN_VERSION >= 6000 + // clang-format off CUDNN_RNN_ALGO_STANDARD, CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC, -#endif - // clang-format on + // clang-format on }; out_algorithms->clear(); @@ -2797,21 +2625,17 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( bool with_winograd_nonfused, int cc_major, int cc_minor, std::vector* out_algorithms) { std::vector algo_types = { - // clang-format off + // clang-format off CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, -#if CUDNN_VERSION >= 5000 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, -#endif - // clang-format on + // clang-format on }; -#if CUDNN_VERSION >= 5100 if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); } -#endif out_algorithms->clear(); for (auto i : algo_types) { @@ -2834,13 +2658,15 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, // Based on cudnn.h, the following is not implemented. // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, + + // Produces incorrect results for some shapes. Disabled for now, see + // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes. + // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, // clang-format on }; -#if CUDNN_VERSION >= 5100 if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); } -#endif out_algorithms->clear(); for (auto i : algo_types) { @@ -2864,11 +2690,13 @@ bool CudnnSupport::DoBatchNormalizationForward( DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - return DoBatchNormalizationForwardImpl( - stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset, - estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y, - batch_mean, batch_var, saved_mean, saved_inv_var, is_training, - std::move(var_to_inv_var), std::move(inv_var_to_var)); + return IsStatusOk( + DoBatchNormalizationForwardImpl( + stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, + offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc, + epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var, + is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)), + /*report_error=*/true); } bool CudnnSupport::DoBatchNormalizationForward( @@ -2883,15 +2711,17 @@ bool CudnnSupport::DoBatchNormalizationForward( DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - return DoBatchNormalizationForwardImpl( - stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset, - estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y, - batch_mean, batch_var, saved_mean, saved_inv_var, is_training, - std::move(var_to_inv_var), std::move(inv_var_to_var)); + return IsStatusOk( + DoBatchNormalizationForwardImpl( + stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset, + estimated_mean, estimated_variance, x_desc, scale_offset_desc, + epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var, + is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)), + /*report_error=*/true); } template -bool CudnnSupport::DoBatchNormalizationForwardImpl( +port::Status CudnnSupport::DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -2903,8 +2733,8 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( DeviceMemory* saved_mean, DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - ScopedTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type)); - ScopedTensorDescriptor scale_offset_descriptor( + CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type)); + CudnnTensorDescriptor scale_offset_descriptor( scale_offset_desc, ToCudnnDataType(scale_data_type)); cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7000 @@ -2916,7 +2746,6 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( float zero = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = CUDNN_STATUS_SUCCESS; if (is_training) { CHECK_EQ(batch_mean->is_null(), batch_var->is_null()) << "batch_mean and batch_var must both be null or both be non-null"; @@ -2933,35 +2762,21 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( batch_var_opaque = nullptr; } - status = cudnnBatchNormalizationForwardTraining( + RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(), - saved_inv_var->opaque()); -#if CUDNN_VERSION < 5000 - CHECK(inv_var_to_var); - inv_var_to_var(); -#endif + saved_inv_var->opaque())); } else { -#if CUDNN_VERSION < 5000 - CHECK(var_to_inv_var); - const void* maybe_inv_var = var_to_inv_var().opaque(); -#else const void* maybe_inv_var = estimated_variance.opaque(); -#endif - status = cudnnBatchNormalizationForwardInference( + RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var, - epsilon); + epsilon)); } - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward batch normalization on stream: " - << ToString(status); - return false; - } - return true; + return port::Status::OK(); } bool CudnnSupport::DoBatchNormalizationBackward( @@ -2972,10 +2787,11 @@ bool CudnnSupport::DoBatchNormalizationBackward( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - return DoBatchNormalizationBackwardImpl( - stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean, - inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + return IsStatusOk(DoBatchNormalizationBackwardImpl( + stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, + x, scale, mean, inv_var, x_desc, scale_offset_desc, + epsilon, x_backprop, scale_backprop, offset_backprop), + /*report_error=*/true); } bool CudnnSupport::DoBatchNormalizationBackward( @@ -2986,14 +2802,15 @@ bool CudnnSupport::DoBatchNormalizationBackward( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - return DoBatchNormalizationBackwardImpl( - stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean, - inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + return IsStatusOk(DoBatchNormalizationBackwardImpl( + stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, + x, scale, mean, inv_var, x_desc, scale_offset_desc, + epsilon, x_backprop, scale_backprop, offset_backprop), + /*report_error=*/true); } template -bool CudnnSupport::DoBatchNormalizationBackwardImpl( +port::Status CudnnSupport::DoBatchNormalizationBackwardImpl( Stream* stream, int cudnn_input_type, int cudnn_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& mean, @@ -3001,9 +2818,9 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - ScopedTensorDescriptor x_descriptor( + CudnnTensorDescriptor x_descriptor( x_desc, static_cast(cudnn_input_type)); - ScopedTensorDescriptor scale_offset_descriptor( + CudnnTensorDescriptor scale_offset_descriptor( scale_offset_desc, static_cast(cudnn_scale_type)); cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7000 @@ -3016,19 +2833,14 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnBatchNormalizationBackward( + RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward( cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(), scale_offset_descriptor.handle(), scale.opaque(), scale_backprop->opaque(), offset_backprop->opaque(), epsilon, - mean.opaque(), inv_var.opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward batch normalization on stream: " - << ToString(status); - return false; - } - return true; + mean.opaque(), inv_var.opaque())); + return port::Status::OK(); } bool CudnnSupport::DoConvolve( @@ -3041,10 +2853,12 @@ bool CudnnSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveImpl( - stream, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolve( @@ -3057,10 +2871,12 @@ bool CudnnSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveImpl( - stream, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolve( @@ -3073,10 +2889,12 @@ bool CudnnSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveImpl( - stream, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3092,13 +2910,15 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3114,13 +2934,15 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3137,13 +2959,15 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3159,11 +2983,6 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION < 6000 - LOG(WARNING) << "cudnnConvolutionBiasActivationForward() is only " - "supported for cuDNN version >= 6"; - return false; -#else int cc_major, cc_minor; stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); @@ -3172,63 +2991,17 @@ bool CudnnSupport::DoFusedConvolve( "supported on GPUs with compute capability 6.1 or later."; return false; } - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); -#endif + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } -namespace { -// NOTE(keveman): Temporary data layout transformation until cuDNN supports -// kBatchYXDepth for backward pass. This function allocates temporary memory, -// lays out the source data into the temporary but in the kBatchDepthXY -// layout, and returns the temporary memory. The caller is responsible for -// deallocating the temporary. Since the allocation is done using Stream's -// AllocateTemporaryMemory, a later BlockHostUntilDone could be used for -// deallocation. -// -// transform_scratch is populated with a legitimate temporary allocation iff -// the original output data needs to be transformed. -template -DeviceMemory MaybeTransformLayout( - Stream* stream, const CudnnHandle& cudnn, - dnn::BatchDescriptor* output_descriptor, - DeviceMemory backward_output_data, - std::unique_ptr>* transform_scratch) { - if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) { - return backward_output_data; - } - CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth); - *transform_scratch = - stream->AllocateTemporaryArray(backward_output_data.ElementCount()) - .ConsumeValueOrDie(); - dnn::BatchDescriptor transformed_output_descriptor; - transformed_output_descriptor.CloneFrom(*output_descriptor); - transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX); - cudnnDataType_t cudnn_type = GetCudnnDataType(); - ScopedTensorDescriptor orig_out_back_nd(*output_descriptor, cudnn_type); - ScopedTensorDescriptor transformed_out_back_nd(transformed_output_descriptor, - cudnn_type); - - float alpha = 1.0f; - float beta = 0.0f; - auto status = cudnnTransformTensor( - cudnn.handle(), &alpha, orig_out_back_nd.handle(), - backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(), - (*transform_scratch)->mutable_device_memory()->opaque()); - - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "Failed to transform the data layout."; - } - output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX); - return (*transform_scratch)->device_memory(); -} -} // namespace - bool CudnnSupport::DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, dnn::DataType input_type, @@ -3237,30 +3010,25 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { float beta = 0.0f; - ScopedTensorDescriptor input_tensor_desc( + CudnnTensorDescriptor input_tensor_desc( input_desc, ToCudnnDataType(input_type, input_desc.layout())); - ScopedTensorDescriptor output_tensor_desc( + CudnnTensorDescriptor output_tensor_desc( output_desc, ToCudnnDataType(output_type, output_desc.layout())); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnTransformTensor( - cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(), - &beta, output_tensor_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Could not transform a tensor with layout " - << input_desc.ToString() << " and data type " - << static_cast(input_type) << " to another with layout " - << output_desc.ToString() << " and data type " - << static_cast(output_type) << ": " << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnTransformTensor( + cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(), + &beta, output_tensor_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } template -bool CudnnSupport::DoConvolveBackwardDataImpl( +port::Status CudnnSupport::DoConvolveBackwardDataImpl( Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, - const dnn::BatchDescriptor& output_descriptor_in, + const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::BatchDescriptor& input_descriptor, @@ -3281,192 +3049,71 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass. - dnn::BatchDescriptor output_descriptor; - output_descriptor.CloneFrom(output_descriptor_in); - std::unique_ptr> transform_scratch; - backward_output_data = - MaybeTransformLayout(stream, cudnn, &output_descriptor, - backward_output_data, &transform_scratch); - - ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type); - ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type); - ScopedFilterDescriptor filter(filter_descriptor, cudnn_type); - ScopedConvolutionDescriptor conv(convolution_descriptor, - GetConvComputeType()); + CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type); + CudnnTensorDescriptor in_back_nd(input_descriptor, cudnn_type); + CudnnFilterDescriptor filter(filter_descriptor, cudnn_type); + CudnnConvolutionDescriptor conv(convolution_descriptor, + GetConvComputeType()); const bool is_profiling = output_profile_result != nullptr; - cudnnConvolutionBwdDataAlgo_t algo; - DeviceMemory scratch; - - if (algorithm_config.algorithm().is_default()) { - // With the default algorithm, use Cudnn's heuristics. - auto get_algorithm = - [&](bool specify_limit) -> cudnnConvolutionBwdDataAlgo_t { - cudnnConvolutionBwdDataPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - cudnnConvolutionBwdDataAlgo_t algo_to_use; - cudnnStatus_t status = cudnnGetConvolutionBackwardDataAlgorithm( - cudnn.handle(), - /*filterDesc=*/filter.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/in_back_nd.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " - "algorithm for doing backward " - "data convolution"; - return algo_to_use; - }; - - algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); - - if (scratch_allocator != nullptr) { - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn.handle(), - /*filterDesc=*/filter.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/in_back_nd.handle(), - /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { - if (size_in_bytes_int64 > 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - } else { - LOG(WARNING) - << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - } - // If we didn't allocate any scratch space (perhaps because of failed - // allocation), we force a switch back to the "no workspace" algorithm. - if (scratch == nullptr) { - algo = get_algorithm(/*specify_limit=*/false); - } - } else { - // An algorithm has been specified. - dnn::AlgorithmDesc algotype = algorithm_config.algorithm(); - algo = ToConvBackwardDataAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn.handle(), - /*filterDesc=*/filter.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/in_back_nd.handle(), - /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - // Silently return when we are profiling. - return false; - } - LOG(FATAL) << "Cannot query the size of workspace needed for the given " - "algorithm: " - << algorithm_config.algorithm().algo_id(); - } - int64 size_in_bytes_int64 = size_in_bytes; - if (size_in_bytes_int64 > 0) { - if (scratch_allocator == nullptr) { - LOG(FATAL) << "An allocator must be specified when scratch memory is " - "needed"; - } - auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (is_profiling && !allocated.ok()) { - // Silently return when we are profiling. - return false; - } - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - if (scratch == nullptr) { - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch(); - algo = ToConvBackwardDataAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - } - } else if (size_in_bytes_int64 < 0) { - LOG(WARNING) << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionBackwardDataAlgorithm( + stream, cudnn, algorithm_config, in_back_nd, filter, + conv, out_back_nd, scratch_allocator, &scratch)); - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - timer->Init(); // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - timer->Start(AsCUDAStream(stream)); + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } } -#if CUDNN_VERSION >= 5000 - auto status = + // Cudnn 7.1.4 has a bug if the workspace of the following convolution is not + // zero-initialized. + // TODO(timshen): Add an nvbugs/ link. + if (CUDNN_VERSION >= 7000 && + algorithm_config.algorithm().algo_id() == + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 && + cudnn_type == CUDNN_DATA_HALF && + algorithm_config.algorithm().tensor_ops_enabled() && + input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && + filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX && + output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX && + (convolution_descriptor.vertical_filter_stride() > 1 || + convolution_descriptor.horizontal_filter_stride() > 1)) { + stream->ThenMemZero(&scratch, scratch.size()); + } + + RETURN_IF_CUDNN_ERROR( cudnnConvolutionBackwardData(cudnn.handle(), -#else - auto status = - cudnnConvolutionBackwardData_v3(cudnn.handle(), -#endif /*alpha=*/alpha, /*wDesc=*/filter.handle(), /*w=*/filter_data.opaque(), /*dyDesc=*/out_back_nd.handle(), /*dy=*/backward_output_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, + /*algo=*/ToConvBackwardDataAlgo(algo_desc), /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, /*dxDesc=*/in_back_nd.handle(), - /*dx=*/backward_input_data->opaque()); + /*dx=*/backward_input_data->opaque())); if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - if (status == CUDNN_STATUS_SUCCESS) { - bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - dnn::AlgorithmDesc algotype(algo, use_tensor_ops); - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - timer->Destroy(); - } - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); + if (!timer->Stop(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + + return port::Status::OK(); } bool CudnnSupport::DoConvolveBackwardData( @@ -3480,11 +3127,13 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor, backward_output_data, - convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardData( @@ -3498,11 +3147,13 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor, backward_output_data, - convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardData( @@ -3516,18 +3167,20 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor, backward_output_data, - convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } template -bool CudnnSupport::DoConvolveBackwardFilterImpl( +port::Status CudnnSupport::DoConvolveBackwardFilterImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, - const dnn::BatchDescriptor& output_descriptor_in, + const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -3548,195 +3201,83 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass. - dnn::BatchDescriptor output_descriptor; - output_descriptor.CloneFrom(output_descriptor_in); - std::unique_ptr> transform_scratch; - backward_output_data = - MaybeTransformLayout(stream, cudnn, &output_descriptor, - backward_output_data, &transform_scratch); - - ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type); - ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type); - ScopedFilterDescriptor filter(filter_descriptor, cudnn_type); - ScopedConvolutionDescriptor conv(convolution_descriptor, - GetConvComputeType()); + CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type); + CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type); + CudnnFilterDescriptor filter(filter_descriptor, cudnn_type); + CudnnConvolutionDescriptor conv(convolution_descriptor, + GetConvComputeType()); const bool is_profiling = output_profile_result != nullptr; - cudnnConvolutionBwdFilterAlgo_t algo; - DeviceMemory scratch; - - if (algorithm_config.algorithm().is_default()) { - // With the default algorithm, use Cudnn's heuristics. - - // Lambda that retrieves the algorithm. - // specify_limit will occur when we have a scratch allocator and it succeeds - // in allocating; otherwise, we'll fall back to the "no workspace" version. - auto get_algorithm = [&](bool specify_limit) { - cudnnConvolutionBwdFilterPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - - cudnnConvolutionBwdFilterAlgo_t algo_to_use; - cudnnStatus_t status = cudnnGetConvolutionBackwardFilterAlgorithm( - cudnn.handle(), - /*srcDesc=*/input_nd.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/filter.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " - "algorithm for doing backward " - "filter convolution"; - return algo_to_use; - }; - - algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); - - if (scratch_allocator != nullptr) { - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), - /*gradDesc=*/filter.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { - if (size_in_bytes_int64 > 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - } else { - LOG(WARNING) - << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - } - // If we didn't allocate any scratch space (perhaps because of failed - // allocation), we force a switch back to the "no workspace" algorithm. - if (scratch == nullptr) { - algo = get_algorithm(/*specify_limit=*/false); - } - } else { - // An algorithm has been specified. - dnn::AlgorithmDesc algotype = algorithm_config.algorithm(); - algo = ToConvBackwardFilterAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), - /*gradDesc=*/filter.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - // Silently return when we are profiling. - return false; - } - LOG(FATAL) << "Cannot query the size of workspace needed for the given " - "algorithm: " - << algorithm_config.algorithm().algo_id(); - } - int64 size_in_bytes_int64 = size_in_bytes; - if (size_in_bytes_int64 > 0) { - if (scratch_allocator == nullptr) { - LOG(FATAL) << "An allocator must be specified when scratch memory is " - "needed"; - } - auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (is_profiling && !allocated.ok()) { - // Silently return when we are profiling. - return false; - } - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - if (scratch == nullptr) { - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch(); - algo = ToConvBackwardFilterAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - } - } else if (size_in_bytes_int64 < 0) { - LOG(WARNING) - << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionBackwardFilterAlgorithm( + stream, cudnn, algorithm_config, input_nd, filter, + conv, out_back_nd, scratch_allocator, &scratch)); - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - timer->Init(); // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - timer->Start(AsCUDAStream(stream)); + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } } -#if CUDNN_VERSION >= 5000 - auto status = cudnnConvolutionBackwardFilter( - cudnn.handle(), -#else - auto status = cudnnConvolutionBackwardFilter_v3( + // Report an error if we might be hitting a cuDNN bug that produces incorrect + // results. See nvbugs/2072856 + SE_RETURN_IF_ERROR([&] { + if (algo_desc.algo_id() != CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) { + return port::Status::OK(); + } + if (output_descriptor.height() > 1 && output_descriptor.width() > 1) { + return port::Status::OK(); + } + int convolution_size = output_descriptor.height() > 1 + ? filter_descriptor.input_filter_height() + : filter_descriptor.input_filter_width(); + if (convolution_size <= 32) { + return port::Status::OK(); + } + cudnnConvolutionMode_t convolution_mode; + cudnnDataType_t compute_type; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor( + conv.handle(), 0, nullptr, nullptr, nullptr, nullptr, &convolution_mode, + &compute_type)); + if (convolution_mode != CUDNN_CONVOLUTION) { + return port::Status::OK(); + } + return port::Status( + port::error::FAILED_PRECONDITION, + "This configuration potentially produces incorrect results."); + }()); + + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter( cudnn.handle(), -#endif /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*diffDesc=*/out_back_nd.handle(), /*diffData=*/backward_output_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, + /*algo=*/ToConvBackwardFilterAlgo(algo_desc), /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, /*gradDesc=*/filter.handle(), - /*gradData=*/backward_filter_data->opaque()); - + /*dw=*/backward_filter_data->opaque())); if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - if (status == CUDNN_STATUS_SUCCESS) { - bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - dnn::AlgorithmDesc algotype(algo, use_tensor_ops); - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - timer->Destroy(); - } - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); + if (!timer->Stop(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + + return port::Status::OK(); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3750,11 +3291,13 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, - output_descriptor, backward_output_data, - convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3768,11 +3311,13 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, - output_descriptor, backward_output_data, - convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3786,22 +3331,24 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, - output_descriptor, backward_output_data, - convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } template -bool CudnnSupport::DoConvolveBackwardBiasImpl( +port::Status CudnnSupport::DoConvolveBackwardBiasImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { cudnnDataType_t cudnn_type = GetCudnnDataType(); - ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type); - ScopedTensorDescriptor bias_nd(bias_descriptor, cudnn_type); + CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type); + CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type); // Alpha is the scaling factor for input. float alpha = 1.0; @@ -3809,15 +3356,10 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl( float beta = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnConvolutionBackwardBias( + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias( cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta, - bias_nd.handle(), backward_bias_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward convolution on stream: " - << ToString(status); - return false; - } - return true; + bias_nd.handle(), backward_bias_data->opaque())); + return port::Status::OK(); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3825,8 +3367,10 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, - bias_descriptor, backward_bias_data); + return IsStatusOk( + DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data), + /*report_error=*/true); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3834,8 +3378,10 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, - bias_descriptor, backward_bias_data); + return IsStatusOk( + DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data), + /*report_error=*/true); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3843,8 +3389,10 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, - bias_descriptor, backward_bias_data); + return IsStatusOk( + DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data), + /*report_error=*/true); } bool CudnnSupport::DoMatMul(Stream* stream, @@ -3987,7 +3535,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, const DeviceMemory& biases, const dnn::BatchDescriptor& dimensions, DeviceMemory* output_data) { - ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT); + CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT); dnn::BatchDescriptor bias_dimensions; bias_dimensions.set_count(1) @@ -3995,7 +3543,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, .set_height(1) .set_width(1) .set_layout(dnn::DataLayout::kBatchYXDepth); - ScopedTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT); + CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT); // cudnnAddTensor after R3 is in-place, so we need to copy input_data to // output_data before doing the addition, unless the input and @@ -4016,20 +3564,13 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, auto cudnn = cudnn_->GetHandle(parent_, stream); -#if CUDNN_VERSION >= 5000 - auto status = cudnnAddTensor( -#else - auto status = cudnnAddTensor_v3( -#endif - cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta, - input_descriptor.handle(), output_data->opaque()); - - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "stream " << stream << " could not enqueue bias addition."; - return false; - } - - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnAddTensor( + cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), + &beta, input_descriptor.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoActivate(Stream* stream, @@ -4038,61 +3579,23 @@ bool CudnnSupport::DoActivate(Stream* stream, const DeviceMemory& input_data, DeviceMemory* output_data, uint64 options) { -#if CUDNN_VERSION >= 5000 - ScopedActivationDescriptor activation_desc( + CudnnActivationDescriptor activation_desc( activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()); -#else - cudnnActivationMode_t mode; - switch (activation_mode) { - case dnn::ActivationMode::kRelu6: - // TODO(leary) should probably do a post-pass to clip at 6? - LOG(WARNING) << "user requested Relu6, but providing Relu instead"; - mode = CUDNN_ACTIVATION_RELU; - break; - case dnn::ActivationMode::kReluX: - // TODO(broune) should probably do a post-pass to clip at X? - LOG(WARNING) << "user requested ReluX, but providing Relu instead"; - mode = CUDNN_ACTIVATION_RELU; - break; - case dnn::ActivationMode::kRelu: - mode = CUDNN_ACTIVATION_RELU; - break; - case dnn::ActivationMode::kSigmoid: - mode = CUDNN_ACTIVATION_SIGMOID; - break; - case dnn::ActivationMode::kTanh: - mode = CUDNN_ACTIVATION_TANH; - break; - default: - LOG(ERROR) << "unrecognized activation mode: " - << static_cast(activation_mode); - return false; - } -#endif - ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT); + CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT); // Alpha is the input scaling factor. float alpha = 1.0; // Beta is the output scaling factor. float beta = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = - cudnnActivationForward(cudnn.handle(), -#if CUDNN_VERSION >= 5000 - activation_desc.handle(), -#else - mode, -#endif - &alpha, input_nd.handle(), input_data.opaque(), - &beta, input_nd.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "stream " << stream - << " could not enqueue activation: " << ToString(status); - return false; - } - - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnActivationForward( + cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(), + input_data.opaque(), &beta, input_nd.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolForward( @@ -4106,20 +3609,18 @@ bool CudnnSupport::DoPoolForward( // Beta is the scaling factor for output. double beta = 0.0; - ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE); - ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE); - ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE); + CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE); + CudnnPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingForward( - cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), - input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolForward( @@ -4133,20 +3634,18 @@ bool CudnnSupport::DoPoolForward( // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT); - ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT); - ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT); + CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT); + CudnnPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingForward( - cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), - input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolForward( @@ -4160,19 +3659,17 @@ bool CudnnSupport::DoPoolForward( // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF); - ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); - ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF); + CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); + CudnnPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingForward( - cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), - input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolBackward( @@ -4188,22 +3685,20 @@ bool CudnnSupport::DoPoolBackward( // Beta is the scaling factor for output. double beta = 0.0; - ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE); - ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE); - ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE); + CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE); + CudnnPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingBackward( - cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), - output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), - src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), - output_diff_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolBackward( @@ -4219,22 +3714,20 @@ bool CudnnSupport::DoPoolBackward( // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT); - ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT); - ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT); + CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT); + CudnnPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingBackward( - cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), - output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), - src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), - output_diff_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolBackward( @@ -4250,22 +3743,20 @@ bool CudnnSupport::DoPoolBackward( // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF); - ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); - ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF); + CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); + CudnnPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingBackward( - cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), - output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), - src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), - output_diff_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoNormalize( @@ -4289,8 +3780,8 @@ bool CudnnSupport::DoNormalizeWithDimensions( return false; } - ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT); - ScopedNormalizeDescriptor normalize(normalize_descriptor); + CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT); + CudnnNormalizeDescriptor normalize(normalize_descriptor); // Alpha is the scaling factor for input. float alpha = 1.0f; @@ -4300,15 +3791,14 @@ bool CudnnSupport::DoNormalizeWithDimensions( auto cudnn = cudnn_->GetHandle(parent_, stream); // Launch the normalization. - auto status = cudnnLRNCrossChannelForward( - cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, - dims.handle(), input_data.opaque(), &beta, dims.handle(), - output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward"; - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward( + cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(), + output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoNormalizeBackwardWithDimensions( @@ -4327,23 +3817,22 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( return false; } - ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT); - ScopedNormalizeDescriptor normalize(normalize_descriptor); + CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT); + CudnnNormalizeDescriptor normalize(normalize_descriptor); float alpha = 1.0f; float beta = 0.0f; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnLRNCrossChannelBackward( - cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, - dims.handle(), normalized_data.opaque(), dims.handle(), - normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), - &beta, dims.handle(), raw_variable_gradient->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to run cudnnLRNCrossChannelBackward"; - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward( + cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, dims.handle(), normalized_data.opaque(), dims.handle(), + normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), + &beta, dims.handle(), raw_variable_gradient->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoDepthConcatenate( @@ -4452,30 +3941,26 @@ bool CudnnSupport::DeriveOutputBatchDescriptor( const dnn::FilterDescriptor& filter_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::BatchDescriptor* output_batch_descriptor) { - ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT); - ScopedFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT); - ScopedConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT); + CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT); + CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT); + CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT); int dn = batch_descriptor.ndims() + 2; std::vector dims(dn); // in BDYX - auto status = cudnnGetConvolutionNdForwardOutputDim( - conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not get output tensor for convolution: " - << ToString(status); - return false; - } - - output_batch_descriptor->set_count(dims[0]) - .set_feature_map_count(dims[1]) - .set_layout(batch_descriptor.layout()); + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim( + conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data())); + output_batch_descriptor->set_count(dims[0]) + .set_feature_map_count(dims[1]) + .set_layout(batch_descriptor.layout()); - for (int i = 0; i < batch_descriptor.ndims(); i++) { - output_batch_descriptor->set_spatial_dim(static_cast(i), - dims.rbegin()[i]); - } - - return true; + for (int i = 0; i < batch_descriptor.ndims(); i++) { + output_batch_descriptor->set_spatial_dim(static_cast(i), + dims.rbegin()[i]); + } + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } } // namespace cuda diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index e2de3c62d81ae56c28fd4b888c74435ceecc6b22..c924d41cb5239d704e658f0b5452e04087caeba2 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -631,7 +631,7 @@ class CudnnSupport : public dnn::DnnSupport { std::unique_ptr cudnn_; template - bool DoBatchNormalizationForwardImpl( + port::Status DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -646,7 +646,7 @@ class CudnnSupport : public dnn::DnnSupport { std::function inv_var_to_var); template - bool DoBatchNormalizationBackwardImpl( + port::Status DoBatchNormalizationBackwardImpl( Stream* stream, int cudnn_input_type, int cudnn_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& mean, @@ -656,21 +656,20 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* offset_backprop); template - bool DoConvolveImpl(Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const DeviceMemory& input_data, - const dnn::FilterDescriptor& filter_descriptor, - const DeviceMemory& filter_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemory* output_data, - ScratchAllocator* scratch_allocator, - const dnn::AlgorithmConfig& algorithm_config, - dnn::ProfileResult* output_profile_result); + port::Status DoConvolveImpl( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result); template - bool DoFusedConvolveImpl( + port::Status DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, const DeviceMemory& conv_input_data, ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, @@ -685,9 +684,8 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - bool DoConvolveBackwardDataImpl( - Stream* stream, - const dnn::FilterDescriptor& filter_descriptor, + port::Status DoConvolveBackwardDataImpl( + Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, @@ -698,10 +696,10 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - bool DoConvolveBackwardFilterImpl( + port::Status DoConvolveBackwardFilterImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, - const dnn::BatchDescriptor& output_descriptor_in, + const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -711,56 +709,56 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - bool DoConvolveBackwardBiasImpl(Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const DeviceMemory& input_data, - const dnn::BatchDescriptor& bias_descriptor, - DeviceMemory* backward_bias_data); + port::Status DoConvolveBackwardBiasImpl( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::BatchDescriptor& bias_descriptor, + DeviceMemory* backward_bias_data); template - bool DoRnnForwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const CudnnRnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const CudnnRnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const CudnnRnnSequenceTensorDescriptor& output_desc, - DeviceMemory* output_data, - const CudnnRnnStateTensorDescriptor& output_h_desc, - DeviceMemory* output_h_data, - const CudnnRnnStateTensorDescriptor& output_c_desc, - DeviceMemory* output_c_data, bool is_training, - ScratchAllocator* reserve_space_allocator, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result); + port::Status DoRnnForwardImpl( + Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, const DeviceMemory& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + DeviceMemory* output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + DeviceMemory* output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + DeviceMemory* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator, + dnn::ProfileResult* output_profile_result); template - bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const CudnnRnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const CudnnRnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const CudnnRnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const CudnnRnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const CudnnRnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result); + port::Status DoRnnBackwardImpl( + Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, const DeviceMemory& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + dnn::ProfileResult* output_profile_result); SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport); }; diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index e7e4192dfc7cc041819e1dc789fbf959187f716e..d508f6594a9f9ac3c924b0b952620b6a4ac727ea 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -26,16 +26,16 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/human_readable.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/notification.h" -#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/stringprintf.h" +#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" bool FLAGS_gpuexec_cuda_driver_inject_init_error = false; bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false; @@ -204,11 +204,11 @@ string ToString(CUresult result) { case 719: return "CUDA_ERROR_LAUNCH_FAILED"; - OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE) - OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED) - OSTREAM_CUDA_ERROR(NOT_PERMITTED) - OSTREAM_CUDA_ERROR(NOT_SUPPORTED) - OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA. + OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE) + OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED) + OSTREAM_CUDA_ERROR(NOT_PERMITTED) + OSTREAM_CUDA_ERROR(NOT_SUPPORTED) + OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA. default: return port::StrCat("CUresult(", static_cast(result), ")"); } @@ -470,7 +470,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, } /* static */ port::Status CUDADriver::CreateContext( - CUdevice device, DeviceOptions device_options, CudaContext** context) { + CUdevice device, const DeviceOptions &device_options, + CudaContext **context) { *context = nullptr; int flags = 0; @@ -481,62 +482,45 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, CUresult res; CUcontext former_context; CUcontext new_context; - { - // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their - // context creation: see http://b/13248943 -#if CUDA_VERSION >= 7000 - { - unsigned int former_primary_context_flags; - int former_primary_context_is_active; - CHECK_EQ(CUDA_SUCCESS, - cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, - &former_primary_context_is_active)); - if (former_primary_context_flags != flags) { - if (former_primary_context_is_active) { - LOG(ERROR) - << "The primary context is active and has a different flag set (" - << former_primary_context_flags << ") than the desired flag set (" - << flags << ")."; - } else { - CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags)); - } - } + unsigned int former_primary_context_flags; + int former_primary_context_is_active; + CHECK_EQ(CUDA_SUCCESS, + cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, + &former_primary_context_is_active)); + if (former_primary_context_flags != flags) { + if (former_primary_context_is_active) { + LOG(ERROR) + << "The primary context is active and has a different flag set (" + << former_primary_context_flags << ") than the desired flag set (" + << flags << ")."; + } else { + CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags)); } + } - former_context = CUDADriver::CurrentContextOrDie(); - res = cuDevicePrimaryCtxRetain(&new_context, device); - if (former_context != nullptr) { - CUdevice former_device; - if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) { - if (former_device == device) { - if (former_context == new_context) { - VLOG(2) << "The primary context " << former_context - << " for device " << device - << " exists before initializing the StreamExecutor."; - } else { - LOG(WARNING) - << "A non-primary context " << former_context << " for device " - << device - << " exists before initializing the StreamExecutor. The " - << "primary context is now " << new_context << ". We " - << "haven't verified StreamExecutor works with that."; - } + former_context = CUDADriver::CurrentContextOrDie(); + res = cuDevicePrimaryCtxRetain(&new_context, device); + if (former_context != nullptr) { + CUdevice former_device; + if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) { + if (former_device == device) { + if (former_context == new_context) { + VLOG(2) << "The primary context " << former_context << " for device " + << device + << " exists before initializing the StreamExecutor."; + } else { + LOG(WARNING) << "A non-primary context " << former_context + << " for device " << device + << " exists before initializing the StreamExecutor. The " + << "primary context is now " << new_context << ". We " + << "haven't verified StreamExecutor works with that."; } - } else { - LOG(ERROR) << "Failed to get the device of the current context " - << former_context; } + } else { + LOG(ERROR) << "Failed to get the device of the current context " + << former_context; } -#else - former_context = CurrentContext(); - if (former_context != nullptr) { - LOG(WARNING) - << "creating context when one is currently active; existing: " - << former_context; - } - res = cuCtxCreate(&new_context, flags, device); -#endif } CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(former_context)); @@ -548,11 +532,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, return port::Status::OK(); } -#if CUDA_VERSION >= 7000 string message = "failed call to cuDevicePrimaryCtxRetain: " + ToString(res); -#else - string message = "failed call to cuCtxCreate: " + ToString(res); -#endif if (res == CUDA_ERROR_OUT_OF_MEMORY) { uint64 total_memory; if (GetDeviceTotalMemory(device, &total_memory)) { @@ -569,7 +549,6 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, if (context == nullptr) { return; } -#if CUDA_VERSION >= 7000 CUcontext former_context = CurrentContext(); CUresult res = cuCtxSetCurrent(context->context()); CUdevice device; @@ -577,9 +556,6 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, cuCtxSetCurrent(former_context); res = cuDevicePrimaryCtxRelease(device); -#else - CUresult res = cuCtxDestroy(context->context()); -#endif if (res != CUDA_SUCCESS) { LOG(ERROR) << "failed to release CUDA context; leaking: " << ToString(res); @@ -948,6 +924,37 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) { } } +/* static */ void *CUDADriver::UnifiedMemoryAllocate(CudaContext *context, + uint64 bytes) { + ScopedActivateContext activation(context); + CUdeviceptr result = 0; + // "Portable" memory is visible to all CUDA contexts. Safe for our use model. + CUresult res = cuMemAllocManaged(&result, bytes, CU_MEM_ATTACH_GLOBAL); + if (res != CUDA_SUCCESS) { + LOG(ERROR) << "failed to alloc " << bytes + << " bytes unified memory; result: " << ToString(res); + return nullptr; + } + void *ptr = reinterpret_cast(result); + VLOG(2) << "allocated " << ptr << " for context " << context << " of " + << bytes << " bytes in unified memory"; + return ptr; +} + +/* static */ void CUDADriver::UnifiedMemoryDeallocate(CudaContext *context, + void *location) { + ScopedActivateContext activation(context); + CUdeviceptr pointer = port::bit_cast(location); + CUresult res = cuMemFree(pointer); + if (res != CUDA_SUCCESS) { + LOG(ERROR) << "failed to free unified memory at " << location + << "; result: " << ToString(res); + } else { + VLOG(2) << "deallocated unified memory at " << location << " for context " + << context; + } +} + /* static */ void *CUDADriver::HostAllocate(CudaContext *context, uint64 bytes) { ScopedActivateContext activation(context); diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h index a9969e247e181599f2b3707f6c65c6527dd4683d..3713a5b7b98f8bd5173d649fa592107f06bda27d 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.h +++ b/tensorflow/stream_executor/cuda/cuda_driver.h @@ -106,6 +106,16 @@ class CUDADriver { // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a static void DeviceDeallocate(CudaContext* context, void *location); + // Allocates a unified memory space of size bytes associated with the given + // context via cuMemAllocManaged. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb347ded34dc326af404aa02af5388a32 + static void* UnifiedMemoryAllocate(CudaContext* context, uint64 bytes); + + // Deallocates a unified memory space of size bytes associated with the given + // context via cuMemFree. + // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a + static void UnifiedMemoryDeallocate(CudaContext* context, void* location); + // Allocates page-locked and CUDA-registered memory on the host via // cuMemAllocHost. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0 @@ -147,7 +157,7 @@ class CUDADriver { // userspace processes is given here: // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf static port::Status CreateContext(CUdevice device, - DeviceOptions device_options, + const DeviceOptions& device_options, CudaContext** context); // Destroys the provided context via cuCtxDestroy. diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index f2be68bc421c1fbc31ea5a054b91130c11949635..f11022ef1dfd4a1a08d035f5328724d93ac808be 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -180,11 +180,11 @@ bool CUDAExecutor::FindOnDiskForComputeCapability( static string GetBinaryDir(bool strip_exe) { char exe_path[PATH_MAX] = {0}; #if defined(__APPLE__) - uint32_t buffer_size = 0U; - _NSGetExecutablePath(nullptr, &buffer_size); - char unresolved_path[buffer_size]; - _NSGetExecutablePath(unresolved_path, &buffer_size); - CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1); + uint32_t buffer_size = 0U; + _NSGetExecutablePath(nullptr, &buffer_size); + char unresolved_path[buffer_size]; + _NSGetExecutablePath(unresolved_path, &buffer_size); + CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1); #else #if defined(PLATFORM_WINDOWS) HMODULE hModule = GetModuleHandle(NULL); diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index f686685474b35acfb54c327401500c42109006d0..773cbfb8a17a416d18ae599bf4f72e1550538dee 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -74,6 +74,14 @@ class CUDAExecutor : public internal::StreamExecutorInterface { void Deallocate(DeviceMemoryBase *mem) override; + void *UnifiedMemoryAllocate(uint64 size) override { + return CUDADriver::UnifiedMemoryAllocate(context_, size); + } + + void UnifiedMemoryDeallocate(void *location) override { + return CUDADriver::UnifiedMemoryDeallocate(context_, location); + } + // CUDA allocation/registration functions are necessary because the driver // internally sets up buffers for DMA operations (and page locks them). // There's no external interface for us to otherwise control these DMA diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc index ebe4dcc90436a7e410596694daf155245d5c94c2..622a4a4edb1fe4163831e9429c1a7ab9262f2727 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform.cc @@ -206,7 +206,6 @@ static void InitializeCudaPlatform() { REGISTER_MODULE_INITIALIZER(cuda_platform, stream_executor::InitializeCudaPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(cuda_platform, multi_platform_manager); diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h index 70554ec93120fcb0251ba0995a1ce9d6e5997016..e040cf86fad1f40a708ad4ca28693e31908393f0 100644 --- a/tensorflow/stream_executor/cuda/cuda_timer.h +++ b/tensorflow/stream_executor/cuda/cuda_timer.h @@ -37,8 +37,9 @@ class CUDATimer : public internal::TimerInterface { explicit CUDATimer(CUDAExecutor *parent) : parent_(parent), start_event_(nullptr), stop_event_(nullptr) {} - // Note: teardown is explicitly handled in this API by a call to + // Note: teardown needs to be explicitly handled in this API by a call to // StreamExecutor::DeallocateTimer(), which invokes Destroy(). + // TODO(csigg): Change to RAII. ~CUDATimer() override {} // Allocates the platform-specific pieces of the timer, called as part of diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index eed93efc8d655276d4afc8c651abc90dab7dc3c4..82aa8ceb3298a30a4c117882dc96c504d9d10226 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -141,6 +141,10 @@ string PadAlignmentString(PadAlignment alignment) { return "unknown pad alignment"; } +std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) { + return str << PadAlignmentString(alignment); +} + string ShortPoolingModeString(PoolingMode mode) { switch (mode) { case PoolingMode::kMaximum: @@ -407,6 +411,8 @@ string FilterDescriptor::ToShortString() const { switch (layout_) { case FilterLayout::kOutputInputYX: return port::StrCat(od, id, spatial); + case FilterLayout::kOutputYXInput: + return port::StrCat(od, spatial, id); case FilterLayout::kOutputInputYX4: return port::StrCat(od, id, spatial, "(VECT_C)"); case FilterLayout::kInputYXOutput: diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 38abc66079264dd46634fca1b0d9297844a31aa1..9eca5abe1ae7265ebca0a1ea653823816deaa8f5 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -349,6 +349,8 @@ enum class FilterLayout : int64 { kOutputInputYX = 0, // cuDNN's default filter layout, laid out as: // (major) output feature maps >> input feature maps >> // rows >> columns (minor). + kOutputYXInput, // major to minor: + // (output features, row, columns, input features) kOutputInputYX4, // laid out the same as kOutputInputYX but each element is a // vector of 4 feature maps. kInputYXOutput, // Same as dist_belief's default filter layout. @@ -467,6 +469,9 @@ enum class PadAlignment : int64 { // Returns a string representation of the given padding alignment. string PadAlignmentString(PadAlignment alignment); +// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments. +std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment); + // Describes a convolution. // // Uses the named argument construction form: @@ -708,7 +713,7 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(false) {} + AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {} AlgorithmDesc(Index a, bool use_tensor_ops) : algo_(a), tensor_ops_enabled_(use_tensor_ops) {} bool is_default() const { return algo_ == kDefaultAlgorithm; } diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 2c4819651acaa2c6ee99c720b2c3d80e5c2ea1a9..c8a629733006e17b7642a59afb8e0cb468f2c538 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -95,7 +95,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream, // the nature of the HostExecutor) memcpy on the stream (HostStream) // associated with the HostExecutor. AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); }); + [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); return true; } diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc index eeb6a06e3d6b9ab9d32866ee647961a4bb4f8a32..410dc9da899cc967b36c1738a6b7c128a98cf70c 100644 --- a/tensorflow/stream_executor/host/host_platform.cc +++ b/tensorflow/stream_executor/host/host_platform.cc @@ -100,7 +100,6 @@ static void InitializeHostPlatform() { REGISTER_MODULE_INITIALIZER(host_platform, stream_executor::host::InitializeHostPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(host_platform, multi_platform_manager); diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/stream_executor/lib/statusor.cc similarity index 89% rename from tensorflow/compiler/xla/statusor.cc rename to tensorflow/stream_executor/lib/statusor.cc index 72ab67ff810e0ec384a22da092363cc7446435bb..e0e851f96ef6fe18ec32ff7d3fd1d1aed18b0343 100644 --- a/tensorflow/compiler/xla/statusor.cc +++ b/tensorflow/stream_executor/lib/statusor.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" -namespace xla { +namespace stream_executor { +namespace port { namespace internal_statusor { void Helper::HandleInvalidStatusCtorArg(Status* status) { @@ -35,4 +36,5 @@ void Helper::Crash(const Status& status) { } } // namespace internal_statusor -} // namespace xla +} // namespace port +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h index dab59096740102b94c0ff63c089b83ce052ea264..3c716acb462f1ca25e1d86408386d9eca37265b7 100644 --- a/tensorflow/stream_executor/lib/statusor.h +++ b/tensorflow/stream_executor/lib/statusor.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,19 +13,297 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" - +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.ValueOrDie(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// StatusOr> result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return tensorflow::InvalidArgument("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor_internals.h" namespace stream_executor { namespace port { -// Use XLA's StatusOr so we don't duplicate code. +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +template +class TF_MUST_USE_RESULT StatusOr; +#endif + +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::TraitsBase< + std::is_copy_constructible::value, + std::is_move_constructible::value> { + template + friend class StatusOr; + + typedef internal_statusor::StatusOrData Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // StatusOr> will be initialized with an empty vector, + // instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr will be copy constructible/assignable if T is copy + // constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr will be move constructible/assignable if T is move + // constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Conversion copy/move constructor, T must be convertible from U. + template ::value>::type* = nullptr> + StatusOr(const StatusOr& other); + template ::value>::type* = nullptr> + StatusOr(StatusOr&& other); + + // Conversion copy/move assignment operator, T must be convertible from U. + template ::value>::type* = nullptr> + StatusOr& operator=(const StatusOr& other); + template ::value>::type* = nullptr> + StatusOr& operator=(StatusOr&& other); + + // Constructs a new StatusOr with the given value. After calling this + // constructor, calls to ValueOrDie() will succeed, and calls to status() will + // return OK. + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: T is copy constructible. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing Status::OK() here will have the effect + // of passing tensorflow::error::INTERNAL as a fallback. + StatusOr(const Status& status); + StatusOr& operator=(const Status& status); + + // TODO(b/62186997): Add operator=(T) overloads. + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(Status&& status); + StatusOr& operator=(Status&& status); + + // Returns this->status().ok() + bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns Status::OK(). + const Status& status() const &; + Status status() &&; + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). + // + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.ValueOrDie(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.ValueOrDie(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).ValueOrDie(); + // + // The std::move on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. + // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 + // See go/ref-qualifiers for more details on such overloads. + const T& ValueOrDie() const &; + T& ValueOrDie() &; + const T&& ValueOrDie() const &&; + T&& ValueOrDie() &&; + + T ConsumeValueOrDie() { return std::move(ValueOrDie()); } + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +template +StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} + +template +StatusOr::StatusOr(const T& value) : Base(value) {} + +template +StatusOr::StatusOr(const Status& status) : Base(status) {} + +template +StatusOr& StatusOr::operator=(const Status& status) { + this->Assign(status); + return *this; +} + +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} + +template +StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} + +template +StatusOr& StatusOr::operator=(Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template +template ::value>::type*> +inline StatusOr::StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} + +template +template ::value>::type*> +inline StatusOr& StatusOr::operator=(const StatusOr& other) { + if (other.ok()) + this->Assign(other.ValueOrDie()); + else + this->Assign(other.status()); + return *this; +} + +template +template ::value>::type*> +inline StatusOr::StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} + +template +template ::value>::type*> +inline StatusOr& StatusOr::operator=(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); + } + return *this; +} + +template +const Status& StatusOr::status() const & { + return this->status_; +} +template +Status StatusOr::status() && { + return ok() ? Status::OK() : std::move(this->status_); +} + +template +const T& StatusOr::ValueOrDie() const & { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::ValueOrDie() const && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + template -using StatusOr = ::xla::StatusOr; +void StatusOr::IgnoreError() const { + // no-op +} } // namespace port } // namespace stream_executor diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/stream_executor/lib/statusor_internals.h similarity index 94% rename from tensorflow/compiler/xla/statusor_internals.h rename to tensorflow/stream_executor/lib/statusor_internals.h index 14636bd144bc0a155fc96c5a350c658fd2dadfe6..09f88f5825f57c8e654bd079616a074e84de4f30 100644 --- a/tensorflow/compiler/xla/statusor_internals.h +++ b/tensorflow/stream_executor/lib/statusor_internals.h @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ -#define TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ + -#include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/status.h" -namespace xla { +namespace stream_executor { +namespace port { namespace internal_statusor { class Helper { @@ -240,6 +242,7 @@ struct TraitsBase { }; } // namespace internal_statusor -} // namespace xla +} // namespace port +} // namespace stream_executor -#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc similarity index 99% rename from tensorflow/compiler/xla/statusor_test.cc rename to tensorflow/stream_executor/lib/statusor_test.cc index 377a618ffbd99316d409130df8a39f352664dee0..56584e189208b2576f10650fd56bca6d04ecc6c1 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/stream_executor/lib/statusor_test.cc @@ -15,18 +15,18 @@ limitations under the License. // Unit tests for StatusOr -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include #include -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test_benchmark.h" -namespace xla { +namespace stream_executor { +namespace port { namespace { class Base1 { @@ -672,4 +672,5 @@ void BM_StatusOrFactoryFailLongMsg(int iters) { BENCHMARK(BM_StatusOrFactoryFailLongMsg); } // namespace -} // namespace xla +} // namespace port +} // namespace stream_executor diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index 7e316879ca0cf9c2a97ee37c556e0f0d9b83e5fa..146a128e85cfe84a844aae0fd50d5a329df2723c 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -68,6 +68,7 @@ limitations under the License. #include #include +#include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" @@ -182,4 +183,9 @@ class MultiPlatformManager { } // namespace stream_executor +// multi_platform_manager.cc will define this instance. Includers of this header +// should use +// REGISTER_MODULE_INITIALIZER_SEQUENCE(my_platform, multi_platform_manager); +DECLARE_MODULE_INITIALIZER(multi_platform_manager); + #endif // TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 4a98cfe16460ff860b6b73fedc21e98b5a3ed9fd..0cd0790a72b49bb259b9c72268535b5d74531cf5 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -192,6 +192,7 @@ string ToVlogString(dnn::DataType data_type) { case dnn::DataType::kInt8: return "dnn::DataType::kInt8"; } + return "unknown DataType"; } // Used together with PARAM to VLOG calls made to the stream. Intended diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3da1b856d6a41fa0c8d5a77feac33932da392422..e8885e1eb682d9ee67c6b7594f96c0911c7c1fa2 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/macros.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dnn.h" @@ -156,14 +157,13 @@ class Stream { const TypedKernel &kernel, Args... args); // Record a "start" event for the interval timer at this point in the - // stream's - // execution (relative to the previously and subsequently enqueued items in - // the stream's execution). Streams may be started/stopped multiple times. + // stream's execution (relative to the previously and subsequently enqueued + // items in the stream's execution). Streams may be started/stopped multiple + // times. Stream &ThenStartTimer(Timer *t); // Record a "stop" event for the interval timer at this point in the - // stream's - // execution. See also Stream::ThenStartTimer. + // stream's execution. See also Stream::ThenStartTimer. Stream &ThenStopTimer(Timer *t); // TODO(leary) If work is added to the stream that is being depended upon, @@ -179,8 +179,7 @@ class Stream { // // Checks that a stream does not wait for itself, and it is up to the // user to guarantee that a stream does not come to wait on itself in a - // cyclic - // manner; in that case, behavior is undefined. + // cyclic manner; in that case, behavior is undefined. // // N.B. Base recursion case for the variadic ThenWaitFor. Stream &ThenWaitFor(Stream *other); @@ -1351,33 +1350,39 @@ class Stream { DeviceMemory> *x, int incx); // See BlasSupport::DoBlasGemm. - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, float beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, float beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, double beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &b, int ldb, - std::complex beta, - DeviceMemory> *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &b, int ldb, - std::complex beta, - DeviceMemory> *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + float beta, DeviceMemory *c, + int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + float beta, DeviceMemory *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, double alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + double beta, DeviceMemory *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, + std::complex alpha, + const DeviceMemory> &a, + int lda, + const DeviceMemory> &b, + int ldb, std::complex beta, + DeviceMemory> *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, + std::complex alpha, + const DeviceMemory> &a, + int lda, + const DeviceMemory> &b, + int ldb, std::complex beta, + DeviceMemory> *c, + int ldc); Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 2584c92f0c5a1129e2f10aa7148161a8d2d40c50..9c989b971dcee6dd99aa155cd2230ba849d204fe 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -174,6 +174,15 @@ class StreamExecutorInterface { virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset, uint64 size) = 0; virtual void Deallocate(DeviceMemoryBase *mem) = 0; + // Allocates unified memory space of the given size, if supported. + // See + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd + // for more details on unified memory. + virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; } + + // Deallocates unified memory space previously allocated with + // UnifiedMemoryAllocate. + virtual void UnifiedMemoryDeallocate(void *mem) {} virtual void *HostMemoryAllocate(uint64 size) = 0; virtual void HostMemoryDeallocate(void *mem) = 0; virtual bool HostMemoryRegister(void *mem, uint64 size) = 0; diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index eecd5bfe1f7e7f51e4b45a579d3ac816d3e24b96..000795ff0048dddb0eb4a08956e6de6f5e336f28 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -464,6 +464,20 @@ bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem, return implementation_->GetSymbol(symbol_name, mem, bytes); } +void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) { + void *buffer = implementation_->UnifiedMemoryAllocate(bytes); + VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes + << ") returns " << buffer << StackTraceIfVLOG10(); + return buffer; +} + +void StreamExecutor::UnifiedMemoryDeallocate(void *location) { + VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location=" + << location << ")" << StackTraceIfVLOG10(); + + return implementation_->UnifiedMemoryDeallocate(location); +} + void *StreamExecutor::HostMemoryAllocate(uint64 size) { void *buffer = implementation_->HostMemoryAllocate(size); VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size @@ -596,7 +610,7 @@ port::Status StreamExecutor::SynchronousMemcpyD2H( port::Status StreamExecutor::SynchronousMemcpyH2D( const void *host_src, int64 size, DeviceMemoryBase *device_dst) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src - << ", size=" << size << ", device_dst" << device_dst->opaque() << ")" + << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")" << StackTraceIfVLOG10(); port::Status result; diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index e426cf99315a8671d71143cd8813e60de029a59c..ad80a1ba259ce0c6e2785373cc986b8bf34f6460 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -190,6 +190,16 @@ class StreamExecutor { // activated. void GetMemAllocs(std::map *records_out); + // Allocates unified memory space of the given size, if supported. + // See + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd + // for more details on unified memory. + void *UnifiedMemoryAllocate(uint64 bytes); + + // Deallocates unified memory space previously allocated with + // UnifiedMemoryAllocate. + void UnifiedMemoryDeallocate(void *location); + // Allocates a region of host memory and registers it with the platform API. // Memory allocated in this manner (or allocated and registered with // HostMemoryRegister() is required for use in asynchronous memcpy operations, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d71fd71bbd83add63d11bcd62ae7ecaa2a8be8d1..e4632c48112d40fb96b4c2b510da93678b11efc4 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -148,6 +148,12 @@ def if_windows(a): "//conditions:default": [], }) +def if_not_windows_cuda(a): + return select({ + clean_dep("//tensorflow:with_cuda_support_windows_override"): [], + "//conditions:default": a, + }) + def if_linux_x86_64(a): return select({ clean_dep("//tensorflow:linux_x86_64"): a, @@ -241,6 +247,9 @@ def tf_opts_nortti_if_android(): # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) +def tf_features_nomodules_if_android(): + return if_android(["-use_header_modules"]) + # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. def tf_gen_op_libs(op_lib_names, deps=None, is_external=True): @@ -919,6 +928,7 @@ def tf_gpu_kernel_library(srcs, hdrs=[], **kwargs): copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts() + kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( srcs=srcs, @@ -959,6 +969,7 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): if not cuda_deps: cuda_deps = [] + kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( deps=deps + if_cuda(cuda_deps + [ clean_dep("//tensorflow/core:cuda"), @@ -1301,6 +1312,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): name=basename + "_gpu", srcs=gpu_srcs, copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), + features = if_cuda(["-use_header_modules"]), deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) @@ -1353,12 +1365,6 @@ register_extension_info( label_regex_for_dep = "{extension_name}", ) -def tf_extension_linkopts(): - return [] # No extension link opts - -def tf_extension_copts(): - return [] # No extension c opts - # In tf_py_wrap_cc generated libraries # module init functions are not exported unless # they contain one of the keywords in the version file @@ -1459,10 +1465,10 @@ def tf_py_wrap_cc(name, tf_cc_shared_object( name=cc_library_name, srcs=[module_name + ".cc"], - copts=(copts + if_not_windows([ + copts=copts + if_not_windows([ "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings" - ]) + tf_extension_copts()), - linkopts=tf_extension_linkopts() + extra_linkopts, + ]), + linkopts=extra_linkopts, linkstatic=1, deps=deps + extra_deps, **kwargs) @@ -1725,7 +1731,7 @@ def tf_py_build_info_genrule(): name="py_build_info_gen", outs=["platform/build_info.py"], cmd= - "$(location //tensorflow/tools/build_info:gen_build_info.py) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"), + "$(location //tensorflow/tools/build_info:gen_build_info.py) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"), local=1, tools=[clean_dep("//tensorflow/tools/build_info:gen_build_info.py")],) diff --git a/tensorflow/tf_framework_version_script.lds b/tensorflow/tf_framework_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..d4977f88c0c340fa236b746efcefd607f4752359 --- /dev/null +++ b/tensorflow/tf_framework_version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Hide libjpeg symbols to avoid symbol conflict with OpenCV + local: + jpeg_*; + jinit_*; + jdiv_round_up; + jround_up; + jzero_far; + jcopy_*; + jsimd_*; +}; diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index f46bb4b5fcc5d6eface8617ba5261abf29e34b02..8c760e6f52598a5e7399c9250adf99283572d3a4 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -3,136 +3,69 @@ licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) +load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") +load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES") -py_binary( - name = "create_python_api", - srcs = ["create_python_api.py"], +exports_files( + [ + "LICENSE", + "create_python_api.py", + ], +) + +py_library( + name = "doc_srcs", + srcs = ["doc_srcs.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ - "//tensorflow/python", + "//tensorflow/python:util", ], ) py_test( name = "create_python_api_test", - srcs = ["create_python_api_test.py"], + srcs = [ + "create_python_api.py", + "create_python_api_test.py", + ], srcs_version = "PY2AND3", deps = [ - ":create_python_api", + ":doc_srcs", "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", ], ) -genrule( - name = "python_api_gen", - # List of API files. This list should include file name for - # every module exported using tf_export. For e.g. if an op is decorated with - # @tf_export('module1.module2', 'module3'). Then, outs should include - # api/module1/module2/__init__.py and api/module3/__init__.py. - # keep sorted - outs = [ - # BEGIN GENERATED FILES - "api/__init__.py", - "api/app/__init__.py", - "api/bitwise/__init__.py", - "api/compat/__init__.py", - "api/contrib/__init__.py", - "api/contrib/stat_summarizer/__init__.py", - "api/data/__init__.py", - "api/distributions/__init__.py", - "api/distributions/bijectors/__init__.py", - "api/errors/__init__.py", - "api/estimator/__init__.py", - "api/estimator/export/__init__.py", - "api/estimator/inputs/__init__.py", - "api/feature_column/__init__.py", - "api/gfile/__init__.py", - "api/graph_util/__init__.py", - "api/image/__init__.py", - "api/initializers/__init__.py", - "api/keras/__init__.py", - "api/keras/activations/__init__.py", - "api/keras/applications/__init__.py", - "api/keras/applications/densenet/__init__.py", - "api/keras/applications/inception_resnet_v2/__init__.py", - "api/keras/applications/inception_v3/__init__.py", - "api/keras/applications/mobilenet/__init__.py", - "api/keras/applications/nasnet/__init__.py", - "api/keras/applications/resnet50/__init__.py", - "api/keras/applications/vgg16/__init__.py", - "api/keras/applications/vgg19/__init__.py", - "api/keras/applications/xception/__init__.py", - "api/keras/backend/__init__.py", - "api/keras/callbacks/__init__.py", - "api/keras/constraints/__init__.py", - "api/keras/datasets/__init__.py", - "api/keras/datasets/boston_housing/__init__.py", - "api/keras/datasets/cifar10/__init__.py", - "api/keras/datasets/cifar100/__init__.py", - "api/keras/datasets/fashion_mnist/__init__.py", - "api/keras/datasets/imdb/__init__.py", - "api/keras/datasets/mnist/__init__.py", - "api/keras/datasets/reuters/__init__.py", - "api/keras/estimator/__init__.py", - "api/keras/initializers/__init__.py", - "api/keras/layers/__init__.py", - "api/keras/losses/__init__.py", - "api/keras/metrics/__init__.py", - "api/keras/models/__init__.py", - "api/keras/optimizers/__init__.py", - "api/keras/preprocessing/__init__.py", - "api/keras/preprocessing/image/__init__.py", - "api/keras/preprocessing/sequence/__init__.py", - "api/keras/preprocessing/text/__init__.py", - "api/keras/regularizers/__init__.py", - "api/keras/utils/__init__.py", - "api/keras/wrappers/__init__.py", - "api/keras/wrappers/scikit_learn/__init__.py", - "api/layers/__init__.py", - "api/linalg/__init__.py", - "api/logging/__init__.py", - "api/losses/__init__.py", - "api/manip/__init__.py", - "api/math/__init__.py", - "api/metrics/__init__.py", - "api/nn/__init__.py", - "api/nn/rnn_cell/__init__.py", - "api/profiler/__init__.py", - "api/python_io/__init__.py", - "api/resource_loader/__init__.py", - "api/strings/__init__.py", - "api/saved_model/__init__.py", - "api/saved_model/builder/__init__.py", - "api/saved_model/constants/__init__.py", - "api/saved_model/loader/__init__.py", - "api/saved_model/main_op/__init__.py", - "api/saved_model/signature_constants/__init__.py", - "api/saved_model/signature_def_utils/__init__.py", - "api/saved_model/tag_constants/__init__.py", - "api/saved_model/utils/__init__.py", - "api/sets/__init__.py", - "api/sparse/__init__.py", - "api/spectral/__init__.py", - "api/summary/__init__.py", - "api/sysconfig/__init__.py", - "api/test/__init__.py", - "api/train/__init__.py", - "api/train/queue_runner/__init__.py", - "api/user_ops/__init__.py", - # END GENERATED FILES +py_test( + name = "tensorflow_doc_srcs_test", + srcs = ["doc_srcs_test.py"], + args = [ + "--package=tensorflow.python", + "--api_name=tensorflow", + ] + TENSORFLOW_API_INIT_FILES, + main = "doc_srcs_test.py", + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", ], - cmd = "$(location create_python_api) $(OUTS)", - tools = ["create_python_api"], ) -py_library( - name = "python_api", - srcs = [":python_api_gen"], +py_test( + name = "estimator_doc_srcs_test", + srcs = ["doc_srcs_test.py"], + args = [ + "--package=tensorflow.python.estimator", + "--api_name=estimator", + ] + ESTIMATOR_API_INIT_FILES, + main = "doc_srcs_test.py", srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/contrib:contrib_py", # keep - "//tensorflow/python", # keep + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl new file mode 100644 index 0000000000000000000000000000000000000000..d746b5d3e4f7745d78563eac65ccdf822511a7ef --- /dev/null +++ b/tensorflow/tools/api/generator/api_gen.bzl @@ -0,0 +1,161 @@ +"""Targets for generating TensorFlow Python API __init__.py files.""" + +# keep sorted +TENSORFLOW_API_INIT_FILES = [ + # BEGIN GENERATED FILES + "__init__.py", + "app/__init__.py", + "bitwise/__init__.py", + "compat/__init__.py", + "data/__init__.py", + "debugging/__init__.py", + "distributions/__init__.py", + "distributions/bijectors/__init__.py", + "dtypes/__init__.py", + "errors/__init__.py", + "feature_column/__init__.py", + "gfile/__init__.py", + "graph_util/__init__.py", + "image/__init__.py", + "io/__init__.py", + "initializers/__init__.py", + "keras/__init__.py", + "keras/activations/__init__.py", + "keras/applications/__init__.py", + "keras/applications/densenet/__init__.py", + "keras/applications/inception_resnet_v2/__init__.py", + "keras/applications/inception_v3/__init__.py", + "keras/applications/mobilenet/__init__.py", + "keras/applications/nasnet/__init__.py", + "keras/applications/resnet50/__init__.py", + "keras/applications/vgg16/__init__.py", + "keras/applications/vgg19/__init__.py", + "keras/applications/xception/__init__.py", + "keras/backend/__init__.py", + "keras/callbacks/__init__.py", + "keras/constraints/__init__.py", + "keras/datasets/__init__.py", + "keras/datasets/boston_housing/__init__.py", + "keras/datasets/cifar10/__init__.py", + "keras/datasets/cifar100/__init__.py", + "keras/datasets/fashion_mnist/__init__.py", + "keras/datasets/imdb/__init__.py", + "keras/datasets/mnist/__init__.py", + "keras/datasets/reuters/__init__.py", + "keras/estimator/__init__.py", + "keras/initializers/__init__.py", + "keras/layers/__init__.py", + "keras/losses/__init__.py", + "keras/metrics/__init__.py", + "keras/models/__init__.py", + "keras/optimizers/__init__.py", + "keras/preprocessing/__init__.py", + "keras/preprocessing/image/__init__.py", + "keras/preprocessing/sequence/__init__.py", + "keras/preprocessing/text/__init__.py", + "keras/regularizers/__init__.py", + "keras/utils/__init__.py", + "keras/wrappers/__init__.py", + "keras/wrappers/scikit_learn/__init__.py", + "layers/__init__.py", + "linalg/__init__.py", + "logging/__init__.py", + "losses/__init__.py", + "manip/__init__.py", + "math/__init__.py", + "metrics/__init__.py", + "nn/__init__.py", + "nn/rnn_cell/__init__.py", + "profiler/__init__.py", + "python_io/__init__.py", + "quantization/__init__.py", + "resource_loader/__init__.py", + "strings/__init__.py", + "saved_model/__init__.py", + "saved_model/builder/__init__.py", + "saved_model/constants/__init__.py", + "saved_model/loader/__init__.py", + "saved_model/main_op/__init__.py", + "saved_model/signature_constants/__init__.py", + "saved_model/signature_def_utils/__init__.py", + "saved_model/tag_constants/__init__.py", + "saved_model/utils/__init__.py", + "sets/__init__.py", + "sparse/__init__.py", + "spectral/__init__.py", + "summary/__init__.py", + "sysconfig/__init__.py", + "test/__init__.py", + "train/__init__.py", + "train/queue_runner/__init__.py", + "user_ops/__init__.py", + # END GENERATED FILES +] + +# keep sorted +ESTIMATOR_API_INIT_FILES = [ + # BEGIN GENERATED ESTIMATOR FILES + "__init__.py", + "estimator/__init__.py", + "estimator/export/__init__.py", + "estimator/inputs/__init__.py", + # END GENERATED ESTIMATOR FILES +] + +# Creates a genrule that generates a directory structure with __init__.py +# files that import all exported modules (i.e. modules with tf_export +# decorators). +# +# Args: +# name: name of genrule to create. +# output_files: List of __init__.py files that should be generated. +# This list should include file name for every module exported using +# tf_export. For e.g. if an op is decorated with +# @tf_export('module1.module2', 'module3'). Then, output_files should +# include module1/module2/__init__.py and module3/__init__.py. +# root_init_template: Python init file that should be used as template for +# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this +# template will be replaced with root imports collected by this genrule. +# srcs: genrule sources. If passing root_init_template, the template file +# must be included in sources. +# api_name: Name of the project that you want to generate API files for +# (e.g. "tensorflow" or "estimator"). +# package: Python package containing the @tf_export decorators you want to +# process +# package_dep: Python library target containing your package. + +def gen_api_init_files( + name, + output_files = TENSORFLOW_API_INIT_FILES, + root_init_template = None, + srcs = [], + api_name = "tensorflow", + package = "tensorflow.python", + package_dep = "//tensorflow/python:no_contrib"): + root_init_template_flag = "" + if root_init_template: + root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" + + api_gen_binary_target = "create_" + package + "_api" + native.py_binary( + name = "create_" + package + "_api", + srcs = ["//tensorflow/tools/api/generator:create_python_api.py"], + main = "//tensorflow/tools/api/generator:create_python_api.py", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + package_dep, + "//tensorflow/tools/api/generator:doc_srcs", + ], + ) + + native.genrule( + name = name, + outs = output_files, + cmd = ( + "$(location :" + api_gen_binary_target + ") " + + root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"), + srcs = srcs, + tools = [":" + api_gen_binary_target ], + visibility = ["//tensorflow:__pkg__"], + ) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index 18182090dabab1f0552001e1388e4f74e3514f1a..48d7dcd09eb38f53031afde70fe2e1a9b660ad1a 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -25,17 +25,21 @@ import os import sys from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_export +from tensorflow.tools.api.generator import doc_srcs +API_ATTRS = tf_export.API_ATTRS -_API_CONSTANTS_ATTR = '_tf_api_constants' -_API_NAMES_ATTR = '_tf_api_names' -_API_DIR = '/api/' _DEFAULT_PACKAGE = 'tensorflow.python' -_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api' -_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. - -This file is MACHINE GENERATED! Do not edit. -Generated by: tensorflow/tools/api/generator/create_python_api.py script. +_GENFILES_DIR_SUFFIX = 'genfiles/' +_SYMBOLS_TO_SKIP_EXPLICITLY = { + # Overrides __getattr__, so that unwrapping tf_decorator + # would have side effects. + 'tensorflow.python.platform.flags.FLAGS' +} +_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. +# Generated by: tensorflow/tools/api/generator/create_python_api.py script. +\"\"\"%s \"\"\" from __future__ import print_function @@ -147,20 +151,21 @@ class _ModuleInitCodeBuilder(object): # the script outputs. module_text_map[''] = module_text_map.get('', '') + ''' _names_with_underscore = [%s] -__all__ = [s for s in dir() if not s.startswith('_')] -__all__.extend([s for s in _names_with_underscore]) +__all__ = [_s for _s in dir() if not _s.startswith('_')] +__all__.extend([_s for _s in _names_with_underscore]) __all__.remove('print_function') ''' % underscore_names_str return module_text_map -def get_api_init_text(package): +def get_api_init_text(package, api_name): """Get a map from destination module to __init__.py code for that module. Args: package: Base python package containing python with target tf_export decorators. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). Returns: A dictionary where @@ -175,17 +180,20 @@ def get_api_init_text(package): for module in list(sys.modules.values()): # Only look at tensorflow modules. if (not module or not hasattr(module, '__name__') or - package not in module.__name__): + module.__name__ is None or package not in module.__name__): continue # Do not generate __init__.py files for contrib modules for now. if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): continue for module_contents_name in dir(module): + if (module.__name__ + '.' + module_contents_name + in _SYMBOLS_TO_SKIP_EXPLICITLY): + continue attr = getattr(module, module_contents_name) # If attr is _tf_api_constants attribute, then add the constants. - if module_contents_name == _API_CONSTANTS_ATTR: + if module_contents_name == API_ATTRS[api_name].constants: for exports, value in attr: for export in exports: names = export.split('.') @@ -197,8 +205,9 @@ def get_api_init_text(package): _, attr = tf_decorator.unwrap(attr) # If attr is a symbol with _tf_api_names attribute, then # add import for it. - if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__: - for export in attr._tf_api_names: # pylint: disable=protected-access + if (hasattr(attr, '__dict__') and + API_ATTRS[api_name].names in attr.__dict__): + for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access names = export.split('.') dest_module = '.'.join(names[:-1]) module_code_builder.add_import( @@ -209,6 +218,7 @@ def get_api_init_text(package): # For e.g. if we import 'foo.bar.Value'. Then, we also # import 'bar' in 'foo'. imported_modules = set(module_code_builder.module_imports.keys()) + import_from = '.' for module in imported_modules: if not module: continue @@ -216,11 +226,9 @@ def get_api_init_text(package): parent_module = '' # we import submodules in their parent_module for submodule_index in range(len(module_split)): - import_from = _OUTPUT_MODULE if submodule_index > 0: parent_module += ('.' + module_split[submodule_index-1] if parent_module else module_split[submodule_index-1]) - import_from += '.' + parent_module module_code_builder.add_import( -1, parent_module, import_from, module_split[submodule_index], module_split[submodule_index]) @@ -228,7 +236,65 @@ def get_api_init_text(package): return module_code_builder.build() -def create_api_files(output_files, package): +def get_module(dir_path, relative_to_dir): + """Get module that corresponds to path relative to relative_to_dir. + + Args: + dir_path: Path to directory. + relative_to_dir: Get module relative to this directory. + + Returns: + Name of module that corresponds to the given directory. + """ + dir_path = dir_path[len(relative_to_dir):] + # Convert path separators to '/' for easier parsing below. + dir_path = dir_path.replace(os.sep, '/') + return dir_path.replace('/', '.').strip('.') + + +def get_module_docstring(module_name, package, api_name): + """Get docstring for the given module. + + This method looks for docstring in the following order: + 1. Checks if module has a docstring specified in doc_srcs. + 2. Checks if module has a docstring source module specified + in doc_srcs. If it does, gets docstring from that module. + 3. Checks if module with module_name exists under base package. + If it does, gets docstring from that module. + 4. Returns a default docstring. + + Args: + module_name: module name relative to tensorflow + (excluding 'tensorflow.' prefix) to get a docstring for. + package: Base python package containing python with target tf_export + decorators. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Returns: + One-line docstring to describe the module. + """ + # Module under base package to get a docstring from. + docstring_module_name = module_name + + doc_sources = doc_srcs.get_doc_sources(api_name) + + if module_name in doc_sources: + docsrc = doc_sources[module_name] + if docsrc.docstring: + return docsrc.docstring + if docsrc.docstring_module_name: + docstring_module_name = docsrc.docstring_module_name + + docstring_module_name = package + '.' + docstring_module_name + if (docstring_module_name in sys.modules and + sys.modules[docstring_module_name].__doc__): + return sys.modules[docstring_module_name].__doc__ + + return 'Public API for tf.%s namespace.' % module_name + + +def create_api_files( + output_files, package, root_init_template, output_dir, api_name): """Creates __init__.py files for the Python API. Args: @@ -236,6 +302,11 @@ def create_api_files(output_files, package): Each file must be under api/ directory. package: Base python package containing python with target tf_export decorators. + root_init_template: Template for top-level __init__.py file. + "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced + with imports. + output_dir: output API root directory. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). Raises: ValueError: if an output file is not under api/ directory, @@ -243,18 +314,7 @@ def create_api_files(output_files, package): """ module_name_to_file_path = {} for output_file in output_files: - # Convert path separators to '/' for easier parsing below. - normalized_output_file = output_file.replace(os.sep, '/') - if _API_DIR not in output_file: - raise ValueError( - 'Output files must be in api/ directory, found %s.' % output_file) - # Get the module name that corresponds to output_file. - # First get module directory under _API_DIR. - module_dir = os.path.dirname( - normalized_output_file[ - normalized_output_file.rfind(_API_DIR)+len(_API_DIR):]) - # Convert / to . - module_name = module_dir.replace('/', '.').strip('.') + module_name = get_module(os.path.dirname(output_file), output_dir) module_name_to_file_path[module_name] = os.path.normpath(output_file) # Create file for each expected output in genrule. @@ -263,25 +323,36 @@ def create_api_files(output_files, package): os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map = get_api_init_text(package) + module_text_map = get_api_init_text(package, api_name) # Add imports to output files. missing_output_files = [] for module, text in module_text_map.items(): # Make sure genrule output file list is in sync with API exports. if module not in module_name_to_file_path: - module_file_path = '"api/%s/__init__.py"' % ( + module_file_path = '"%s/__init__.py"' % ( module.replace('.', '/')) missing_output_files.append(module_file_path) continue + contents = '' + if module or not root_init_template: + contents = ( + _GENERATED_FILE_HEADER % + get_module_docstring(module, package, api_name) + + text + _GENERATED_FILE_FOOTER) + else: + # Read base init file + with open(root_init_template, 'r') as root_init_template_file: + contents = root_init_template_file.read() + contents = contents.replace('# API IMPORTS PLACEHOLDER', text) with open(module_name_to_file_path[module], 'w') as fp: - fp.write(_GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER) + fp.write(contents) if missing_output_files: raise ValueError( 'Missing outputs for python_api_gen genrule:\n%s.' 'Make sure all required outputs are in the ' - 'tensorflow/tools/api/generator/BUILD file.' % + 'tensorflow/tools/api/generator/api_gen.bzl file.' % ',\n'.join(sorted(missing_output_files))) @@ -297,6 +368,20 @@ def main(): '--package', default=_DEFAULT_PACKAGE, type=str, help='Base package that imports modules containing the target tf_export ' 'decorators.') + parser.add_argument( + '--root_init_template', default='', type=str, + help='Template for top level __init__.py file. ' + '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') + parser.add_argument( + '--apidir', type=str, required=True, + help='Directory where generated output files are placed. ' + 'gendir should be a prefix of apidir. Also, apidir ' + 'should be a prefix of every directory in outputs.') + parser.add_argument( + '--apiname', required=True, type=str, + choices=API_ATTRS.keys(), + help='The API you want to generate.') + args = parser.parse_args() if len(args.outputs) == 1: @@ -309,7 +394,8 @@ def main(): # Populate `sys.modules` with modules containing tf_export(). importlib.import_module(args.package) - create_api_files(outputs, args.package) + create_api_files(outputs, args.package, args.root_init_template, + args.apidir, args.apiname) if __name__ == '__main__': diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py index 986340cf6d4a1bb18841d781dcd11c0208279ec8..651ec9d040302a4343ae6e0053cf6a4b37a971d4 100644 --- a/tensorflow/tools/api/generator/create_python_api_test.py +++ b/tensorflow/tools/api/generator/create_python_api_test.py @@ -57,7 +57,8 @@ class CreatePythonApiTest(test.TestCase): def testFunctionImportIsAdded(self): imports = create_python_api.get_api_init_text( - package=create_python_api._DEFAULT_PACKAGE) + package=create_python_api._DEFAULT_PACKAGE, + api_name='tensorflow') expected_import = ( 'from tensorflow.python.test_module ' 'import test_op as test_op1') @@ -73,7 +74,8 @@ class CreatePythonApiTest(test.TestCase): def testClassImportIsAdded(self): imports = create_python_api.get_api_init_text( - package=create_python_api._DEFAULT_PACKAGE) + package=create_python_api._DEFAULT_PACKAGE, + api_name='tensorflow') expected_import = ('from tensorflow.python.test_module ' 'import TestClass') self.assertTrue( @@ -82,7 +84,8 @@ class CreatePythonApiTest(test.TestCase): def testConstantIsAdded(self): imports = create_python_api.get_api_init_text( - package=create_python_api._DEFAULT_PACKAGE) + package=create_python_api._DEFAULT_PACKAGE, + api_name='tensorflow') expected = ('from tensorflow.python.test_module ' 'import _TEST_CONSTANT') self.assertTrue(expected in str(imports), diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/tools/api/generator/doc_srcs.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1988494dae4a9d3ee96af5af76f02c52c0dff4 --- /dev/null +++ b/tensorflow/tools/api/generator/doc_srcs.py @@ -0,0 +1,92 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Specifies sources of doc strings for API modules.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.util import tf_export + + +# Specifies docstring source for a module. +# Only one of docstring or docstring_module_name should be set. +# * If docstring is set, then we will use this docstring when +# for the module. +# * If docstring_module_name is set, then we will copy the docstring +# from docstring source module. +DocSource = collections.namedtuple( + 'DocSource', ['docstring', 'docstring_module_name']) +# Each attribute of DocSource is optional. +DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields) + +_TENSORFLOW_DOC_SOURCES = { + 'app': DocSource(docstring_module_name='platform.app'), + 'compat': DocSource(docstring_module_name='util.compat'), + 'distributions': DocSource( + docstring_module_name='ops.distributions.distributions'), + 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'), + 'errors': DocSource(docstring_module_name='framework.errors'), + 'gfile': DocSource(docstring_module_name='platform.gfile'), + 'graph_util': DocSource(docstring_module_name='framework.graph_util'), + 'image': DocSource(docstring_module_name='ops.image_ops'), + 'keras.estimator': DocSource(docstring_module_name='keras.estimator'), + 'linalg': DocSource(docstring_module_name='ops.linalg_ops'), + 'logging': DocSource(docstring_module_name='ops.logging_ops'), + 'losses': DocSource(docstring_module_name='ops.losses.losses'), + 'manip': DocSource(docstring_module_name='ops.manip_ops'), + 'math': DocSource(docstring_module_name='ops.math_ops'), + 'metrics': DocSource(docstring_module_name='ops.metrics'), + 'nn': DocSource(docstring_module_name='ops.nn_ops'), + 'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'), + 'python_io': DocSource(docstring_module_name='lib.io.python_io'), + 'resource_loader': DocSource( + docstring_module_name='platform.resource_loader'), + 'sets': DocSource(docstring_module_name='ops.sets'), + 'sparse': DocSource(docstring_module_name='ops.sparse_ops'), + 'spectral': DocSource(docstring_module_name='ops.spectral_ops'), + 'strings': DocSource(docstring_module_name='ops.string_ops'), + 'sysconfig': DocSource(docstring_module_name='platform.sysconfig'), + 'test': DocSource(docstring_module_name='platform.test'), + 'train': DocSource(docstring_module_name='training.training'), + 'train.queue_runner': DocSource( + docstring_module_name='training.queue_runner'), +} + +_ESTIMATOR_DOC_SOURCES = { + 'estimator': DocSource( + docstring_module_name='estimator_lib'), + 'estimator.export': DocSource( + docstring_module_name='export.export_lib'), + 'estimator.inputs': DocSource( + docstring_module_name='inputs.inputs'), +} + + +def get_doc_sources(api_name): + """Get a map from module to a DocSource object. + + Args: + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Returns: + Map from module name to DocSource object. + """ + if api_name == tf_export.TENSORFLOW_API_NAME: + return _TENSORFLOW_DOC_SOURCES + if api_name == tf_export.ESTIMATOR_API_NAME: + return _ESTIMATOR_DOC_SOURCES + return {} diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dbff904abe6251ad180140c4c7c404f051b17d55 --- /dev/null +++ b/tensorflow/tools/api/generator/doc_srcs_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for tensorflow.tools.api.generator.doc_srcs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import importlib +import sys + +from tensorflow.python.platform import test +from tensorflow.tools.api.generator import doc_srcs + + +FLAGS = None + + +class DocSrcsTest(test.TestCase): + + def testModulesAreValidAPIModules(self): + for module_name in doc_srcs.get_doc_sources(FLAGS.api_name): + # Convert module_name to corresponding __init__.py file path. + file_path = module_name.replace('.', '/') + if file_path: + file_path += '/' + file_path += '__init__.py' + + self.assertIn( + file_path, FLAGS.outputs, + msg='%s is not a valid API module' % module_name) + + def testHaveDocstringOrDocstringModule(self): + for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): + self.assertFalse( + docsrc.docstring and docsrc.docstring_module_name, + msg=('%s contains DocSource has both a docstring and a ' + 'docstring_module_name. Only one of "docstring" or ' + '"docstring_module_name" should be set.') % (module_name)) + + def testDocstringModulesAreValidModules(self): + for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): + if docsrc.docstring_module_name: + doc_module_name = '.'.join([ + FLAGS.package, docsrc.docstring_module_name]) + self.assertIn( + doc_module_name, sys.modules, + msg=('docsources_module %s is not a valid module under %s.' % + (docsrc.docstring_module_name, FLAGS.package))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'outputs', metavar='O', type=str, nargs='+', + help='create_python_api output files.') + parser.add_argument( + '--package', type=str, + help='Base package that imports modules containing the target tf_export ' + 'decorators.') + parser.add_argument( + '--api_name', type=str, + help='API name: tensorflow or estimator') + FLAGS, unparsed = parser.parse_known_args() + + importlib.import_module(FLAGS.package) + + # Now update argv, so that unittest library does not get confused. + sys.argv = [sys.argv[0]] + unparsed + test.main() diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt index 0fb1aaba2831e63cea9b9a38954b361e5cabd072..f1dffd595285098afaeb0ff04e5db35d594f7fac 100644 --- a/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt @@ -1,108 +1,70 @@ path: "tensorflow.AttrValue.ListValue" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "B_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FUNC_FIELD_NUMBER" - mtype: "" - } - member { - name: "F_FIELD_NUMBER" - mtype: "" - } - member { - name: "I_FIELD_NUMBER" - mtype: "" - } - member { - name: "SHAPE_FIELD_NUMBER" - mtype: "" - } - member { - name: "S_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSOR_FIELD_NUMBER" - mtype: "" - } - member { - name: "TYPE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "ListValue" + field { + name: "s" + number: 2 + label: LABEL_REPEATED + type: TYPE_BYTES + } + field { + name: "i" + number: 3 + label: LABEL_REPEATED + type: TYPE_INT64 + options { + packed: true + } + } + field { + name: "f" + number: 4 + label: LABEL_REPEATED + type: TYPE_FLOAT + options { + packed: true + } + } + field { + name: "b" + number: 5 + label: LABEL_REPEATED + type: TYPE_BOOL + options { + packed: true + } + } + field { + name: "type" + number: 6 + label: LABEL_REPEATED + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + options { + packed: true + } + } + field { + name: "shape" + number: 7 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + field { + name: "tensor" + number: 8 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + } + field { + name: "func" + number: 9 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt index e7a3a1f02faf104a03eecc4a45f5a54ab1a26f9a..6ccd64f428c3b87c807d0af82f67a884187f738c 100644 --- a/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt @@ -1,120 +1,151 @@ path: "tensorflow.AttrValue" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "B_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FUNC_FIELD_NUMBER" - mtype: "" - } - member { - name: "F_FIELD_NUMBER" - mtype: "" - } - member { - name: "I_FIELD_NUMBER" - mtype: "" - } - member { - name: "LIST_FIELD_NUMBER" - mtype: "" - } - member { - name: "ListValue" - mtype: "" - } - member { - name: "PLACEHOLDER_FIELD_NUMBER" - mtype: "" - } - member { - name: "SHAPE_FIELD_NUMBER" - mtype: "" - } - member { - name: "S_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSOR_FIELD_NUMBER" - mtype: "" - } - member { - name: "TYPE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "AttrValue" + field { + name: "s" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "i" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + oneof_index: 0 + } + field { + name: "f" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "b" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + oneof_index: 0 + } + field { + name: "type" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + oneof_index: 0 + } + field { + name: "shape" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + field { + name: "list" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue.ListValue" + oneof_index: 0 + } + field { + name: "func" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList" + oneof_index: 0 + } + field { + name: "placeholder" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + nested_type { + name: "ListValue" + field { + name: "s" + number: 2 + label: LABEL_REPEATED + type: TYPE_BYTES + } + field { + name: "i" + number: 3 + label: LABEL_REPEATED + type: TYPE_INT64 + options { + packed: true + } + } + field { + name: "f" + number: 4 + label: LABEL_REPEATED + type: TYPE_FLOAT + options { + packed: true + } + } + field { + name: "b" + number: 5 + label: LABEL_REPEATED + type: TYPE_BOOL + options { + packed: true + } + } + field { + name: "type" + number: 6 + label: LABEL_REPEATED + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + options { + packed: true + } + } + field { + name: "shape" + number: 7 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + field { + name: "tensor" + number: 8 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + } + field { + name: "func" + number: 9 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList" + } + } + oneof_decl { + name: "value" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt index 29bb3be35cba5f261f44811c731ba4c1fc007612..d9b142682899bf5d9fd5d942437359adf8962466 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt @@ -1,84 +1,21 @@ path: "tensorflow.ConfigProto.DeviceCountEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "DeviceCountEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9e09a8d48ec7a501cb25a30163b5dae84b7c8655 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.ConfigProto.Experimental" +tf_proto { + descriptor { + name: "Experimental" + field { + name: "collective_group_leader" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt index 009d64aed09ddcb47410d6ee6fb42fca42861ddd..4af4ed70ef0698e996905bcb3b2222380b8694d8 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -1,140 +1,136 @@ path: "tensorflow.ConfigProto" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER" - mtype: "" - } - member { - name: "CLUSTER_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "DEVICE_COUNT_FIELD_NUMBER" - mtype: "" - } - member { - name: "DEVICE_FILTERS_FIELD_NUMBER" - mtype: "" - } - member { - name: "DeviceCountEntry" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "GPU_OPTIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "GRAPH_OPTIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "INTER_OP_PARALLELISM_THREADS_FIELD_NUMBER" - mtype: "" - } - member { - name: "INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER" - mtype: "" - } - member { - name: "ISOLATE_SESSION_STATE_FIELD_NUMBER" - mtype: "" - } - member { - name: "LOG_DEVICE_PLACEMENT_FIELD_NUMBER" - mtype: "" - } - member { - name: "OPERATION_TIMEOUT_IN_MS_FIELD_NUMBER" - mtype: "" - } - member { - name: "PLACEMENT_PERIOD_FIELD_NUMBER" - mtype: "" - } - member { - name: "RPC_OPTIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "SESSION_INTER_OP_THREAD_POOL_FIELD_NUMBER" - mtype: "" - } - member { - name: "USE_PER_SESSION_THREADS_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "ConfigProto" + field { + name: "device_count" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto.DeviceCountEntry" + } + field { + name: "intra_op_parallelism_threads" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "inter_op_parallelism_threads" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "use_per_session_threads" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "session_inter_op_thread_pool" + number: 12 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.ThreadPoolOptionProto" + } + field { + name: "placement_period" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "device_filters" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "gpu_options" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions" + } + field { + name: "allow_soft_placement" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "log_device_placement" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "graph_options" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphOptions" + } + field { + name: "operation_timeout_in_ms" + number: 11 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "rpc_options" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RPCOptions" + } + field { + name: "cluster_def" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDef" + } + field { + name: "isolate_session_state" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto.Experimental" + } + nested_type { + name: "DeviceCountEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + options { + map_entry: true + } + } + nested_type { + name: "Experimental" + field { + name: "collective_group_leader" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-event.pbtxt b/tensorflow/tools/api/golden/tensorflow.-event.pbtxt index 9bf8c124288854abc847a59db2c68b29759bfc7a..3b75a1735be76fe77689736e492c42c54ab795c1 100644 --- a/tensorflow/tools/api/golden/tensorflow.-event.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-event.pbtxt @@ -1,112 +1,74 @@ path: "tensorflow.Event" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FILE_VERSION_FIELD_NUMBER" - mtype: "" - } - member { - name: "GRAPH_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "LOG_MESSAGE_FIELD_NUMBER" - mtype: "" - } - member { - name: "META_GRAPH_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "SESSION_LOG_FIELD_NUMBER" - mtype: "" - } - member { - name: "STEP_FIELD_NUMBER" - mtype: "" - } - member { - name: "SUMMARY_FIELD_NUMBER" - mtype: "" - } - member { - name: "TAGGED_RUN_METADATA_FIELD_NUMBER" - mtype: "" - } - member { - name: "WALL_TIME_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Event" + field { + name: "wall_time" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "step" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "file_version" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + field { + name: "graph_def" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "summary" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary" + oneof_index: 0 + } + field { + name: "log_message" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.LogMessage" + oneof_index: 0 + } + field { + name: "session_log" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SessionLog" + oneof_index: 0 + } + field { + name: "tagged_run_metadata" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TaggedRunMetadata" + oneof_index: 0 + } + field { + name: "meta_graph_def" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + oneof_decl { + name: "what" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt index 875d802a9c458e299f73c130bb2b37c5d8828aad..353e63127de174a79c209a05327da2de20bf0dd7 100644 --- a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt @@ -1,116 +1,92 @@ path: "tensorflow.GPUOptions" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ALLOCATOR_TYPE_FIELD_NUMBER" - mtype: "" - } - member { - name: "ALLOW_GROWTH_FIELD_NUMBER" - mtype: "" - } - member { - name: "DEFERRED_DELETION_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "EXPERIMENTAL_FIELD_NUMBER" - mtype: "" - } - member { - name: "Experimental" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FORCE_GPU_COMPATIBLE_FIELD_NUMBER" - mtype: "" - } - member { - name: "PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER" - mtype: "" - } - member { - name: "POLLING_ACTIVE_DELAY_USECS_FIELD_NUMBER" - mtype: "" - } - member { - name: "POLLING_INACTIVE_DELAY_MSECS_FIELD_NUMBER" - mtype: "" - } - member { - name: "VISIBLE_DEVICE_LIST_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "GPUOptions" + field { + name: "per_process_gpu_memory_fraction" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "allow_growth" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "allocator_type" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "deferred_deletion_bytes" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "visible_device_list" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "polling_active_delay_usecs" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "polling_inactive_delay_msecs" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "force_gpu_compatible" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions.Experimental" + } + nested_type { + name: "Experimental" + field { + name: "virtual_devices" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions.Experimental.VirtualDevices" + } + field { + name: "use_unified_memory" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "num_dev_to_dev_copy_streams" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + nested_type { + name: "VirtualDevices" + field { + name: "memory_limit_mb" + number: 1 + label: LABEL_REPEATED + type: TYPE_FLOAT + } + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt index 7405202b892bba67a36d86cd43fb7a67ab3be947..cbf655498c02a6521ef45f722f30acd7c13de9cc 100644 --- a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt @@ -10,6 +10,14 @@ tf_class { name: "gradient" argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "reset" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_recording" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "watch" argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt index 1495e847cb08ed39ee5e365744ab1d798c3eed41..19eccff03d24719d95ea84ccdad4014aa777ccd5 100644 --- a/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt @@ -1,92 +1,36 @@ path: "tensorflow.GraphDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "LIBRARY_FIELD_NUMBER" - mtype: "" - } - member { - name: "NODE_FIELD_NUMBER" - mtype: "" - } - member { - name: "VERSIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "VERSION_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "GraphDef" + field { + name: "node" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NodeDef" + } + field { + name: "versions" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.VersionDef" + } + field { + name: "version" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + options { + deprecated: true + } + } + field { + name: "library" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FunctionDefLibrary" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt index 0844f891cad3d4ea798dec82d318e2bc53c53683..a9f99bc171cc3661031981f467f583b122e43476 100644 --- a/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt @@ -1,112 +1,67 @@ path: "tensorflow.GraphOptions" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "BUILD_COST_MODEL_AFTER_FIELD_NUMBER" - mtype: "" - } - member { - name: "BUILD_COST_MODEL_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "ENABLE_BFLOAT16_SENDRECV_FIELD_NUMBER" - mtype: "" - } - member { - name: "ENABLE_RECV_SCHEDULING_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "INFER_SHAPES_FIELD_NUMBER" - mtype: "" - } - member { - name: "OPTIMIZER_OPTIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "PLACE_PRUNED_GRAPH_FIELD_NUMBER" - mtype: "" - } - member { - name: "REWRITE_OPTIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TIMELINE_STEP_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "GraphOptions" + field { + name: "enable_recv_scheduling" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "optimizer_options" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.OptimizerOptions" + } + field { + name: "build_cost_model" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "build_cost_model_after" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "infer_shapes" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "place_pruned_graph" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "enable_bfloat16_sendrecv" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "timeline_step" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "rewrite_options" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RewriterConfig" + } + reserved_range { + start: 1 + end: 2 + } + reserved_name: "skip_common_subexpression_elimination" } } diff --git a/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt index 2567d2fe60293833b340d790ac1110f91d018107..d4402f330b8a28eaa61eb2b74c9ca412dce06b62 100644 --- a/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt @@ -1,104 +1,54 @@ path: "tensorflow.HistogramProto" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "BUCKET_FIELD_NUMBER" - mtype: "" - } - member { - name: "BUCKET_LIMIT_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "MAX_FIELD_NUMBER" - mtype: "" - } - member { - name: "MIN_FIELD_NUMBER" - mtype: "" - } - member { - name: "NUM_FIELD_NUMBER" - mtype: "" - } - member { - name: "SUM_FIELD_NUMBER" - mtype: "" - } - member { - name: "SUM_SQUARES_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "HistogramProto" + field { + name: "min" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "max" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "num" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "sum" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "sum_squares" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "bucket_limit" + number: 6 + label: LABEL_REPEATED + type: TYPE_DOUBLE + options { + packed: true + } + } + field { + name: "bucket" + number: 7 + label: LABEL_REPEATED + type: TYPE_DOUBLE + options { + packed: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt b/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt index a43c5eb7e30c3c2b025e750de5786ef4338e4ffc..5023aa96bf3b4f3f550421db5f41872d9f62b70d 100644 --- a/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt @@ -1,112 +1,46 @@ path: "tensorflow.LogMessage" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DEBUGGING" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "ERROR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FATAL" - mtype: "" - } - member { - name: "INFO" - mtype: "" - } - member { - name: "LEVEL_FIELD_NUMBER" - mtype: "" - } - member { - name: "Level" - mtype: "" - } - member { - name: "MESSAGE_FIELD_NUMBER" - mtype: "" - } - member { - name: "UNKNOWN" - mtype: "" - } - member { - name: "WARN" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "LogMessage" + field { + name: "level" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.LogMessage.Level" + } + field { + name: "message" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + enum_type { + name: "Level" + value { + name: "UNKNOWN" + number: 0 + } + value { + name: "DEBUGGING" + number: 10 + } + value { + name: "INFO" + number: 20 + } + value { + name: "WARN" + number: 30 + } + value { + name: "ERROR" + number: 40 + } + value { + name: "FATAL" + number: 50 + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt index 3572126fbfd77dbefc2ecb0246bd732f0c9aec63..0ba09bec4b3fa6e9eaf59978beaa958ebc038b4c 100644 --- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.MetaGraphDef.CollectionDefEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "CollectionDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.CollectionDef" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt index b0e983115499c5b5b79459affc931600ad16256b..41c62a407b8577288016f2376c35ba6ec1c3c1ca 100644 --- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt @@ -1,104 +1,50 @@ path: "tensorflow.MetaGraphDef.MetaInfoDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ANY_INFO_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "META_GRAPH_VERSION_FIELD_NUMBER" - mtype: "" - } - member { - name: "STRIPPED_DEFAULT_ATTRS_FIELD_NUMBER" - mtype: "" - } - member { - name: "STRIPPED_OP_LIST_FIELD_NUMBER" - mtype: "" - } - member { - name: "TAGS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSORFLOW_GIT_VERSION_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSORFLOW_VERSION_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "MetaInfoDef" + field { + name: "meta_graph_version" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_op_list" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.OpList" + } + field { + name: "any_info" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".google.protobuf.Any" + } + field { + name: "tags" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "tensorflow_version" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tensorflow_git_version" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_default_attrs" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt index 48fccac99d60b5035e207dc3ddf10054c70c5e61..73dc414a779ded3d1f896e743b7f1f1a443352f0 100644 --- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.MetaGraphDef.SignatureDefEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SignatureDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SignatureDef" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt index 3e683a8715992357c1a2e744cb3cef510ce966ae..d71c2358c93e9597726665fdf8f92e648b2ea772 100644 --- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt @@ -1,112 +1,133 @@ path: "tensorflow.MetaGraphDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ASSET_FILE_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "COLLECTION_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "CollectionDefEntry" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "GRAPH_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "META_INFO_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "MetaInfoDef" - mtype: "" - } - member { - name: "SAVER_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "SIGNATURE_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "SignatureDefEntry" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "MetaGraphDef" + field { + name: "meta_info_def" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.MetaGraphDef.MetaInfoDef" + } + field { + name: "graph_def" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + field { + name: "saver_def" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SaverDef" + } + field { + name: "collection_def" + number: 4 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.MetaGraphDef.CollectionDefEntry" + } + field { + name: "signature_def" + number: 5 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.MetaGraphDef.SignatureDefEntry" + } + field { + name: "asset_file_def" + number: 6 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.AssetFileDef" + } + nested_type { + name: "MetaInfoDef" + field { + name: "meta_graph_version" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_op_list" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.OpList" + } + field { + name: "any_info" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".google.protobuf.Any" + } + field { + name: "tags" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "tensorflow_version" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tensorflow_git_version" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_default_attrs" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + } + nested_type { + name: "CollectionDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.CollectionDef" + } + options { + map_entry: true + } + } + nested_type { + name: "SignatureDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SignatureDef" + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt index 2750bd780caa418f933ada2073c5e8d0475c2a33..b119b208772199e5c3596be142f3e0f62d3ed50e 100644 --- a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.NameAttrList.AttrEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt index d10faf67d027a4dc8c7a32ec31ea22773104508a..fcdb411ffce9b68ac28696f86ca11a47f9e64e8f 100644 --- a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt @@ -1,88 +1,38 @@ path: "tensorflow.NameAttrList" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ATTR_FIELD_NUMBER" - mtype: "" - } - member { - name: "AttrEntry" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "NAME_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "NameAttrList" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "attr" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList.AttrEntry" + } + nested_type { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt index b1b62d60f1e8c95a5e8cc13bc8162cf1de087195..622e4c3d0f60ce4842a6fd4cc421551aa795fcbf 100644 --- a/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.NodeDef.AttrEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt index b812b4df2b3c15af3c2c81944a82d9878865b8fb..646fa8abb9b22dbd908ff821cbe66a33ad02ba64 100644 --- a/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt @@ -1,100 +1,56 @@ path: "tensorflow.NodeDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ATTR_FIELD_NUMBER" - mtype: "" - } - member { - name: "AttrEntry" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "DEVICE_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "INPUT_FIELD_NUMBER" - mtype: "" - } - member { - name: "NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "OP_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "NodeDef" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "op" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "input" + number: 3 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "device" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "attr" + number: 5 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NodeDef.AttrEntry" + } + nested_type { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt index 6cac5c4d99fd7537b8fa852013ab348344be3f7e..3ccf9d459b133b48e5456f02e4780ade8d3042c8 100644 --- a/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt @@ -1,132 +1,74 @@ path: "tensorflow.OptimizerOptions" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DEFAULT" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "DO_COMMON_SUBEXPRESSION_ELIMINATION_FIELD_NUMBER" - mtype: "" - } - member { - name: "DO_CONSTANT_FOLDING_FIELD_NUMBER" - mtype: "" - } - member { - name: "DO_FUNCTION_INLINING_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "GLOBAL_JIT_LEVEL_FIELD_NUMBER" - mtype: "" - } - member { - name: "GlobalJitLevel" - mtype: "" - } - member { - name: "L0" - mtype: "" - } - member { - name: "L1" - mtype: "" - } - member { - name: "Level" - mtype: "" - } - member { - name: "MAX_FOLDED_CONSTANT_IN_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "OFF" - mtype: "" - } - member { - name: "ON_1" - mtype: "" - } - member { - name: "ON_2" - mtype: "" - } - member { - name: "OPT_LEVEL_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "OptimizerOptions" + field { + name: "do_common_subexpression_elimination" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "do_constant_folding" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "max_folded_constant_in_bytes" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "do_function_inlining" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "opt_level" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.OptimizerOptions.Level" + } + field { + name: "global_jit_level" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.OptimizerOptions.GlobalJitLevel" + } + enum_type { + name: "Level" + value { + name: "L1" + number: 0 + } + value { + name: "L0" + number: -1 + } + } + enum_type { + name: "GlobalJitLevel" + value { + name: "DEFAULT" + number: 0 + } + value { + name: "OFF" + number: -1 + } + value { + name: "ON_1" + number: 1 + } + value { + name: "ON_2" + number: 2 + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt index 808fa0fa217a407b2c86459b32fcef46b96afa5c..1287940326c0196e76fff2cf6363622226092504 100644 --- a/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt @@ -1,88 +1,27 @@ path: "tensorflow.RunMetadata" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "COST_GRAPH_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "PARTITION_GRAPHS_FIELD_NUMBER" - mtype: "" - } - member { - name: "STEP_STATS_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "RunMetadata" + field { + name: "step_stats" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.StepStats" + } + field { + name: "cost_graph" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.CostGraphDef" + } + field { + name: "partition_graphs" + number: 3 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..537e73aa8969905c108a59688cfd99793ce211f0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.RunOptions.Experimental" +tf_proto { + descriptor { + name: "Experimental" + field { + name: "collective_graph_key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt index 2f3e7f1a847dd3609f06b1af535be6f5968edfaf..cec04a2bf0962455495340da001214914cc8bb36 100644 --- a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt @@ -1,120 +1,83 @@ path: "tensorflow.RunOptions" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DEBUG_OPTIONS_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FULL_TRACE" - mtype: "" - } - member { - name: "HARDWARE_TRACE" - mtype: "" - } - member { - name: "INTER_OP_THREAD_POOL_FIELD_NUMBER" - mtype: "" - } - member { - name: "NO_TRACE" - mtype: "" - } - member { - name: "OUTPUT_PARTITION_GRAPHS_FIELD_NUMBER" - mtype: "" - } - member { - name: "REPORT_TENSOR_ALLOCATIONS_UPON_OOM_FIELD_NUMBER" - mtype: "" - } - member { - name: "SOFTWARE_TRACE" - mtype: "" - } - member { - name: "TIMEOUT_IN_MS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TRACE_LEVEL_FIELD_NUMBER" - mtype: "" - } - member { - name: "TraceLevel" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "RunOptions" + field { + name: "trace_level" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.RunOptions.TraceLevel" + } + field { + name: "timeout_in_ms" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "inter_op_thread_pool" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "output_partition_graphs" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "debug_options" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.DebugOptions" + } + field { + name: "report_tensor_allocations_upon_oom" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RunOptions.Experimental" + } + nested_type { + name: "Experimental" + field { + name: "collective_graph_key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + } + enum_type { + name: "TraceLevel" + value { + name: "NO_TRACE" + number: 0 + } + value { + name: "SOFTWARE_TRACE" + number: 1 + } + value { + name: "HARDWARE_TRACE" + number: 2 + } + value { + name: "FULL_TRACE" + number: 3 + } + } + reserved_range { + start: 4 + end: 5 + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt index ec66d7f3354083f953066e33dff73ba9c185fc16..259f2418740cbfe47cdb4bd871d4f5c6306d25f5 100644 --- a/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt @@ -1,108 +1,44 @@ path: "tensorflow.SessionLog" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CHECKPOINT" - mtype: "" - } - member { - name: "CHECKPOINT_PATH_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "MSG_FIELD_NUMBER" - mtype: "" - } - member { - name: "START" - mtype: "" - } - member { - name: "STATUS_FIELD_NUMBER" - mtype: "" - } - member { - name: "STATUS_UNSPECIFIED" - mtype: "" - } - member { - name: "STOP" - mtype: "" - } - member { - name: "SessionStatus" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SessionLog" + field { + name: "status" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.SessionLog.SessionStatus" + } + field { + name: "checkpoint_path" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "msg" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + enum_type { + name: "SessionStatus" + value { + name: "STATUS_UNSPECIFIED" + number: 0 + } + value { + name: "START" + number: 1 + } + value { + name: "STOP" + number: 2 + } + value { + name: "CHECKPOINT" + number: 3 + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt index 067f02ce8cbb1a1f6e65758f37bb1d36927fad98..a66b74b315c6132e8f884bd52e7a3b5bd7f52ccd 100644 --- a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt @@ -1,84 +1,18 @@ path: "tensorflow.SummaryMetadata.PluginData" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CONTENT_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "PLUGIN_NAME_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "PluginData" + field { + name: "plugin_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "content" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt index b9156521ccbee25486113a82ddec1053f8b32e3b..c02575b9626c848e9b871d2cc6febb26a5142f08 100644 --- a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt @@ -1,92 +1,40 @@ path: "tensorflow.SummaryMetadata" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "DISPLAY_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "PLUGIN_DATA_FIELD_NUMBER" - mtype: "" - } - member { - name: "PluginData" - mtype: "" - } - member { - name: "SUMMARY_DESCRIPTION_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SummaryMetadata" + field { + name: "plugin_data" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata.PluginData" + } + field { + name: "display_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "summary_description" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + nested_type { + name: "PluginData" + field { + name: "plugin_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "content" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt index 781010d75e23c16624b193e9f1041b6d58eef34e..94f712073e0d0dda201fcf7adba849dd45a1229b 100644 --- a/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt @@ -1,96 +1,36 @@ path: "tensorflow.Summary.Audio" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CONTENT_TYPE_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "ENCODED_AUDIO_STRING_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "LENGTH_FRAMES_FIELD_NUMBER" - mtype: "" - } - member { - name: "NUM_CHANNELS_FIELD_NUMBER" - mtype: "" - } - member { - name: "SAMPLE_RATE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt index feb9c7ee9270a7d64cf228dffeb1187fbd225704..fc1acb483b3051cba01f5d9bc8501a61965bbc37 100644 --- a/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt @@ -1,92 +1,30 @@ path: "tensorflow.Summary.Image" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "COLORSPACE_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "ENCODED_IMAGE_STRING_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "HEIGHT_FIELD_NUMBER" - mtype: "" - } - member { - name: "WIDTH_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt index ffb4f45fc5e2a9db35b57c36a23de5bb74a9517d..feb84b6ee996549ac58aa0e8a4ac560f947b6339 100644 --- a/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt @@ -1,112 +1,74 @@ path: "tensorflow.Summary.Value" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "AUDIO_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "HISTO_FIELD_NUMBER" - mtype: "" - } - member { - name: "IMAGE_FIELD_NUMBER" - mtype: "" - } - member { - name: "METADATA_FIELD_NUMBER" - mtype: "" - } - member { - name: "NODE_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "OBSOLETE_OLD_STYLE_HISTOGRAM_FIELD_NUMBER" - mtype: "" - } - member { - name: "SIMPLE_VALUE_FIELD_NUMBER" - mtype: "" - } - member { - name: "TAG_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSOR_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt index 38de17fa9e52b87d19413a64271b70755e604610..b2bdff7171804aae114d1e3631e3074b1e4006ba 100644 --- a/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt @@ -1,92 +1,144 @@ path: "tensorflow.Summary" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "Audio" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "Image" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member { - name: "Value" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Summary" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Value" + } + nested_type { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } + nested_type { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + nested_type { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt index 425c35e0674610904c65a2876b1a184a30355682..0064c8460cb374f1e3f108085a2efed4131dd205 100644 --- a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt @@ -1,88 +1,24 @@ path: "tensorflow.TensorInfo.CooSparse" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DENSE_SHAPE_TENSOR_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "INDICES_TENSOR_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUES_TENSOR_NAME_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "CooSparse" + field { + name: "values_tensor_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "indices_tensor_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "dense_shape_tensor_name" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt index 41ea393be51bd77acb2136affc203d7df332064d..63566c808e55cb4d3b630f0a017fa3a2c8a30de3 100644 --- a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt @@ -1,96 +1,59 @@ path: "tensorflow.TensorInfo" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "COO_SPARSE_FIELD_NUMBER" - mtype: "" - } - member { - name: "CooSparse" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "DTYPE_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSOR_SHAPE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "TensorInfo" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + field { + name: "coo_sparse" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorInfo.CooSparse" + oneof_index: 0 + } + field { + name: "dtype" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + } + field { + name: "tensor_shape" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + nested_type { + name: "CooSparse" + field { + name: "values_tensor_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "indices_tensor_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "dense_shape_tensor_name" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + oneof_decl { + name: "encoding" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..36b534af360835e3c1cbd1f0fb12a38c42232abf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.VariableAggregation" +tf_class { + is_instance: "" + member { + name: "MEAN" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "SUM" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt index 8e539069da05fbb192c383d3f5acff78ab9bfeff..ec1f72453fdb540463503a626d75d481907a3676 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt @@ -56,7 +56,7 @@ tf_class { } member_method { name: "get_variable" - argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "global_variables" diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7589bb28888774839a3011e1e5581f004313f81d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.VariableSynchronization" +tf_class { + is_instance: "" + member { + name: "AUTO" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "ON_READ" + mtype: "" + } + member { + name: "ON_WRITE" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index 8c8912dfabb9b5ee7ce15725064f1bdf2fd35bfd..23b552cc38488bdc15d7deed20f563379dba24c3 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -43,6 +43,10 @@ tf_class { name: "shape" mtype: "" } + member { + name: "trainable" + mtype: "" + } member_method { name: "__init__" argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt index 8e7e945ed1bc26669d7c7f0ed3c2002df9f1883b..834f0954d5bba655a8eb923672d89bac6bb80808 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "cache" @@ -80,7 +80,7 @@ tf_class { } member_method { name: "padded_batch" - argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "prefetch" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt index 5cfb2fd2f0c6a7b733e70445aa130e96c512205e..4d854a4ceea3907d7d795d0a19d081f4069c9ba9 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -25,7 +25,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "cache" @@ -81,7 +81,7 @@ tf_class { } member_method { name: "padded_batch" - argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "prefetch" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt index 3327e5b274b43c0b424933cb086c894d47ad25cb..601f095a60ae481b895a535efa37341611499499 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt @@ -25,7 +25,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "cache" @@ -81,7 +81,7 @@ tf_class { } member_method { name: "padded_batch" - argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "prefetch" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt index 9d59375282b39564456b4c8aa49435c3836c58ea..587829a4c078e8ab945f66c64f5adad21223dfb1 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt @@ -25,7 +25,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "cache" @@ -81,7 +81,7 @@ tf_class { } member_method { name: "padded_batch" - argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "prefetch" diff --git a/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9efe97821904f5891148b72a0c31e02c9562bd7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.debugging" +tf_module { + member_method { + name: "check_numerics" + argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_finite" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_inf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_nan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..98e1feed002ceb4f455aa5ec361d26a159fdad1a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.dtypes" +tf_module { + member_method { + name: "as_string" + argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt index 099838fa65f6a532a594c08e8a44ead8ce008185..9dbb5d16a4e903a755c86bd0a6241180d1999f4d 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], " + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt index 87bd19a23a3db727b5c1f13de04e3c11fd91de9b..34a30c2874b90285706c9df6bec8cbbdc3451fe4 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], " + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt index 111914f643a3b192d496c5b0857b4429da12b1d6..0c6b7e4a821ad47c20b6f6074b575bf83c403653 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt index 67e4ee02d0581207e7dd316196aeb782930e7602..9c1c072124083006a1dd8e04526755dd980ba85a 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt index e1289b975e721e94f4a63889f3e0b76b0db23d81..7391d4b07a7e79541091b94fe4a9f38f42d6f68a 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt index d030b2f51f019ecc179a09b76c4484e60ada9dd0..f50e375f7cd392567f5c87536c95eb1f6809bc97 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt index cb578759eee2ed43465195a8c4e8760443a60b71..154f171e89571a43a3f905094a1dbd41cbb000d3 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt index fcd01bb663c7af22791c3855e6da22d93c667f84..4d46d1e6b68758bf634f9b0f82c279fdfa91a0b8 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index acc3fc4c5bbb767997b9844c7268b717184e4ea8..e89b4dbffdfe85f471fb1dd1b976cc701d526c64 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -20,6 +20,10 @@ tf_module { name: "adjust_hue" argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "adjust_jpeg_quality" + argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "adjust_saturation" argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -54,7 +58,7 @@ tf_module { } member_method { name: "decode_image" - argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " } member_method { name: "decode_jpeg" @@ -80,6 +84,10 @@ tf_module { name: "extract_glimpse" argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], " } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "extract_jpeg_shape" argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " @@ -110,7 +118,7 @@ tf_module { } member_method { name: "non_max_suppression" - argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'None\'], " + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " } member_method { name: "pad_to_bounding_box" @@ -144,6 +152,10 @@ tf_module { name: "random_hue" argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "random_jpeg_quality" + argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "random_saturation" argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -164,9 +176,13 @@ tf_module { name: "resize_image_with_crop_or_pad" argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "resize_image_with_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], " + } member_method { name: "resize_images" - argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\'], varargs=None, keywords=None, defaults=[\'0\', \'False\'], " + argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], " } member_method { name: "resize_nearest_neighbor" diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt index a6b6e5eceb62654c9ad567a361f7558a2865e57a..86340913e2506c96499aae05a3ed0d5273c93bba 100644 --- a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/tensorflow.io.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3a36c168aa703721421b662185fc852fa3d6a3ec --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.io.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.io" +tf_module { + member_method { + name: "decode_base64" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_compressed" + argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "decode_json_example" + argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_raw" + argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "encode_base64" + argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "matching_files" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "parse_tensor" + argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_file" + argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write_file" + argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index 2d02647eaab5928c9860433e3ac3e27326830aa0..11cdd6f0b5e48f5835385fdd4e3e5144fb7d5166 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.Model" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -127,7 +127,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index 60b0c1000ad4f185a84fcd842530d75d3fc8314e..4afad3e4df308d412a1c18dea3b4e99aa1d2c84f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.Sequential" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt index 454823fd23e72c6aa6bf6aa608707fa3b893b986..9eee9b378964a9947b067b7ec495ef6556ab6d0c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.BaseLogger" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt index 86b264c79f63ff78133f0989b5086984a3b16dbd..5bb949c5bb650acee91b14a4d6bf95b36029edf7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.CSVLogger" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt index 1474b392ff38c0c224725867006721096b951567..a5340d52c1af6d69da30fd710bcee9d832917574 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.callbacks.Callback" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt index 27d4a208a4108b107ed6a0ffbab733cb1e3d8f46..f71292856cd29b2e52194bec8a586686fbfad667 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt @@ -1,11 +1,11 @@ path: "tensorflow.keras.callbacks.EarlyStopping" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\'], " + argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\', \'baseline\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\', \'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt index a7b2deea8286df935db3a85e9569c3097b0b39ce..ee400b31c43829efba156298d5ee807cdafc8a98 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.History" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt index 5ee22948ad52ed082a8790a2127bbe4afc182049..df8d7b0ef7afca17338a26388c38827b5b306f95 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.LambdaCallback" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt index d4c85a4519eb922629f107ef7b61c3f11cb27163..ce1a9b694d8708720e0eb677afd25607c6262e9c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.LearningRateScheduler" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt index 79f9c88bbcaba136c544be1cb4b620b4ae55e17a..48bb24a05274addca03f11acef99607f78b92e51 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.ModelCheckpoint" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt index 543de0ad48b86502fc83374e5e6d82822485f331..d8bb8b2a7d0f491c7ec2b30096a1acaf04681a56 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.ProgbarLogger" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt index 805b1c350e8198f35b586e36b612731b83322ee5..dc27af9552a88650261b4f0694ea0265e6bda05c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.ReduceLROnPlateau" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt index 1d80559a5eeed339972d91d72628bdef2852973d..5a3b791c0adc0d61129d38b2995ee9077cf0988b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.RemoteMonitor" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt index 7de4008c4541b9054543927cad167293c5a4cf5c..2f52464315d8c1b526792c92f5cf8e83ce3ce087 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt index bf17e8736c50031c484f5c08bac65ee3566f7da3..5c2d336353aee7fc98b45620adac4f4bcda05ea0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.callbacks.TerminateOnNaN" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt index 14977c696fbe70a9d19f37581c926b6c0fdb3d11..8e07b7d98e1d832628f65bed19eddca76bfbd51a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.constraints.Constraint" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt index a2269f8a18f5b55ffa88031e8ef3d1c39e0bd423..2b81174b6cd4d57d8d6e20da7f6961442045d908 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.MaxNorm" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt index afe0d6478dde929aa98556d52ceece03c28c8e5f..a41eda86ac2583b1adfe745f713ac8f8647f7a31 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.MinMaxNorm" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt index e8c4bb90881ae389cd5215c21e44380b62cb7c9c..572e3eea4d985999f513a066b348d088ab01fe54 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.NonNeg" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt index d457cb6419ef86e83b5440554b2e97706440a734..fe16c38cc83fb9979ecf0d08ab2cba7a2c38f9b6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.UnitNorm" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt index 48128096d4638388c99cc62ecc23322a8d368124..6650bae07a0d32448e748598af3426f85ca8e199 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.max_norm" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt index 02eb3fb00c0ae516bac336066fc8ae5818e455d8..9dd3bc92fc4fadee863f30b300ddb60fe0b3d340 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.min_max_norm" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt index cc1101097ce9c4888e4b239f8ae16a58cabf31db..a565840939f99080b784e4e95302071600a1fa7c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.non_neg" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt index 086f9f2d43c3d340850f02df3e5bcb0cc5a5b8e5..5cbe0da4c1d1ff97fe836f76402cfca92e1cc511 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.constraints.unit_norm" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt index 32a6f6ee88815b3dc70e9cca855f73099554953b..03f4064b9ef5093044a9cbb897043d643cf7f83e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bddc37b907e7573c9fff27a0c3a5f7e199b88a9a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.constant" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a4c5a6149047ffdaadde1243e4c80feae05cd77b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.identity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7485772784d40b7bf552efe9bbe8b22fadee3b86 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a89f78d1e1a47c7cd5a252cfd0a7b2fa23979e90 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.ones" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ee1e9bbae2b7130db5b96309e2d87719169d788a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.orthogonal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt index 093c56595bd54eef4062d4ac9134e4bb3e7f7d98..14a667870d3118e48bfac03eee9accb3d48a72ce 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt @@ -40,6 +40,46 @@ tf_module { name: "Zeros" mtype: "" } + member { + name: "constant" + mtype: "" + } + member { + name: "identity" + mtype: "" + } + member { + name: "normal" + mtype: "" + } + member { + name: "ones" + mtype: "" + } + member { + name: "orthogonal" + mtype: "" + } + member { + name: "random_normal" + mtype: "" + } + member { + name: "random_uniform" + mtype: "" + } + member { + name: "truncated_normal" + mtype: "" + } + member { + name: "uniform" + mtype: "" + } + member { + name: "zeros" + mtype: "" + } member_method { name: "deserialize" argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a6df1e87a3f68fb16e32dce1ba4ee29f6d86e74e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.random_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..37a0fa0d5508de0026472ff1a3aa532bb8f343cd --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.random_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f97e93f0b72d5e959722d15fa9dc35869c550710 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.truncated_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..58186b1383d8997165bb457e1cb54df86cd02d11 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a262390687f31a5fb79822e69273306b9e1897b5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.zeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt index 010eaf7eb37bf23b079e424613563760eb2959e0..2bf973debb175d27bb80e627d7ccbb41b567020d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Activation" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt index 01d25110b23f9c1425be06d9e4301de2b0985259..03f20e72c2a325cec000cf4a5cfc0f1bbf255c8f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ActivityRegularization" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt index edd7ec0981aa57525ab63fbf0919c7ed4e2ee648..4b46b8d15afb0a2f636962b762e1808312c2f7c3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Add" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt index dd3c77a95e3de033b7f93b10ef6392ac7002a0bf..d8a1c76fd07634ef413152020a397897f2d5b97c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.AlphaDropout" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt index d9945d71cc20bcec1a50cb1583fced78bdad6199..622926bc4b8b2430ee1ab936665acb5744155e0d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.AveragePooling1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt index 3dfe41f4dec8c95f3c5333dd205852ffaf506dca..82100d8e09c8e95730993527293d2b72ce69f1d4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.AveragePooling2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt index 91f0cc9599e817190b4d82eac01ce49af53e9f0b..408061077cdeab2f8fd08c7e972744e5ee383f52 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.AveragePooling3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt index 11586b27bd997d877ffcc7151fc2420d03a88bf9..a3c80311043eeb95b06855f662a5e3d344803ba3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Average" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt index 3bb1bfb1381dc98c1e1862ff8e1992131b556a6f..e2dfaca29f86bd9d91d524ec337afad81e7f2da3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.AvgPool1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt index 3b36febd4495b3c78e2a98d24b1058a50995a244..4f068d2066a450bab77becc85a33662b78ad03e2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.AvgPool2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 21b8b0ecc95babb6c683a5f412ebfc6dddae77ec..b8c261a74364e9bb6bf8f6c7463993fbff5e9552 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.AvgPool3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt index df7c84934d90cb1251be160531aae4c74e7317f5..4ccd6cace650e2efd1583c75f6639c8598bb8f20 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.BatchNormalization" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt index 86f0a153e6c499a39d2c0a0302abd7c62bee0b16..2790e5fd850c24bd3e94cd15a6e079e1c9f79868 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Bidirectional" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt index 72e840cc5738fc1ee6fcc981a1803f1c40f52395..b1326bd0e6054b2a3fd36e7ad42cd3d4a0cad8dc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Concatenate" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index 8b77d3f30d4c397ef4d824ba6ca18b9be0508b1e..e3ac3dbf28da731e14640d5f464547d62391a28f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.ConvLSTM2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt index 117b941336dcc7c648549700e345a4077ea8845b..1117a695a395f495d988464bbf59d4b8e01877e6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Conv1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index aa64a99a458debb9339d7bf1cf1f413d68dc4de5..b9de1421428dcf61b988df343a22996cfb8fecef 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.Conv2DTranspose" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt index 02473004a3557dac01248c5436fde9c25412f39d..deb535e06e06008a17b80c8e13d8f01ad1535059 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Conv2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index 528c31e00294629ee9fe0c86af7b828ecab8d33c..9a9a223fbad11cafd8620110d80b27d5382dd29c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.Conv3DTranspose" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt index a16038e34ada2ed44fdc83237cfc78dd245ed765..1c59b0bdf624b09a7454f2d51698951a790f393a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Conv3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt index f8993208f39c3333dfa8e392d339fdeb616fafd6..30cf5489f4fcd4af3d0bd957fc9c576c57ee2bbd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Convolution1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index ad373fab8217d569e69a0601cd4e443aafd4d2c5..0ec69508d5a1992b46d1a7c65255cfb5408ab439 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.Convolution2DTranspose" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt index 17f5b982d01737e72f9cc38d3c55788605b44959..4cd8928403c98abad85bc1349a29148c73003c9d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Convolution2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index b67d1320c4bdf0b0df269bcea5b9f68f5d37f700..4b4912496deac2a79a5b0ea3d1ca0f8fa625301a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.Convolution3DTranspose" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt index 3b558711d8a040aaa803e1a4c4379bc953bd86b5..d0ad9cf56702e585e31a79de0f93d9efd48ed484 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Convolution3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt index 1c03f24bebadd424c979837ddbddcb8a770beeb3..98cff95a7fe9d4e58cf883502df08c58c651cd76 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Cropping1D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt index 6649e5b9fdbcd3212f149b5082ec041f2a9e6df1..2357498b46376ef13de102944b69931a9e7d3584 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Cropping2D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt index c676e861b4cadeff76d49ef31e68b52eb1cc4676..3324cbff304c5106360f3f3d3d608a528fa5fc31 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Cropping3D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt index eef2a589fc218d3b856ff15fbe2022240d101c77..6c81823654b78a936cded4a1d5a6f54e02dc7fc9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.CuDNNGRU" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt index 5a8bc2c11440bca326d4c37010b3d62d9f8dc6a9..487e04fd0790cb39ef6aee8d0498b3aae6726084 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.CuDNNLSTM" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt index 3142724cb3eee3a828378166ca823fab959fa691..137e7cced4e8113dd6a54a837e08cfd5af35c94d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Dense" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt index 468c2d0d31752e056e7747e6036315bb672f5b0b..7161665d2550c1cc3aff1c28f9d7676276b62303 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.DepthwiseConv2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt index 753dfa7759c07826351c07913c3ec2cf4a2f0792..24affa248121bcb1e1a947417a95ad4f5ba55ab2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Dot" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt index 178bfd5f5a4aec700549f585f359d4b1b055a909..7ba19a42695da37b4ad43cdde2c0d4978fd0a1eb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Dropout" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt index d0f223cafce074500e94ce7272efaaecac79a94f..503aa9162c3a78e9bb42ce16af98451441adbbb7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ELU" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt index 86b4ac12af1d87f202d2b1f7ff12f7e1796bec83..1737e590a29c5777b5eca2b4cb23081aa8ece738 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Embedding" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt index e73a5d1edd858d47b6e10675411b4f3424b865f7..021d024dc2150a75532ea7597d85f36efd2a3cf2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Flatten" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt index cd8a6fe4af84e3da779bfc8ec1b85cf7621ad66c..65387008bf3f78e404d8d8bbd7bb8cd3789bf256 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.GRUCell" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index f061b9ac345c3d636761fe4a4918181ec7fbb30d..4f791acf0585c95d6c0f1d5ea48e607f9a05188d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GRU" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt index fe2f5535a9c7e144689004f8c647558b95dd9b0f..abc30e54e0630a2d7b4de6074445e155e0ac2782 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.GaussianDropout" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt index fa36090e0c64dc51e7f40593d4c2e09e73baa6b1..20791bb448d17788ea4aebe4900169a70a9703d6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.GaussianNoise" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index e7fbd4e808c43332c14eb311bde2763665e439e3..449a91d8735c59f563360307cdb35c5a30344d82 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalAveragePooling1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index 19ab9aecc2dc29c9c758d944e3c3f6c015a0aeac..bb361e129728ddd42c21144937efbc617d98ba30 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalAveragePooling2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index e6f6254199dac38717fcaba031be4563326a2131..e564bf3216104a902fb6cfbe65b1e2b6dafc2524 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalAveragePooling3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 1390ef2fc814db7367a03fb30e7b39171375969f..4cb9cc3ec84d679b78465e43caa5a257466d5676 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalAvgPool1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index 9091bec1b577f77dccacf88be3ade4b06911c8d2..5ed52b88ae3e2dd25b560206db404952034a04cd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalAvgPool2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index cce7d96d82b290f5119d430d70da98d1b14d3125..f4559d29d75ef7cd8fcbdeac0a1a2c9e633246bc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalAvgPool3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index b91265751b1a95139b9d9eb036d5c8235c0a9361..64e2d061e26997365c461113d3ea15140fef64dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalMaxPool1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index 2a165a1d1f9cc67c6e7b08066c2624edd4a2015a..3372ad645388beb54f7ed9e3715449facba07f87 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalMaxPool2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index 5d082dea963d4de20059b2eeaef9f65915171cc3..08a6860bcd7d9a260e44af87c51796a9cc2af379 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalMaxPool3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 308ccf02118507638af06f5eb5adf7d334624187..22c9eab64fde41e1199ecbb1b8b03939653ecd00 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalMaxPooling1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index f22686742012cf7a79007ca6c812341702f2b74f..74c405ba9b1b465f89c4fef43020181a1a7f3d31 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalMaxPooling2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index 8fa2c0ae37be0e4213559a2e56f2844cf53023d6..39f6f981931296eb6d31eb6580f93b479ff64ce6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.GlobalMaxPooling3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt index 6283a344cc8e098fd296254bdc98c3d4ea9954d0..7b25e80b6b7653c5e76bf176b54110b1aabaf5ea 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.InputLayer" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt index 29edabe0483a21d7db35eec04d6ae7a855a82da1..5fd0a47a68c0d4ad218c4c64cc6be8f603d9673a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.InputSpec" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index bc190ec807eb34c54b85feb3552afff15690a9af..3619b8bfc44373ba6b8e306b020ac63d4b498573 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LSTMCell" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt index 69ea66f01ae61b89a6aa0c2730e54c1456f64037..8ef3d71dd82efc79e333770d4a7a7c8aee1a4202 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.LSTM" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt index fa395e5becde6352b39330bdc886183b36e015a2..ecbaa9ce2c76bf3d2964a6c79c96c4d67cc3b80e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Lambda" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt index 66260a7de779ce2c60bce849f65fb757523dc74b..9b90db1e5e56d1e5749669bba8dba1cdbd45bb55 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.Layer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -105,7 +105,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 914ce32a70a89cc6dd2e1a2fa785d9cb195ff4d7..3c60eaab7f1df15331004685676d74943d5d538f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LeakyReLU" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt index ede2e0583b0ad3b6628acc5c184e7d71d04c0b07..3dac1ff342ac1b7f984e9af5a6028ef71da701df 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LocallyConnected1D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt index eed43fbdb356e1fc6eef83b4565fae68efb406e6..7f1b5db4d34f706f2107ef43ab9c5acf67dac9f6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LocallyConnected2D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt index 15b59a9388c39d8e70d658bf65be65548d2f661c..b3e31000f3bca0821377d70b1d88a20aa8f8e4ef 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Masking" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt index b2a486e1e7cdca163197c41ca805846672124622..bbd9d1b0dc075bb9241f240b423933db20b38b75 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.MaxPool1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt index ebfc8c067ec9318333e45765d42cf2e681e3a9be..fe72beea802d12b996948b00436b274ee7e83177 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.MaxPool2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt index b28948d111a9145304b2676ce8024b13c0b0c209..e9bf57b2b0e60376a28c0abfc16fba393df3e73c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.MaxPool3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt index 2a55d2a675a707953ee0ce733daf2ee10ed618a7..0eecc58a2b6a2846a2c92502cc23bd328f8b5193 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.MaxPooling1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 6f80f18ed434d05682b27d0d65372b41a00ea579..96785a7d8559611a19b7f36216dbf0f8a3e39e61 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.MaxPooling2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt index 7c8c5b5c4545cedc448d864b700f7ede49df4f42..42c46cccb37b1ab7ece7760e6858b2180ea833b9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.MaxPooling3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt index ebecf555b84b2e2ae9e74f60da58e2734ae9adde..ac816f68d492cbfc5503c057a869e3e981de9190 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Maximum" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..56e32e9d3690a92c3f6e41bf2b5164c6bf62f443 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Minimum" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt index 841d81774ef8a0b753f99cb153c6be7ff02ca259..9ae99563e9a1b3b0700116ed88c13f94fafe1658 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.Multiply" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt index 5c5b51cd027524c4f9066006424f1722c9b67c1d..815f3bc2d142069adb4e418a4dc6ef82d683373f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.PReLU" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt index 72982667a95cf8fc1960bc2ebb9b8fb0d85ca952..e704992b4a18f6bdbd9474af2ee59ea81534d80a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Permute" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt index 5c273db8b431128a91f532d5f39245f3fb94423a..b3a58fa11eda61baa5c932bcc04fdca7459a215f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.RNN" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f3a96ab895dc9dbf8e2362dbcbfdccdf6af749ec --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_value\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt index 1be64d5ceb8e679e59f7503e4ae54b5c459efb1a..78f464583b4e8083f4cdd1a8c6b9f377645cd562 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.RepeatVector" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt index 1be686f80007e00cdea8fd301c61a2f019daa0dc..222344fd0497afe9a32d1d05ec37aa160479d88a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Reshape" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt index 68fc2249a5ac03dd1fafcbcb882bdd6b40834883..55fddf576cac6afabe984cd51e2ddbf112a55d25 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.SeparableConv1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 04774b4a9a6f45ad38d09f4855d802bd38c2e09d..96314ce49849a50ccc6b968b50c98ddae74c6c70 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.SeparableConv2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index 9bdde6f9134f502f413f0bcaaf33b691680574c6..88bdf9956603c590940e3ef857765586df7e91d7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.SeparableConvolution1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index a480454928873a5d0e88a8eab9ef5c9480391701..6eeea7a8d1312ada423206378b4c6ee079ffdd73 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.layers.SeparableConvolution2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 8bc0955c782457bb7efbb449db24ab7aac6b7a57..3050d46249003716eb0778104b729ee9cb52b34f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.SimpleRNNCell" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt index d8ee1557f110bb60cd053e8b4ffd6dfde80827ac..dda4c9358ba5faa084ad2e6cf75ff83b6a7b2b20 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.SimpleRNN" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt index 642c75396f7ba1a021087dff840de0182bcd81df..cc6275158b67e94c3c39802cc7c0f9e169c8b144 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Softmax" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index 45a3e091790f112c6419fef58fcdb679f278f27f..5eb7e750477b17571ef861305806894dd2b9ac38 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.SpatialDropout1D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index 3c61a8191844bf3177c09ca4ccde7a9308644d1e..500cb8c14ead3eeff28d11b72e2300cc471756d2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.SpatialDropout2D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 4909632c12564a75a7e74b2f2f7ef832e2aabff6..1113a7634fa98b499175d90ae7da2d3fb9fb1a13 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.SpatialDropout3D" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -107,7 +107,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index a376019d9b9407304ddee8608917cc41e927a2ee..c4b9f93561de6a5d8ecc19bbae17831466b51fe6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.StackedRNNCells" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..35ad87ad5d91f1cc5d413b0adc8e9e5d1403726a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Subtract" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index 71d5a91475a9be4e0102eea40f6b86741b32e165..282c98d79a6e1da46e4d7ea2e5c7228754792f09 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ThresholdedReLU" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt index c5cf655d06fa9c1fcb19dee3acddec824c9069d8..acab93706b29fedc1bf7b48da2f5b6636dea48e5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.TimeDistributed" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt index 801465a032bf6519e0653ba29eef46e14c7b990c..a5ec228a074721775d4ec0369345b5439d84e186 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.UpSampling1D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt index eba83057232feca165968a07cdda8929e75f080d..d8d8e0bfe95a6cf2ef61cdb344b963df3f21aabb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.UpSampling2D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt index a59bd305ae06cb16f25c9a57e1dc6d292a1312f3..97d6dc06fb2e883b20540e4496efa5b39a538263 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.UpSampling3D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt index 3ebb240898b2500bac9dc823cefe262601a2b85a..ea9bb41b9979de9049397892372f37aafc719a68 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Wrapper" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt index 193e354b4c1efa354886c5a680ea4eebeb03b4a2..e6d1d2e089b01c4eb212d01c456f6fa6b850f7de 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ZeroPadding1D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt index 55e71e152cebd285a628d586eb701da1b2567114..f62017305f26519181b1ef86bdd0946d44d16b88 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ZeroPadding2D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 97d34a4f24011954857e1d0037b0e7c7a31a489b..07a1fde5bdc35535ca5d8443a97cb85adc54b14a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ZeroPadding3D" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -106,7 +106,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt index 709eb5be55ef180ce9836def4bef601ea4315be0..9d7e5bb8c7808689bedd8abb835e61c1f38fdb1d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt @@ -280,6 +280,10 @@ tf_module { name: "Maximum" mtype: "" } + member { + name: "Minimum" + mtype: "" + } member { name: "Multiply" mtype: "" @@ -296,6 +300,10 @@ tf_module { name: "RNN" mtype: "" } + member { + name: "ReLU" + mtype: "" + } member { name: "RepeatVector" mtype: "" @@ -348,6 +356,10 @@ tf_module { name: "StackedRNNCells" mtype: "" } + member { + name: "Subtract" + mtype: "" + } member { name: "ThresholdedReLU" mtype: "" @@ -408,8 +420,16 @@ tf_module { name: "maximum" argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" } + member_method { + name: "minimum" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } member_method { name: "multiply" argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" } + member_method { + name: "subtract" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt index ae5f6305b7d1bb85c1c6acd8daf5628d83814b27..eca6b915388ebff0103f7ad16f43c6be0df60b7d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt @@ -1,5 +1,25 @@ path: "tensorflow.keras.losses" tf_module { + member_method { + name: "KLD" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAPE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSLE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "binary_crossentropy" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -12,6 +32,10 @@ tf_module { name: "categorical_hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "cosine" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "cosine_proximity" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -28,6 +52,10 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kld" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -36,6 +64,14 @@ tf_module { name: "logcosh" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mae" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mape" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mean_absolute_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -52,6 +88,14 @@ tf_module { name: "mean_squared_logarithmic_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mse" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "msle" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "poisson" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt index 42729e4237685638d38301cece6e93383ddfffba..a97a9b57587070ec4841b627920ac91737a67997 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt @@ -1,5 +1,25 @@ path: "tensorflow.keras.metrics" tf_module { + member_method { + name: "KLD" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAPE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSLE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "binary_accuracy" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -16,6 +36,10 @@ tf_module { name: "categorical_crossentropy" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "cosine" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "cosine_proximity" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -32,10 +56,22 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kld" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mae" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mape" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mean_absolute_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -52,6 +88,14 @@ tf_module { name: "mean_squared_logarithmic_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mse" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "msle" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "poisson" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index 9417f77f9dd8c2950c0d884648ac7c7b2de595c1..62aa929d32b57518abbe924c036062eb7ccd3acf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.models.Model" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -127,7 +127,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index e658f8594abeb6c87251f8defc4bd1b3ae376cc9..93ecbbce9b17b9ca6157e65bbabd6c36008c3992 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.keras.models.Sequential" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt index 32667cf31e4aaacf3374ca4a434f32eec5b3e07e..b9ce154bddef609e0aaf6627d6f59de551e51e3b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.Adadelta" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt index efca59e8e427d28de36446a49ea4e1ca0bb385eb..d0dc9e37a386a26143365eb443d5ba5fce8a87d9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.Adagrad" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt index 5546e2067ab65abce928d609b41b65bbc40246f6..06815fa99a4a474ec131c29d0cbc78bb2b9cb72d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.Adam" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt index aaa54a106066266d0a7c19f4609e4cc7ed766d95..47b55fdb44e79e976b6de13d760a7cf175323c6c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.Adamax" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt index 1fada7fd9c6eefbb16f1b5a042e6fea607a461a9..8c63a7dda98568b24ea1b3cda15d4c840fbfd804 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.Nadam" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt index ca47e952282e0c1a9ee85d8912e479a0ed5b4e86..53d64dae932e250b9d81b2767a833de3bac8c403 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.optimizers.Optimizer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt index fd3f97f35dcb18c82188c51345c2e3276a88f23f..a1e9b8cceb95e8f25ac5f414fadacf237be33cd9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.RMSprop" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt index 25adfd3f0bc89d9dbd3b2b8068e7b4ff99170909..a67fefb1bafebd62db9f6108f0fe1847b5d2e0cb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.optimizers.SGD" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt index ec0f3d892d9d03a738d34a40afe701e788908a8e..dddace87dca85cae378618fcf4d4e6d005ca9d4a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.preprocessing.image.DirectoryIterator" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt index f5bc04e44c198e5bc60f8361dd32e4ae00250468..c1e2e94f0bea933a630655eda205b6b6daf2eb93 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.preprocessing.image.ImageDataGenerator" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt index 69488d63bf118272d9b3f62027f10ff1c2dd0eff..825d9f1d1d6a828296458b831c65eecae391e0f6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.preprocessing.image.Iterator" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt index 42196ddeee7aab144537eef250c07060923fa6a9..75924a254a6a59232b1e9c9bd01ddb7445cda5d2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.preprocessing.image.NumpyArrayIterator" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt index d9c3215b555c19bc5cf4b32b0d227a9e1b63ce1e..326b1fa4fda1c0554efd8e6ba8dc93fdef0ede0f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.preprocessing.sequence.TimeseriesGenerator" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt index ce91caa1afe081ccf05ecdd4884a3e29ea93d496..b42b12b6c060f59c30590f7cc4892a09881d08d7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.preprocessing.text.Tokenizer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt index 04dcda38609c7114bdf6e2784938905fc3ef8af3..a45fb7b55e58a5679427752af22dce49203dc1cc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.regularizers.L1L2" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt index b0a125f238e58fb8b1213f52fc1fb85781ca5807..641001a646564d0a466739ee6d2bdd31a27beab7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.regularizers.Regularizer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt index dda39ed221a06827601a9432f887ddc5f5ee9b01..109682046b990107915d65be3cad86ead3e22688 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.utils.CustomObjectScope" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt index 1c5868e711beeeb072e41630f06ba7d9841defbb..939fd547d06bbd03b7e1a1db1404263ff01fd07c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.utils.GeneratorEnqueuer" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt index ce62c8bafcaec1bf2e6ab3989da68588f7c848e9..6b832051a975b61ba05874c3dda558c63aeaa055 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.utils.HDF5Matrix" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "dtype" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt index 16e1cbe650e1662f8694fd7137ad20a48a90675b..be4496e753f8bdcd76a4761f9bd1804a77380359 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.utils.Progbar" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt index 5cf2a07b0b265ba88d7942698640520d53a2f407..a9e499d1009b5a7458080db6c10a948af21c7b6c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.utils.SequenceEnqueuer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt index 5b272253e3767941b10d42ef5fef9c09433e9f59..e2dc932dc86dbba49d186e1dbc4bc026a52f6ef5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.utils.Sequence" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt index 8d200f99fd14d6a7735e1a74299159d6b198cd68..67cca3af41dbf68b963fb2315b65f9f843c9a42d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.wrappers.scikit_learn.KerasClassifier" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt index 7a971346d86f4930c7bba872031e049a93445d1d..f4b9b7e277ecdb155327d83c57ec2a997c043555 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.wrappers.scikit_learn.KerasRegressor" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt index c9feadbf5f7c15510faff514128822ffa2646020..11067058d5852669e1672bf3eb8b7c680d0e5dc9 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.AveragePooling1D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -117,7 +117,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt index 8405bee18a9ddde5cddf6a86ea4d037d28ffdaae..3259e706d7f7ea4d0348c1ee586c50f5a2c82b39 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.AveragePooling2D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -117,7 +117,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt index ffe517474d570acecbde91e92557e3275b7536cf..e561f2f415018840420232a97f0ece3f3c60d0d7 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.AveragePooling3D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -117,7 +117,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt index a50b83a67a25dbbf8a07611eff3ec90881b5544f..3124a35c7852a97e79a3cfe575017484f2f5731f 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.BatchNormalization" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt index 56d77595be1c027dca8ee27e2a57a600f8ca387a..b5ec61255ace78c1fa13370727eb5f5084522f4a 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.Conv1D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt index 6ab4e0aea49487b1e5b6a87a4c320e5862ae724e..b2c89ae66f53299289508eef174b5c44a6be2606 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -1,11 +1,11 @@ path: "tensorflow.layers.Conv2DTranspose" tf_class { is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt index e4d47c7eb008263df0eecbd698a61e54d8994485..9e4f4969dc6e1b6a39cf1d25c5e5e6175fa87c7c 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.Conv2D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt index 9195b548bec1775fbbe8f1dca95085da313980ae..9850e6d7659d311c93dabad73d35f2fcd028dd52 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -1,11 +1,11 @@ path: "tensorflow.layers.Conv3DTranspose" tf_class { is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt index 4d0033fef85724c752bea0b371ac806ab01b2cb9..be113826cc2b9589e1f8bbde896fbcbe183d4d1b 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.Conv3D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt index 7017921c61e7d3b5899271467f7599c747c84d7f..0d951bf6336ac7b65be57535c1065e5f87a77a0b 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.Dense" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt index 3381b5955cd4b080ba326bc548613fa5314cc791..f1beeed9ef0cb54318249e42b1279680ea117ba8 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.Dropout" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -116,7 +116,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt index af8f55aac68010edc6a0090848ecd8f8fc8ce0b4..b75a012811ff10f055382ea1315eaba506c24ed8 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.Flatten" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -116,7 +116,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-input-spec.pbtxt index 2ff89f0a6faef905bcafdcb36121f506e9a9977a..fd02c919aeb5a536bd052324618983af699e7c47 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-input-spec.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-input-spec.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.layers.InputSpec" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt index ff6c5b12010a95cebdaa0208ac23cc06987a9b15..80e0fb228b034727854ab1a4df97e25c6bc2cd97 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.layers.Layer" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt index aaabf135cee8ddeb008442d2b4810c76847d4b2d..50ff484d733633e20e9923dbbf1344af7b51ba9a 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.MaxPooling1D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -117,7 +117,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt index 813d322a96a8c1c2e469aa35f77020a69a0ced5b..cea809744cd07cc6ed0d1655f217cb5821e503e4 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.MaxPooling2D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -117,7 +117,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt index 62c46d9fa0f696cc618bef06c62a6e2ed594c916..ab9e89554c81decf5ee7e42dc963da9ab35e65c7 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.MaxPooling3D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -117,7 +117,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt index fb7af3e8881e243994e1105dea1bc0a2c161c358..4362568445e892d6127759c925d47426d49d9927 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt @@ -1,11 +1,11 @@ path: "tensorflow.layers.SeparableConv1D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt index d3dfb84ed39aee4c88c704fa31bd0bdec473c228..3cad824cd3b197b91a749347c860ff926610c081 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt @@ -1,11 +1,11 @@ path: "tensorflow.layers.SeparableConv2D" tf_class { is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt index 00b9238543367546cff96b736f73440214e99e22..3b5845f99a474ed976b91dab4f80ac2f231e7fc1 100644 --- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt @@ -68,6 +68,10 @@ tf_module { name: "cholesky_solve" argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "cross" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "det" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -140,6 +144,14 @@ tf_module { name: "svd" argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " } + member_method { + name: "tensor_diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensor_diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "tensordot" argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt index 0b84165285102daf0a8e3dd6542bfc391e50f77b..9add462396ea526ae94678e969c9acf5bce86df1 100644 --- a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt @@ -1,7 +1,35 @@ path: "tensorflow.manip" tf_module { + member_method { + name: "batch_to_space_nd" + argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather_nd" + argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reshape" + argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "roll" argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "scatter_nd" + argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch_nd" + argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tile" + argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt index 897718c05e0d10a6f961f33b8c65f5dab1d03f5b..a308c76ebc08df06c0c360579451ea70e60695d4 100644 --- a/tensorflow/tools/api/golden/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.math.pbtxt @@ -1,7 +1,239 @@ path: "tensorflow.math" tf_module { + member_method { + name: "acos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan2" + argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i0" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i0e" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i1e" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "betainc" + argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ceil" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "digamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erfc" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "exp" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "expm1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floor" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igammac" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert_permutation" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lgamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log1p" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_not" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polygamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "polyval" argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "reciprocal" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rint" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rsqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softsign" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squared_difference" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zeta" + argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index 455590d866a4c1ebea65ccff51e34f2e0b0479d7..d9e5b0d0fca8bbcf82feb34304f2a1e4f43f48dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -260,6 +260,10 @@ tf_module { name: "relu_layer" argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "safe_embedding_lookup_sparse" + argspec: "args=[\'embedding_weights\', \'sparse_ids\', \'sparse_weights\', \'combiner\', \'default_id\', \'name\', \'partition_strategy\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'mean\', \'None\', \'None\', \'div\', \'None\'], " + } member_method { name: "sampled_softmax_loss" argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index b1d335278dc2d332fb349daefd6bde4d5b87fcb5..a8d9e120cb4aa965c1d85df59de1fbabc196bf54 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 143247e5312f6783d35869e0ceb4040181e3c962..c039890e1f4c1d57e7b795f1f09cff71921f6554 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 18ce1cb08cdb2b1a9b90645a3c3f2fa21e8e7c66..62c393de34475a8806015bed187572f79cf2a196 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index b4ac45947b5d603ed8dfb986ecfd5d24374c8a6f..f121ba7939acb14681aa6b04b333668dded37aad 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 3cdfd6c7416f28eaff526a1460169159335710df..4583dc32b2e98d4a9912378fe0e3d841882772fd 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index fc7f72cb7453cdc357d275be9423c0ceee8a2a86..5016b6ac3010e2e184674db4837173c57c44b97e 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index dab10b38a52c021a06a00d66083cbbe4e897a06c..59623fc983a63c2966882aa5113423c0a9e23b72 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 79f299312bd07896667da559337a727c02a98c7a..e2ab5aaee9456ffbe42894f2384d7bc9c7ad6a6f 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.nn.rnn_cell.RNNCell" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index a29b6e8a51d4d55e0f32f73eeba99b13043caf52..bd2a6d61f8578a2a3c8d94d3a8d5eb49679df2f7 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 3051c4437e9a14bf0ef86adfa8c596b736a6172d..bf2533e1b5d992d818faefa8e5a53aa8f553fa0e 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -260,10 +260,18 @@ tf_module { name: "Variable" mtype: "" } + member { + name: "VariableAggregation" + mtype: "" + } member { name: "VariableScope" mtype: "" } + member { + name: "VariableSynchronization" + mtype: "" + } member { name: "WholeFileReader" mtype: "" @@ -308,6 +316,10 @@ tf_module { name: "data" mtype: "" } + member { + name: "debugging" + mtype: "" + } member { name: "distributions" mtype: "" @@ -316,6 +328,10 @@ tf_module { name: "double" mtype: "" } + member { + name: "dtypes" + mtype: "" + } member { name: "errors" mtype: "" @@ -380,6 +396,10 @@ tf_module { name: "int8" mtype: "" } + member { + name: "io" + mtype: "" + } member { name: "keras" mtype: "" @@ -456,6 +476,10 @@ tf_module { name: "qint8" mtype: "" } + member { + name: "quantization" + mtype: "" + } member { name: "quint16" mtype: "" @@ -792,6 +816,10 @@ tf_module { name: "broadcast_static_shape" argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "broadcast_to" + argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "case" argspec: "args=[\'pred_fn_pairs\', \'default\', \'exclusive\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\', \'case\'], " @@ -1130,7 +1158,7 @@ tf_module { } member_method { name: "get_local_variable" - argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], " } member_method { name: "get_seed" @@ -1146,7 +1174,7 @@ tf_module { } member_method { name: "get_variable" - argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "get_variable_scope" @@ -1290,7 +1318,7 @@ tf_module { } member_method { name: "lbeta" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'lbeta\'], " + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "less" @@ -2170,7 +2198,7 @@ tf_module { } member_method { name: "while_loop" - argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " } member_method { name: "write_file" diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt index bd5c36f390add9cfb31642b80a792d65d59bb3e8..e09c44cc9ce71305692740ba2d63b0940b2e0573 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt @@ -1,80 +1,12 @@ path: "tensorflow.profiler.AdviceProto.Checker" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "REPORTS_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Checker" + field { + name: "reports" + number: 2 + label: LABEL_REPEATED + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt index 7c8c68e155c99da4f0c1c1ba2c944719c42c12c7..87462435496fd2eedeb0bc8d92e8a833671b6531 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.profiler.AdviceProto.CheckersEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "CheckersEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.AdviceProto.Checker" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt index 1b789f4fc92ed63fc72f3ecfe6be80a99eb3427f..a8a8858ccd5af3fb3dac612eef44e5cb450df914 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt @@ -1,88 +1,41 @@ path: "tensorflow.profiler.AdviceProto" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CHECKERS_FIELD_NUMBER" - mtype: "" - } - member { - name: "Checker" - mtype: "" - } - member { - name: "CheckersEntry" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "AdviceProto" + field { + name: "checkers" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.AdviceProto.CheckersEntry" + } + nested_type { + name: "CheckersEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.AdviceProto.Checker" + } + options { + map_entry: true + } + } + nested_type { + name: "Checker" + field { + name: "reports" + number: 2 + label: LABEL_REPEATED + type: TYPE_STRING + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt index f0b9605bee1c7cf2f0154f65c475aac49c411f76..afec73f537aadd5d1a274db8d57e37b8c6fa3e74 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.profiler.GraphNodeProto.InputShapesEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "InputShapesEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt index b80896a8a0f36a9f1c4da76528b9c4e70500ad4c..3c83177005323a277f929d8c769cd7b1eeff4d51 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt @@ -1,188 +1,191 @@ path: "tensorflow.profiler.GraphNodeProto" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "CHILDREN_FIELD_NUMBER" - mtype: "" - } - member { - name: "CPU_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "DEVICES_FIELD_NUMBER" - mtype: "" - } - member { - name: "EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FLOAT_OPS_FIELD_NUMBER" - mtype: "" - } - member { - name: "INPUT_SHAPES_FIELD_NUMBER" - mtype: "" - } - member { - name: "InputShapesEntry" - mtype: "" - } - member { - name: "NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "OUTPUT_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "PARAMETERS_FIELD_NUMBER" - mtype: "" - } - member { - name: "PEAK_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "REQUESTED_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "RESIDUAL_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "RUN_COUNT_FIELD_NUMBER" - mtype: "" - } - member { - name: "SHAPES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSOR_VALUE_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_CPU_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_DEFINITION_COUNT_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_FLOAT_OPS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_OUTPUT_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_PARAMETERS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_PEAK_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_REQUESTED_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_RESIDUAL_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_RUN_COUNT_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "GraphNodeProto" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tensor_value" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.TFProfTensorProto" + } + field { + name: "run_count" + number: 21 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "exec_micros" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "accelerator_exec_micros" + number: 17 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "cpu_exec_micros" + number: 18 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "requested_bytes" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "peak_bytes" + number: 24 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "residual_bytes" + number: 25 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "output_bytes" + number: 26 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "parameters" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "float_ops" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "devices" + number: 10 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "total_definition_count" + number: 23 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_run_count" + number: 22 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_exec_micros" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_accelerator_exec_micros" + number: 19 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_cpu_exec_micros" + number: 20 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_requested_bytes" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_peak_bytes" + number: 27 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_residual_bytes" + number: 28 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_output_bytes" + number: 29 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_parameters" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_float_ops" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "shapes" + number: 11 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + field { + name: "input_shapes" + number: 16 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.GraphNodeProto.InputShapesEntry" + } + field { + name: "children" + number: 12 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.GraphNodeProto" + } + nested_type { + name: "InputShapesEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt index 33deff649791322c0e8361a46de6d180f23be9a9..2b08a05437f90b91160fc08e670b2466ae163149 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt @@ -1,160 +1,134 @@ path: "tensorflow.profiler.MultiGraphNodeProto" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "CHILDREN_FIELD_NUMBER" - mtype: "" - } - member { - name: "CPU_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FLOAT_OPS_FIELD_NUMBER" - mtype: "" - } - member { - name: "GRAPH_NODES_FIELD_NUMBER" - mtype: "" - } - member { - name: "NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "OUTPUT_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "PARAMETERS_FIELD_NUMBER" - mtype: "" - } - member { - name: "PEAK_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "REQUESTED_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "RESIDUAL_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_CPU_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_EXEC_MICROS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_FLOAT_OPS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_OUTPUT_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_PARAMETERS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_PEAK_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_REQUESTED_BYTES_FIELD_NUMBER" - mtype: "" - } - member { - name: "TOTAL_RESIDUAL_BYTES_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "MultiGraphNodeProto" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "exec_micros" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "accelerator_exec_micros" + number: 12 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "cpu_exec_micros" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "requested_bytes" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "peak_bytes" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "residual_bytes" + number: 17 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "output_bytes" + number: 18 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "parameters" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "float_ops" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_exec_micros" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_accelerator_exec_micros" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_cpu_exec_micros" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_requested_bytes" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_peak_bytes" + number: 19 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_residual_bytes" + number: 20 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_output_bytes" + number: 21 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_parameters" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_float_ops" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "graph_nodes" + number: 10 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.GraphNodeProto" + } + field { + name: "children" + number: 11 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.MultiGraphNodeProto" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt index 8c4727cf35bdfcdaf3c45e636f6627b3c85102d6..b3adc50c7e14152a81a148df9deccc5272189aad 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt @@ -1,84 +1,21 @@ path: "tensorflow.profiler.OpLogProto.IdToStringEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "IdToStringEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt index 1071a82b5ce1396e235b37af1991908dbd4ca7a4..7510c566ba574e9370f5e54c29023ef4fb5ee804 100644 --- a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt @@ -1,88 +1,38 @@ path: "tensorflow.profiler.OpLogProto" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "ID_TO_STRING_FIELD_NUMBER" - mtype: "" - } - member { - name: "IdToStringEntry" - mtype: "" - } - member { - name: "LOG_ENTRIES_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "OpLogProto" + field { + name: "log_entries" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.OpLogEntry" + } + field { + name: "id_to_string" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.OpLogProto.IdToStringEntry" + } + nested_type { + name: "IdToStringEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6d865efed0bfdada8dde64e86ddb5d2b2b364c79 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.quantization" +tf_module { + member_method { + name: "dequantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "quantized_concat" + argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt index ca8e5884b18110d4293225e595c030e9629b5663..83bd7035409534abf036c7e2b0d66fcc060ada3a 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt @@ -8,11 +8,11 @@ tf_class { } member_method { name: "add_meta_graph" - argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], " } member_method { name: "add_meta_graph_and_variables" - argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], " } member_method { name: "save" diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt index 896e2160c693039ab5582be13286f387c08d8f37..511e6b4712d3c55746a39fe9098fa3b649bc75dc 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader" tf_module { member_method { name: "load" - argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None" + argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], " } member_method { name: "maybe_saved_model_directory" diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt index 4f306540ccfdeac8ce59a394ec77b24284f13ceb..6a421ef12d58dc047905ec916cbe777b4ce19b9a 100644 --- a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "fft3d" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "idct" + argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], " + } member_method { name: "ifft" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt index a3fbe95bbad4b8c1d803e1002b2cf9ef2812fed0..9a831fed2692b30db6ce991c86f46a42908c0789 100644 --- a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt @@ -1,7 +1,43 @@ path: "tensorflow.strings" tf_module { + member_method { + name: "join" + argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } member_method { name: "regex_full_match" argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + } + member_method { + name: "strip" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "substr" + argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket" + argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_fast" + argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_strong" + argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_number" + argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt index ab3449d80f6108d83b721563427bd07d07a7104b..eb99d0f5334457aa654fed0553af143839328dba 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt @@ -1,112 +1,74 @@ path: "tensorflow.summary.Event" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FILE_VERSION_FIELD_NUMBER" - mtype: "" - } - member { - name: "GRAPH_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "LOG_MESSAGE_FIELD_NUMBER" - mtype: "" - } - member { - name: "META_GRAPH_DEF_FIELD_NUMBER" - mtype: "" - } - member { - name: "SESSION_LOG_FIELD_NUMBER" - mtype: "" - } - member { - name: "STEP_FIELD_NUMBER" - mtype: "" - } - member { - name: "SUMMARY_FIELD_NUMBER" - mtype: "" - } - member { - name: "TAGGED_RUN_METADATA_FIELD_NUMBER" - mtype: "" - } - member { - name: "WALL_TIME_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Event" + field { + name: "wall_time" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "step" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "file_version" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + field { + name: "graph_def" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "summary" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary" + oneof_index: 0 + } + field { + name: "log_message" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.LogMessage" + oneof_index: 0 + } + field { + name: "session_log" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SessionLog" + oneof_index: 0 + } + field { + name: "tagged_run_metadata" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TaggedRunMetadata" + oneof_index: 0 + } + field { + name: "meta_graph_def" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + oneof_decl { + name: "what" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt index 92ca4872caf1c1ce7e19201b0a612c1a74ef59b0..73de73869c8d1a6808b16fe8853fd21cc8891879 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt @@ -1,108 +1,44 @@ path: "tensorflow.summary.SessionLog" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CHECKPOINT" - mtype: "" - } - member { - name: "CHECKPOINT_PATH_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "MSG_FIELD_NUMBER" - mtype: "" - } - member { - name: "START" - mtype: "" - } - member { - name: "STATUS_FIELD_NUMBER" - mtype: "" - } - member { - name: "STATUS_UNSPECIFIED" - mtype: "" - } - member { - name: "STOP" - mtype: "" - } - member { - name: "SessionStatus" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SessionLog" + field { + name: "status" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.SessionLog.SessionStatus" + } + field { + name: "checkpoint_path" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "msg" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + enum_type { + name: "SessionStatus" + value { + name: "STATUS_UNSPECIFIED" + number: 0 + } + value { + name: "START" + number: 1 + } + value { + name: "STOP" + number: 2 + } + value { + name: "CHECKPOINT" + number: 3 + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt index f93da2196adbc28524f93746a8e047b5c0f610d8..4a8b59cf02ed46ef70f22564f3134214840600fe 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt @@ -1,80 +1,12 @@ path: "tensorflow.summary.SummaryDescription" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "TYPE_HINT_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SummaryDescription" + field { + name: "type_hint" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt index 605e305e82cc3f4dd6a0bce68f846a43347a00e2..8b271cf58fc11c8666abd456021afeedc0b14c7a 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt @@ -1,96 +1,36 @@ path: "tensorflow.summary.Summary.Audio" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CONTENT_TYPE_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "ENCODED_AUDIO_STRING_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "LENGTH_FRAMES_FIELD_NUMBER" - mtype: "" - } - member { - name: "NUM_CHANNELS_FIELD_NUMBER" - mtype: "" - } - member { - name: "SAMPLE_RATE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt index 0646972196dc728b3f39aad07540aa7b6893ab88..dbbc02dd0506dbcebd1690602b5786b02c3ed4a0 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt @@ -1,92 +1,30 @@ path: "tensorflow.summary.Summary.Image" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "COLORSPACE_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "ENCODED_IMAGE_STRING_FIELD_NUMBER" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "HEIGHT_FIELD_NUMBER" - mtype: "" - } - member { - name: "WIDTH_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt index b319cd03d9e8678b36ae107b581aea431a969b4e..4176171cd938e383fe5366153364d8e8e8c1a1ee 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt @@ -1,112 +1,74 @@ path: "tensorflow.summary.Summary.Value" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "AUDIO_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "HISTO_FIELD_NUMBER" - mtype: "" - } - member { - name: "IMAGE_FIELD_NUMBER" - mtype: "" - } - member { - name: "METADATA_FIELD_NUMBER" - mtype: "" - } - member { - name: "NODE_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "OBSOLETE_OLD_STYLE_HISTOGRAM_FIELD_NUMBER" - mtype: "" - } - member { - name: "SIMPLE_VALUE_FIELD_NUMBER" - mtype: "" - } - member { - name: "TAG_FIELD_NUMBER" - mtype: "" - } - member { - name: "TENSOR_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt index 132ef1b7d2e933c3fe953ca2eb19b32133db8186..d6c5e3a87a115b9bdcfd044abe93177eda2af275 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt @@ -1,92 +1,144 @@ path: "tensorflow.summary.Summary" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "Audio" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "Image" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member { - name: "Value" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Summary" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Value" + } + nested_type { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } + nested_type { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + nested_type { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt index 4dce20819de06fb3a31d6b044a8c751c22da5c74..27c8873320403cb2e7402ef9f1bb0e7134d5f96b 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt @@ -1,84 +1,18 @@ path: "tensorflow.summary.TaggedRunMetadata" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "RUN_METADATA_FIELD_NUMBER" - mtype: "" - } - member { - name: "TAG_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "TaggedRunMetadata" + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "run_metadata" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt index 8cf52b817f342a3ccd8bcf5f4f532b886a318f23..87e4f160e5bd5950dfc338649fb531c92cc81b60 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt @@ -1,80 +1,12 @@ path: "tensorflow.train.BytesList" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "BytesList" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_BYTES + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt index ddc553d7c984b24fe33c03bb90e00e7e81f55d26..2d067e4eff13208cb03ca01b7b8a8018a1e99097 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.train.Checkpoint" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt index 93ff856b09de15f12954bb11802a935b82c1d278..f9de26839f5f6dc1591bfc909ca8e6c02271b5c7 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt @@ -1,80 +1,13 @@ path: "tensorflow.train.ClusterDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "JOB_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "ClusterDef" + field { + name: "job" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.JobDef" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt index f7215a20372e981a2fb20f20d9e4cfa43973c7cc..23c30f1ef4fe2dd93e8714655dbb1ef3b8e05c65 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt @@ -1,80 +1,13 @@ path: "tensorflow.train.Example" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FEATURES_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Example" + field { + name: "features" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Features" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt index 737acbe07c93da30b4a206cbdae2efcbc2cb2159..c9fe136e68b5f3cadaff6d4fd0638b7f10d18365 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt @@ -2,6 +2,10 @@ path: "tensorflow.train.ExponentialMovingAverage" tf_class { is_instance: "" is_instance: "" + member { + name: "name" + mtype: "" + } member_method { name: "__init__" argspec: "args=[\'self\', \'decay\', \'num_updates\', \'zero_debias\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'ExponentialMovingAverage\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt index 3ad98354d69453d6f66a858991d4a19e2525d1e0..2a8b3714fc0c4f5e979bc02550a8e08835d53cb4 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt @@ -1,80 +1,13 @@ path: "tensorflow.train.FeatureList" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FEATURE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "FeatureList" + field { + name: "feature" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Feature" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt index cd171f4ca3ef1e48848be1bd71f8a56685534b8c..cd1d56e606c96b62346b936001a5a0f07a8a8ad8 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.train.FeatureLists.FeatureListEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "FeatureListEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureList" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt index 3d95017d584ad95f96a54ef52a966aa6f2a69a58..3c183a64769b59b104c52b6840e8f351f4b0cef5 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt @@ -1,84 +1,32 @@ path: "tensorflow.train.FeatureLists" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FEATURE_LIST_FIELD_NUMBER" - mtype: "" - } - member { - name: "FeatureListEntry" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "FeatureLists" + field { + name: "feature_list" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureLists.FeatureListEntry" + } + nested_type { + name: "FeatureListEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureList" + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt index 9cca132bba91c46398c2fecb4ff7b45bd5ed2af2..5d0eb871c2f4aeb13d6b8518486f11b1f80d0620 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt @@ -1,88 +1,33 @@ path: "tensorflow.train.Feature" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "BYTES_LIST_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FLOAT_LIST_FIELD_NUMBER" - mtype: "" - } - member { - name: "INT64_LIST_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Feature" + field { + name: "bytes_list" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.BytesList" + oneof_index: 0 + } + field { + name: "float_list" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FloatList" + oneof_index: 0 + } + field { + name: "int64_list" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Int64List" + oneof_index: 0 + } + oneof_decl { + name: "kind" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt index 858aee03415dead500cdb450f5885a904f620221..f912005f1cc35f12ce6eba5313b0c67adebe70f7 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt @@ -1,84 +1,22 @@ path: "tensorflow.train.Features.FeatureEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "FeatureEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Feature" + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt index 49cd12153bf3078eb1e68cfd6efad6e2673439f4..b788ca1d57e1d679a1b809d85c6aa9bcef01f252 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt @@ -1,84 +1,32 @@ path: "tensorflow.train.Features" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FEATURE_FIELD_NUMBER" - mtype: "" - } - member { - name: "FeatureEntry" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Features" + field { + name: "feature" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Features.FeatureEntry" + } + nested_type { + name: "FeatureEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Feature" + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt index e3f01334b547feef87d07166eb3784659c41d542..55d3b46f20e17ec4e6fbac5672e1b0a8ef98552d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt @@ -1,80 +1,15 @@ path: "tensorflow.train.FloatList" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "FloatList" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_FLOAT + options { + packed: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt index 8917dc122cfd0b0a7de0a3a74da3c45104d9eaff..1de92b3ab7b5e0ff873a7e8092c7e6c2edcbd2ce 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt @@ -1,80 +1,15 @@ path: "tensorflow.train.Int64List" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "Int64List" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_INT64 + options { + packed: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt index ac6d81541a43e934ebd131afe07be0bd6e427a7b..58115590a5eebd742afac4b31b5f585e8077e049 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt @@ -1,84 +1,21 @@ path: "tensorflow.train.JobDef.TasksEntry" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "KEY_FIELD_NUMBER" - mtype: "" - } - member { - name: "VALUE_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "TasksEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt index ce34537fa13b92f7900128d769ac3161d2b4d287..d7eb505e27930d6411a589909584f237a7e8b8f5 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt @@ -1,88 +1,37 @@ path: "tensorflow.train.JobDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "TASKS_FIELD_NUMBER" - mtype: "" - } - member { - name: "TasksEntry" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "JobDef" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tasks" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.JobDef.TasksEntry" + } + nested_type { + name: "TasksEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt index 84498a64f5b04526e989ec03f1894dcea19d850e..4ec99469e4025603e7ab340b190cbebf7e33eed7 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt @@ -1,120 +1,64 @@ path: "tensorflow.train.SaverDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CheckpointFormatVersion" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FILENAME_TENSOR_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "KEEP_CHECKPOINT_EVERY_N_HOURS_FIELD_NUMBER" - mtype: "" - } - member { - name: "LEGACY" - mtype: "" - } - member { - name: "MAX_TO_KEEP_FIELD_NUMBER" - mtype: "" - } - member { - name: "RESTORE_OP_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "SAVE_TENSOR_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "SHARDED_FIELD_NUMBER" - mtype: "" - } - member { - name: "V1" - mtype: "" - } - member { - name: "V2" - mtype: "" - } - member { - name: "VERSION_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SaverDef" + field { + name: "filename_tensor_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "save_tensor_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "restore_op_name" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "max_to_keep" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "sharded" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "keep_checkpoint_every_n_hours" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "version" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.SaverDef.CheckpointFormatVersion" + } + enum_type { + name: "CheckpointFormatVersion" + value { + name: "LEGACY" + number: 0 + } + value { + name: "V1" + number: 1 + } + value { + name: "V2" + number: 2 + } + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt index 9ab95537021167f368d3a8f6b1e1ec1a3996aa88..6a4553bbc157960696ef17959f532fecdfd54ae8 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt @@ -1,84 +1,20 @@ path: "tensorflow.train.SequenceExample" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CONTEXT_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "FEATURE_LISTS_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "SequenceExample" + field { + name: "context" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Features" + } + field { + name: "feature_lists" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureLists" + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt index af0a3b73cc2ff3510e9a0426c28696fe51097f9d..83ee7b3eb91a558765abcde630fe6e0480b9818f 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt @@ -1,96 +1,38 @@ path: "tensorflow.train.ServerDef" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "CLUSTER_FIELD_NUMBER" - mtype: "" - } - member { - name: "DEFAULT_SESSION_CONFIG_FIELD_NUMBER" - mtype: "" - } - member { - name: "DESCRIPTOR" - mtype: "" - } - member { - name: "Extensions" - mtype: "" - } - member { - name: "JOB_NAME_FIELD_NUMBER" - mtype: "" - } - member { - name: "PROTOCOL_FIELD_NUMBER" - mtype: "" - } - member { - name: "TASK_INDEX_FIELD_NUMBER" - mtype: "" - } - member_method { - name: "ByteSize" - } - member_method { - name: "Clear" - } - member_method { - name: "ClearExtension" - } - member_method { - name: "ClearField" - } - member_method { - name: "CopyFrom" - } - member_method { - name: "DiscardUnknownFields" - } - member_method { - name: "FindInitializationErrors" - } - member_method { - name: "FromString" - } - member_method { - name: "HasExtension" - } - member_method { - name: "HasField" - } - member_method { - name: "IsInitialized" - } - member_method { - name: "ListFields" - } - member_method { - name: "MergeFrom" - } - member_method { - name: "MergeFromString" - } - member_method { - name: "ParseFromString" - } - member_method { - name: "RegisterExtension" - } - member_method { - name: "SerializePartialToString" - } - member_method { - name: "SerializeToString" - } - member_method { - name: "SetInParent" - } - member_method { - name: "WhichOneof" - } - member_method { - name: "__init__" +tf_proto { + descriptor { + name: "ServerDef" + field { + name: "cluster" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDef" + } + field { + name: "job_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "task_index" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "default_session_config" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto" + } + field { + name: "protocol" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt index cc31bb4e4b396917a00d1162125b6d2e47343322..448764fe081b250e1e22633f118268ad638cb9dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\'], " + argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], " } member_method { name: "prepare_session" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt index 1f0e59a1ac2d899a50ff30c7c8da8f91a0258a1e..9677e5a98e4a8308093f51a84d8b1edae405cd2b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt @@ -104,7 +104,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "loop" diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 9fb18e77afd7c9c989ad5e967be291406e7239aa..b0fb04d7d4d71e8cb2630ca79284e0ade1db8571 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -242,7 +242,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\', \'None\'], " } member_method { name: "NewCheckpointReader" @@ -400,6 +400,10 @@ tf_module { name: "range_input_producer" argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " } + member_method { + name: "remove_checkpoint" + argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], " + } member_method { name: "replica_device_setter" argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt index a58398d645e8397dc8e61a6e0241710c3e34218f..09d7bc03b4f238923db6778ec32ce78ae76eed61 100644 --- a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/lib/api_objects.proto b/tensorflow/tools/api/lib/api_objects.proto index 0966a5f1d530ecd70c9e904c12816f0aa33b3ada..7207b9c5a9f4db7a8efcea3207adf1eb99df7d5b 100644 --- a/tensorflow/tools/api/lib/api_objects.proto +++ b/tensorflow/tools/api/lib/api_objects.proto @@ -1,5 +1,7 @@ syntax = "proto2"; +import "google/protobuf/descriptor.proto"; + package third_party.tensorflow.tools.api; message TFAPIMember { @@ -24,8 +26,17 @@ message TFAPIClass { repeated TFAPIMethod member_method = 3; }; +message TFAPIProto { + // Suppress generation of the proto API's descriptor() method lest it + // conflict with the standard accessor for the field having the same name. + option no_standard_descriptor_accessor = true; + + optional google.protobuf.DescriptorProto descriptor = 1; +}; + message TFAPIObject { optional string path = 1; optional TFAPIModule tf_module = 2; optional TFAPIClass tf_class = 3; + optional TFAPIProto tf_proto = 4; }; diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 0b30f7b4d12e3d78b12f8a229b19cb0e8d530d56..1cf330e70247260cd9e50b18903bdfecad6260e4 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from google.protobuf import message from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -101,6 +102,11 @@ def _SanitizedMRO(obj): return return_list +def _IsProtoClass(obj): + """Returns whether the passed obj is a Protocol Buffer class.""" + return isinstance(obj, type) and issubclass(obj, message.Message) + + class PythonObjectToProtoVisitor(object): """A visitor that summarizes given python objects as protobufs.""" @@ -153,6 +159,13 @@ class PythonObjectToProtoVisitor(object): # Store the constructed module object. self._protos[lib_path] = api_objects_pb2.TFAPIObject( path=lib_path, tf_module=module_obj) + elif _IsProtoClass(parent): + proto_obj = api_objects_pb2.TFAPIProto() + parent.DESCRIPTOR.CopyToProto(proto_obj.descriptor) + + # Store the constructed proto object. + self._protos[lib_path] = api_objects_pb2.TFAPIObject( + path=lib_path, tf_proto=proto_obj) elif tf_inspect.isclass(parent): # Construct a class. class_obj = api_objects_pb2.TFAPIClass() @@ -161,7 +174,7 @@ class PythonObjectToProtoVisitor(object): if name in parent_corner_cases: # If we have an empty entry, skip this object. if parent_corner_cases[name]: - module_obj.member.add(**(parent_corner_cases[name])) + class_obj.member.add(**(parent_corner_cases[name])) else: _AddMember(name, child, class_obj) diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 1ad6b6d1c0ae5ca1ac1329fd49e972840020e4c3..90375a794f64a9edd2bab2671f5870ae02e84e3c 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -35,6 +35,7 @@ import unittest import tensorflow as tf +from google.protobuf import message from google.protobuf import text_format from tensorflow.python.lib.io import file_io @@ -195,6 +196,25 @@ class ApiCompatibilityTest(test.TestCase): else: logging.info('No differences found between API and golden.') + def testNoSubclassOfMessage(self): + + def Visit(path, parent, unused_children): + """A Visitor that crashes on subclasses of generated proto classes.""" + # If the traversed object is a proto Message class + if not (isinstance(parent, type) and + issubclass(parent, message.Message)): + return + if parent is message.Message: + return + # Check that it is a direct subclass of Message. + if message.Message not in parent.__bases__: + raise NotImplementedError( + 'Object tf.%s is a subclass of a generated proto Message. ' + 'They are not yet supported by the API tools.' % path) + visitor = public_api.PublicAPIVisitor(Visit) + visitor.do_not_descend_map['tf'].append('contrib') + traverse.traverse(tf, visitor) + @unittest.skipUnless( sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index eeb1fab40c4c7a2fff417b18111ecfe8ceabc71a..de93b12b97081feea5be96edf3b6e6dfbe5599b4 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -667,12 +667,12 @@ int Main(int argc, char** argv) { output_prefix, benchmark_name, "meta-init-plus-first-inference", 1, initialization_time_s + (warmup_time_us / 1000000.0) / warmup_runs); - std::map node_type_map_count; - std::map node_type_map_time; - std::map node_type_map_memory; - std::map node_type_map_times_called; + std::map node_type_map_count; + std::map node_type_map_time; + std::map node_type_map_memory; + std::map node_type_map_times_called; - int64 accumulated_us; + int64_t accumulated_us; stats->ComputeStatsByType(&node_type_map_count, &node_type_map_time, &node_type_map_memory, &node_type_map_times_called, &accumulated_us); diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake index d5dea4f3e41841aed5aeac02fcca850dbfdfaeb3..e8c319982839b7b5adc17d6fb7ac364660ac76fe 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cmake +++ b/tensorflow/tools/ci_build/Dockerfile.cmake @@ -28,6 +28,8 @@ RUN pip install --upgrade astor RUN pip install --upgrade gast RUN pip install --upgrade numpy RUN pip install --upgrade termcolor +RUN pip install keras_applications==1.0.2 +RUN pip install keras_preprocessing==1.0.1 # Install golang RUN apt-get install -t xenial-backports -y golang-1.9 diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le new file mode 100644 index 0000000000000000000000000000000000000000..e879c34bbdadd7b90973fda0f7c3fdb71a385856 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le @@ -0,0 +1,20 @@ +FROM ubuntu:16.04 + +LABEL maintainer="William Irons " + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa +RUN /install/install_deb_packages.sh +RUN apt-get update && apt-get install -y libopenblas-dev +RUN /install/install_hdf5_ppc64le.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel_from_source.sh +RUN /install/install_proto3.sh +RUN /install/install_buildifier_from_source.sh +RUN /install/install_auditwheel.sh +RUN /install/install_golang_ppc64le.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le new file mode 100644 index 0000000000000000000000000000000000000000..89671387472a15c112a09fa2fa7a9798446d135b --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le @@ -0,0 +1,28 @@ +FROM nvidia/cuda-ppc64le:9.0-cudnn7-devel-ubuntu16.04 + +LABEL maintainer="William Irons " + +# In the Ubuntu 16.04 images, cudnn is placed in system paths. Move them to +# /usr/local/cuda +RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include +RUN cp -P /usr/lib/powerpc64le-linux-gnu/libcudnn* /usr/local/cuda/lib64 + +# Copy and run the install scripts. +COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa +RUN /install/install_deb_packages.sh +RUN apt-get update && apt-get install -y libopenblas-dev +RUN /install/install_hdf5_ppc64le.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel_from_source.sh +RUN /install/install_golang_ppc64le.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc +ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH + +# Configure the build for our CUDA configuration. +ENV TF_NEED_CUDA 1 +ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0 diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu index 3bc52b9ed611a0f0a4a269a2864d5b349ee9232c..7e5860aeec186d908e5d2884bd690b2e5e43cffa 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu @@ -1,4 +1,4 @@ -FROM launcher.gcr.io/google/rbe-debian8:r327695 +FROM launcher.gcr.io/google/rbe-ubuntu16-04:r327695 LABEL maintainer="Yu Yi " # Copy install scripts @@ -9,6 +9,6 @@ ENV CC /usr/local/bin/clang ENV CXX /usr/local/bin/clang++ ENV AR /usr/bin/ar -# Run pip install script for RBE Debian8 container. +# Run pip install script for RBE Ubuntu 16-04 container. RUN /install/install_pip_packages_remote.sh RUN /install/install_pip_packages.sh diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh index 5fa75e1d61cceeebfa77439bb64f1c644c9dba70..883bb9364742e74b4a5a7c8b0d41253352d6c2e7 100755 --- a/tensorflow/tools/ci_build/builds/pip.sh +++ b/tensorflow/tools/ci_build/builds/pip.sh @@ -322,6 +322,10 @@ create_activate_virtualenv_and_install_tensorflow() { pip install -v ${PIP_FLAGS} ${WHL_PATH} || \ die "pip install (forcing to reinstall tensorflow) FAILED" echo "Successfully installed pip package ${TF_WHEEL_PATH}" + + # Force downgrade setuptools. + pip install --upgrade setuptools==39.1.0 + } ################################################################################ diff --git a/tensorflow/tools/ci_build/builds/test_user_ops.sh b/tensorflow/tools/ci_build/builds/test_user_ops.sh index c342367bacea9d2ba8152d928b93bf61cf60d0e7..25ecee472524d5346252772b3058a5e824eef217 100755 --- a/tensorflow/tools/ci_build/builds/test_user_ops.sh +++ b/tensorflow/tools/ci_build/builds/test_user_ops.sh @@ -239,8 +239,9 @@ function run_op() { fi } -run_op $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.Session('').run(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT})))") -run_op $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT}))") " in eager mode" +run_op "$("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.Session('').run(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT})))")" +run_op "$("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT}).numpy())")" " in eager mode" + popd diff --git a/tensorflow/tools/ci_build/builds/with_the_same_user b/tensorflow/tools/ci_build/builds/with_the_same_user index d4bf546d401d058bd205a70c147615c8efc4f4ba..b216e3549f8ab7850c966e5a8e138f3b566f9952 100755 --- a/tensorflow/tools/ci_build/builds/with_the_same_user +++ b/tensorflow/tools/ci_build/builds/with_the_same_user @@ -40,7 +40,7 @@ if [ -n "${CI_BUILD_USER_FORCE_BADNAME}" ]; then ADDUSER_OPTS="--force-badname" fi -getent group "${CI_BUILD_GID}" || addgroup --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" +getent group "${CI_BUILD_GID}" || addgroup ${ADDUSER_OPTS} --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" getent passwd "${CI_BUILD_UID}" || adduser ${ADDUSER_OPTS} \ --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" \ --gecos "${CI_BUILD_USER} (generated by with_the_same_user script)" \ diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh index 072dd6ab995bb41c3197d6c898405be487534593..f6a50d3d4c4f948e37ff841a880b373f1034fd76 100755 --- a/tensorflow/tools/ci_build/ci_build.sh +++ b/tensorflow/tools/ci_build/ci_build.sh @@ -79,7 +79,7 @@ if [[ "${CONTAINER_TYPE}" == "cmake" ]]; then fi # Use nvidia-docker if the container is GPU. -if [[ "${CONTAINER_TYPE}" == "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" == gpu* ]]; then DOCKER_BINARY="nvidia-docker" else DOCKER_BINARY="docker" @@ -99,7 +99,7 @@ BUILD_TAG="${BUILD_TAG:-tf_ci}" # Add extra params for cuda devices and libraries for GPU container. # And clear them if we are not building for GPU. -if [[ "${CONTAINER_TYPE}" != "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" != gpu* ]]; then GPU_EXTRA_PARAMS="" fi @@ -134,6 +134,12 @@ if [[ $? != "0" ]]; then die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" fi +# If caller wants the with_the_same_user script to allow bad usernames, +# pass the var to the docker environment +if [ -n "${CI_BUILD_USER_FORCE_BADNAME}" ]; then + CI_BUILD_USER_FORCE_BADNAME_ENV="-e CI_BUILD_USER_FORCE_BADNAME=yes" +fi + # Run the command inside the container. echo "Running '${COMMAND[*]}' inside ${DOCKER_IMG_NAME}..." mkdir -p ${WORKSPACE}/bazel-ci_build-cache @@ -148,6 +154,7 @@ ${DOCKER_BINARY} run --rm --pid=host \ -e "CI_BUILD_GROUP=$(id -g -n)" \ -e "CI_BUILD_GID=$(id -g)" \ -e "CI_TENSORFLOW_SUBMODULE_PATH=${CI_TENSORFLOW_SUBMODULE_PATH}" \ + ${CI_BUILD_USER_FORCE_BADNAME_ENV} \ -v ${WORKSPACE}:/workspace \ -w /workspace \ ${GPU_EXTRA_PARAMS} \ diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 797e0a6db52aa6216486bdc9c6a88ff353c57e15..08e2c3edd2d22fbb7b9912c9ce7ec561dc5a7113 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -59,6 +59,9 @@ # TF_BUILD_BAZEL_CLEAN: # Will perform "bazel clean", if and only if this variable # is set to any non-empty and non-0 value +# TF_BAZEL_BUILD_ONLY: +# If it is set to any non-empty value that is not "0", Bazel +# will only build specified targets # TF_GPU_COUNT: # Run this many parallel tests for serial builds. # For now, only can be edited for PIP builds. @@ -94,10 +97,6 @@ # # This script can be used by Jenkins parameterized / matrix builds. -# TODO(jhseu): Temporary for the gRPC pull request due to the -# protobuf -> protobuf_archive rename. Remove later. -TF_BUILD_BAZEL_CLEAN=1 - # Helper function: Convert to lower case to_lower () { echo "$1" | tr '[:upper:]' '[:lower:]' @@ -132,7 +131,7 @@ BAZEL_CMD="bazel test" BAZEL_BUILD_ONLY_CMD="bazel build" BAZEL_CLEAN_CMD="bazel clean" -DEFAULT_BAZEL_CONFIGS="--config=gcp --config=hdfs" +DEFAULT_BAZEL_CONFIGS="" PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh" PIP_TEST_TUTORIALS_FLAG="--test_tutorials" @@ -168,7 +167,6 @@ else BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_test" BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test" BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:fully_connected_test" - # BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/testing:generated_examples_zip_test" BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:hashtable_lookup_test" BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:local_response_norm_test" BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lsh_projection_test" @@ -263,9 +261,9 @@ function set_script_variable() { # Process container type -if [[ ${CTYPE} == "cpu" ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then +if [[ ${CTYPE} == cpu* ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then : -elif [[ ${CTYPE} == "gpu" ]]; then +elif [[ ${CTYPE} == gpu* ]]; then set_script_variable TF_NEED_CUDA 1 if [[ $TF_CUDA_CLANG == "1" ]]; then @@ -415,6 +413,11 @@ fi # this flag, and it only affects a few tests. EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false" +if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] && + [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then + BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD} +fi + # Process PIP install-test option if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || [[ ${TF_BUILD_IS_PIP} == "both" ]]; then @@ -423,12 +426,12 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET} fi - if [[ ${CTYPE} == "cpu" ]] || \ + if [[ ${CTYPE} == cpu* ]] || \ [[ ${CTYPE} == "debian.jessie.cpu" ]]; then # CPU only command, fully parallel. NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} ${EXTRA_ARGS} -- "\ "${BAZEL_TARGET}" - elif [[ ${CTYPE} == "gpu" ]]; then + elif [[ ${CTYPE} == gpu* ]]; then # GPU only command, run as many jobs as the GPU count only. NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\ "--local_test_jobs=${TF_GPU_COUNT} "\ diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 8e8b2191e5c8f3fb6ada929cbc6b327fa0a67584..db37edf8097844646236aace5e3517a8080d70cb 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -100,9 +100,9 @@ do_pylint() { "^tensorflow/contrib/eager/python/evaluator\.py.*\[E0202.*method-hidden "\ "^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\ -"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\ -"^tensorflow/python/keras/_impl/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\ -"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\ +"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\ +"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\ +"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\ "^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned" echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\"" @@ -349,12 +349,12 @@ do_external_licenses_check(){ # Blacklist echo ${MISSING_LICENSES_FILE} - grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} # Whitelist echo ${EXTRA_LICENSE_FILE} - grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} @@ -543,7 +543,7 @@ SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "d SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency" "Check file names for cases") INCREMENTAL_FLAG="" -DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp" +DEFAULT_BAZEL_CONFIGS="" # Parse command-line arguments BAZEL_FLAGS=${DEFAULT_BAZEL_CONFIGS} diff --git a/tensorflow/tools/ci_build/copy_binary.py b/tensorflow/tools/ci_build/copy_binary.py index 420d390d2b9dc1ec25461b3502c63467a7eda16b..148526492d25e9acebe036294175e2814b2ead12 100755 --- a/tensorflow/tools/ci_build/copy_binary.py +++ b/tensorflow/tools/ci_build/copy_binary.py @@ -32,7 +32,8 @@ import shutil import tempfile import zipfile -TF_NIGHTLY_REGEX = r"(.+)tf_nightly(|_gpu)-(\d\.\d\.\d.dev[\d]{0,8})-(.+)\.whl" +TF_NIGHTLY_REGEX = (r"(.+)tf_nightly(|_gpu)-(\d\.[\d]{1,2}" + "\.\d.dev[\d]{0,8})-(.+)\.whl") BINARY_STRING_TEMPLATE = "%s-%s-%s.whl" diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..ddad00c5f01a78164903702b03c816c427aeb0b8 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script is to be used to install bzel on non x86_64 systems +# It will compile bazel from source and install it in /usr/local/bin + +# Select bazel version. +BAZEL_VERSION="0.11.0" + +set +e +local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') + +if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then + exit 0 +fi + +set -e + +# Compile bazel from source +mkdir -p /bazel +cd /bazel + +curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip +unzip bazel-$BAZEL_VERSION-dist.zip +bash ./compile.sh +cp output/bazel /usr/local/bin/ +rm -rf /bazel diff --git a/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..a93c258fad1ca62b0c95f22560110ba231aa0053 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +BUILDTOOLS_VERSION="0.11.1" + +# Clone buildtools +git clone -b $BUILDTOOLS_VERSION https://github.com/bazelbuild/buildtools +cd buildtools + +# Build buildifier +bazel build //buildifier +sudo mv bazel-bin/buildifier/linux*stripped/buildifier /usr/local/bin + +# Build buildozer +bazel build //buildozer +sudo mv bazel-bin/buildozer/linux*stripped/buildozer /usr/local/bin diff --git a/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..47d23a59b3ee9152ef9812fbe939e20ee7c2b40a --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +GOLANG_URL="https://storage.googleapis.com/golang/go1.10.linux-ppc64le.tar.gz" + +sudo mkdir -p /usr/local +wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz diff --git a/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..4989d986b8eb0690f63ecff41f7107371724bc3a --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +#This is required because pypi doesn't have a pre-built h5py binary for ppc64le +#It has to be compiled from source during the install +apt-get update +apt-get install -y libhdf5-dev + +#h5py is not expecting the shared libraries to have _serial in the name. +ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial.so /usr/lib/powerpc64le-linux-gnu/libhdf5.so +ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial_hl.so /usr/lib/powerpc64le-linux-gnu/libhdf5_hl.so + +#pip is not installed yet, so use easy_install +#CPATH is the location of hdf5.h +CPATH=/usr/include/hdf5/serial/ easy_install -U h5py +CPATH=/usr/include/hdf5/serial/ easy_install3 -U h5py diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 982161cefeefddce7705515b0771b53acfff2706..221b5b80fb48979af09cb99a5c35cbe5fc4e5ca1 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -51,8 +51,8 @@ pip2 install --upgrade markdown==2.6.8 pip3 install --upgrade markdown==2.6.8 # Install protobuf. -pip2 install --upgrade protobuf==3.3.0 -pip3 install --upgrade protobuf==3.3.0 +pip2 install --upgrade protobuf==3.6.0 +pip3 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -109,3 +109,17 @@ pip2 install --upgrade gast pip3 install --upgrade gast pip2 install --upgrade termcolor pip3 install --upgrade termcolor + +# Install last working version of setuptools. +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 + +# Keras +pip2 install keras_applications==1.0.2 +pip3 install keras_applications==1.0.2 +pip2 install keras_preprocessing==1.0.1 +pip3 install keras_preprocessing==1.0.1 + +# Install last working version of setuptools. +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 diff --git a/tensorflow/tools/ci_build/install/install_proto3.sh b/tensorflow/tools/ci_build/install/install_proto3.sh index 7934002b2c982cd10216016f8614b70b77b58e29..821d50baff325106fceca368d46042401d13c336 100755 --- a/tensorflow/tools/ci_build/install/install_proto3.sh +++ b/tensorflow/tools/ci_build/install/install_proto3.sh @@ -17,7 +17,7 @@ # Install protobuf3. # Select protobuf version. -PROTOBUF_VERSION="3.3.0" +PROTOBUF_VERSION="3.6.0" protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g') local_protobuf_ver=$(protoc --version) local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g') diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 204a82f647eed550a1ad14bd6fed4cd72b0f7dba..45a30c6e82c336a0171c7602e09f2184f1459175 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -39,7 +39,6 @@ if [[ -z $pip35_version ]]; then fi set -e -pip3.5 install --upgrade setuptools pip3.5 install --upgrade pip pip3.5 install --upgrade virtualenv @@ -49,7 +48,7 @@ pip3.5 install --upgrade absl-py pip3.5 install --upgrade six==1.10.0 # Install protobuf. -pip3.5 install --upgrade protobuf==3.3.0 +pip3.5 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -82,4 +81,14 @@ pip3.5 install --upgrade astor pip3.5 install --upgrade gast pip3.5 install --upgrade termcolor +# Install last working version of setuptools. +pip3.5 install --upgrade setuptools==39.1.0 + +# Keras +pip3.5 install keras_applications==1.0.2 +pip3.5 install keras_preprocessing==1.0.1 + +# Install last working version of setuptools. +pip3.5 install --upgrade setuptools==39.1.0 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index 275abeb669792423301f09339c786bd3869d6de9..d66b2aa18a7d77dd697031cfd2616712d586280a 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -49,7 +49,6 @@ cd Python-3.6.1 make altinstall ln -s /usr/local/bin/pip3.6 /usr/local/bin/pip3 -pip3 install --upgrade setuptools pip3 install --upgrade pip pip3 install --upgrade virtualenv @@ -61,7 +60,7 @@ pip3 install --upgrade absl-py pip3 install --upgrade six==1.10.0 # Install protobuf. -pip3 install --upgrade protobuf==3.3.0 +pip3 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -98,4 +97,11 @@ pip3 install --upgrade astor pip3 install --upgrade gast pip3 install --upgrade termcolor +# Install last working version of setuptools. +pip3 install --upgrade setuptools==39.1.0 + +# Keras +pip3.5 install keras_applications==1.0.2 +pip3.5 install keras_preprocessing==1.0.1 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh index 51e10f81f82da7920e9d219eaec3e1eb2973b998..8eeddcdb824e84c8c20e948488fd40f3b26fff01 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh @@ -34,5 +34,5 @@ yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc,java -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \ - --test_output=errors -- \ + --test_output=errors --test_size_filters=small,medium -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh index ea14848b1ae74ef0c42d14678fde225d465512bf..8eca1987f08491d92953971584ee612b07c13566 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh @@ -33,5 +33,5 @@ yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \ - --test_output=errors -- \ + --test_output=errors --test_size_filters=small,medium -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index 6d017c8a1f0232deab82278b26797a73b3a8ea9c..2b68de3c5b9bbb0c09ddead7466049827fac4147 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -33,7 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \ - --test_output=errors -- \ + --test_size_filters=small,medium --test_output=errors -- \ //tensorflow/contrib/... \ -//tensorflow/contrib/lite/... \ //tensorflow/contrib/lite:context_test \ @@ -52,7 +52,7 @@ bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ //tensorflow/contrib/lite/kernels:embedding_lookup_test \ //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test \ //tensorflow/contrib/lite/kernels:fully_connected_test \ - //tensorflow/contrib/lite/testing:generated_examples_zip_test \ + //tensorflow/contrib/lite/testing:generated_zip_tests \ //tensorflow/contrib/lite/kernels:hashtable_lookup_test \ //tensorflow/contrib/lite/kernels:local_response_norm_test \ //tensorflow/contrib/lite/kernels:lsh_projection_test \ diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh index a9accb9dd5b2d23e028a34ac3d99976d5f2f59db..51eb2cd7e67a4eb53cfc033e9543e13ceaf1b963 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh @@ -33,5 +33,5 @@ yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \ - --test_output=errors -- \ + --test_output=errors --test_size_filters=small,medium -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh index 02224d8e9d9efd92b5c1658118bd0c45bdf4f1db..9d2c8383fae5c1a65ff2bf16a496116c519f6ce4 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh @@ -37,5 +37,6 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \ + --test_size_filters=small,medium \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh new file mode 100755 index 0000000000000000000000000000000000000000..50ee07e727b309c1370edc993928d7165e9eb6cc --- /dev/null +++ b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export PYTHON_BIN_PATH=`which python2` + +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=9.0 +export TF_CUDNN_VERSION=7 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | $PYTHON_BIN_PATH configure.py + +# Run bazel test command. Double test timeouts to avoid flakes. +# Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution +# in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads +# caused by executing multiple tests concurrently. +bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test \ + --test_lang_filters=cc,py -k --jobs="${N_JOBS}" \ + --test_timeout 300,450,1200,3600 --build_tests_only --test_env=KMP_BLOCKTIME=0\ + --config=mkl --config=opt --test_output=errors --local_test_jobs=8 \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... + diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh index 0367a53d1459e7207a76c83e0c1e5c83580722a7..5b3383e1059a189c55ac2f2374a3160df572e5e8 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh @@ -37,5 +37,6 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \ + --test_size_filters=small,medium \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh index bf992cf63d27f0f169185a38fa33a01cd5375051..f958b3c9b75c302c9e0a0fb84dae9561e939ba73 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh @@ -21,7 +21,12 @@ # See libtensorflow_cpu.sh and libtensorflow_gpu.sh set -ex + +# Current script directory + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source "${SCRIPT_DIR}/../builds/builds_common.sh" DOCKER_CONTEXT_PATH="$(realpath ${SCRIPT_DIR}/..)" ROOT_DIR="$(realpath ${SCRIPT_DIR}/../../../../)" diff --git a/tensorflow/python/keras/applications/mobilenet/__init__.py b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh old mode 100644 new mode 100755 similarity index 53% rename from tensorflow/python/keras/applications/mobilenet/__init__.py rename to tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh index b809e91193b459a46906443796344c092e1d2a6b..68354bf7c1cd6717bd0e27dc872703bb723925c4 --- a/tensorflow/python/keras/applications/mobilenet/__init__.py +++ b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""MobileNet Keras application.""" +# +# Usage: basic_mkl_test.sh -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +# Helper function to traverse directories up until given file is found. +function upsearch () { + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" +} -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +# Set up WORKSPACE. +WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" -del absolute_import -del division -del print_function +BUILD_TAG=mkl-gpu-ci-test CI_BUILD_USER_FORCE_BADNAME=yes ${WORKSPACE}/tensorflow/tools/ci_build/ci_build.sh gpu tensorflow/tools/ci_build/linux/gpu/run_mkl.sh diff --git a/tensorflow/python/keras/applications/inception_v3/__init__.py b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh old mode 100644 new mode 100755 similarity index 53% rename from tensorflow/python/keras/applications/inception_v3/__init__.py rename to tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh index abf8393ae45d71dc0cb746706abb72f77b82d199..10a09a415a1fd5657efe1734ebf63b9cfc3dfc6e --- a/tensorflow/python/keras/applications/inception_v3/__init__.py +++ b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Inception V3 Keras application.""" +# +# Usage: basic_mkl_test.sh -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +# Helper function to traverse directories up until given file is found. +function upsearch () { + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" +} -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +# Set up WORKSPACE. +WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" -del absolute_import -del division -del print_function +BUILD_TAG=mkl-ci-test CI_BUILD_USER_FORCE_BADNAME=yes ${WORKSPACE}/tensorflow/tools/ci_build/ci_build.sh cpu tensorflow/tools/ci_build/linux/cpu/run_mkl.sh diff --git a/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh new file mode 100755 index 0000000000000000000000000000000000000000..ad22ebe4eb304fe6b6f8613f43f2c7c001111503 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Build a whl and container with Intel(R) MKL support +# Usage: build-dev-container.sh + +# Helper function to traverse directories up until given file is found. +function upsearch () { + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" +} + +# Set up WORKSPACE. +WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" + +TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH:-master} +TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME:-intel-mkl/tensorflow} +TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION:-nightly} + +echo "TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH}" +echo "TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}" +echo "TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}" + +# build the python 2 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + +# build the python 3 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index 1bd1852ffc570166ecc6efca1420bc54d702ed89..3d27e84b81c586729aff21d0859383c24f436a11 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -65,6 +65,10 @@ OPENBLAS_SRC_PATH=/tmp/openblas_src/ sudo rm -rf ${OPENBLAS_SRC_PATH} git clone https://github.com/xianyi/OpenBLAS ${OPENBLAS_SRC_PATH} cd ${OPENBLAS_SRC_PATH} +# The commit after this introduced Fortran compile issues. In theory they should +# be solvable using NOFORTRAN=1 on the make command, but my initial tries didn't +# work, so pinning to the last know good version. +git checkout 5a6a2bed9aff0ba8a18651d5514d029c8cae336a # If this path is changed, you'll also need to update # cxx_builtin_include_directory in third_party/toolchains/cpus/arm/CROSSTOOL.tpl OPENBLAS_INSTALL_PATH=/tmp/openblas_install/ @@ -79,6 +83,7 @@ if [[ $1 == "PI_ONE" ]]; then --linkopt=-L${OPENBLAS_INSTALL_PATH}/lib/ --linkopt=-l:libopenblas.a" echo "Building for the Pi One/Zero, with no NEON support" + WHEEL_ARCH=linux_armv6l else PI_COPTS='--copt=-march=armv7-a --copt=-mfpu=neon-vfpv4 --copt=-std=gnu11 --copt=-DS_IREAD=S_IRUSR --copt=-DS_IWRITE=S_IWUSR @@ -86,6 +91,7 @@ else --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8' + WHEEL_ARCH=linux_armv7l echo "Building for the Pi Two/Three, with NEON acceleration" fi @@ -100,6 +106,8 @@ bazel build -c opt ${PI_COPTS} \ --copt=-fomit-frame-pointer --cpu=armeabi \ --crosstool_top=@local_config_arm_compiler//:toolchain \ --verbose_failures \ + //tensorflow:libtensorflow.so \ + //tensorflow:libtensorflow_framework.so \ //tensorflow/tools/benchmark:benchmark_model \ //tensorflow/tools/pip_package:build_pip_package @@ -112,10 +120,12 @@ BDIST_OPTS="--universal" \ bazel-bin/tensorflow/tools/pip_package/build_pip_package "${OUTDIR}" OLD_FN=$(ls "${OUTDIR}" | grep -m 1 \.whl) -SUB='s/tensorflow-([^-]+)-([^-]+)-.*/tensorflow-\1-\2-none-any.whl/; print' +SUB='s/tensorflow-([^-]+)-([^-]+)-.*/tensorflow-\1-\2-none-'${WHEEL_ARCH}'.whl/; print' NEW_FN=$(echo "${OLD_FN}" | perl -ne "${SUB}") mv "${OUTDIR}/${OLD_FN}" "${OUTDIR}/${NEW_FN}" cp bazel-bin/tensorflow/tools/benchmark/benchmark_model "${OUTDIR}" +cp bazel-bin/tensorflow/libtensorflow.so "${OUTDIR}" +cp bazel-bin/tensorflow/libtensorflow_framework.so "${OUTDIR}" echo "Output can be found here:" find "${OUTDIR}" diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 00bfcfd49bd1d90dccf094de21173ca9e4307319..642dde36a7caae35df764d5d7513df972e1e5615 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -37,7 +37,7 @@ SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR README_MD = "./README.md" DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel" % TF_SRC_DIR GPU_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-gpu" % TF_SRC_DIR -CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-cpu-mkl" % TF_SRC_DIR +CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-mkl" % TF_SRC_DIR RELEVANT_FILES = [TF_SRC_DIR, VERSION_H, SETUP_PY, diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index 582188fc00b260926820a6add1331cf8fe0c8a9b..e10483e7fdc55926d678b157cffbd98b5d57def6 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -14,136 +14,33 @@ # limitations under the License. # ============================================================================== # -# C++ tests -failing_cpu_cc_tests="\ - //tensorflow/core/kernels:control_flow_ops_test + \ - //tensorflow/core:example_example_parser_configuration_test + \ - //tensorflow/core:lib_core_status_test + \ - //tensorflow/core:lib_monitoring_collection_registry_test + \ - //tensorflow/core:lib_strings_numbers_test + \ - //tensorflow/core/platform/hadoop:hadoop_file_system_test + \ - //tensorflow/core:platform_file_system_test + \ - //tensorflow/core:platform_logging_test + \ - //tensorflow/core:util_sparse_sparse_tensor_test + \ - //tensorflow/cc:framework_gradient_checker_test + \ - //tensorflow/cc:framework_gradients_test + \ - //tensorflow/cc:gradients_array_grad_test + \ - //tensorflow/cc:gradients_math_grad_test + \ - //tensorflow/cc:gradients_nn_grad_test + \ - //tensorflow/cc/saved_model:loader_test \ -" - -broken_cpu_cc_tests="\ - //tensorflow/cc:framework_cc_ops_test + \ - //tensorflow/core/platform/cloud:time_util_test + \ - //tensorflow/core/platform/cloud:oauth_client_test + \ - //tensorflow/core/platform/cloud:http_request_test + \ - //tensorflow/core/platform/cloud:google_auth_provider_test + \ - //tensorflow/core/platform/cloud:gcs_file_system_test + \ - //tensorflow/core/kernels/cloud:bigquery_table_accessor_test + \ - //tensorflow/core/kernels/hexagon:graph_transferer_test + \ - //tensorflow/core/kernels:remote_fused_graph_execute_utils_test + \ - //tensorflow/core/kernels:requantize_op_test + \ - //tensorflow/core/kernels:requantization_range_op_test + \ - //tensorflow/core/kernels:quantized_reshape_op_test + \ - //tensorflow/core/kernels:quantized_pooling_ops_test + \ - //tensorflow/core/kernels:quantized_matmul_op_test + \ - //tensorflow/core/kernels:quantized_conv_ops_test + \ - //tensorflow/core/kernels:quantized_concat_op_test + \ - //tensorflow/core/kernels:quantized_bias_add_op_test + \ - //tensorflow/core/kernels:quantized_batch_norm_op_test + \ - //tensorflow/core/kernels:quantized_activation_ops_test + \ - //tensorflow/core/kernels:quantize_op_test + \ - //tensorflow/core/kernels:quantize_down_and_shrink_range_op_test + \ - //tensorflow/core/kernels:quantize_and_dequantize_op_test_gpu + \ - //tensorflow/core/kernels:quantize_and_dequantize_op_test + \ - //tensorflow/core/kernels:quantization_utils_test + \ - //tensorflow/core/kernels:debug_ops_test + \ - //tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr_test_gpu + \ - //tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr_test + \ - //tensorflow/core/distributed_runtime/rpc:grpc_tensor_coding_test + \ - //tensorflow/core/distributed_runtime/rpc:grpc_session_test_gpu + \ - //tensorflow/core/distributed_runtime/rpc:grpc_session_test + \ - //tensorflow/core/distributed_runtime/rpc:grpc_channel_test_gpu + \ - //tensorflow/core/distributed_runtime/rpc:grpc_channel_test + \ - //tensorflow/core/distributed_runtime:remote_device_test_gpu + \ - //tensorflow/core/distributed_runtime:remote_device_test + \ - //tensorflow/core/distributed_runtime:executor_test_gpu + \ - //tensorflow/core/distributed_runtime:executor_test + \ - //tensorflow/core/debug:debug_gateway_test + \ - //tensorflow/core/debug:debug_grpc_io_utils_test + \ - //tensorflow/core:util_reporter_test + \ - //tensorflow/core:util_memmapped_file_system_test + \ - //tensorflow/core:platform_subprocess_test + \ - //tensorflow/core:platform_profile_utils_cpu_utils_test + \ - //tensorflow/core:lib_jpeg_jpeg_mem_unittest + \ - //tensorflow/core/debug:debug_io_utils_test \ -" - -# lib_core_threadpool_test is timeout, but it passes when running alone -extra_failing_gpu_cc_tests="\ - //tensorflow/core:lib_core_threadpool_test + \ - //tensorflow/core:cuda_libdevice_path_test + \ - //tensorflow/core:common_runtime_direct_session_test + \ - //tensorflow/core:common_runtime_direct_session_with_tracking_alloc_test + \ - //tensorflow/core:device_tracer_test + \ - //tensorflow/core:ops_math_grad_test \ -" - -exclude_cpu_cc_tests="${failing_cpu_cc_tests} + ${broken_cpu_cc_tests}" - -exclude_gpu_cc_tests="${extra_failing_gpu_cc_tests} + ${exclude_cpu_cc_tests}" function run_configure_for_cpu_build { - # Due to a bug in Bazel: https://github.com/bazelbuild/bazel/issues/2182 - # yes "" | ./configure doesn't work on Windows, so we set all the - # environment variables in advance to avoid interact with the script. - export TF_NEED_CUDA=0 - if [ -z "$TF_ENABLE_XLA" ]; then - export TF_ENABLE_XLA=0 - fi - if [ -z "$TF_NEED_MKL" ]; then - export TF_NEED_MKL=0 - fi - export TF_NEED_VERBS=0 - export TF_NEED_GCP=1 - export TF_NEED_HDFS=0 - export TF_NEED_OPENCL_SYCL=0 - echo "" | ./configure + yes "" | ./configure } function run_configure_for_gpu_build { - # Due to a bug in Bazel: https://github.com/bazelbuild/bazel/issues/2182 - # yes "" | ./configure doesn't work on Windows, so we set all the - # environment variables in advance to avoid interact with the script. + # Enable CUDA support export TF_NEED_CUDA=1 - export TF_CUDA_VERSION=9.0 - export CUDA_TOOLKIT_PATH="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0" - export TF_CUDNN_VERSION=7.0 - if [ -z "$CUDNN_INSTALL_PATH" ]; then - export CUDNN_INSTALL_PATH="C:/tools/cuda" - fi - export TF_CUDA_COMPUTE_CAPABILITIES="3.7" - if [ -z "$TF_ENABLE_XLA" ]; then - export TF_ENABLE_XLA=0 - fi - export TF_NEED_VERBS=0 - export TF_NEED_MKL=0 - export TF_NEED_GCP=0 - export TF_NEED_HDFS=0 - export TF_NEED_OPENCL_SYCL=0 # TODO(pcloudy): Remove this after TensorFlow uses its own CRSOOTOOL # for GPU build on Windows export USE_MSVC_WRAPPER=1 - echo "" | ./configure + yes "" | ./configure } -function set_gcs_remote_cache_options { - echo "build --experimental_remote_spawn_cache" >> "${TMP_BAZELRC}" +function set_remote_cache_options { + echo "build --remote_instance_name=projects/tensorflow-testing-cpu" >> "${TMP_BAZELRC}" echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}" - echo "build --remote_http_cache=https://storage.googleapis.com/$GCS_BUCKET_NAME" >> "${TMP_BAZELRC}" + echo "build --remote_cache=remotebuildexecution.googleapis.com" >> "${TMP_BAZELRC}" + echo "build --tls_enabled=true" >> "${TMP_BAZELRC}" + echo "build --remote_timeout=3600" >> "${TMP_BAZELRC}" + echo "build --auth_enabled=true" >> "${TMP_BAZELRC}" + echo "build --spawn_strategy=remote" >> "${TMP_BAZELRC}" + echo "build --strategy=Javac=remote" >> "${TMP_BAZELRC}" + echo "build --strategy=Closure=remote" >> "${TMP_BAZELRC}" + echo "build --genrule_strategy=remote" >> "${TMP_BAZELRC}" echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}" } diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index 0e6c0227b7ffb6b35193e133aa7d3fbcd16ce3c4..8a237e4e28376771742ba93b795950d368660196 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -50,7 +50,14 @@ export PATH="/c/Program Files/Git/cmd:$PATH" # Make sure we have pip in PATH export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH" +# Setting default values to CUDA related environment variables +export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0} +export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0} +export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7} +export CUDA_INSTALL_PATH=${CUDA_INSTALL_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"} +export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"} + # Add Cuda and Cudnn dll directories into PATH -export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH" -export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH" -export PATH="/c/tools/cuda/bin:$PATH" +export PATH="$(cygpath -u "${CUDA_INSTALL_PATH}")/bin:$PATH" +export PATH="$(cygpath -u "${CUDA_INSTALL_PATH}")/extras/CUPTI/libx64:$PATH" +export PATH="$(cygpath -u "${CUDNN_INSTALL_PATH}")/bin:$PATH" diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index a2300811bb93b9e9d96b9db314943ab08870fcb3..ed7340146789078bf12fc3bbfba46fb0f740ba54 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -54,24 +54,39 @@ function cleanup { trap cleanup EXIT skip_test=0 +release_build=0 for ARG in "$@"; do if [[ "$ARG" == --skip_test ]]; then skip_test=1 - elif [[ "$ARG" == --enable_gcs_remote_cache ]]; then - set_gcs_remote_cache_options + elif [[ "$ARG" == --enable_remote_cache ]]; then + set_remote_cache_options + elif [[ "$ARG" == --release_build ]]; then + release_build=1 fi done -# --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc -# by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521 -echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" +if [[ "$release_build" != 1 ]]; then + # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc + # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521 + # Because this hurts the performance of TF, we don't enable it in release build. + echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" +fi + +# The host and target platforms are the same in Windows build. So we don't have +# to distinct them. This helps avoid building the same targets twice. +echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" -echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +# Enable short object file path to avoid long path issue on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi run_configure_for_cpu_build -bazel build --announce_rc -c opt tensorflow/tools/pip_package:build_pip_package || exit $? +bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $? if [[ "$skip_test" == 1 ]]; then exit 0 @@ -92,7 +107,7 @@ N_JOBS="${NUMBER_OF_PROCESSORS}" # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow -bazel test -c opt -k --test_output=errors \ +bazel test --announce_rc --config=opt -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ --test_tag_filters=-no_pip,-no_windows,-no_oss \ --build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \ diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index 4656afe0256d03540fed6912677c8e93f9cf9eb6..cec5b717f8ad07c0090ee424f3ae47e60df34a5a 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -30,7 +30,6 @@ IF DEFINED SWIG_EXE (ECHO SWIG_EXE is set to %SWIG_EXE%) ELSE (SET SWIG_EXE="C:\ IF DEFINED PY_EXE (ECHO PY_EXE is set to %PY_EXE%) ELSE (SET PY_EXE="C:\Program Files\Anaconda3\python.exe") IF DEFINED PY_LIB (ECHO PY_LIB is set to %PY_LIB%) ELSE (SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib") IF DEFINED CUDNN_HOME (ECHO CUDNN_HOME is set to %CUDNN_HOME%) ELSE (SET CUDNN_HOME="c:\tools\cuda") -verbosity:quiet IF DEFINED DISABLE_FORCEINLINE (ECHO DISABLE_FORCEINLINE is set to %DISABLE_FORCEINLINE%) ELSE (SET DISABLE_FORCEINLINE="OFF") SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh index 922bb67bbf6ce34f55acad6d3399bd810032abd0..fe3bce428fb2feb053cb1b8c097f707dd2762a20 100644 --- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh @@ -42,9 +42,58 @@ source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \ source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \ || { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; } +# Recreate an empty bazelrc file under source root +export TMP_BAZELRC=.tmp.bazelrc +rm -f "${TMP_BAZELRC}" +touch "${TMP_BAZELRC}" + +function cleanup { + # Remove all options in .tmp.bazelrc + echo "" > "${TMP_BAZELRC}" +} +trap cleanup EXIT + +skip_test=0 +release_build=0 + +for ARG in "$@"; do + if [[ "$ARG" == --skip_test ]]; then + skip_test=1 + elif [[ "$ARG" == --enable_remote_cache ]]; then + set_remote_cache_options + elif [[ "$ARG" == --release_build ]]; then + release_build=1 + fi +done + +if [[ "$release_build" != 1 ]]; then + # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc + # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521 + # Because this hurts the performance of TF, we don't enable it in release build. + echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" +fi + +# The host and target platforms are the same in Windows build. So we don't have +# to distinct them. This helps avoid building the same targets twice. +echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" + +# Enable short object file path to avoid long path issue on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +# Disable nvcc warnings to reduce log file size. +echo "build --copt=-nvcc_options=disable-warnings" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi + run_configure_for_gpu_build -bazel build -c opt tensorflow/tools/pip_package:build_pip_package || exit $? +bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $? + +if [[ "$skip_test" == 1 ]]; then + exit 0 +fi # Create a python test directory to avoid package name conflict PY_TEST_DIR="py_test_dir" @@ -59,8 +108,11 @@ reinstall_tensorflow_pip ${PIP_NAME} # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow # GPU tests are very flaky when running concurrently, so set local_test_jobs=1 -bazel test -c opt -k --test_output=errors \ +bazel test --announce_rc --config=opt -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ - --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \ - --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \ - --local_test_jobs=1 --build_tests_only //${PY_TEST_DIR}/tensorflow/python/... + --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \ + --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \ + --local_test_jobs=1 --test_timeout="300,450,1200,3600" \ + --flaky_test_attempts=3 \ + //${PY_TEST_DIR}/tensorflow/python/... \ + //${PY_TEST_DIR}/tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index 583d1d5f09527861015458c636af2259b34d45f8..fdbd1120b20ea4461a4ec5f84c666d8b62309905 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -41,7 +41,7 @@ run_configure_for_cpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt --copt=/arch:AVX \ +bazel --output_user_root=${TMPDIR} build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh index a410c10b61b9f3f2cf8fc00074237a2bcfcbbf78..d085e21b0305d1c4266db0adb97f586286a12735 100755 --- a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh +++ b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh @@ -37,6 +37,7 @@ bazel clean # Run bazel test command. Double test timeouts to avoid flakes. bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test,-no_oss -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --test_size_filters=small,medium \ --build_tests_only --test_output=errors --local_test_jobs=8 \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ --config=xla -- \ diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index b9032c046e93527fd0f41f183e49e4933029ec62..8c01d15a8060040825fff381367208ff1e322b20 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -40,7 +40,24 @@ py_test( srcs = ["traverse_test.py"], srcs_version = "PY2AND3", deps = [ + ":test_module1", + ":test_module2", ":traverse", "//tensorflow/python:platform_test", ], ) + +py_library( + name = "test_module1", + srcs = ["test_module1.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_module2", + ], +) + +py_library( + name = "test_module2", + srcs = ["test_module2.py"], + srcs_version = "PY2AND3", +) diff --git a/tensorflow/python/keras/datasets/cifar100/__init__.py b/tensorflow/tools/common/test_module1.py similarity index 70% rename from tensorflow/python/keras/datasets/cifar100/__init__.py rename to tensorflow/tools/common/test_module1.py index ca93742673341660ba69712feb59c5dd32ea3252..cc185cf36e2616e3d1ba46d356a57bc184a35ae4 100644 --- a/tensorflow/python/keras/datasets/cifar100/__init__.py +++ b/tensorflow/tools/common/test_module1.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""CIFAR100 small image classification dataset.""" +"""A module target for TraverseTest.test_module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.tools.common import test_module2 + + +class ModuleClass1(object): + + def __init__(self): + self._m2 = test_module2.ModuleClass2() + + def __model_class1_method__(self): + pass -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py b/tensorflow/tools/common/test_module2.py similarity index 83% rename from tensorflow/python/keras/datasets/fashion_mnist/__init__.py rename to tensorflow/tools/common/test_module2.py index 7f5ddecc4707334d52ebf4966f2ec6141cce0d46..d9da99d9c0f8141ad35b9d0d2e6c830b4c3828a8 100644 --- a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py +++ b/tensorflow/tools/common/test_module2.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Fashion-MNIST dataset.""" +"""A module target for TraverseTest.test_module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.fashion_mnist import load_data -del absolute_import -del division -del print_function +class ModuleClass2(object): + + def __init__(self): + pass + + def __model_class1_method__(self): + pass + diff --git a/tensorflow/tools/common/traverse_test.py b/tensorflow/tools/common/traverse_test.py index eb195ec18efbc77e7b6edceff6970aa98f683948..ed410694ce1c86af8f9483575fbc12c8ddde312d 100644 --- a/tensorflow/tools/common/traverse_test.py +++ b/tensorflow/tools/common/traverse_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - from tensorflow.python.platform import googletest +from tensorflow.tools.common import test_module1 +from tensorflow.tools.common import test_module2 from tensorflow.tools.common import traverse @@ -30,10 +30,6 @@ class TestVisitor(object): self.call_log = [] def __call__(self, path, parent, children): - # Do not traverse googletest, it's very deep. - for item in list(children): - if item[1] is googletest: - children.remove(item) self.call_log += [(path, parent, children)] @@ -51,13 +47,12 @@ class TraverseTest(googletest.TestCase): def test_module(self): visitor = TestVisitor() - traverse.traverse(sys.modules[__name__], visitor) + traverse.traverse(test_module1, visitor) called = [parent for _, parent, _ in visitor.call_log] - self.assertIn(TestVisitor, called) - self.assertIn(TraverseTest, called) - self.assertIn(traverse, called) + self.assertIn(test_module1.ModuleClass1, called) + self.assertIn(test_module2.ModuleClass2, called) def test_class(self): visitor = TestVisitor() diff --git a/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl index 47539b2423e602bb9771541ae5b01ba76c79f56f..f8f63e276cab61900cba9de599a11efc7718d078 100644 --- a/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl +++ b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl @@ -31,7 +31,11 @@ def _def_file_filter_configure_impl(repository_ctx): vc_path = find_vc_path(repository_ctx) if vc_path == "visual-studio-not-found": auto_configure_fail("Visual C++ build tools not found on your machine") - undname_bin_path = find_msvc_tool(repository_ctx, vc_path, "undname.exe").replace("\\", "\\\\") + + undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe") + if undname == None: + auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path) + undname_bin_path = undname.replace("\\", "\\\\") repository_ctx.template( "def_file_filter.py", diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh index 225c0347416ec8c8fef855946d18e838bd767690..345217d733acec62c599dd6dfeffd4839e5a79bc 100755 --- a/tensorflow/tools/dist_test/build_server.sh +++ b/tensorflow/tools/dist_test/build_server.sh @@ -23,7 +23,7 @@ # E.g.: tensorflow/tf_grpc_test_server:0.11.0rc1 # # whl_file_location: URL from which the TensorFlow whl file will be downloaded. -# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl +# E.g.: https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0-cp27-none-linux_x86_64.whl # E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl # # The optional flag --test lets the script to use the Dockerfile for the diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh index caae7fd5305af9846628eaf00348dd08df4e827f..b0114721bd2435dd2d4b8ee667250d3b824f1207 100755 --- a/tensorflow/tools/dist_test/local_test.sh +++ b/tensorflow/tools/dist_test/local_test.sh @@ -35,7 +35,7 @@ # # Arguments: # whl_file_location: URL from which the TensorFlow whl file will be acquired. -# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl +# E.g.: https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0-cp27-none-linux_x86_64.whl # E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl # # --leave_container_running: Do not stop the docker-in-docker container after @@ -64,9 +64,6 @@ die() { # Configurations DOCKER_IMG_NAME="tensorflow/tf-dist-test-local-cluster" -# Use TensorFlow v1.5.0 for Python 2.7 and CPU only as we set num_gpus to 0 in the below -DEFAULT_WHL_FILE_LOCATION="https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0-cp27-none-linux_x86_64.whl" - # Parse input arguments LEAVE_CONTAINER_RUNNING=0 MODEL_NAME="" @@ -77,8 +74,7 @@ SYNC_REPLICAS_FLAG="" WHL_FILE_LOCATION=${1} if [[ -z "${WHL_FILE_LOCATION}" ]]; then - WHL_FILE_LOCATION=${DEFAULT_WHL_FILE_LOCATION} - echo "use default whl file location" + echo "WARNING: No wheel url passed. Will use latest tf-nightly cpu p2 wheel." fi while true; do @@ -131,7 +127,11 @@ echo "Building in temporary directory: ${BUILD_DIR}" cp -r ${DIR}/* "${BUILD_DIR}"/ || \ die "Failed to copy files to ${BUILD_DIR}" -if [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then +# Download whl file into the build context directory. +if [[ -z "${WHL_FILE_LOCATION}" ]]; then + pip2 download --no-deps tf-nightly + cp tf-nightly-*.whl "${BUILD_DIR}"/tensorflow-none-any.whl +elif [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then # Download whl file into the build context directory. wget -P "${BUILD_DIR}" "${WHL_FILE_LOCATION}" || \ die "Failed to download tensorflow whl file from URL: ${WHL_FILE_LOCATION}" diff --git a/tensorflow/tools/dist_test/remote_test.sh b/tensorflow/tools/dist_test/remote_test.sh index 935535312d326a70aaf949332435569567cb1ca7..e188c88c8fa725daa619e244072fdb58765ea0a0 100755 --- a/tensorflow/tools/dist_test/remote_test.sh +++ b/tensorflow/tools/dist_test/remote_test.sh @@ -108,7 +108,7 @@ fi # Parse command-line arguments. WHL_URL=${1} if [[ -z "${WHL_URL}" ]]; then - die "whl URL is not specified" + echo "WARNING: No wheel url passed. Will use latest tf-nightly cpu p2 wheel." fi # Create docker build context directory. @@ -121,8 +121,13 @@ cp -r ${DIR}/* ${BUILD_DIR}/ || \ die "Failed to copy files to ${BUILD_DIR}" # Download whl file into the build context directory. -wget -P "${BUILD_DIR}" ${WHL_URL} || \ - die "Failed to download tensorflow whl file from URL: ${WHL_URL}" +if [[ -z "${WHL_URL}" ]]; then + pip2 download --no-deps tf-nightly + cp tf-nightly-*.whl "${BUILD_DIR}"/tensorflow-none-any.whl +else + wget -P "${BUILD_DIR}" ${WHL_URL} || \ + die "Failed to download tensorflow whl file from URL: ${WHL_URL}" +fi # Build docker image for test. docker build ${NO_CACHE_FLAG} \ diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 406d134699ff182dde219c137f79a27094b09169..57a491255ea968b08e6e9cbaf9dd0178e8d2c3bf 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -76,7 +76,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.8 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git . # TODO(craigcitro): Don't install the pip package, since it makes it # more difficult to experiment with local changes. Instead, just add diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl deleted file mode 100644 index a6cd44ced1d546846f274ef79aad75bcf950fd03..0000000000000000000000000000000000000000 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ /dev/null @@ -1,83 +0,0 @@ -FROM tensorflow/tensorflow:latest-devel - -LABEL maintainer="Clayne Robison" - -# These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.8 -ARG WHL_DIR=/whl - -RUN apt-get update && apt-get install -y --no-install-recommends \ - golang \ - vim \ - emacs \ - && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -RUN pip --no-cache-dir install --upgrade \ - pip setuptools - -RUN pip --no-cache-dir install wheel - -# Download and build TensorFlow. -WORKDIR / -RUN rm -rf tensorflow && \ - git clone https://github.com/tensorflow/tensorflow.git && \ - cd tensorflow && \ - git checkout ${TF_BRANCH} -WORKDIR /tensorflow - -# Configure the build for CPU with MKL by accepting default build options and -# setting library locations -ENV CI_BUILD_PYTHON=python \ - LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - PYTHON_BIN_PATH=/usr/bin/python \ - PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \ - CC_OPT_FLAGS='-march=native' \ - TF_NEED_JEMALLOC=0 \ - TF_NEED_GCP=1 \ - TF_NEED_CUDA=0 \ - TF_NEED_HDFS=0 \ - TF_NEED_S3=1 \ - TF_NEED_OPENCL=0 \ - TF_NEED_GDR=0 \ - TF_ENABLE_XLA=0 \ - TF_NEED_VERBS=0 \ - TF_NEED_MPI=0 -RUN ./configure - -# Build and Install TensorFlow. -# The 'mkl' option builds with Intel(R) Math Kernel Library (MKL), which detects -# the platform it is currently running on and takes appropriately optimized -# paths. The -march=native option is for code that is not in MKL, and assumes -# this container will be run on the same architecture on which it is built. -RUN LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - bazel build --config=mkl \ - --config="opt" \ - --copt="-march=broadwell" \ - --copt="-O3" \ - //tensorflow/tools/pip_package:build_pip_package && \ - mkdir ${WHL_DIR} && \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package ${WHL_DIR} - -# Clean up Bazel cache when done, but leave the whl. -# This will upgrade the default Tensorflow version with the Intel MKL version -RUN pip --no-cache-dir install --upgrade ${WHL_DIR}/tensorflow-*.whl && \ - rm -rf /root/.cache - -WORKDIR /root - -#add welcome message with instructions - -RUN echo '[ ! -z "$TERM" -a -r /etc/motd ] && cat /etc/issue && cat /etc/motd' \ - >> /etc/bash.bashrc \ - ; echo "\ -||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\ -| \n\ -| Docker container running Ubuntu \n\ -| with TensorFlow ${TF_BRANCH} optimized for CPU \n\ -| with Intel(R) MKL \n\ -| \n\ -||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\ -\n "\ - > /etc/motd diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 2fe47f3356ce26da4174b95d59dce1889d3ec90c..204b5b4dba1b607fb709b7f45d145ceafc33f3e7 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -13,8 +13,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cuda-cusparse-dev-9-0 \ curl \ git \ - libcudnn7=7.0.5.15-1+cuda9.0 \ - libcudnn7-dev=7.0.5.15-1+cuda9.0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libcudnn7-dev=7.1.4.18-1+cuda9.0 \ libcurl3-dev \ libfreetype6-dev \ libhdf5-serial-dev \ @@ -85,7 +85,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.8 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git . # Configure the build for our CUDA configuration. ENV CI_BUILD_PYTHON python diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl new file mode 100755 index 0000000000000000000000000000000000000000..6dca0e393fa8d61ec819a5f9b5a2e5ffd3c7be92 --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.devel-mkl @@ -0,0 +1,128 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Clayne Robison " + +# These parameters can be overridden by parameterized_docker_build.sh +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON3_DEV="" +ARG WHL_DIR="/tmp/pip" +ARG PIP="pip" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + ${PYTHON3_DEV} \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ + ${PYTHON} get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + && \ + ${PYTHON} -m ipykernel.kernelspec + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# Set up Bazel. + +# Running bazel inside a `docker build` command causes trouble, cf: +# https://github.com/bazelbuild/bazel/issues/134 +# The easiest solution is to set up a bazelrc file forcing --batch. +RUN echo "startup --batch" >>/etc/bazel.bazelrc +# Similarly, we need to workaround sandboxing issues: +# https://github.com/bazelbuild/bazel/issues/418 +RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ + >>/etc/bazel.bazelrc +# Install the most recent bazel release. +ENV BAZEL_VERSION 0.11.0 +WORKDIR / +RUN mkdir /bazel && \ + cd /bazel && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ + chmod +x bazel-*.sh && \ + ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + cd / && \ + rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Download and build TensorFlow. +WORKDIR /tensorflow + +# Download and build TensorFlow. +# Enable checking out both tags and branches +RUN export TAG_PREFIX="v" && \ + echo ${TF_BUILD_VERSION} | grep -q ^${TAG_PREFIX}; \ + if [ $? -eq 0 ]; then \ + git clone --depth=1 https://github.com/tensorflow/tensorflow.git . && \ + git fetch --tags && \ + git checkout ${TF_BUILD_VERSION}; \ + else \ + git clone --depth=1 --branch=${TF_BUILD_VERSION} https://github.com/tensorflow/tensorflow.git . ; \ + fi + +RUN yes "" | ${PYTHON} configure.py + +ENV CI_BUILD_PYTHON ${PYTHON} + +# Set bazel build parameters in .bazelrc in parameterized_docker_build.sh +# Use --copt=-march values to get optimized builds appropriate for the hardware +# platform of your choice. +# For ivy-bridge or sandy-bridge +# --copt=-march="avx" \ +# For haswell, broadwell, or skylake +# --copt=-march="avx2" \ +COPY .bazelrc /root/.bazelrc + +RUN tensorflow/tools/ci_build/builds/configured CPU \ + bazel --bazelrc=/root/.bazelrc build -c opt \ + tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package "${WHL_DIR}" && \ + ${PIP} --no-cache-dir install --upgrade "${WHL_DIR}"/tensorflow-*.whl && \ + rm -rf /root/.cache +# Clean up Bazel cache when done. + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR /root diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index bff4a20392076994c75705b73c25dcb740ba1f09..9197651ff4326e9b40264183a94b82e936746010 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cuda-cusolver-9-0 \ cuda-cusparse-9-0 \ curl \ - libcudnn7=7.0.5.15-1+cuda9.0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ libfreetype6-dev \ libhdf5-serial-dev \ libpng12-dev \ diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl new file mode 100755 index 0000000000000000000000000000000000000000..139395d49102fe2de3e241936095613da3f21bf8 --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.mkl @@ -0,0 +1,75 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Clayne Robison " + +# This parameter MUST be set by parameterized_docker_build.sh +ARG TF_WHL_URL + +# Optional parameters +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON_DEV="python-dev" +ARG PIP="pip" + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python \ + ${PYTHON_DEV} \ + rsync \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ + python get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + matplotlib \ + numpy \ + pandas \ + scipy \ + sklearn \ + && \ + python -m ipykernel.kernelspec + +COPY ${TF_WHL_URL} / +RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \ + rm -rf /${TF_WHL_URL} + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Copy sample notebooks. +COPY notebooks /notebooks + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR "/notebooks" + +CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh index 05de25f2cb11d76f223a31bc12329e6ab7368e8a..4681c5fd61158e0be998d72bb4329f204808eda7 100755 --- a/tensorflow/tools/docker/parameterized_docker_build.sh +++ b/tensorflow/tools/docker/parameterized_docker_build.sh @@ -19,8 +19,8 @@ # parameterized_docker_build.sh # # The script obeys the following environment variables: -# TF_DOCKER_BUILD_TYPE: (CPU | GPU) -# CPU or GPU image +# TF_DOCKER_BUILD_TYPE: (CPU | GPU | MKL) +# CPU, GPU, or MKL image # # TF_DOCKER_BUILD_IS_DEVEL: (NO | YES) # Is this developer image @@ -87,6 +87,15 @@ # TF_DOCKER_BUILD_OPTIONS # (Optional) # Specifies the desired build options. Defaults to OPT. +# +# TF_DOCKER_BUILD_ARGS +# (Optional) +# A list (array) of docker build args. Will be passed to docker build +# command as list of --build-arg parameters. +# +# TF_BAZEL_BUILD_OPTIONS +# (Optional) +# Bazel compiler flags to be passed to the bazelrc file # Script directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -116,6 +125,8 @@ echo " TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}" echo " TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}" echo " TF_DOCKER_BUILD_PORT=${TF_DOCKER_BUILD_PORT}" echo " TF_DOCKER_BUILD_PUSH_CMD=${TF_DOCKER_BUILD_PUSH_CMD}" +echo " TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]:-()}" +echo " TF_BAZEL_BUILD_OPTIONS=${TF_BAZEL_BUILD_OPTIONS}" CONTAINER_PORT=${TF_DOCKER_BUILD_PORT:-8888} @@ -149,6 +160,15 @@ fi if [[ ${TF_DOCKER_BUILD_TYPE} == "cpu" ]]; then DOCKER_BINARY="docker" +elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + DOCKER_BINARY="docker" + FINAL_TAG="${FINAL_TAG}-mkl" + if [[ ${ORIG_DOCKERFILE} == *"."* ]]; then + # There is already a dot in the tag, use "-" + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}-mkl" + else + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl" + fi elif [[ ${TF_DOCKER_BUILD_TYPE} == "gpu" ]]; then DOCKER_BINARY="nvidia-docker" @@ -203,6 +223,10 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then export TF_BUILD_OPTIONS=${TF_DOCKER_BUILD_OPTIONS} export TF_BUILD_IS_PIP="PIP" + if [[ "${TF_DOCKER_BUILD_TYPE}" == "mkl" ]]; then + die "FAIL: Non-development MKL builds require a pre-built pip whl." + fi + if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\ "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2" @@ -255,25 +279,39 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Use string replacement to put the correct file name into the Dockerfile PIP_WHL=$(basename "${PIP_WHL}") - # Modify the non-devel Dockerfile to point to the correct pip whl file - # location - sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}" ) + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the non-devel Dockerfile to point to the correct pip whl file + # location + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ "/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ "COPY ${PIP_WHL} /\n"\ "RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \ - > "${DOCKERFILE}" + > "${DOCKERFILE}" + fi echo "Using local pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" echo - else echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" - echo - - # Modify the non-devel Dockerfile to point to the correct pip whl URL. - sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + pushd "${TMP_DIR}/" + curl -O ${TF_DOCKER_BUILD_CENTRAL_PIP} + popd + PIP_WHL_PATH=`find ${TMP_DIR} -name "*.whl"` + PIP_WHL=$(basename "${PIP_WHL_PATH}") + echo "PIP_WHL= ${PIP_WHL}" + echo + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the non-devel Dockerfile to point to the correct pip whl URL. + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ "/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ "RUN pip --no-cache-dir install ${TF_DOCKER_BUILD_CENTRAL_PIP}" "${ORIG_DOCKERFILE}" \ - > "${DOCKERFILE}" + > "${DOCKERFILE}" + fi fi echo "Modified Dockerfile at: ${DOCKERFILE}" @@ -281,36 +319,66 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ - sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ - sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" - then - echo "Modified Dockerfile for python version "\ -"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}") + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON_DEV=python3-dev") + TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" else - die "FAILED to modify ${DOCKERFILE} for python3" + if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ + sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ + sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile for python version "\ + "${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi -else +else # TF_DOCKER_BUILD_IS_DEVEL == 'yes' DOCKERFILE="${TMP_DIR}/Dockerfile" - # Modify the devel Dockerfile to specify the git branch - sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \ - "${ORIG_DOCKERFILE}" > "${DOCKERFILE}" + # Set up Dockerfile ARGS for mkl build + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + if [[ -z "${TF_BAZEL_BUILD_OPTIONS// }" ]]; then + TF_BAZEL_BUILD_OPTIONS=("--config=mkl --copt=-mavx --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0") + else + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" + fi + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_BUILD_VERSION=${TF_DOCKER_BUILD_DEVEL_BRANCH}") + echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}" + + # Pass the build options to bazel using the user-specific .bazelrc file + echo "build ${TF_BAZEL_BUILD_OPTIONS}" >> ${TMP_DIR}/.bazelrc + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the devel Dockerfile to specify the git branch + sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \ + "${ORIG_DOCKERFILE}" > "${DOCKERFILE}" + fi # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ - sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ - sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" - then - echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}") + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON3_DEV=python3-dev") + TF_DOCKER_BUILD_ARGS+=("--build-arg WHL_DIR=/tmp/pip3") + TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" else - die "FAILED to modify ${DOCKERFILE} for python3" + if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ + sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ + sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ + sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ + sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi fi @@ -319,8 +387,11 @@ fi # Intermediate image name with tag IMG="${USER}/tensorflow:${FINAL_TAG}" echo "Building docker image with image name and tag: ${IMG}" +echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}" +CMD="${DOCKER_BINARY} build ${TF_DOCKER_BUILD_ARGS[@]} --no-cache --pull -t ${IMG} -f ${DOCKERFILE} ${TMP_DIR}" +echo "CMD=${CMD}" +${CMD} -"${DOCKER_BINARY}" build --no-cache --pull -t "${IMG}" -f "${DOCKERFILE}" "${TMP_DIR}" if [[ $? == "0" ]]; then echo "${DOCKER_BINARY} build of ${IMG} succeeded" else @@ -340,7 +411,7 @@ fi DOCKER_RUN_LOG="${TMP_DIR}/docker_run.log" echo "" echo "Running docker container from image ${IMG}..." -echo " (Log file is at: ${DOCKER_RUN_LOG}" +echo " Log file is at: ${DOCKER_RUN_LOG}" echo "" if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then @@ -386,7 +457,6 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Stop the running docker container sleep 1 "${DOCKER_BINARY}" stop --time=0 ${CONTAINER_ID} - fi diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 58b5ef8345c9de83e2d50cd01fe11e11f51fe298..2403e2d966929b86976bf6a31f8144d9b4f58bc6 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -37,7 +37,11 @@ py_library( srcs = ["parser.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["@astor_archive//:astor"], + deps = [ + "//tensorflow/python:platform", + "//tensorflow/python:util", + "@astor_archive//:astor", + ], ) py_test( @@ -92,6 +96,7 @@ py_binary( deps = [ ":generate_lib", "//tensorflow:tensorflow_py", + "//tensorflow/python:util", "//tensorflow/python/debug:debug_py", ], ) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 111d54d8205f805cc24d21c610acc81610b8d47d..e7634cd5dcf19d5f21b0bd42b282dfe928659a52 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -21,6 +21,7 @@ from __future__ import print_function import argparse import fnmatch import os +import shutil import six @@ -50,7 +51,11 @@ def _is_free_function(py_object, full_name, index): return True -def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'): +def write_docs(output_dir, + parser_config, + yaml_toc, + root_title='TensorFlow', + search_hints=True): """Write previously extracted docs to disk. Write a docs page for each symbol included in the indices of parser_config to @@ -66,6 +71,8 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'): indices. yaml_toc: Set to `True` to generate a "_toc.yaml" file. root_title: The title name for the root level index.md. + search_hints: (bool) include meta-data search hints at the top of each + output file. Raises: ValueError: if `output_dir` is not an absolute path @@ -75,12 +82,8 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'): raise ValueError("'output_dir' must be an absolute path.\n" " output_dir='%s'" % output_dir) - try: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (output_dir, e)) - raise + if not os.path.exists(output_dir): + os.makedirs(output_dir) # These dictionaries are used for table-of-contents generation below # They will contain, after the for-loop below:: @@ -123,8 +126,6 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'): module_children.setdefault(subname, []).append(full_name) break - print('Writing docs for %s (%r).' % (full_name, py_object)) - # Generate docs for `py_object`, resolving references. page_info = parser.docs_for_object(full_name, py_object, parser_config) @@ -134,15 +135,20 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'): if not os.path.exists(directory): os.makedirs(directory) # This function returns raw bytes in PY2 or unicode in PY3. - text = pretty_docs.build_md_page(page_info) + if search_hints: + content = [page_info.get_metadata_html()] + else: + content = [''] + + content.append(pretty_docs.build_md_page(page_info)) + text = '\n'.join(content) if six.PY3: text = text.encode('utf-8') with open(path, 'wb') as f: f.write(text) - except OSError as e: - print('Cannot write documentation for %s to %s: %s' % (full_name, - directory, e)) - raise + except OSError: + raise OSError( + 'Cannot write documentation for %s to %s' % (full_name, directory)) if yaml_toc: # Generate table of contents @@ -382,16 +388,40 @@ def _build_guide_index(guide_src_dir): class _UpdateTags(py_guide_parser.PyGuideParser): - """Rewrites a Python guide so that each section has an explicit tag.""" + """Rewrites a Python guide so that each section has an explicit id tag. + + "section" here refers to blocks delimited by second level headings. + """ def process_section(self, line_number, section_title, tag): self.replace_line(line_number, '

%s

' % (tag, section_title)) +def update_id_tags_inplace(src_dir): + """Set explicit ids on all second-level headings to ensure back-links work. + + Args: + src_dir: The directory of md-files to convert (inplace). + """ + tag_updater = _UpdateTags() + + for dirpath, _, filenames in os.walk(src_dir): + for base_name in filenames: + if not base_name.endswith('.md'): + continue + full_path = os.path.join(src_dir, dirpath, base_name) + + # Tag updater loads the file, makes the replacements, and returns the + # modified file contents + content = tag_updater.process(full_path) + with open(full_path, 'w') as f: + f.write(content) + + EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt']) -def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): +def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): """Fix @{} references in all files under `src_dir` matching `file_pattern`. A matching directory structure, with the modified files is @@ -412,7 +442,6 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): using fnmatch. Non-matching files are copied unchanged. """ # Iterate through all the source files and process them. - tag_updater = _UpdateTags() for dirpath, _, filenames in os.walk(src_dir): # How to get from `dirpath` to api_docs/python/ relative_path_to_root = os.path.relpath( @@ -421,41 +450,32 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): # Make the directory under output_dir. new_dir = os.path.join(output_dir, os.path.relpath(path=dirpath, start=src_dir)) - try: - if not os.path.exists(new_dir): - os.makedirs(new_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (new_dir, e)) - raise + if not os.path.exists(new_dir): + os.makedirs(new_dir) for base_name in filenames: if base_name in EXCLUDED: - print('Skipping excluded file %s...' % base_name) continue full_in_path = os.path.join(dirpath, base_name) + # Set the `current_doc_full_name` so bad files can be reported on errors. reference_resolver.current_doc_full_name = full_in_path suffix = os.path.relpath(path=full_in_path, start=src_dir) full_out_path = os.path.join(output_dir, suffix) + # Copy files that do not match the file_pattern, unmodified. if not fnmatch.fnmatch(base_name, file_pattern): - print('Copying un-matched file %s...' % suffix) - open(full_out_path, 'wb').write(open(full_in_path, 'rb').read()) + shutil.copyfile(full_in_path, full_out_path) continue - if dirpath.endswith('/api_guides/python'): - print('Processing Python guide %s...' % base_name) - content = tag_updater.process(full_in_path) - else: - print('Processing doc %s...' % suffix) - content = open(full_in_path, 'rb').read().decode('utf-8') + + with open(full_in_path, 'rb') as f: + content = f.read().decode('utf-8') content = reference_resolver.replace_references(content, relative_path_to_root) with open(full_out_path, 'wb') as f: f.write(content.encode('utf-8')) - print('Done.') - class DocGenerator(object): """Main entry point for generating docs.""" @@ -467,6 +487,12 @@ class DocGenerator(object): self._do_not_descend_map = _get_default_do_not_descend_map() self.yaml_toc = True + self.argument_parser.add_argument( + '--no_search_hints', + dest='search_hints', + action='store_false', + default=True) + def add_output_dir_argument(self): self.argument_parser.add_argument( '--output_dir', @@ -536,15 +562,43 @@ class DocGenerator(object): self._do_not_descend_map) def build(self, flags): - """Actually build the docs.""" + """Build all the docs. + + This produces two outputs + + python api docs: + + * generated from modules set with `set_py_modules`. + * written to '{FLAGS.output_dir}/api_docs/python/' + + non-api docs: + + * Everything in '{FLAGS.src_dir}' is copied to '{FLAGS.output_dir}'. + * '@{}' references in '.md' files are replaced with links. + * '.md' files under 'api_guides/python' have explicit ids set for their + second level headings. + + Args: + flags: + * src_dir: Where to fetch the non-api-docs. + * base_dir: Base of the docs directory (Used to build correct + relative links). + * output_dir: Where to write the resulting docs. + + Returns: + The number of errors encountered while processing. + """ + # Extract the python api from the _py_modules doc_index = build_doc_index(flags.src_dir) visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + # Build the guide_index for the api_docs back links. root_title = getattr(flags, 'root_title', 'TensorFlow') guide_index = _build_guide_index( os.path.join(flags.src_dir, 'api_guides/python')) + # Write the api docs. parser_config = self.make_parser_config(visitor, reference_resolver, guide_index, flags.base_dir) output_dir = os.path.join(flags.output_dir, 'api_docs/python') @@ -553,9 +607,18 @@ class DocGenerator(object): output_dir, parser_config, yaml_toc=self.yaml_toc, - root_title=root_title) - _other_docs(flags.src_dir, flags.output_dir, reference_resolver) - + root_title=root_title, + search_hints=getattr(flags, 'search_hints', True)) + + # Replace all the @{} references in files under `FLAGS.src_dir` + replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md') + # Fix the tags in the guide dir. + guide_dir = os.path.join(flags.output_dir, 'api_guides/python') + if os.path.exists(guide_dir): + update_id_tags_inplace(guide_dir) + + # Report all errors found by the reference resolver, and return the error + # code. parser_config.reference_resolver.log_errors() return parser_config.reference_resolver.num_errors() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index ea6d28a02b1f3c07fe8783fd59e345dade1fc804..7a6f9fd9f799db5a14015d77e5297955c76a51cd 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -51,7 +51,9 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_write(self): + def get_test_objects(self): + # These are all mutable objects, so rebuild them for each test. + # Don't cache the objects. module = sys.modules[__name__] index = { @@ -98,6 +100,11 @@ class GenerateTest(googletest.TestCase): guide_index={}, base_dir=base_dir) + return reference_resolver, parser_config + + def test_write(self): + _, parser_config = self.get_test_objects() + output_dir = googletest.GetTempDir() generate_lib.write_docs(output_dir, parser_config, yaml_toc=True) @@ -127,6 +134,107 @@ class GenerateTest(googletest.TestCase): os.path.exists( os.path.join(output_dir, 'tf/TestModule/test_function.md'))) + def test_update_id_tags_inplace(self): + test_dir = googletest.GetTempDir() + test_sub_dir = os.path.join(test_dir, 'a/b') + os.makedirs(test_sub_dir) + + test_path1 = os.path.join(test_dir, 'file1.md') + test_path2 = os.path.join(test_sub_dir, 'file2.md') + test_path3 = os.path.join(test_sub_dir, 'file3.notmd') + + with open(test_path1, 'w') as f: + f.write('## abc&123') + + with open(test_path2, 'w') as f: + f.write('# A Level 1 Heading\n') + f.write('## A Level 2 Heading') + + with open(test_path3, 'w') as f: + f.write("## don\'t change this") + + generate_lib.update_id_tags_inplace(test_dir) + + with open(test_path1) as f: + content = f.read() + + self.assertEqual(content, '

abc&123

') + + with open(test_path2) as f: + content = f.read() + + self.assertEqual( + content, '# A Level 1 Heading\n' + '

A Level 2 Heading

') + + with open(test_path3) as f: + content = f.read() + + self.assertEqual(content, "## don\'t change this") + + def test_replace_refes(self): + test_dir = googletest.GetTempDir() + test_in_dir = os.path.join(test_dir, 'in') + test_in_dir_a = os.path.join(test_dir, 'in/a') + test_in_dir_b = os.path.join(test_dir, 'in/b') + os.makedirs(test_in_dir) + os.makedirs(test_in_dir_a) + os.makedirs(test_in_dir_b) + + test_out_dir = os.path.join(test_dir, 'out') + os.makedirs(test_out_dir) + + test_path1 = os.path.join(test_in_dir_a, 'file1.md') + test_path2 = os.path.join(test_in_dir_b, 'file2.md') + test_path3 = os.path.join(test_in_dir_b, 'file3.notmd') + test_path4 = os.path.join(test_in_dir_b, 'OWNERS') + + with open(test_path1, 'w') as f: + f.write('Use `tf.test_function` to test things.') + + with open(test_path2, 'w') as f: + f.write('Use @{tf.TestModule.TestClass.ChildClass} to test things.\n' + "`tf.whatever` doesn't exist") + + with open(test_path3, 'w') as f: + file3_content = ( + 'Not a .md file. Should be copied unchanged:' + '@{tf.TestModule.TestClass.ChildClass}, `tf.test_function`') + f.write(file3_content) + + with open(test_path4, 'w') as f: + f.write('') + + reference_resolver, _ = self.get_test_objects() + generate_lib.replace_refs(test_in_dir, test_out_dir, reference_resolver, + '*.md') + + with open(os.path.join(test_out_dir, 'a/file1.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + 'tf.test_function to test things.') + + with open(os.path.join(test_out_dir, 'b/file2.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + '' + 'tf.TestModule.TestClass.ChildClass ' + 'to test things.\n' + '`tf.whatever` doesn\'t exist') + + with open(os.path.join(test_out_dir, 'b/file3.notmd')) as f: + content = f.read() + self.assertEqual(content, file3_content) + + with self.assertRaises(IOError): + # This should fail. The OWNERS file should not be copied + with open(os.path.join(test_out_dir, 'b/OWNERS')) as f: + content = f.read() + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index fb0bd2c2ff438aa9b3fa04719c447a2f3a91a95e..ffb93027ed48dd2106c702758917c0846f20cb1c 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -21,15 +21,16 @@ from __future__ import print_function import ast import collections import functools +import itertools import json import os import re -import sys import astor import six from google.protobuf.message import Message as ProtoMessage +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect @@ -52,7 +53,7 @@ class _Errors(object): template = 'ERROR:\n output file name: %s\n %s\n\n' for full_name, message in self._errors: - print(template % (full_name, message), file=sys.stderr) + logging.warn(template, full_name, message) def append(self, full_name, message): """Add an error to the collection. @@ -614,6 +615,9 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver): docstring, compatibility = _handle_compatibility(raw_docstring) docstring, function_details = _parse_function_details(docstring) + if 'Generated by: tensorflow/tools/api/generator' in docstring: + docstring = '' + return _DocstringInfo( docstring.split('\n')[0], docstring, function_details, compatibility) @@ -757,8 +761,9 @@ def _generate_signature(func, reverse_index): lookup_text = public_name + default_text[len(internal_name):] break if default_text is lookup_text: - print('WARNING: Using default arg, failed lookup: %s, repr: %r' % - (default_text, default)) + logging.warn( + 'WARNING: Using default arg, failed lookup: %s, repr: %r', + default_text, default) else: default_text = lookup_text else: @@ -906,6 +911,9 @@ class _FunctionPageInfo(object): def add_decorator(self, dec): self._decorators.append(dec) + def get_metadata_html(self): + return _Metadata(self.full_name).build_html() + class _ClassPageInfo(object): """Collects docs for a class page. @@ -1099,6 +1107,14 @@ class _ClassPageInfo(object): """Returns a list of `_LinkInfo` pointing to any nested classes.""" return self._classes + def get_metadata_html(self): + meta_data = _Metadata(self.full_name) + for item in itertools.chain(self.classes, self.properties, self.methods, + self.other_members): + meta_data.append(item) + + return meta_data.build_html() + def _add_class(self, short_name, full_name, obj, doc, url): """Adds a `_LinkInfo` for a nested class to `classes` list. @@ -1150,7 +1166,7 @@ class _ClassPageInfo(object): if short_name in [ '__class__', '__base__', '__weakref__', '__doc__', '__module__', '__dict__', '__abstractmethods__', '__slots__', '__getnewargs__', - '__str__', '__repr__', '__hash__' + '__str__', '__repr__', '__hash__', '__reduce__' ]: continue @@ -1198,8 +1214,6 @@ class _ClassPageInfo(object): if not child_doc.brief.strip() and short_name in [ '__del__', '__copy__' ]: - print('Skipping %s, defined in %s, no docstring.' % (child_name, - defining_class)) continue try: @@ -1330,6 +1344,16 @@ class _ModulePageInfo(object): self._other_members.append( _OtherMemberInfo(short_name, full_name, obj, doc)) + def get_metadata_html(self): + meta_data = _Metadata(self.full_name) + + # Objects with their own pages are not added to the matadata list for the + # module, the module only has a link to the object page. No docs. + for item in self.other_members: + meta_data.append(item) + + return meta_data.build_html() + def collect_docs_for_module(self, parser_config): """Collect information necessary specifically for a module's doc page. @@ -1346,7 +1370,8 @@ class _ModulePageInfo(object): for name in member_names: if name in ['__builtins__', '__doc__', '__file__', - '__name__', '__path__', '__package__']: + '__name__', '__path__', '__package__', + '__cached__', '__loader__', '__spec__']: continue member_full_name = self.full_name + '.' + name if self.full_name else name @@ -1575,7 +1600,8 @@ class _GeneratedFile(object): return True def __str__(self): - return 'Defined in `%s%s`.\n\n' % (self.path_prefix, self.path) + return 'Defined in generated file: `%s%s`.\n\n' % (self.path_prefix, + self.path) def _get_defined_in(py_object, parser_config): @@ -1612,6 +1638,8 @@ def _get_defined_in(py_object, parser_config): if re.match(r'.*/gen_[^/]*\.py$', path): return _GeneratedFile(path, parser_config) + if 'genfiles' in path or 'tools/api/generator' in path: + return _GeneratedFile(path, parser_config) elif re.match(r'.*_pb2\.py$', path): # The _pb2.py files all appear right next to their defining .proto file. return _ProtoFile(path[:-7] + '.proto', parser_config) @@ -1656,3 +1684,41 @@ def generate_global_index(library_name, index, reference_resolver): # TODO(markdaoust): use a _ModulePageInfo -> prety_docs.build_md_page() return '\n'.join(lines) + + +class _Metadata(object): + """A class for building a page's Metadata block. + + Attributes: + name: The name of the page being described by the Metadata block. + """ + + def __init__(self, name): + """Creates a Metadata builder. + + Args: + name: The name of the page being described by the Metadata block. + """ + self.name = name + self._content = [] + + def append(self, item): + """Adds an item from the page to the Metadata block. + + Args: + item: The parsed page section to add. + """ + self._content.append(item.short_name) + + def build_html(self): + """Returns the Metadata block as an Html string.""" + schema = 'http://developers.google.com/ReferenceObject' + parts = ['
' % schema] + + parts.append('' % self.name) + for item in self._content: + parts.append('' % item) + + parts.extend(['
', '']) + + return '\n'.join(parts) diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index 55ab5bdd49a427e680221f4864b3f31a65b12e8d..63d4fef91cc752b8fa053b92c833349ca3bc8f19 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -27,7 +27,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import textwrap @@ -58,8 +57,7 @@ def build_md_page(page_info): def _build_function_page(page_info): """Given a FunctionPageInfo object Return the page as an md string.""" - parts = [_Metadata(page_info.full_name).build_html()] - parts.append('# %s\n\n' % page_info.full_name) + parts = ['# %s\n\n' % page_info.full_name] if len(page_info.aliases) > 1: parts.append('### Aliases:\n\n') @@ -83,17 +81,7 @@ def _build_function_page(page_info): def _build_class_page(page_info): """Given a ClassPageInfo object Return the page as an md string.""" - meta_data = _Metadata(page_info.full_name) - for item in itertools.chain( - page_info.classes, - page_info.properties, - page_info.methods, - page_info.other_members): - meta_data.append(item) - - parts = [meta_data.build_html()] - - parts.append('# {page_info.full_name}\n\n'.format(page_info=page_info)) + parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)] parts.append('## Class `%s`\n\n' % page_info.full_name.split('.')[-1]) if page_info.bases: @@ -186,17 +174,7 @@ def _build_class_page(page_info): def _build_module_page(page_info): """Given a ClassPageInfo object Return the page as an md string.""" - meta_data = _Metadata(page_info.full_name) - - # Objects with their own pages are not added to the matadata list for the - # module, as the only thing on the module page is a link to the object's page. - for item in page_info.other_members: - meta_data.append(item) - - parts = [meta_data.build_html()] - - parts.append( - '# Module: {full_name}\n\n'.format(full_name=page_info.full_name)) + parts = ['# Module: {full_name}\n\n'.format(full_name=page_info.full_name)] if len(page_info.aliases) > 1: parts.append('### Aliases:\n\n') @@ -317,41 +295,3 @@ def _build_function_details(function_details): parts.append(''.join(sub)) return '\n'.join(parts) - - -class _Metadata(object): - """A class for building a page's Metadata block. - - Attributes: - name: The name of the page being described by the Metadata block. - """ - - def __init__(self, name): - """Create a Metadata builder. - - Args: - name: The name of the page being described by the Metadata block. - """ - self.name = name - self._content = [] - - def append(self, item): - """Add an item from the page to the Metadata block. - - Args: - item: The parsed page section to add. - """ - self._content.append(item.short_name) - - def build_html(self): - """Return the Metadata block as an Html string.""" - schema = 'http://developers.google.com/ReferenceObject' - parts = ['
' % schema] - - parts.append('' % self.name) - for item in self._content: - parts.append('' % item) - - parts.extend(['
', '', '']) - - return '\n'.join(parts) diff --git a/tensorflow/tools/docs/py_guide_parser.py b/tensorflow/tools/docs/py_guide_parser.py index 328f42d18f1efb0fd82725a4683abad2df0d5a19..b00694dc40322161f180410630bb4dcfd8c2fb18 100644 --- a/tensorflow/tools/docs/py_guide_parser.py +++ b/tensorflow/tools/docs/py_guide_parser.py @@ -44,7 +44,8 @@ class PyGuideParser(object): def process(self, full_path): """Read and process the file at `full_path`.""" - md_string = open(full_path, 'rb').read().decode('utf-8') + with open(full_path, 'rb') as f: + md_string = f.read().decode('utf-8') self._lines = md_string.split('\n') seen = set() diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py index 73dee98bae8946b747e1b28bd14b0a26edc62736..cc2288a7fa9202efcd077e54b941cc278b25993c 100755 --- a/tensorflow/tools/git/gen_git_source.py +++ b/tensorflow/tools/git/gen_git_source.py @@ -164,14 +164,17 @@ def get_git_version(git_base_path, git_tag_override): "git", str("--git-dir=%s/.git" % git_base_path), str("--work-tree=" + git_base_path), "describe", "--long", "--tags" ]).strip()) - if git_tag_override: + if git_tag_override and val: split_val = val.split("-") - if len(split_val) != 3: + if len(split_val) < 3: raise Exception( ("Expected git version in format 'TAG-COMMITS AFTER TAG-HASH' " "but got '%s'") % val) - split_val[0] = git_tag_override - val = bytes("-".join(split_val)) + # There might be "-" in the tag name. But we can be sure that the final + # two "-" are those inserted by the git describe command. + abbrev_commit = split_val[-1] + val = bytes( + "-".join([git_tag_override, "0", abbrev_commit])) return val if val else unknown_label except (subprocess.CalledProcessError, OSError): return unknown_label diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 85660f94a85dce29360525f7bb7474494b3f010f..f85841187670fef0fdc9237886237f84057d6bd5 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -117,6 +117,31 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def, return Status::OK(); } +Status RewriteInputsAsPlaceholders(const TransformFuncContext& context, + GraphDef* graph_def) { + std::unordered_set input_names; + for (const string& input_name : context.input_names) { + input_names.insert(ParseTensorName(input_name).first.ToString()); + } + + for (NodeDef& node : *graph_def->mutable_node()) { + if (input_names.find(node.name()) == input_names.end()) { + continue; + } + if (node.op() == "PlaceholderWithDefault") { + node.set_op("Placeholder"); + node.clear_input(); + } else if (node.op() != "Placeholder") { + return errors::InvalidArgument( + "Input '", node.name(), + "' was expected to be a Placeholder or PlaceholderWithDefault op, " + "but was ", + node.op()); + } + } + return Status::OK(); +} + Status RemoveUnusedNodes(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { @@ -165,6 +190,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def, input_graph_def, [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; }, output_graph_def); + TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def)); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index a082399a87dbaad913be421fe273ba89b6f7340e..dcdc3c29069c212c499aa21e420b47f239ce62f2 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -330,48 +330,6 @@ class ConstantFoldingTest : public ::testing::Test { EXPECT_EQ(0, node_map.count("unused")); } - void TestRemoveUnusedNodesMultipleOutputs() { - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - auto root = tensorflow::Scope::NewRootScope(); - - // a b - // \ / - // shape_n - // \ / - // c - auto a = Placeholder(root.WithOpName("a"), DT_FLOAT); - auto b = Placeholder(root.WithOpName("b"), DT_FLOAT); - auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)}); - auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]); - - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - GraphDef result_graph_def; - TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( - graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def)); - - // Only one output of shape_n node is fed input. Hence the graph search - // should propagate to inputs of shape_n. Nothing to remove here. - std::map node_map; - graph_transforms::MapNamesToNodes(result_graph_def, &node_map); - EXPECT_EQ(1, node_map.count("a")); - EXPECT_EQ(1, node_map.count("b")); - EXPECT_EQ(1, node_map.count("c")); - - result_graph_def.Clear(); - TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( - graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}}, - &result_graph_def)); - - // Both outputs of shape_n node are fed inputs. shape_n does not function - // and inputs to shape_n should be removed. - node_map.clear(); - graph_transforms::MapNamesToNodes(result_graph_def, &node_map); - EXPECT_EQ(0, node_map.count("a")); - EXPECT_EQ(0, node_map.count("b")); - EXPECT_EQ(1, node_map.count("c")); - } - void TestMaxConstantSizeInBytes() { auto root = tensorflow::Scope::NewRootScope(); @@ -431,10 +389,6 @@ TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) { TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); } -TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) { - TestRemoveUnusedNodesMultipleOutputs(); -} - TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) { TestMaxConstantSizeInBytes(); } diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 7651a03fe51012678d6d6fc495fd82e497aa512b..435f46c107cd9b0a6d64d4c0d52607ec5f41eb4f 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -191,7 +191,7 @@ class FoldOldBatchNormsTest : public ::testing::Test { std::vector fused_outputs; TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); - test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 2e-5); for (const NodeDef& node : fused_graph_def.node()) { EXPECT_NE("FusedBatchNorm", node.op()); diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 77f83b77a0214110e520c85d15ffa38bce65955f..173f418dc8d998bc51d208a04c8671bacf364cdc 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -115,6 +115,7 @@ genrule( "//third_party/fft2d:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", @@ -130,7 +131,7 @@ genrule( "@highwayhash//:LICENSE", "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@libxsmm_archive//:LICENSE", + "@libxsmm_archive//:LICENSE.md", "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", @@ -156,6 +157,7 @@ genrule( "//third_party/fft2d:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", @@ -168,7 +170,7 @@ genrule( "@highwayhash//:LICENSE", "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@libxsmm_archive//:LICENSE", + "@libxsmm_archive//:LICENSE.md", "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 677ea65edd91df9eef2347ab305f47a05f6cedaa..c9d53f46c3cff9eceb6eb03a872d05e8afd06047 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -57,14 +57,18 @@ COMMON_PIP_DEPS = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/autograph:autograph", "//tensorflow/contrib/autograph/converters:converters", - "//tensorflow/contrib/autograph/converters:test_lib", + "//tensorflow/contrib/autograph/core:core", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl:impl", + "//tensorflow/contrib/autograph/lang:lang", + "//tensorflow/contrib/autograph/operators:operators", "//tensorflow/contrib/autograph/pyct:pyct", "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis", + "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers", "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/contrib/data/python/ops:contrib_op_loader", "//tensorflow/contrib/eager/python/examples:examples_pip", "//tensorflow/contrib/eager/python:evaluator", @@ -90,6 +94,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/contrib/timeseries:timeseries_pip", "//tensorflow/contrib/tpu", "//tensorflow/examples/tutorials/mnist:package", + "//tensorflow/python:cond_v2", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python:spectral_ops_test_util", @@ -125,6 +130,8 @@ filegroup( "@astor_archive//:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googleapis_googleapis//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_google_absl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", @@ -142,7 +149,7 @@ filegroup( "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", "@kafka//:LICENSE", - "@libxsmm_archive//:LICENSE", + "@libxsmm_archive//:LICENSE.md", "@lmdb//:LICENSE", "@local_config_nccl//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", @@ -173,9 +180,7 @@ sh_binary( "//conditions:default": COMMON_PIP_DEPS + [ ":simple_console", "//tensorflow/contrib/lite/python:interpreter_test_data", - "//tensorflow/contrib/lite/python:tf_lite_py_pip", - "//tensorflow/contrib/lite/toco:toco", - "//tensorflow/contrib/lite/toco/python:toco_wrapper", + "//tensorflow/contrib/lite/python:tflite_convert", "//tensorflow/contrib/lite/toco/python:toco_from_protos", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([ diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 1a83c6e7578fed88f0bd7db5a5b620a5281fd95a..9e41514cfa1a70d649eab6fd23a599db4afae2a8 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -24,9 +24,15 @@ function real_path() { function cp_external() { local src_dir=$1 local dest_dir=$2 - for f in `find "$src_dir" -maxdepth 1 -mindepth 1 ! -name '*local_config_cuda*' ! -name '*local_config_tensorrt*' ! -name '*org_tensorflow*'`; do - cp -R "$f" "$dest_dir" + + pushd . + cd "$src_dir" + for f in `find . ! -type d ! -name '*.py' ! -name '*local_config_cuda*' ! -name '*local_config_tensorrt*' ! -name '*org_tensorflow*'`; do + mkdir -p "${dest_dir}/$(dirname ${f})" + cp "${f}" "${dest_dir}/$(dirname ${f})/" done + popd + mkdir -p "${dest_dir}/local_config_cuda/cuda/cuda/" cp "${src_dir}/local_config_cuda/cuda/cuda/cuda_config.h" "${dest_dir}/local_config_cuda/cuda/cuda/" } @@ -41,51 +47,17 @@ function is_windows() { fi } -function main() { +function prepare_src() { if [ $# -lt 1 ] ; then echo "No destination dir provided" exit 1 fi - DEST=$(real_path $1) - TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) + TMPDIR="$1" + mkdir -p "$TMPDIR" + EXTERNAL_INCLUDES="${TMPDIR}/tensorflow/include/external" - PKG_NAME_FLAG="" - GPU_BUILD=0 - NIGHTLY_BUILD=0 - PROJECT_NAME="" - while true; do - if [[ "$1" == "--nightly_flag" ]]; then - NIGHTLY_BUILD=1 - elif [[ "$1" == "--gpu" ]]; then - GPU_BUILD=1 - elif [[ "$1" == "--gpudirect" ]]; then - PKG_NAME_FLAG="--project_name tensorflow_gpudirect" - elif [[ "$1" == "--project_name" ]]; then - shift - if [[ -z "$1" ]]; then - break - fi - PROJECT_NAME="$1" - fi - shift - - if [[ -z "$1" ]]; then - break - fi - done - - if [[ -n ${PROJECT_NAME} ]]; then - PKG_NAME_FLAG="--project_name ${PROJECT_NAME}" - elif [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then - PKG_NAME_FLAG="--project_name tf_nightly_gpu" - elif [[ ${NIGHTLY_BUILD} == "1" ]]; then - PKG_NAME_FLAG="--project_name tf_nightly" - elif [[ ${GPU_BUILD} == "1" ]]; then - PKG_NAME_FLAG="--project_name tensorflow_gpu" - fi - - echo $(date) : "=== Using tmpdir: ${TMPDIR}" + echo $(date) : "=== Preparing sources in dir: ${TMPDIR}" if [ ! -d bazel-bin/tensorflow ]; then echo "Could not find bazel-bin. Did you run from the root of the build tree?" @@ -102,10 +74,9 @@ function main() { cp -R \ bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" - mkdir "${TMPDIR}/external" cp_external \ bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles \ - "${TMPDIR}/external" + "${EXTERNAL_INCLUDES}/" RUNFILES=bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow else RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow @@ -114,10 +85,9 @@ function main() { cp -R \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" - mkdir "${TMPDIR}/external" cp_external \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \ - "${TMPDIR}/external" + "${EXTERNAL_INCLUDES}" # Copy MKL libs over so they can be loaded at runtime so_lib_dir=$(ls $RUNFILES | grep solib) || true if [ -n "${so_lib_dir}" ]; then @@ -132,10 +102,9 @@ function main() { cp -R \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" - mkdir "${TMPDIR}/external" cp_external \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \ - "${TMPDIR}/external" + "${EXTERNAL_INCLUDES}" # Copy MKL libs over so they can be loaded at runtime so_lib_dir=$(ls $RUNFILES | grep solib) || true if [ -n "${so_lib_dir}" ]; then @@ -148,26 +117,35 @@ function main() { fi mkdir "${TMPDIR}/tensorflow/aux-bin" # Install toco as a binary in aux-bin. - # TODO(aselle): Re-enable this when we find a way to do it without doubling - # the whl size (over the limit). - # cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/ + cp bazel-bin/tensorflow/contrib/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/ fi # protobuf pip package doesn't ship with header files. Copy the headers # over so user defined ops can be compiled. mkdir -p ${TMPDIR}/google mkdir -p ${TMPDIR}/third_party - pushd ${RUNFILES%org_tensorflow} + pushd ${RUNFILES%org_tensorflow} > /dev/null for header in $(find protobuf_archive -name \*.h); do mkdir -p "${TMPDIR}/google/$(dirname ${header})" cp "$header" "${TMPDIR}/google/$(dirname ${header})/" done - popd + popd > /dev/null cp -R $RUNFILES/third_party/eigen3 ${TMPDIR}/third_party cp tensorflow/tools/pip_package/MANIFEST.in ${TMPDIR} cp tensorflow/tools/pip_package/README ${TMPDIR} cp tensorflow/tools/pip_package/setup.py ${TMPDIR} +} + +function build_wheel() { + if [ $# -lt 2 ] ; then + echo "No src and dest dir provided" + exit 1 + fi + + TMPDIR="$1" + DEST="$2" + PKG_NAME_FLAG="$3" # Before we leave the top-level directory, make sure we know how to # call python. @@ -175,15 +153,110 @@ function main() { source tools/python_bin_path.sh fi - pushd ${TMPDIR} + pushd ${TMPDIR} > /dev/null rm -f MANIFEST echo $(date) : "=== Building wheel" "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel ${PKG_NAME_FLAG} >/dev/null mkdir -p ${DEST} cp dist/* ${DEST} - popd - rm -rf ${TMPDIR} + popd > /dev/null echo $(date) : "=== Output wheel file is in: ${DEST}" } +function usage() { + echo "Usage:" + echo "$0 [--src srcdir] [--dst dstdir] [options]" + echo "$0 dstdir [options]" + echo "" + echo " --src prepare sources in srcdir" + echo " will use temporary dir if not specified" + echo "" + echo " --dst build wheel in dstdir" + echo " if dstdir is not set do not build, only prepare sources" + echo "" + echo " Options:" + echo " --project_name set project name to name" + echo " --gpu build tensorflow_gpu" + echo " --gpudirect build tensorflow_gpudirect" + echo " --nightly_flag build tensorflow nightly" + echo "" + exit 1 +} + +function main() { + PKG_NAME_FLAG="" + PROJECT_NAME="" + GPU_BUILD=0 + NIGHTLY_BUILD=0 + SRCDIR="" + DSTDIR="" + CLEANSRC=1 + while true; do + if [[ "$1" == "--help" ]]; then + usage + exit 1 + elif [[ "$1" == "--nightly_flag" ]]; then + NIGHTLY_BUILD=1 + elif [[ "$1" == "--gpu" ]]; then + GPU_BUILD=1 + elif [[ "$1" == "--gpudirect" ]]; then + PKG_NAME_FLAG="--project_name tensorflow_gpudirect" + elif [[ "$1" == "--project_name" ]]; then + shift + if [[ -z "$1" ]]; then + break + fi + PROJECT_NAME="$1" + elif [[ "$1" == "--src" ]]; then + shift + SRCDIR="$(real_path $1)" + CLEANSRC=0 + elif [[ "$1" == "--dst" ]]; then + shift + DSTDIR="$(real_path $1)" + else + DSTDIR="$(real_path $1)" + fi + shift + + if [[ -z "$1" ]]; then + break + fi + done + + if [[ -z "$DSTDIR" ]] && [[ -z "$SRCDIR" ]]; then + echo "No destination dir provided" + usage + exit 1 + fi + + if [[ -z "$SRCDIR" ]]; then + # make temp srcdir if none set + SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)" + fi + + prepare_src "$SRCDIR" + + if [[ -z "$DSTDIR" ]]; then + # only want to prepare sources + exit + fi + + if [[ -n ${PROJECT_NAME} ]]; then + PKG_NAME_FLAG="--project_name ${PROJECT_NAME}" + elif [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tf_nightly_gpu" + elif [[ ${NIGHTLY_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tf_nightly" + elif [[ ${GPU_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tensorflow_gpu" + fi + + build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG" + + if [[ $CLEANSRC -ne 0 ]]; then + rm -rf "${TMPDIR}" + fi +} + main "$@" diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 319878e1b5ae9ff9d72132e2421062f6ca26197a..c630ca04b885d35da6550d4e5f3e6912b5fd7a00 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -12,6 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""TensorFlow is an open source machine learning framework for everyone. + +TensorFlow is an open source software library for high performance numerical +computation. Its flexible architecture allows easy deployment of computation +across a variety of platforms (CPUs, GPUs, TPUs), and from desktops to clusters +of servers to mobile and edge devices. + +Originally developed by researchers and engineers from the Google Brain team +within Google's AI organization, it comes with strong support for machine +learning and deep learning and the flexible numerical computation core is used +across many other scientific domains. +""" from __future__ import absolute_import from __future__ import division @@ -28,25 +40,12 @@ from setuptools import setup from setuptools.command.install import install as InstallCommandBase from setuptools.dist import Distribution +DOCLINES = __doc__.split('\n') + # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.8.0' - -_SHORT_DESCRIPTION = ('TensorFlow is an open source machine learning framework ' - 'for everyone.') - -_LONG_DESCRIPTION = ('TensorFlow is an open source software library for high ' - 'performance numerical computation. Its flexible ' - 'architecture allows easy deployment of computation across' - ' a variety of platforms (CPUs, GPUs, TPUs), and from ' - 'desktops to clusters of servers to mobile and edge ' - 'devices. Originally developed by researchers and ' - 'engineers from the Google Brain team within Google\'s AI ' - 'organization, it comes with strong support for machine ' - 'learning and deep learning and the flexible numerical ' - 'computation core is used across many other scientific ' - 'domains.') +_VERSION = '1.9.0-rc0' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', @@ -54,7 +53,8 @@ REQUIRED_PACKAGES = [ 'gast >= 0.2.0', 'numpy >= 1.13.3', 'six >= 1.10.0', - 'protobuf >= 3.4.0', + 'protobuf >= 3.6.0', + 'setuptools <= 39.1.0', 'tensorboard >= 1.8.0, < 1.9.0', 'termcolor >= 1.1.0', ] @@ -84,7 +84,7 @@ else: if 'tf_nightly' in project_name: for i, pkg in enumerate(REQUIRED_PACKAGES): if 'tensorboard' in pkg: - REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.9.0a0, < 1.10.0a0' + REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.10.0a0, < 1.11.0a0' break # weakref.finalize and enum were introduced in Python 3.4 @@ -96,7 +96,8 @@ if sys.version_info < (3, 4): CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main', 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', - 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main', + 'tflite_convert = tensorflow.contrib.lite.python.tflite_convert:main', + 'toco = tensorflow.contrib.lite.python.tflite_convert:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', # We need to keep the TensorBoard command, even though the console script # is now declared by the tensorboard pip package. If we remove the @@ -169,8 +170,9 @@ class InstallHeaders(Command): # symlink within the directory hierarchy. # NOTE(keveman): Figure out how to customize bdist_wheel package so # we can do the symlink. - if 'external/eigen_archive/' in install_dir: - extra_dir = install_dir.replace('external/eigen_archive', '') + if 'tensorflow/include/external/eigen_archive/' in install_dir: + extra_dir = install_dir.replace( + 'tensorflow/include/external/eigen_archive', '') if not os.path.exists(extra_dir): self.mkpath(extra_dir) self.copy_file(header, extra_dir) @@ -203,13 +205,12 @@ def find_files(pattern, root): yield os.path.join(dirpath, filename) -matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x] - so_lib_paths = [ i for i in os.listdir('.') if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') ] +matches = [] for path in so_lib_paths: matches.extend( ['../' + x for x in find_files('*', path) if '.py' not in x] @@ -224,14 +225,15 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'google/protobuf_archive/src')) + list(find_files('*', 'third_party/eigen3')) + - list(find_files('*', 'external/eigen_archive'))) + list(find_files('*', 'tensorflow/include/external/eigen_archive'))) setup( name=project_name, version=_VERSION.replace('-', ''), - description=_SHORT_DESCRIPTION, - long_description=_LONG_DESCRIPTION, + description=DOCLINES[0], + long_description='\n'.join(DOCLINES[2:]), url='https://www.tensorflow.org/', + download_url='https://github.com/tensorflow/tensorflow/tags', author='Google Inc.', author_email='opensource@google.com', # Contained modules and scripts. @@ -257,7 +259,7 @@ setup( }, # PyPI package information. classifiers=[ - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc index 62e29b5128f3a2c9a22eadb02df6162fce352c60..15d7c702819ddec256b779f41b8745633d4a7769 100644 --- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc @@ -279,8 +279,13 @@ void Generator::AppendFieldValueAppend(const FieldDescriptor& field, if (omit_default) { Print("if (", field_expr, " != 0) {").Nest(); } - Print("o->AppendEnumName(\"", field.name(), "\", ", - GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, "));"); + Print("const char* enum_name = ", + GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, ");"); + Print("if (enum_name[0]) {").Nest(); + Print("o->AppendEnumName(\"", field.name(), "\", enum_name);"); + Unnest().Print("} else {").Nest(); + Print("o->AppendNumeric(\"", field.name(), "\", ", field_expr, ");"); + Unnest().Print("}"); if (omit_default) { Unnest().Print("}"); } @@ -540,18 +545,24 @@ void Generator::AppendParseMessageFunction(const Descriptor& md) { for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) { const auto* value_d = enum_d->value(enum_i); const string& value_name = value_d->name(); - string condition = StrCat("value == \"", value_name, - "\" || value == \"", value_d->number(), "\""); - if (value_d->number() == 0) { - StrAppend(&condition, " || value == \"-0\""); - } + string condition = StrCat("value == \"", value_name, "\""); Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {"); Nest(); Print(set_value_prefix, "(", value_prefix, value_name, ");"); Unnest(); } + Print("} else {"); + Nest(); + // Proto3 allows all numeric values. + Print("int32 int_value;"); + Print("if (strings::SafeStringToNumeric(value, &int_value)) {"); + Nest(); + Print(set_value_prefix, "(static_cast<", GetQualifiedName(*enum_d), + ">(int_value));"); + Unnest(); Print("} else {").Nest().Print("return false;").Unnest().Print("}"); + Unnest().Print("}"); } else { Print(field->cpp_type_name(), " value;"); switch (field->cpp_type()) { @@ -803,6 +814,9 @@ void Generator::Generate(const FileDescriptor& fd) { // Add header to cc file. SetOutput(&cc_); Print("// GENERATED FILE - DO NOT MODIFY"); + Print(); + Print("#include "); // for `std::stable_sort()` + Print(); headers = {GetProtoTextHeaderName(fd, true /* impl */)}; AddHeadersToCurrentSection(headers); Print(); diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc index 6f0b4f47de6464aa0f0648f3b0a2fac1e7d3c7cc..e67add72de660b9c8dd566b6db978a8dc489c749 100644 --- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc @@ -455,7 +455,10 @@ TEST(CreateProtoDebugStringLibTest, Enums) { "repeated_nested_enum: 1")); EXPECT_PARSE_SUCCESS("", "optional_nested_enum: -0"); - EXPECT_PARSE_FAILURE("optional_nested_enum: 6"); + // TODO(amauryfa): restore the line below when protobuf::TextFormat also + // supports unknonwn enum values. + // EXPECT_PARSE_SUCCESS("optional_nested_enum: 6", "optional_nested_enum: 6"); + EXPECT_PARSE_FAILURE("optional_nested_enum: 2147483648"); // > INT32_MAX EXPECT_PARSE_FAILURE("optional_nested_enum: BARNONE"); EXPECT_PARSE_FAILURE("optional_nested_enum: 'BAR'"); EXPECT_PARSE_FAILURE("optional_nested_enum: \"BAR\" "); diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py index df71840b64db3a1a451ec74b12d039a412976666..92bb5127dacf316c62cd64b3874b283309deffd5 100644 --- a/tensorflow/tools/quantization/quantize_graph_test.py +++ b/tensorflow/tools/quantization/quantize_graph_test.py @@ -119,8 +119,8 @@ def are_tensors_near(a, b, tolerance): flat_a = a.flatten() flat_b = b.flatten() if len(flat_a) != len(flat_b): - print("Tensors are different sizes: " + str(len(flat_a)) + " vs " + str( - len(flat_b))) + tf_logging.info("Tensors are different sizes: " + str(len(flat_a)) + " vs " + + str(len(flat_b))) return False value_count = len(flat_a) how_many_different = 0 @@ -140,10 +140,10 @@ def are_tensors_near(a, b, tolerance): if how_many_different == 0: return True else: - print("Tensors have {0} different values ({1}%), with mean difference" - " {2} and mean absolute difference {3}".format( - how_many_different, proportion_different * 100, mean_difference, - mean_abs_difference)) + tf_logging.info("Tensors have {0} different values ({1}%), with mean" + " difference {2} and mean absolute difference {3}".format( + how_many_different, proportion_different * 100, + mean_difference, mean_abs_difference)) return False diff --git a/tensorflow/tools/test/upload_test_benchmarks.py b/tensorflow/tools/test/upload_test_benchmarks.py index 9c45359ee1b037ffb01820f874b88b6cabc6d14b..c0305751092c3ed1916f671bd515cb4253f5ada2 100644 --- a/tensorflow/tools/test/upload_test_benchmarks.py +++ b/tensorflow/tools/test/upload_test_benchmarks.py @@ -89,7 +89,6 @@ import shutil from six import text_type from google.cloud import datastore -from six import text_type def is_real_file(dirpath, fname): diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD deleted file mode 100644 index 71443cc41eb5ecdd23e1a47712633c77fcd7d395..0000000000000000000000000000000000000000 --- a/tensorflow/user_ops/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -# Description: -# An example for custom op and kernel defined as a TensorFlow plugin. - -package( - default_visibility = ["//tensorflow:internal"], -) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") - -tf_custom_op_library( - name = "ackermann_op.so", - srcs = ["ackermann_op.cc"], -) - -tf_py_test( - name = "ackermann_test", - size = "small", - srcs = ["ackermann_test.py"], - additional_deps = ["//tensorflow:tensorflow_py"], - data = [":ackermann_op.so"], -) - -tf_custom_op_library( - name = "duplicate_op.so", - srcs = ["duplicate_op.cc"], -) - -tf_py_test( - name = "duplicate_op_test", - size = "small", - srcs = ["duplicate_op_test.py"], - additional_deps = ["//tensorflow:tensorflow_py"], - data = [":duplicate_op.so"], -) - -tf_custom_op_library( - name = "invalid_op.so", - srcs = ["invalid_op.cc"], -) - -tf_py_test( - name = "invalid_op_test", - size = "small", - srcs = ["invalid_op_test.py"], - additional_deps = ["//tensorflow:tensorflow_py"], - data = [":invalid_op.so"], -) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index aa50c0b7f7b96720e1155a1e642586b4bbdc1eb4..4982cc26db3d33d4a126a6b4dd22430a2ca37eb5 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -50,31 +50,31 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_linux", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.13/mklml_lnx_2018.0.2.20180127.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.13/mklml_lnx_2018.0.2.20180127.tgz", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz" ], - sha256 = "74844bd77294742bf2396ff040369d1aa4cdd9e826fcd38cf8398ae83564d146", - strip_prefix = "mklml_lnx_2018.0.2.20180127", + sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725", + strip_prefix = "mklml_lnx_2018.0.3.20180406", build_file = clean_dep("//third_party/mkl:mkl.BUILD") ) mkl_repository( name = "mkl_windows", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.13/mklml_win_2018.0.2.20180127.zip", - "https://github.com/intel/mkl-dnn/releases/download/v0.13/mklml_win_2018.0.2.20180127.zip" + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip", + "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip" ], - sha256 = "d8fbf0faa0684bffa3548005d05fe5cfe56ff9dbc0e15e7612d7ac01055a6ded", - strip_prefix = "mklml_win_2018.0.2.20180127", + sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694", + strip_prefix = "mklml_win_2018.0.3.20180406", build_file = clean_dep("//third_party/mkl:mkl.BUILD") ) mkl_repository( name = "mkl_darwin", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.13/mklml_mac_2018.0.2.20180127.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.13/mklml_mac_2018.0.2.20180127.tgz" + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz" ], - sha256 = "aa740d71e14562bfea56e6829e6dc186e7487cbcf6748a88dec73826b7ec1943", - strip_prefix = "mklml_mac_2018.0.2.20180127", + sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b", + strip_prefix = "mklml_mac_2018.0.3.20180406", build_file = clean_dep("//third_party/mkl:mkl.BUILD") ) @@ -85,11 +85,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "mkl_dnn", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/v0.13.tar.gz", - "https://github.com/intel/mkl-dnn/archive/v0.13.tar.gz", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/v0.14.tar.gz", + "https://github.com/intel/mkl-dnn/archive/v0.14.tar.gz", ], - sha256 = "d2cfd93a70cfe86ebe054477c530c9b5c1218b70f75856eb6d1956c68ee89e8f", - strip_prefix = "mkl-dnn-0.13", + sha256 = "efebc53882856afec86457a2da644693f5d59c68772d41d640d6b60a8efc4eb0", + strip_prefix = "mkl-dnn-0.14", build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), ) @@ -107,13 +107,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "eigen_archive", urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/6913f0cf7d06.tar.gz", - "https://bitbucket.org/eigen/eigen/get/6913f0cf7d06.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", + "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", ], - sha256 = "791b836cacd03e20bae5bdd25f1c4a5505a0a9975ba94a61eb4e2631fbd1d53a", - strip_prefix = "eigen-eigen-6913f0cf7d06", + sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9", + strip_prefix = "eigen-eigen-fd6845384b86", build_file = clean_dep("//third_party:eigen.BUILD"), - patch_file = clean_dep("//third_party:eigen_fix_cuda_compilation.patch") ) tf_http_archive( @@ -132,11 +131,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "libxsmm_archive", urls = [ - "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", - "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", + "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.9.tar.gz", ], - sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", - strip_prefix = "libxsmm-1.8.1", + sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa", + strip_prefix = "libxsmm-1.9", build_file = clean_dep("//third_party:libxsmm.BUILD"), ) @@ -156,19 +155,39 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_googlesource_code_re2", urls = [ - "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", - "https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", + "https://mirror.bazel.build/github.com/google/re2/archive/2018-04-01.tar.gz", + "https://github.com/google/re2/archive/2018-04-01.tar.gz", ], - sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b", - strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857", + sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912", + strip_prefix = "re2-2018-04-01", + ) + + tf_http_archive( + name = "com_github_googlecloudplatform_google_cloud_cpp", + urls = [ + "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/53f822805e77ea7715f5b52c592a162c515c7219.tar.gz", + "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/53f822805e77ea7715f5b52c592a162c515c7219.tar.gz", + ], + sha256 = "06853bfca77ef4aec09db5ab48c548f68ef2e18f17404cbce61f8d9b820f951b", + strip_prefix = "google-cloud-cpp-53f822805e77ea7715f5b52c592a162c515c7219", + ) + + tf_http_archive( + name = "com_github_googleapis_googleapis", + urls = [ + "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + ], + sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378", + strip_prefix="googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb", + build_file = clean_dep("//third_party:googleapis.BUILD"), ) tf_http_archive( name = "gemmlowp", urls = [ - # TODO (yongtang): uncomment once mirror.bazel.build is propagated. - # "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", ], sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658", @@ -202,6 +221,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): urls = [ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2", + "http://www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", ], sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324", strip_prefix = "nasm-2.12.02", @@ -300,11 +320,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "absl_py", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", - "https://github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", + "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", ], - sha256 = "c30b48e0d2580ef1412e55c5c0e1dab8db2ee4ab56e2075eccff29c90c7c7059", - strip_prefix = "abseil-py-ea8c4d2ddbf3fba610c4d613260561699b776db8", + sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c", + strip_prefix = "abseil-py-pypi-v0.2.2", ) tf_http_archive( @@ -317,7 +337,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "backports.weakref-1.0rc1/src", build_file = clean_dep("//third_party:backports_weakref.BUILD"), ) - + filegroup_external( name = "org_python_license", licenses = ["notice"], # Python 2.0 @@ -332,11 +352,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "protobuf_archive", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", ) # We need to import the protobuf library under the names com_google_protobuf @@ -345,31 +365,31 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_google_protobuf", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", ) tf_http_archive( name = "com_google_protobuf_cc", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", ) tf_http_archive( name = "nsync", urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", - "https://github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz", + "https://github.com/google/nsync/archive/1.20.0.tar.gz", ], - sha256 = "6284454c5cd8b1dae2eeb8cf5eb63004de930b5427ed5f6b1aa793513df6b361", - strip_prefix = "nsync-0559ce013feac8db639ee1bf776aca0325d28777", + sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd", + strip_prefix = "nsync-1.20.0", ) tf_http_archive( @@ -394,12 +414,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "pcre", - sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7", + sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5", urls = [ - "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", - "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", + "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz", + "http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz", ], - strip_prefix = "pcre-8.39", + strip_prefix = "pcre-8.42", build_file = clean_dep("//third_party:pcre.BUILD"), ) @@ -417,23 +437,23 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "curl", - sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", + sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5", urls = [ - "https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", - "https://curl.haxx.se/download/curl-7.49.1.tar.gz", + "https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz", + "https://curl.haxx.se/download/curl-7.60.0.tar.gz", ], - strip_prefix = "curl-7.49.1", + strip_prefix = "curl-7.60.0", build_file = clean_dep("//third_party:curl.BUILD"), ) tf_http_archive( name = "grpc", urls = [ - "https://mirror.bazel.build/github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz", - "https://github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz", + "https://github.com/grpc/grpc/archive/v1.13.0.tar.gz", ], - sha256 = "895b31310e718a61f7335759a778c068a6edde1c089883598a0830cbb7075673", - strip_prefix = "grpc-d184fa229d75d336aedea0041bd59cb93e7e267f", + sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44", + strip_prefix = "grpc-1.13.0", ) @@ -453,33 +473,33 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/e17809bf50a4cdf3cec3b9dc5c9f79d9a45fc32f.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/e17809bf50a4cdf3cec3b9dc5c9f79d9a45fc32f.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93.tar.gz", ], - sha256 = "1b75cb65517e41aaa70a95af55e45d08f37d0d44a192669b10d7b14b976dcc2a", - strip_prefix = "llvm-e17809bf50a4cdf3cec3b9dc5c9f79d9a45fc32f", - build_file = clean_dep("//third_party/llvm:llvm.BUILD"), + sha256 = "280fdc888e2eb88a3a8cc4e7d3034fffc87f98e3e686be31f8c719c6e5b67d2d", + strip_prefix = "llvm-d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93", + build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) tf_http_archive( name = "lmdb", urls = [ - "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", - "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", + "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz", + "https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz", ], - sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", - strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb", + sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28", + strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb", build_file = clean_dep("//third_party:lmdb.BUILD"), ) tf_http_archive( name = "jsoncpp_git", urls = [ - "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", - "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", + "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", + "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", ], - sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", - strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70", + sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6", + strip_prefix = "jsoncpp-1.8.4", build_file = clean_dep("//third_party:jsoncpp.BUILD"), ) @@ -539,11 +559,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "kafka", urls = [ - "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", - "https://github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", + "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz", + "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz", ], - sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e", - strip_prefix = "librdkafka-0.11.1", + sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c", + strip_prefix = "librdkafka-0.11.4", build_file = clean_dep("//third_party:kafka/BUILD"), patch_file = clean_dep("//third_party/kafka:config.patch"), ) @@ -629,6 +649,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""): licenses = ["notice"], # Apache 2.0 ) + java_import_external( + name = "com_squareup_javapoet", + jar_sha256 = "5bb5abdfe4366c15c0da3332c57d484e238bd48260d6f9d6acf2b08fdde1efea", + jar_urls = [ + "http://mirror.bazel.build/repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar", + "http://repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar", + ], + licenses = ["notice"], # Apache 2.0 + ) + tf_http_archive( name = "com_google_pprof", urls = [ @@ -686,11 +716,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-971a68110e4fc1bace10fcb6deeb189e7e1a34ce", - sha256 = "874088d2ee0d9f8524191f77209556415f03dd44e156276edf19e5b90ceb5f55", + strip_prefix = "flatbuffers-1.9.0", + sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3", urls = [ - "https://mirror.bazel.build/github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", - "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", + "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz", ], build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"), ) @@ -724,6 +754,14 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), ) + tf_http_archive( + name = "tflite_mobilenet_ssd_quant", + sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc", + urls = ["https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip", + ], + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + ) tf_http_archive( name = "tflite_conv_actions_frozen", @@ -756,6 +794,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "ovic", ) + tf_http_archive( + name = "build_bazel_rules_android", + sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip", + "https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip", + ], + strip_prefix = "rules_android-0.1.1", + ) + ############################################################################## # BIND DEFINITIONS # @@ -780,10 +828,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@grpc//:grpc_python_plugin", ) - # gRPC has three empty C++ functions which it wants the user to define - # at build time. https://github.com/grpc/grpc/issues/13590 native.bind( name = "grpc_lib", + actual = "@grpc//:grpc++", + ) + + native.bind( + name = "grpc_lib_unsecure", actual = "@grpc//:grpc++_unsecure", ) @@ -821,7 +872,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # Needed by Protobuf native.bind( name = "python_headers", - actual = clean_dep("//util/python:python_headers"), + actual = clean_dep("//third_party/python_runtime:headers"), ) # Needed by Protobuf diff --git a/third_party/android/BUILD b/third_party/android/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/android/android.bzl.tpl b/third_party/android/android.bzl.tpl new file mode 100644 index 0000000000000000000000000000000000000000..e6ed4994f3ba6d721d717a04b0bd22f54dbb1d79 --- /dev/null +++ b/third_party/android/android.bzl.tpl @@ -0,0 +1,9 @@ +"""Set up configurable Android SDK and NDK dependencies.""" + +def android_workspace(): + # String for replacement in Bazel template. + # These will either be replaced by android_sdk_repository if various ENV + # variables are set when `local_config_android` repo_rule is run, or they + # will be replaced by noops otherwise. + MAYBE_ANDROID_SDK_REPOSITORY + MAYBE_ANDROID_NDK_REPOSITORY diff --git a/third_party/android/android_configure.BUILD.tpl b/third_party/android/android_configure.BUILD.tpl new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/android/android_configure.bzl b/third_party/android/android_configure.bzl new file mode 100644 index 0000000000000000000000000000000000000000..da09bdf39eed90b648ca8f47c79d16e3ec3804bb --- /dev/null +++ b/third_party/android/android_configure.bzl @@ -0,0 +1,87 @@ +"""Repository rule for Android SDK and NDK autoconfiguration. + +`android_configure` depends on the following environment variables: + + * `ANDROID_NDK_HOME`: Location of Android NDK root. + * `ANDROID_SDK_HOME`: Location of Android SDK root. + * `ANDROID_SDK_API_LEVEL`: Desired Android SDK API version. + * `ANDROID_NDK_API_LEVEL`: Desired Android NDK API version. + * `ANDROID_BUILD_TOOLS_VERSION`: Desired Android build tools version. +""" + +# TODO(mikecase): Move logic for getting default values for the env variables +# from configure.py script into this rule. + +_ANDROID_NDK_HOME = "ANDROID_NDK_HOME" +_ANDROID_SDK_HOME = "ANDROID_SDK_HOME" +_ANDROID_NDK_API_VERSION = "ANDROID_NDK_API_LEVEL" +_ANDROID_SDK_API_VERSION = "ANDROID_SDK_API_LEVEL" +_ANDROID_BUILD_TOOLS_VERSION = "ANDROID_BUILD_TOOLS_VERSION" + +_ANDROID_SDK_REPO_TEMPLATE = """ + native.android_sdk_repository( + name="androidsdk", + path="%s", + api_level=%s, + build_tools_version="%s", + ) +""" + +_ANDROID_NDK_REPO_TEMPLATE = """ + native.android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s, + ) +""" + +def _android_autoconf_impl(repository_ctx): + """Implementation of the android_autoconf repository rule.""" + sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME) + sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION) + build_tools_version = repository_ctx.os.environ.get( + _ANDROID_BUILD_TOOLS_VERSION) + ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME) + ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION) + + sdk_rule = "pass" + if all([sdk_home, sdk_api_level, build_tools_version]): + sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % ( + sdk_home, sdk_api_level, build_tools_version) + + ndk_rule = "pass" + if all([ndk_home, ndk_api_level]): + ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level) + + repository_ctx.template( + "BUILD", + Label("//third_party/android:android_configure.BUILD.tpl")) + repository_ctx.template( + "android.bzl", + Label("//third_party/android:android.bzl.tpl"), + substitutions={ + "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule, + "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule, + }) + +android_configure = repository_rule( + implementation = _android_autoconf_impl, + environ = [ + _ANDROID_SDK_API_VERSION, + _ANDROID_NDK_API_VERSION, + _ANDROID_BUILD_TOOLS_VERSION, + _ANDROID_NDK_HOME, + _ANDROID_SDK_HOME, + ], +) +"""Writes Android SDK and NDK rules. + +Add the following to your WORKSPACE FILE: + +```python +android_configure(name = "local_config_android") +``` + +Args: + name: A unique name for this workspace rule. +""" diff --git a/third_party/aws.BUILD b/third_party/aws.BUILD index 2dc921933c310aa9ce2bf21798f1b5143386a12d..5426f79e4650a1ce4dcb4a8408691310c864f06c 100644 --- a/third_party/aws.BUILD +++ b/third_party/aws.BUILD @@ -46,6 +46,8 @@ cc_library( "aws-cpp-sdk-core/source/utils/xml/**/*.cpp", "aws-cpp-sdk-core/source/utils/crypto/*.cpp", "aws-cpp-sdk-core/source/utils/crypto/factory/**/*.cpp", + "aws-cpp-sdk-kinesis/include/**/*.h", + "aws-cpp-sdk-kinesis/source/**/*.cpp", "aws-cpp-sdk-s3/include/**/*.h", "aws-cpp-sdk-s3/source/**/*.cpp", ]), @@ -72,6 +74,7 @@ cc_library( }), includes = [ "aws-cpp-sdk-core/include/", + "aws-cpp-sdk-kinesis/include/", "aws-cpp-sdk-s3/include/", ], deps = [ diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl index cfd8bfe98d7851e8192de69bbcdc41cb2b83204b..a014a806a69ecf9d7e43c51daf3672fc5750e706 100644 --- a/third_party/clang_toolchain/download_clang.bzl +++ b/third_party/clang_toolchain/download_clang.bzl @@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder): # Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release # can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py - CLANG_REVISION = '330570' - CLANG_SUB_REVISION = 2 + CLANG_REVISION = '335091' + CLANG_SUB_REVISION = 1 package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION) checksums = { 'Linux_x64': - '2108e172e05d4904c3c46125a33ab4a1175b36ec2a2226619a243e1d8f397e97', + '17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f', 'Mac': - '481b5c6909f0ea250216061bd45e9c982b4befff65cbfca2ee1090c21a109eac', + '9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468', 'Win': - '8f04a3ac99d463d4179eb2f68a13575408c3dddc62887a1e441c77123e35e301', + 'e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e', } platform_folder = _get_platform_folder(repo_ctx.os.name) diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 4def6f94892329e0d8b594b824babd60ea259351..1638b7216162abca208267ff804c6d92231081f6 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -7,6 +7,7 @@ exports_files(["COPYING"]) CURL_WIN_COPTS = [ "/Iexternal/curl/lib", + "/DBUILDING_LIBCURL", "/DHAVE_CONFIG_H", "/DCURL_DISABLE_FTP", "/DCURL_DISABLE_NTLM", @@ -49,6 +50,8 @@ cc_library( "lib/curl_addrinfo.c", "lib/curl_addrinfo.h", "lib/curl_base64.h", + "lib/curl_ctype.c", + "lib/curl_ctype.h", "lib/curl_des.h", "lib/curl_endian.h", "lib/curl_fnmatch.c", @@ -75,6 +78,7 @@ cc_library( "lib/curl_sec.h", "lib/curl_setup.h", "lib/curl_setup_once.h", + "lib/curl_sha256.h", "lib/curl_sspi.c", "lib/curl_sspi.h", "lib/curl_threads.c", @@ -134,6 +138,8 @@ cc_library( "lib/md5.c", "lib/memdebug.c", "lib/memdebug.h", + "lib/mime.c", + "lib/mime.h", "lib/mprintf.c", "lib/multi.c", "lib/multihandle.h", @@ -153,8 +159,8 @@ cc_library( "lib/pop3.h", "lib/progress.c", "lib/progress.h", - "lib/rawstr.c", - "lib/rawstr.h", + "lib/rand.c", + "lib/rand.h", "lib/rtsp.c", "lib/rtsp.h", "lib/security.c", @@ -162,8 +168,11 @@ cc_library( "lib/select.h", "lib/sendf.c", "lib/sendf.h", + "lib/setopt.c", + "lib/setopt.h", "lib/setup-os400.h", "lib/setup-vms.h", + "lib/sha256.c", "lib/share.c", "lib/share.h", "lib/sigpipe.h", @@ -179,10 +188,10 @@ cc_library( "lib/splay.c", "lib/splay.h", "lib/ssh.h", + "lib/strcase.c", + "lib/strcase.h", "lib/strdup.c", "lib/strdup.h", - "lib/strequal.c", - "lib/strequal.h", "lib/strerror.c", "lib/strerror.h", "lib/strtok.c", @@ -241,13 +250,12 @@ cc_library( }), hdrs = [ "include/curl/curl.h", - "include/curl/curlbuild.h", - "include/curl/curlrules.h", "include/curl/curlver.h", "include/curl/easy.h", "include/curl/mprintf.h", "include/curl/multi.h", "include/curl/stdcheaders.h", + "include/curl/system.h", "include/curl/typecheck-gcc.h", ], copts = select({ @@ -256,6 +264,7 @@ cc_library( "//conditions:default": [ "-Iexternal/curl/lib", "-D_GNU_SOURCE", + "-DBUILDING_LIBCURL", "-DHAVE_CONFIG_H", "-DCURL_DISABLE_FTP", "-DCURL_DISABLE_NTLM", # turning it off in configure is not enough @@ -676,6 +685,7 @@ genrule( "# define SIZEOF_INT 4", "# define SIZEOF_LONG 8", "# define SIZEOF_OFF_T 8", + "# define SIZEOF_CURL_OFF_T 8", "# define SIZEOF_SHORT 2", "# define SIZEOF_SIZE_T 8", "# define SIZEOF_TIME_T 8", diff --git a/third_party/eigen.BUILD b/third_party/eigen.BUILD index e54c1a4501d46b6b68a9b8fcc9ce0b1af0535ef4..759f8a9be92e14537d334c3ec37f036d369d8796 100644 --- a/third_party/eigen.BUILD +++ b/third_party/eigen.BUILD @@ -69,3 +69,9 @@ cc_library( includes = ["."], visibility = ["//visibility:public"], ) + +filegroup( + name = "eigen_header_files", + srcs = EIGEN_MPL2_HEADER_FILES, + visibility = ["//visibility:public"], +) diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD index f661093bc9f68b845f3000b0a931c66773fb3339..203991b50f56086aa76932595f6797ae3bbf58db 100644 --- a/third_party/eigen3/BUILD +++ b/third_party/eigen3/BUILD @@ -17,21 +17,23 @@ load("//tensorflow:tensorflow.bzl", "if_mkl") # INTEL_MKL end load("//tensorflow:tensorflow.bzl", "if_mkl") +EIGEN3_THIRD_PARTY_HEADERS = [ + "Eigen/Core", + "Eigen/LU", + "Eigen/Cholesky", + "Eigen/Eigenvalues", + "Eigen/QR", + "Eigen/SVD", + "unsupported/Eigen/MatrixFunctions", + "unsupported/Eigen/SpecialFunctions", + "unsupported/Eigen/CXX11/ThreadPool", + "unsupported/Eigen/CXX11/Tensor", + "unsupported/Eigen/CXX11/FixedPoint", +] + glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + cc_library( name = "eigen3", - hdrs = glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + [ - "Eigen/Core", - "Eigen/LU", - "Eigen/Cholesky", - "Eigen/Eigenvalues", - "Eigen/QR", - "Eigen/SVD", - "unsupported/Eigen/MatrixFunctions", - "unsupported/Eigen/SpecialFunctions", - "unsupported/Eigen/CXX11/ThreadPool", - "unsupported/Eigen/CXX11/Tensor", - "unsupported/Eigen/CXX11/FixedPoint", - ], + hdrs = EIGEN3_THIRD_PARTY_HEADERS, includes = if_mkl(["./mkl_include"]), visibility = ["//visibility:public"], deps = [ @@ -48,3 +50,35 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +filegroup( + name = "eigen_third_party_header_files", + srcs = EIGEN3_THIRD_PARTY_HEADERS, + visibility = ["//visibility:public"], +) + +genrule( + name = "install_eigen_headers", + srcs = [ + "@eigen_archive//:eigen_header_files", + ":eigen_third_party_header_files", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(locations @eigen_archive//:eigen_header_files) ; do + d="$${f%/*}" + d="$${d#*external/eigen_archive/}" + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + + for f in $(locations :eigen_third_party_header_files) ; do + d="$${f%/*}" + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """, +) diff --git a/third_party/eigen_fix_cuda_compilation.patch b/third_party/eigen_fix_cuda_compilation.patch deleted file mode 100644 index b921a7c31d5c96c79cd3033b13c60a8f7e63ba75..0000000000000000000000000000000000000000 --- a/third_party/eigen_fix_cuda_compilation.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h ---- a/Eigen/src/Core/ProductEvaluators.h -+++ b/Eigen/src/Core/ProductEvaluators.h -@@ -137,7 +137,7 @@ struct Assignment::type> - { - typedef Product SrcXprType; -- static EIGEN_STRONG_INLINE -+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) - { - Index dstRows = src.rows(); -@@ -390,7 +390,7 @@ struct generic_product_impl::Scalar Scalar; - - template -- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) -+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) - { - // Same as: dst.noalias() = lhs.lazyProduct(rhs); - // but easier on the compiler side -@@ -398,14 +398,14 @@ struct generic_product_impl -- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) -+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) - { - // dst.noalias() += lhs.lazyProduct(rhs); - call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::add_assign_op()); - } - - template -- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) -+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) - { - // dst.noalias() -= lhs.lazyProduct(rhs); - call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op()); diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md index fbb1fde837b92bc521698d0a517a946da0438dbc..e2fd8009a052d7cbfd01b48af7da6b891ad08c74 100644 --- a/third_party/examples/eager/spinn/README.md +++ b/third_party/examples/eager/spinn/README.md @@ -22,7 +22,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth - [`data.py`](../../../../tensorflow/contrib/eager/python/examples/spinn/data.py): Pipeline for loading and preprocessing the [SNLI](https://nlp.stanford.edu/projects/snli/) data and [GloVe](https://nlp.stanford.edu/projects/glove/) word embedding, written - using the [`tf.data`](https://www.tensorflow.org/programmers_guide/datasets) + using the [`tf.data`](https://www.tensorflow.org/guide/datasets) API. - [`spinn.py`](./spinn.py): Model definition and training routines. This example illustrates how one might perform the following actions with diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 8a2b24aa4e284fd70c7148d26c3c4d6ccd04f98c..67456a5bdfc05f7b41218f5e522e0e74e9065f9b 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -462,7 +462,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable): 2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is the output dimension size of the SNLIClassifier. """ - with tfe.GradientTape() as tape: + with tf.GradientTape() as tape: tape.watch(self._model.variables) logits = self._model(premise, premise_transition, diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD index 824c97be60e7ef148a363b964ed330ba3c5fcb0c..639dff2cd01056cf70e727b39c0a0c537c763c9e 100644 --- a/third_party/flatbuffers/flatbuffers.BUILD +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -98,6 +98,8 @@ cc_binary( "grpc/src/compiler/cpp_generator.h", "grpc/src/compiler/go_generator.cc", "grpc/src/compiler/go_generator.h", + "grpc/src/compiler/java_generator.cc", + "grpc/src/compiler/java_generator.h", "grpc/src/compiler/schema_interface.h", "src/flatc_main.cpp", "src/idl_gen_cpp.cpp", diff --git a/third_party/googleapis.BUILD b/third_party/googleapis.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..95e999af1886576317aa59d133e8d5c88ba368d3 --- /dev/null +++ b/third_party/googleapis.BUILD @@ -0,0 +1,45 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) +licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + +load("@protobuf_archive//:protobuf.bzl", "cc_proto_library") + +cc_proto_library( + name = "bigtable_protos", + srcs = [ + "google/bigtable/admin/v2/bigtable_instance_admin.proto", + "google/bigtable/admin/v2/bigtable_table_admin.proto", + "google/bigtable/admin/v2/common.proto", + "google/bigtable/admin/v2/instance.proto", + "google/bigtable/admin/v2/table.proto", + "google/bigtable/v2/bigtable.proto", + "google/bigtable/v2/data.proto", + "google/iam/v1/iam_policy.proto", + "google/iam/v1/policy.proto", + "google/longrunning/operations.proto", + "google/rpc/status.proto", + "google/rpc/error_details.proto", + "google/api/annotations.proto", + "google/api/auth.proto", + "google/api/http.proto", + ], + include = ".", + protoc = "@protobuf_archive//:protoc", + default_runtime = "@protobuf_archive//:protobuf", + deps = ["@protobuf_archive//:cc_wkt_protos"], + use_grpc_plugin = True, +) diff --git a/third_party/gpus/crosstool/CROSSTOOL_clang.tpl b/third_party/gpus/crosstool/CROSSTOOL.tpl similarity index 54% rename from third_party/gpus/crosstool/CROSSTOOL_clang.tpl rename to third_party/gpus/crosstool/CROSSTOOL.tpl index 2f09473ee2ddf9a38ca0c7aa11094690607b532f..1424ff6511dfe0e7e8eef2843201e825e09a91f1 100644 --- a/third_party/gpus/crosstool/CROSSTOOL_clang.tpl +++ b/third_party/gpus/crosstool/CROSSTOOL.tpl @@ -140,9 +140,7 @@ toolchain { flag_group { # All warnings are enabled. Maybe enable -Werror as well? flag: "-Wall" - # Some parts of the codebase set -Werror and hit this warning, so - # switch it off for now. - flag: "-Wno-invalid-partial-specialization" + %{host_compiler_warnings} } } } @@ -278,7 +276,7 @@ toolchain { } # Set clang as a C/C++ compiler. - tool_path { name: "gcc" path: "%{clang_path}" } + tool_path { name: "gcc" path: "%{host_compiler_path}" } # Use the default system toolchain for everything else. tool_path { name: "ar" path: "/usr/bin/ar" } @@ -297,3 +295,245 @@ toolchain { %{host_compiler_includes} } + +toolchain { + abi_version: "local" + abi_libc_version: "local" + compiler: "compiler" + host_system_name: "local" + needsPic: true + target_libc: "macosx" + target_cpu: "darwin" + target_system_name: "local" + toolchain_identifier: "local_darwin" + feature { + name: "c++11" + flag_set { + action: "c++-compile" + flag_group { + flag: "-std=c++11" + } + } + } + + feature { + name: "stdlib" + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" + flag_group { + flag: "-lc++" + } + } + } + + feature { + name: "determinism" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # Make C++ compilation deterministic. Use linkstamping instead of these + # compiler symbols. + flag: "-Wno-builtin-macro-redefined" + flag: "-D__DATE__=\"redacted\"" + flag: "-D__TIMESTAMP__=\"redacted\"" + flag: "-D__TIME__=\"redacted\"" + } + } + } + + # This feature will be enabled for builds that support pic by bazel. + feature { + name: "pic" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + expand_if_all_available: "pic" + flag: "-fPIC" + } + flag_group { + expand_if_none_available: "pic" + flag: "-fPIE" + } + } + } + + # Security hardening on by default. + feature { + name: "hardening" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases. + # We need to undef it before redefining it as some distributions now + # have it enabled by default. + flag: "-U_FORTIFY_SOURCE" + flag: "-D_FORTIFY_SOURCE=1" + flag: "-fstack-protector" + } + } + flag_set { + action: "c++-link-executable" + flag_group { + flag: "-pie" + } + } + } + + feature { + name: "warnings" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # All warnings are enabled. Maybe enable -Werror as well? + flag: "-Wall" + %{host_compiler_warnings} + } + } + } + + # Keep stack frames for debugging, even in opt mode. + feature { + name: "frame-pointer" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-fno-omit-frame-pointer" + } + } + } + + feature { + name: "no-canonical-prefixes" + flag_set { + action: "c-compile" + action: "c++-compile" + action: "c++-link-executable" + action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" + flag_group { + flag:"-no-canonical-prefixes" + } + } + } + + feature { + name: "disable-assertions" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-DNDEBUG" + } + } + } + + feature { + name: "linker-bin-path" + + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" + flag_group { + flag: "-B/usr/bin/" + } + } + } + + feature { + name: "undefined-dynamic" + flag_set { + action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" + action: "c++-link-executable" + flag_group { + flag: "-undefined" + flag: "dynamic_lookup" + } + } + } + + feature { + name: "common" + implies: "stdlib" + implies: "c++11" + implies: "determinism" + implies: "hardening" + implies: "warnings" + implies: "frame-pointer" + implies: "no-canonical-prefixes" + implies: "linker-bin-path" + implies: "undefined-dynamic" + } + + feature { + name: "opt" + implies: "common" + implies: "disable-assertions" + + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # No debug symbols. + # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt + # or even generally? However, that can't happen here, as it requires + # special handling in Bazel. + flag: "-g0" + + # Conservative choice for -O + # -O3 can increase binary size and even slow down the resulting binaries. + # Profile first and / or use FDO if you need better performance than this. + flag: "-O2" + + # Removal of unused code and data at link time (can this increase binary size in some cases?). + flag: "-ffunction-sections" + flag: "-fdata-sections" + } + } + } + + feature { + name: "fastbuild" + implies: "common" + } + + feature { + name: "dbg" + implies: "common" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-g" + } + } + } + + # Set clang as a C/C++ compiler. + tool_path { name: "gcc" path: "%{host_compiler_path}" } + + # Use the default system toolchain for everything else. + tool_path { name: "ar" path: "/usr/bin/libtool" } + tool_path { name: "compat-ld" path: "/usr/bin/ld" } + tool_path { name: "cpp" path: "/usr/bin/cpp" } + tool_path { name: "dwp" path: "/usr/bin/dwp" } + tool_path { name: "gcov" path: "/usr/bin/gcov" } + tool_path { name: "ld" path: "/usr/bin/ld" } + tool_path { name: "nm" path: "/usr/bin/nm" } + tool_path { name: "objcopy" path: "/usr/bin/objcopy" } + tool_path { name: "objdump" path: "/usr/bin/objdump" } + tool_path { name: "strip" path: "/usr/bin/strip" } + + # Enabled dynamic linking. + linking_mode_flags { mode: DYNAMIC } + +%{host_compiler_includes} +} diff --git a/third_party/gpus/crosstool/CROSSTOOL_nvcc.tpl b/third_party/gpus/crosstool/CROSSTOOL_nvcc.tpl deleted file mode 100644 index 05290d647ea1b25f073f6e0c2a8de07c0fe65d58..0000000000000000000000000000000000000000 --- a/third_party/gpus/crosstool/CROSSTOOL_nvcc.tpl +++ /dev/null @@ -1,249 +0,0 @@ -major_version: "local" -minor_version: "" -default_target_cpu: "same_as_host" - -default_toolchain { - cpu: "k8" - toolchain_identifier: "local_linux" -} -default_toolchain { - cpu: "piii" - toolchain_identifier: "local_linux" -} -default_toolchain { - cpu: "arm" - toolchain_identifier: "local_linux" -} -default_toolchain { - cpu: "darwin" - toolchain_identifier: "local_darwin" -} -default_toolchain { - cpu: "ppc" - toolchain_identifier: "local_linux" -} - -toolchain { - abi_version: "local" - abi_libc_version: "local" - builtin_sysroot: "" - compiler: "compiler" - host_system_name: "local" - needsPic: true - supports_gold_linker: false - supports_incremental_linker: false - supports_fission: false - supports_interface_shared_objects: false - supports_normalizing_ar: false - supports_start_end_lib: false - supports_thin_archives: false - target_libc: "local" - target_cpu: "local" - target_system_name: "local" - toolchain_identifier: "local_linux" - - tool_path { name: "ar" path: "/usr/bin/ar" } - tool_path { name: "compat-ld" path: "/usr/bin/ld" } - tool_path { name: "cpp" path: "/usr/bin/cpp" } - tool_path { name: "dwp" path: "/usr/bin/dwp" } - # As part of the TensorFlow release, we place some cuda-related compilation - # files in @local_config_cuda//crosstool/clang/bin, and this relative - # path, combined with the rest of our Bazel configuration causes our - # compilation to use those files. - tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" } - # Use "-std=c++11" for nvcc. For consistency, force both the host compiler - # and the device compiler to use "-std=c++11". - cxx_flag: "-std=c++11" - linker_flag: "-Wl,-no-as-needed" - linker_flag: "-lstdc++" - linker_flag: "-B/usr/bin/" - -%{host_compiler_includes} - tool_path { name: "gcov" path: "/usr/bin/gcov" } - - # C(++) compiles invoke the compiler (as that is the one knowing where - # to find libraries), but we provide LD so other rules can invoke the linker. - tool_path { name: "ld" path: "/usr/bin/ld" } - - tool_path { name: "nm" path: "/usr/bin/nm" } - tool_path { name: "objcopy" path: "/usr/bin/objcopy" } - objcopy_embed_flag: "-I" - objcopy_embed_flag: "binary" - tool_path { name: "objdump" path: "/usr/bin/objdump" } - tool_path { name: "strip" path: "/usr/bin/strip" } - - # Anticipated future default. - unfiltered_cxx_flag: "-no-canonical-prefixes" - - # Make C++ compilation deterministic. Use linkstamping instead of these - # compiler symbols. - unfiltered_cxx_flag: "-Wno-builtin-macro-redefined" - unfiltered_cxx_flag: "-D__DATE__=\"redacted\"" - unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\"" - unfiltered_cxx_flag: "-D__TIME__=\"redacted\"" - - # Security hardening on by default. - # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases. - # We need to undef it before redefining it as some distributions now have - # it enabled by default. - compiler_flag: "-U_FORTIFY_SOURCE" - compiler_flag: "-D_FORTIFY_SOURCE=1" - compiler_flag: "-fstack-protector" - compiler_flag: "-fPIE" - linker_flag: "-pie" - linker_flag: "-Wl,-z,relro,-z,now" - - # Enable coloring even if there's no attached terminal. Bazel removes the - # escape sequences if --nocolor is specified. This isn't supported by gcc - # on Ubuntu 14.04. - # compiler_flag: "-fcolor-diagnostics" - - # All warnings are enabled. Maybe enable -Werror as well? - compiler_flag: "-Wall" - # Enable a few more warnings that aren't part of -Wall. - compiler_flag: "-Wunused-but-set-parameter" - # But disable some that are problematic. - compiler_flag: "-Wno-free-nonheap-object" # has false positives - - # Keep stack frames for debugging, even in opt mode. - compiler_flag: "-fno-omit-frame-pointer" - - # Anticipated future default. - linker_flag: "-no-canonical-prefixes" - unfiltered_cxx_flag: "-fno-canonical-system-headers" - # Have gcc return the exit code from ld. - linker_flag: "-pass-exit-codes" - # Stamp the binary with a unique identifier. - linker_flag: "-Wl,--build-id=md5" - linker_flag: "-Wl,--hash-style=gnu" - # Gold linker only? Can we enable this by default? - # linker_flag: "-Wl,--warn-execstack" - # linker_flag: "-Wl,--detect-odr-violations" - - # Include directory for cuda headers. -%{cuda_include_path} - - compilation_mode_flags { - mode: DBG - # Enable debug symbols. - compiler_flag: "-g" - } - compilation_mode_flags { - mode: OPT - - # No debug symbols. - # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt or - # even generally? However, that can't happen here, as it requires special - # handling in Bazel. - compiler_flag: "-g0" - - # Conservative choice for -O - # -O3 can increase binary size and even slow down the resulting binaries. - # Profile first and / or use FDO if you need better performance than this. - compiler_flag: "-O2" - - # Disable assertions - compiler_flag: "-DNDEBUG" - - # Removal of unused code and data at link time (can this increase binary size in some cases?). - compiler_flag: "-ffunction-sections" - compiler_flag: "-fdata-sections" - linker_flag: "-Wl,--gc-sections" - } - linking_mode_flags { mode: DYNAMIC } -} - -toolchain { - abi_version: "local" - abi_libc_version: "local" - builtin_sysroot: "" - compiler: "compiler" - host_system_name: "local" - needsPic: true - target_libc: "macosx" - target_cpu: "darwin" - target_system_name: "local" - toolchain_identifier: "local_darwin" - - tool_path { name: "ar" path: "/usr/bin/libtool" } - tool_path { name: "compat-ld" path: "/usr/bin/ld" } - tool_path { name: "cpp" path: "/usr/bin/cpp" } - tool_path { name: "dwp" path: "/usr/bin/dwp" } - tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" } - cxx_flag: "-std=c++11" - ar_flag: "-static" - ar_flag: "-s" - ar_flag: "-o" - linker_flag: "-lc++" - linker_flag: "-undefined" - linker_flag: "dynamic_lookup" - # TODO(ulfjack): This is wrong on so many levels. Figure out a way to auto-detect the proper - # setting from the local compiler, and also how to make incremental builds correct. - cxx_builtin_include_directory: "/" - tool_path { name: "gcov" path: "/usr/bin/gcov" } - tool_path { name: "ld" path: "/usr/bin/ld" } - tool_path { name: "nm" path: "/usr/bin/nm" } - tool_path { name: "objcopy" path: "/usr/bin/objcopy" } - objcopy_embed_flag: "-I" - objcopy_embed_flag: "binary" - tool_path { name: "objdump" path: "/usr/bin/objdump" } - tool_path { name: "strip" path: "/usr/bin/strip" } - - # Anticipated future default. - unfiltered_cxx_flag: "-no-canonical-prefixes" - # Make C++ compilation deterministic. Use linkstamping instead of these - # compiler symbols. - unfiltered_cxx_flag: "-Wno-builtin-macro-redefined" - unfiltered_cxx_flag: "-D__DATE__=\"redacted\"" - unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\"" - unfiltered_cxx_flag: "-D__TIME__=\"redacted\"" - - # Security hardening on by default. - # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases. - compiler_flag: "-D_FORTIFY_SOURCE=1" - compiler_flag: "-fstack-protector" - - # Enable coloring even if there's no attached terminal. Bazel removes the - # escape sequences if --nocolor is specified. - compiler_flag: "-fcolor-diagnostics" - - # All warnings are enabled. Maybe enable -Werror as well? - compiler_flag: "-Wall" - # Enable a few more warnings that aren't part of -Wall. - compiler_flag: "-Wthread-safety" - compiler_flag: "-Wself-assign" - - # Keep stack frames for debugging, even in opt mode. - compiler_flag: "-fno-omit-frame-pointer" - - # Anticipated future default. - linker_flag: "-no-canonical-prefixes" - - # Include directory for cuda headers. -%{cuda_include_path} - - compilation_mode_flags { - mode: DBG - # Enable debug symbols. - compiler_flag: "-g" - } - compilation_mode_flags { - mode: OPT - # No debug symbols. - # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt or even generally? - # However, that can't happen here, as it requires special handling in Bazel. - compiler_flag: "-g0" - - # Conservative choice for -O - # -O3 can increase binary size and even slow down the resulting binaries. - # Profile first and / or use FDO if you need better performance than this. - compiler_flag: "-O2" - - # Disable assertions - compiler_flag: "-DNDEBUG" - - # Removal of unused code and data at link time (can this increase binary size in some cases?). - compiler_flag: "-ffunction-sections" - compiler_flag: "-fdata-sections" - } -} diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 2a37c65bc74a0ec5d0f5b2c9a6dd4339e0e46b68..f6b497f813185f82108de470ae39fac60d5d9f34 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -127,6 +127,15 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "cudnn_header", + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + cc_library( name = "cufft", srcs = ["cuda/lib/%{cufft_lib}"], diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index f3a80d3dd35a1bff1b7fe6a5ff5916f393836214..c90c66912d959af109caab51c742d760e0908f30 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1073,23 +1073,46 @@ def _create_local_cuda_repository(repository_ctx): cc_fullpath = cc if not should_download_clang else "crosstool/" + cc host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath) - cuda_defines = { - "%{cuda_include_path}": _cuda_include_path(repository_ctx, - cuda_config), - "%{host_compiler_includes}": host_compiler_includes, - } + cuda_defines = {} if is_cuda_clang: - cuda_defines["%{clang_path}"] = cc + cuda_defines["%{host_compiler_path}"] = str(cc) + cuda_defines["%{host_compiler_warnings}"] = """ + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. + flag: "-Wno-invalid-partial-specialization" + """ + cuda_defines["%{host_compiler_includes}"] = host_compiler_includes _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty"}) - _tpl(repository_ctx, "crosstool:CROSSTOOL_clang", cuda_defines, out="crosstool/CROSSTOOL") repository_ctx.file("crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "") else: + cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{host_compiler_warnings}"] = "" + # TODO(klimek): We currently need to inject "/" as builtin directory path + # to disable bazel's dependency checks. + # The problem is that: + # - the python rules symlink the python headers into the bazel root + # - the rules use 'includes' in the BUILD file to redirect includes of the + # python headers through those paths + # - bazel currently uses -isystem for include paths specified via 'includes' + # - gcc follows symlinks when resolving files via -isystem paths, and puts + # the resolved paths into the .d file, which makes the dependency check + # fail for bazel + # There are multiple possible ways to solve this: + # 1. make bazel not use -isystem for paths specified via 'includes' + # 2. cp the headers instead of symlinking them + # + # Once this is fixed, the right builtin directory path is: + # (host_compiler_includes + + # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path) + # The cuda directory needs to be passed, as there is currently no rule + # providing the cuda headers in the same way the python headers are + # provided. + cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\"" nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" % (cuda_config.cuda_toolkit_path, ".exe" if cuda_config.cpu_value == "Windows" else ""))) _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc"}) - _tpl(repository_ctx, "crosstool:CROSSTOOL_nvcc", cuda_defines, out="crosstool/CROSSTOOL") _tpl(repository_ctx, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", { @@ -1100,6 +1123,7 @@ def _create_local_cuda_repository(repository_ctx): "%{cuda_compute_capabilities}": ", ".join( ["\"%s\"" % c for c in cuda_config.compute_capabilities]), }) + _tpl(repository_ctx, "crosstool:CROSSTOOL", cuda_defines, out="crosstool/CROSSTOOL") # Set up cuda_config.h, which is used by # tensorflow/stream_executor/dso_loader.cc. diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index 4418ac32fc4b08713ff1d1f0d78042803153c886..663a2187336d4a558a42f9fb6c4017a360976050 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -291,8 +291,10 @@ cc_library( "jchuff.h", "jconfig.h", "jdct.h", + "jerror.h", "jinclude.h", "jmorecfg.h", + "jpegint.h", "jpeglib.h", "jsimd.h", "jsimddct.h", diff --git a/third_party/jsoncpp.BUILD b/third_party/jsoncpp.BUILD index 65f98410b289a7e324c9ed89e33de1c6010fa21a..cf3cba05556a0bb22a632475c6ab810b8230f355 100644 --- a/third_party/jsoncpp.BUILD +++ b/third_party/jsoncpp.BUILD @@ -6,7 +6,6 @@ cc_library( name = "jsoncpp", srcs = [ "include/json/assertions.h", - "src/lib_json/json_batchallocator.h", "src/lib_json/json_reader.cpp", "src/lib_json/json_tool.h", "src/lib_json/json_value.cpp", @@ -20,9 +19,13 @@ cc_library( "include/json/json.h", "include/json/reader.h", "include/json/value.h", + "include/json/version.h", "include/json/writer.h", ], - copts = ["-DJSON_USE_EXCEPTION=0"], + copts = [ + "-DJSON_USE_EXCEPTION=0", + "-DJSON_HAS_INT64", + ], includes = ["include"], visibility = ["//visibility:public"], deps = [":private"], diff --git a/third_party/kafka/BUILD b/third_party/kafka/BUILD index a839ca717e695f35fac684b510f0a022010e0710..75792b0d87366c304ca29f95f943114ee482dfcd 100644 --- a/third_party/kafka/BUILD +++ b/third_party/kafka/BUILD @@ -60,6 +60,8 @@ cc_library( "src/rdkafka_event.h", "src/rdkafka_feature.c", "src/rdkafka_feature.h", + "src/rdkafka_header.c", + "src/rdkafka_header.h", "src/rdkafka_int.h", "src/rdkafka_interceptor.c", "src/rdkafka_interceptor.h", @@ -93,7 +95,6 @@ cc_library( "src/rdkafka_sasl_int.h", "src/rdkafka_sasl_plain.c", "src/rdkafka_subscription.c", - "src/rdkafka_subscription.h", "src/rdkafka_timer.c", "src/rdkafka_timer.h", "src/rdkafka_topic.c", @@ -105,6 +106,8 @@ cc_library( "src/rdlist.h", "src/rdlog.c", "src/rdlog.h", + "src/rdmurmur2.c", + "src/rdmurmur2.h", "src/rdports.c", "src/rdports.h", "src/rdposix.h", diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index 78ed1f4e168891367ddc2249da726a6ef16dd5d5..ee49d281abcd54b566edde119f4a5b3e6b07d2a3 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # BSD 3-clause -exports_files(["LICENSE"]) +exports_files(["LICENSE.md"]) # Arguments to ./scripts/libxsmm_interface.py, see that file for detailed description. # precision: SP & DP diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.autogenerated.BUILD similarity index 89% rename from third_party/llvm/llvm.BUILD rename to third_party/llvm/llvm.autogenerated.BUILD index 35a1ce36e47584a796f27d4cdfb5ca4406b943e6..8f658539187bcf03bf5bc37118884ec28a85e5dd 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.autogenerated.BUILD @@ -8,10 +8,13 @@ exports_files(["LICENSE.TXT"]) load( "@org_tensorflow//third_party/llvm:llvm.bzl", - "gentbl", - "expand_cmake_vars", - "llvm_target_cmake_vars", + "LLVM_COPTS", + "LLVM_DEFINES", + "LLVM_LINKOPTS", "cmake_var_string", + "expand_cmake_vars", + "gentbl", + "llvm_all_cmake_vars", ) load( "@org_tensorflow//third_party:common.bzl", @@ -39,147 +42,25 @@ llvm_target_asm_printers = llvm_targets llvm_target_disassemblers = llvm_targets -# TODO(phawkins): the set of CMake variables was hardcoded for expediency. -# However, we should really detect many of these via configure-time tests. - -# The set of CMake variables common to all targets. -cmake_vars = { - # Headers - "HAVE_DIRENT_H": 1, - "HAVE_DLFCN_H": 1, - "HAVE_ERRNO_H": 1, - "HAVE_EXECINFO_H": 1, - "HAVE_FCNTL_H": 1, - "HAVE_INTTYPES_H": 1, - "HAVE_PTHREAD_H": 1, - "HAVE_SIGNAL_H": 1, - "HAVE_STDINT_H": 1, - "HAVE_SYS_IOCTL_H": 1, - "HAVE_SYS_MMAN_H": 1, - "HAVE_SYS_PARAM_H": 1, - "HAVE_SYS_RESOURCE_H": 1, - "HAVE_SYS_STAT_H": 1, - "HAVE_SYS_TIME_H": 1, - "HAVE_SYS_TYPES_H": 1, - "HAVE_TERMIOS_H": 1, - "HAVE_UNISTD_H": 1, - "HAVE_ZLIB_H": 1, - - # Features - "HAVE_BACKTRACE": 1, - "BACKTRACE_HEADER": "execinfo.h", - "HAVE_DLOPEN": 1, - "HAVE_FUTIMES": 1, - "HAVE_GETCWD": 1, - "HAVE_GETPAGESIZE": 1, - "HAVE_GETRLIMIT": 1, - "HAVE_GETRUSAGE": 1, - "HAVE_GETTIMEOFDAY": 1, - "HAVE_INT64_T": 1, - "HAVE_ISATTY": 1, - "HAVE_LIBEDIT": 1, - "HAVE_LIBPTHREAD": 1, - "HAVE_LIBZ": 1, - "HAVE_MKDTEMP": 1, - "HAVE_MKSTEMP": 1, - "HAVE_MKTEMP": 1, - "HAVE_PREAD": 1, - "HAVE_PTHREAD_GETSPECIFIC": 1, - "HAVE_PTHREAD_MUTEX_LOCK": 1, - "HAVE_PTHREAD_RWLOCK_INIT": 1, - "HAVE_REALPATH": 1, - "HAVE_SBRK": 1, - "HAVE_SETENV": 1, - "HAVE_SETRLIMIT": 1, - "HAVE_SIGALTSTACK": 1, - "HAVE_STRERROR": 1, - "HAVE_STRERROR_R": 1, - "HAVE_STRTOLL": 1, - "HAVE_SYSCONF": 1, - "HAVE_UINT64_T": 1, - "HAVE__UNWIND_BACKTRACE": 1, - - # LLVM features - "ENABLE_BACKTRACES": 1, - "LLVM_BINDIR": "/dev/null", - "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, - "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, - "LLVM_ENABLE_THREADS": 1, - "LLVM_ENABLE_ZLIB": 1, - "LLVM_HAS_ATOMICS": 1, - "LLVM_INCLUDEDIR": "/dev/null", - "LLVM_INFODIR": "/dev/null", - "LLVM_MANDIR": "/dev/null", - "LLVM_NATIVE_TARGET": 1, - "LLVM_NATIVE_TARGETINFO": 1, - "LLVM_NATIVE_TARGETMC": 1, - "LLVM_NATIVE_ASMPRINTER": 1, - "LLVM_NATIVE_ASMPARSER": 1, - "LLVM_NATIVE_DISASSEMBLER": 1, - "LLVM_ON_UNIX": 1, - "LLVM_PREFIX": "/dev/null", - "LLVM_VERSION_MAJOR": 0, - "LLVM_VERSION_MINOR": 0, - "LLVM_VERSION_PATCH": 0, - "LTDL_SHLIB_EXT": ".so", - "PACKAGE_NAME": "llvm", - "PACKAGE_STRING": "llvm tensorflow-trunk", - "PACKAGE_VERSION": "tensorflow-trunk", - "RETSIGTYPE": "void", -} - -# CMake variables specific to the Linux platform -linux_cmake_vars = { - "HAVE_MALLOC_H": 1, - "HAVE_LINK_H": 1, - "HAVE_MALLINFO": 1, - "HAVE_FUTIMENS": 1, -} - -# CMake variables specific to the Darwin (Mac OS X) platform. -darwin_cmake_vars = { - "HAVE_MALLOC_MALLOC_H": 1, -} - -# Select a set of CMake variables based on the platform. -# TODO(phawkins): use a better method to select the right host triple, rather -# than hardcoding x86_64. -all_cmake_vars = select({ - "@org_tensorflow//tensorflow:darwin": cmake_var_string( - cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") + - darwin_cmake_vars, - ), - "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string( - cmake_vars + - llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") + - linux_cmake_vars, - ), - "//conditions:default": cmake_var_string( - cmake_vars + - llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") + - linux_cmake_vars, - ), -}) - # Performs CMake variable substitutions on configuration header files. expand_cmake_vars( name = "config_gen", src = "include/llvm/Config/config.h.cmake", - cmake_vars = all_cmake_vars, + cmake_vars = llvm_all_cmake_vars, dst = "include/llvm/Config/config.h", ) expand_cmake_vars( name = "llvm_config_gen", src = "include/llvm/Config/llvm-config.h.cmake", - cmake_vars = all_cmake_vars, + cmake_vars = llvm_all_cmake_vars, dst = "include/llvm/Config/llvm-config.h", ) expand_cmake_vars( name = "abi_breaking_gen", src = "include/llvm/Config/abi-breaking.h.cmake", - cmake_vars = all_cmake_vars, + cmake_vars = llvm_all_cmake_vars, dst = "include/llvm/Config/abi-breaking.h", ) @@ -240,14 +121,7 @@ cc_library( "include/llvm/Config/config.h", "include/llvm/Config/llvm-config.h", ], - defines = [ - "LLVM_ENABLE_STATS", - "__STDC_LIMIT_MACROS", - "__STDC_CONSTANT_MACROS", - "__STDC_FORMAT_MACROS", - "_DEBUG", - "LLVM_BUILD_GLOBAL_ISEL", - ], + defines = LLVM_DEFINES, includes = ["include"], ) @@ -262,17 +136,6 @@ genrule( ) # Rules that apply the LLVM tblgen tool. -gentbl( - name = "intrinsics_gen", - tbl_outs = [("-gen-intrinsic", "include/llvm/IR/Intrinsics.inc")], - tblgen = ":llvm-tblgen", - td_file = "include/llvm/IR/Intrinsics.td", - td_srcs = glob([ - "include/llvm/CodeGen/*.td", - "include/llvm/IR/Intrinsics*.td", - ]), -) - gentbl( name = "attributes_gen", tbl_outs = [("-gen-attrs", "include/llvm/IR/Attributes.inc")], @@ -292,6 +155,42 @@ gentbl( ], ) +gentbl( + name = "instcombine_transforms_gen", + tbl_outs = [( + "-gen-searchable-tables", + "lib/Transforms/InstCombine/InstCombineTables.inc", + )], + tblgen = ":llvm-tblgen", + td_file = "lib/Transforms/InstCombine/InstCombineTables.td", + td_srcs = glob([ + "include/llvm/CodeGen/*.td", + "include/llvm/IR/Intrinsics*.td", + ]) + ["include/llvm/TableGen/SearchableTable.td"], +) + +gentbl( + name = "intrinsic_enums_gen", + tbl_outs = [("-gen-intrinsic-enums", "include/llvm/IR/IntrinsicEnums.inc")], + tblgen = ":llvm-tblgen", + td_file = "include/llvm/IR/Intrinsics.td", + td_srcs = glob([ + "include/llvm/CodeGen/*.td", + "include/llvm/IR/Intrinsics*.td", + ]), +) + +gentbl( + name = "intrinsics_impl_gen", + tbl_outs = [("-gen-intrinsic-impl", "include/llvm/IR/IntrinsicImpl.inc")], + tblgen = ":llvm-tblgen", + td_file = "include/llvm/IR/Intrinsics.td", + td_srcs = glob([ + "include/llvm/CodeGen/*.td", + "include/llvm/IR/Intrinsics*.td", + ]), +) + # Binary targets used by Tensorflow. cc_binary( name = "llvm-tblgen", @@ -299,11 +198,7 @@ cc_binary( "utils/TableGen/*.cpp", "utils/TableGen/*.h", ]), - linkopts = [ - "-lm", - "-ldl", - "-lpthread", - ], + linkopts = LLVM_LINKOPTS, stamp = 0, deps = [ ":config", @@ -319,11 +214,7 @@ cc_binary( "utils/FileCheck/*.cpp", "utils/FileCheck/*.h", ]), - linkopts = [ - "-ldl", - "-lm", - "-lpthread", - ], + linkopts = LLVM_LINKOPTS, stamp = 0, deps = [":support"], ) @@ -494,7 +385,8 @@ cc_library( "include/llvm/Target/AArch64/AsmParser/*.inc", "lib/Target/AArch64/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_desc", ":aarch64_info", @@ -519,7 +411,8 @@ cc_library( "include/llvm/Target/AArch64/InstPrinter/*.inc", "lib/Target/AArch64/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_target_gen", ":aarch64_utils", @@ -542,7 +435,8 @@ cc_library( "include/llvm/Target/AArch64/*.inc", "lib/Target/AArch64/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_asm_printer", ":aarch64_desc", @@ -575,14 +469,16 @@ cc_library( "include/llvm/Target/AArch64/MCTargetDesc/*.inc", "lib/Target/AArch64/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_asm_printer", ":aarch64_info", ":aarch64_target_gen", ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":support", ], @@ -601,7 +497,8 @@ cc_library( "include/llvm/Target/AArch64/Disassembler/*.inc", "lib/Target/AArch64/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_desc", ":aarch64_info", @@ -629,7 +526,8 @@ cc_library( "lib/Target/AArch64/AArch64*.h", "lib/Target/AArch64/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":code_gen", ":config", @@ -652,7 +550,8 @@ cc_library( "include/llvm/Target/AArch64/Utils/*.inc", "lib/Target/AArch64/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_target_gen", ":config", @@ -674,6 +573,8 @@ cc_library( "include/llvm/Transforms/AggressiveInstCombine/*.def", "include/llvm/Transforms/AggressiveInstCombine/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -698,6 +599,8 @@ cc_library( "include/llvm/Analysis/*.def", "include/llvm/Analysis/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -721,7 +624,8 @@ cc_library( "include/llvm/Target/AMDGPU/MCTargetDesc/*.inc", "lib/Target/AMDGPU/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_asm_printer", ":amdgpu_info", @@ -746,7 +650,8 @@ cc_library( "include/llvm/Target/AMDGPU/Disassembler/*.inc", "lib/Target/AMDGPU/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_desc", ":amdgpu_info", @@ -771,7 +676,8 @@ cc_library( "include/llvm/Target/AMDGPU/TargetInfo/*.inc", "lib/Target/AMDGPU/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_target_gen", ":config", @@ -793,7 +699,8 @@ cc_library( "include/llvm/Target/AMDGPU/Utils/*.inc", "lib/Target/AMDGPU/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_target_gen", ":config", @@ -816,7 +723,8 @@ cc_library( "include/llvm/Target/AMDGPU/AsmParser/*.inc", "lib/Target/AMDGPU/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_desc", ":amdgpu_info", @@ -841,7 +749,8 @@ cc_library( "include/llvm/Target/AMDGPU/InstPrinter/*.inc", "lib/Target/AMDGPU/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_utils", ":config", @@ -863,7 +772,8 @@ cc_library( "include/llvm/Target/AMDGPU/*.inc", "lib/Target/AMDGPU/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_asm_printer", ":amdgpu_desc", @@ -899,7 +809,8 @@ cc_library( "include/llvm/Target/ARM/AsmParser/*.inc", "lib/Target/ARM/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_desc", ":arm_info", @@ -925,7 +836,8 @@ cc_library( "lib/Target/ARM/*.h", "lib/Target/ARM/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_info", ":arm_target_gen", @@ -949,7 +861,8 @@ cc_library( "include/llvm/Target/ARM/*.inc", "lib/Target/ARM/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":arm_asm_printer", @@ -966,6 +879,7 @@ cc_library( ":selection_dag", ":support", ":target", + ":transform_utils", ], ) @@ -984,14 +898,16 @@ cc_library( "include/llvm/Target/ARM/MCTargetDesc/*.inc", "lib/Target/ARM/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_asm_printer", ":arm_info", ":arm_target_gen", ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":mc_disassembler", ":support", @@ -1011,7 +927,8 @@ cc_library( "include/llvm/Target/ARM/Disassembler/*.inc", "lib/Target/ARM/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_desc", ":arm_info", @@ -1036,7 +953,8 @@ cc_library( "include/llvm/Target/ARM/TargetInfo/*.inc", "lib/Target/ARM/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_target_gen", ":config", @@ -1059,7 +977,8 @@ cc_library( "include/llvm/Target/ARM/Utils/*.inc", "lib/Target/ARM/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_target_gen", ":config", @@ -1081,6 +1000,8 @@ cc_library( "include/llvm/AsmParser/*.def", "include/llvm/AsmParser/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -1103,6 +1024,8 @@ cc_library( "include/llvm/CodeGen/AsmPrinter/*.inc", "lib/CodeGen/AsmPrinter/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":binary_format", @@ -1133,6 +1056,8 @@ cc_library( "include/llvm/BinaryFormat/ELFRelocs/*.def", "include/llvm/BinaryFormat/WasmRelocs/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":support", @@ -1153,6 +1078,8 @@ cc_library( "include/llvm/Bitcode/Reader/*.inc", "include/llvm/Bitcode/BitstreamReader.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1176,6 +1103,8 @@ cc_library( "include/llvm/Bitcode/BitcodeWriterPass.h", "include/llvm/Bitcode/BitstreamWriter.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -1200,6 +1129,8 @@ cc_library( "include/llvm/CodeGen/*.inc", "include/llvm/CodeGen/**/*.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":bit_reader", @@ -1237,12 +1168,15 @@ cc_library( "include/llvm/*.h", "include/llvm/Analysis/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":attributes_compat_gen", ":attributes_gen", ":binary_format", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":support", ], ) @@ -1260,6 +1194,8 @@ cc_library( "include/llvm/DebugInfo/CodeView/*.def", "include/llvm/DebugInfo/CodeView/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -1281,6 +1217,8 @@ cc_library( "include/llvm/DebugInfo/MSF/*.def", "include/llvm/DebugInfo/MSF/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":support", @@ -1300,6 +1238,8 @@ cc_library( "include/llvm/Demangle/*.def", "include/llvm/Demangle/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [":config"], ) @@ -1316,6 +1256,8 @@ cc_library( "include/llvm/ExecutionEngine/*.def", "include/llvm/ExecutionEngine/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1340,6 +1282,8 @@ cc_library( "include/llvm/CodeGen/GlobalISel/*.def", "include/llvm/CodeGen/GlobalISel/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":code_gen", @@ -1369,6 +1313,8 @@ cc_library( "include/llvm/Transforms/InstrProfiling.h", "include/llvm/Transforms/PGOInstrumentation.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -1393,10 +1339,13 @@ cc_library( "include/llvm/Transforms/InstCombine/*.def", "include/llvm/Transforms/InstCombine/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", ":core", + ":instcombine_transforms_gen", ":support", ":transform_utils", ], @@ -1418,6 +1367,8 @@ cc_library( "include/llvm/Transforms/IPO/*.def", "include/llvm/Transforms/IPO/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":aggressive_inst_combine", ":analysis", @@ -1451,6 +1402,8 @@ cc_library( "include/llvm/IRReader/*.def", "include/llvm/IRReader/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":asm_parser", ":bit_reader", @@ -1473,6 +1426,8 @@ cc_library( "include/llvm/Linker/*.def", "include/llvm/Linker/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1494,6 +1449,8 @@ cc_library( "include/llvm/MC/*.def", "include/llvm/MC/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -1515,6 +1472,8 @@ cc_library( "include/llvm/MC/MCDisassembler/*.def", "include/llvm/MC/MCDisassembler/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1535,6 +1494,8 @@ cc_library( "include/llvm/MC/MCParser/*.def", "include/llvm/MC/MCParser/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1555,7 +1516,8 @@ cc_library( "include/llvm/Target/NVPTX/InstPrinter/*.inc", "lib/Target/NVPTX/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ "nvptx_target_gen", ":attributes_gen", @@ -1579,7 +1541,8 @@ cc_library( "include/llvm/Target/NVPTX/*.inc", "lib/Target/NVPTX/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":asm_printer", @@ -1613,7 +1576,8 @@ cc_library( "include/llvm/Target/NVPTX/MCTargetDesc/*.inc", "lib/Target/NVPTX/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ "nvptx_target_gen", ":config", @@ -1639,7 +1603,8 @@ cc_library( "lib/Target/NVPTX/NVPTX.h", "lib/Target/NVPTX/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ "nvptx_target_gen", ":attributes_gen", @@ -1663,6 +1628,8 @@ cc_library( "include/llvm/Object/*.def", "include/llvm/Object/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":bit_reader", @@ -1688,6 +1655,8 @@ cc_library( "include/llvm/Transforms/ObjCARC/*.def", "include/llvm/Transforms/ObjCARC/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -1710,13 +1679,17 @@ cc_library( "include/llvm/ExecutionEngine/Orc/*.def", "include/llvm/ExecutionEngine/Orc/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", ":execution_engine", + ":mc", ":object", ":runtime_dyld", ":support", + ":target", ":transform_utils", ], ) @@ -1734,7 +1707,8 @@ cc_library( "include/llvm/Target/PowerPC/AsmParser/*.inc", "lib/Target/PowerPC/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1758,11 +1732,13 @@ cc_library( "include/llvm/Target/PowerPC/InstPrinter/*.inc", "lib/Target/PowerPC/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":powerpc_info", ":powerpc_target_gen", @@ -1783,7 +1759,8 @@ cc_library( "include/llvm/Target/PowerPC/*.inc", "lib/Target/PowerPC/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":asm_printer", @@ -1815,11 +1792,13 @@ cc_library( "include/llvm/Target/PowerPC/MCTargetDesc/*.inc", "lib/Target/PowerPC/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":powerpc_asm_printer", ":powerpc_info", @@ -1841,7 +1820,8 @@ cc_library( "include/llvm/Target/PowerPC/Disassembler/*.inc", "lib/Target/PowerPC/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc_disassembler", @@ -1865,12 +1845,12 @@ cc_library( "lib/Target/PowerPC/PPC*.h", "lib/Target/PowerPC/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":attributes_gen", ":config", ":core", - ":intrinsics_gen", ":powerpc_target_gen", ":support", ":target", @@ -1890,6 +1870,8 @@ cc_library( "include/llvm/ProfileData/*.def", "include/llvm/ProfileData/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1918,6 +1900,8 @@ cc_library( "include/llvm/ExecutionEngine/RTDyldMemoryManager.h", "include/llvm/ExecutionEngine/RuntimeDyld*.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1945,6 +1929,8 @@ cc_library( "include/llvm/Transforms/IPO.h", "include/llvm/Transforms/IPO/SCCP.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":aggressive_inst_combine", ":analysis", @@ -1970,6 +1956,8 @@ cc_library( "include/llvm/CodeGen/SelectionDAG/*.def", "include/llvm/CodeGen/SelectionDAG/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":code_gen", @@ -2007,6 +1995,8 @@ cc_library( "include/llvm/BinaryFormat/MachO.def", "include/llvm/Support/VCSRevision.h", ], + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":demangle", @@ -2029,6 +2019,8 @@ cc_library( "include/llvm/TableGen/*.inc", "include/llvm/Target/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2054,6 +2046,8 @@ cc_library( "include/llvm/CodeGen/*.def", "include/llvm/CodeGen/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -2078,6 +2072,8 @@ cc_library( "include/llvm/Transforms/Utils/*.def", "include/llvm/Transforms/Utils/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -2101,6 +2097,8 @@ cc_library( "include/llvm/Transforms/Vectorize/*.inc", "include/llvm/Transforms/Vectorize.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -2124,7 +2122,8 @@ cc_library( "include/llvm/Target/X86/AsmParser/*.inc", "lib/Target/X86/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2149,7 +2148,8 @@ cc_library( "include/llvm/Target/X86/InstPrinter/*.inc", "lib/Target/X86/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2173,7 +2173,8 @@ cc_library( "include/llvm/Target/X86/*.inc", "lib/Target/X86/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":asm_printer", @@ -2206,7 +2207,8 @@ cc_library( "include/llvm/Target/X86/MCTargetDesc/*.inc", "lib/Target/X86/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2231,7 +2233,8 @@ cc_library( "include/llvm/Target/X86/Disassembler/*.inc", "lib/Target/X86/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc_disassembler", @@ -2254,7 +2257,8 @@ cc_library( "include/llvm/Target/X86/TargetInfo/*.inc", "lib/Target/X86/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2276,7 +2280,8 @@ cc_library( "include/llvm/Target/X86/Utils/*.inc", "lib/Target/X86/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":code_gen", ":config", diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 0efcf319bd99be79263a1b9cd23544523a4c8076..2e809e5f147d9e2b359dbf8fcc57575572bc64cd 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -105,3 +105,136 @@ def expand_cmake_vars(name, src, dst, cmake_vars): "< $< > $@") ) +# TODO(phawkins): the set of CMake variables was hardcoded for expediency. +# However, we should really detect many of these via configure-time tests. + +# The set of CMake variables common to all targets. +cmake_vars = { + # Headers + "HAVE_DIRENT_H": 1, + "HAVE_DLFCN_H": 1, + "HAVE_ERRNO_H": 1, + "HAVE_EXECINFO_H": 1, + "HAVE_FCNTL_H": 1, + "HAVE_INTTYPES_H": 1, + "HAVE_PTHREAD_H": 1, + "HAVE_SIGNAL_H": 1, + "HAVE_STDINT_H": 1, + "HAVE_SYS_IOCTL_H": 1, + "HAVE_SYS_MMAN_H": 1, + "HAVE_SYS_PARAM_H": 1, + "HAVE_SYS_RESOURCE_H": 1, + "HAVE_SYS_STAT_H": 1, + "HAVE_SYS_TIME_H": 1, + "HAVE_SYS_TYPES_H": 1, + "HAVE_TERMIOS_H": 1, + "HAVE_UNISTD_H": 1, + "HAVE_ZLIB_H": 1, + + # Features + "HAVE_BACKTRACE": 1, + "BACKTRACE_HEADER": "execinfo.h", + "HAVE_DLOPEN": 1, + "HAVE_FUTIMES": 1, + "HAVE_GETCWD": 1, + "HAVE_GETPAGESIZE": 1, + "HAVE_GETRLIMIT": 1, + "HAVE_GETRUSAGE": 1, + "HAVE_GETTIMEOFDAY": 1, + "HAVE_INT64_T": 1, + "HAVE_ISATTY": 1, + "HAVE_LIBEDIT": 1, + "HAVE_LIBPTHREAD": 1, + "HAVE_LIBZ": 1, + "HAVE_MKDTEMP": 1, + "HAVE_MKSTEMP": 1, + "HAVE_MKTEMP": 1, + "HAVE_PREAD": 1, + "HAVE_PTHREAD_GETSPECIFIC": 1, + "HAVE_PTHREAD_MUTEX_LOCK": 1, + "HAVE_PTHREAD_RWLOCK_INIT": 1, + "HAVE_REALPATH": 1, + "HAVE_SBRK": 1, + "HAVE_SETENV": 1, + "HAVE_SETRLIMIT": 1, + "HAVE_SIGALTSTACK": 1, + "HAVE_STRERROR": 1, + "HAVE_STRERROR_R": 1, + "HAVE_STRTOLL": 1, + "HAVE_SYSCONF": 1, + "HAVE_UINT64_T": 1, + "HAVE__UNWIND_BACKTRACE": 1, + + # LLVM features + "ENABLE_BACKTRACES": 1, + "LLVM_BINDIR": "/dev/null", + "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, + "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, + "LLVM_ENABLE_THREADS": 1, + "LLVM_ENABLE_ZLIB": 1, + "LLVM_HAS_ATOMICS": 1, + "LLVM_INCLUDEDIR": "/dev/null", + "LLVM_INFODIR": "/dev/null", + "LLVM_MANDIR": "/dev/null", + "LLVM_NATIVE_TARGET": 1, + "LLVM_NATIVE_TARGETINFO": 1, + "LLVM_NATIVE_TARGETMC": 1, + "LLVM_NATIVE_ASMPRINTER": 1, + "LLVM_NATIVE_ASMPARSER": 1, + "LLVM_NATIVE_DISASSEMBLER": 1, + "LLVM_ON_UNIX": 1, + "LLVM_PREFIX": "/dev/null", + "LLVM_VERSION_MAJOR": 0, + "LLVM_VERSION_MINOR": 0, + "LLVM_VERSION_PATCH": 0, + "LTDL_SHLIB_EXT": ".so", + "PACKAGE_NAME": "llvm", + "PACKAGE_STRING": "llvm tensorflow-trunk", + "PACKAGE_VERSION": "tensorflow-trunk", + "RETSIGTYPE": "void", +} + +# CMake variables specific to the Linux platform +linux_cmake_vars = { + "HAVE_MALLOC_H": 1, + "HAVE_LINK_H": 1, + "HAVE_MALLINFO": 1, + "HAVE_FUTIMENS": 1, +} + +# CMake variables specific to the Darwin (Mac OS X) platform. +darwin_cmake_vars = { + "HAVE_MALLOC_MALLOC_H": 1, +} + +# Select a set of CMake variables based on the platform. +# TODO(phawkins): use a better method to select the right host triple, rather +# than hardcoding x86_64. +llvm_all_cmake_vars = select({ + "@org_tensorflow//tensorflow:darwin": cmake_var_string( + cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") + + darwin_cmake_vars), + "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string( + cmake_vars + + llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") + + linux_cmake_vars, + ), + "//conditions:default": cmake_var_string( + cmake_vars + + llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") + + linux_cmake_vars), + +}) + +LLVM_LINKOPTS = ["-ldl", "-lm", "-lpthread"] + +LLVM_DEFINES = [ + "LLVM_ENABLE_STATS", + "__STDC_LIMIT_MACROS", + "__STDC_CONSTANT_MACROS", + "__STDC_FORMAT_MACROS", + "_DEBUG", + "LLVM_BUILD_GLOBAL_ISEL", +] + +LLVM_COPTS = [] diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index 017613abb0246fdcda8ca189da694102abfc8529..a058c46cc424398c7062be329910b5e9e9e2f9cc 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -34,7 +34,7 @@ filegroup( "@org_tensorflow//tensorflow:windows": [ "@mkl_windows//:LICENSE", ], - "//conditions:default": [] + "//conditions:default": [], }), visibility = ["//visibility:public"], ) @@ -55,6 +55,6 @@ cc_library( "@mkl_windows//:mkl_headers", "@mkl_windows//:mkl_libs_windows", ], - "//conditions:default": [] + "//conditions:default": [], }), ) diff --git a/third_party/png.BUILD b/third_party/png.BUILD index 76ab32d69c35055b3796b8f612133394758db330..17c5449cc0d66c407689836f8be4872ab713f577 100644 --- a/third_party/png.BUILD +++ b/third_party/png.BUILD @@ -28,7 +28,14 @@ cc_library( "pngwrite.c", "pngwtran.c", "pngwutil.c", - ], + ] + select({ + "@org_tensorflow//tensorflow:linux_ppc64le": [ + "powerpc/powerpc_init.c", + "powerpc/filter_vsx_intrinsics.c", + ], + "//conditions:default": [ + ], + }), hdrs = [ "png.h", "pngconf.h", diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index 954f21f5f8fe8029c869f8870464a750cfc8a3db..3c7e5c84695e454d96585f285869153e70867955 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -6,6 +6,7 @@ * `PYTHON_LIB_PATH`: Location of python libraries. """ +_BAZEL_SH = "BAZEL_SH" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH" _TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO" @@ -152,6 +153,22 @@ def _get_python_bin(repository_ctx): _PYTHON_BIN_PATH, repository_ctx.os.environ.get("PATH", ""))) +def _get_bash_bin(repository_ctx): + """Gets the bash bin path.""" + bash_bin = repository_ctx.os.environ.get(_BAZEL_SH) + if bash_bin != None: + return bash_bin + else: + bash_bin_path = repository_ctx.which("bash") + if bash_bin_path != None: + return str(bash_bin_path) + else: + _fail("Cannot find bash in PATH, please make sure " + + "bash is installed and add its directory in PATH, or --define " + + "%s='/path/to/bash'.\nPATH=%s" % ( + _BAZEL_SH, repository_ctx.os.environ.get("PATH", ""))) + + def _get_python_lib(repository_ctx, python_bin): """Gets the python lib path.""" python_lib = repository_ctx.os.environ.get(_PYTHON_LIB_PATH) @@ -184,14 +201,14 @@ def _get_python_lib(repository_ctx, python_bin): " print(paths[0])\n" + "END") cmd = '%s - %s' % (python_bin, print_lib) - result = repository_ctx.execute(["bash", "-c", cmd]) + result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) return result.stdout.strip('\n') def _check_python_lib(repository_ctx, python_lib): """Checks the python lib path.""" cmd = 'test -d "%s" -a -x "%s"' % (python_lib, python_lib) - result = repository_ctx.execute(["bash", "-c", cmd]) + result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) if result.return_code == 1: _fail("Invalid python library path: %s" % python_lib) @@ -199,7 +216,7 @@ def _check_python_lib(repository_ctx, python_lib): def _check_python_bin(repository_ctx, python_bin): """Checks the python bin path.""" cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin) - result = repository_ctx.execute(["bash", "-c", cmd]) + result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) if result.return_code == 1: _fail("--define %s='%s' is not executable. Is it the python binary?" % ( _PYTHON_BIN_PATH, python_bin)) @@ -294,6 +311,7 @@ def _python_autoconf_impl(repository_ctx): python_configure = repository_rule( implementation = _python_autoconf_impl, environ = [ + _BAZEL_SH, _PYTHON_BIN_PATH, _PYTHON_LIB_PATH, _TF_PYTHON_CONFIG_REPO, diff --git a/util/python/BUILD b/third_party/python_runtime/BUILD similarity index 86% rename from util/python/BUILD rename to third_party/python_runtime/BUILD index f5fa0c6d29c905cd9073e5001e993da5c8560ec0..2a1609191fe3515615cf537932532c71f4ef2773 100644 --- a/util/python/BUILD +++ b/third_party/python_runtime/BUILD @@ -3,6 +3,6 @@ licenses(["notice"]) # New BSD, Python Software Foundation package(default_visibility = ["//visibility:public"]) alias( - name = "python_headers", + name = "headers", actual = "@local_config_python//:python_headers", ) diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 36f5aa5bdee43a511abf5634af85643ac7e11cfc..9cee1fcc4b5c2b05ecc09b4f372eadeca9e91be8 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -16,8 +16,6 @@ _SINGLE_URL_WHITELIST = depset([ "arm_compiler", - "ortools_archive", - "gemmlowp", ]) def _is_windows(ctx): @@ -88,7 +86,9 @@ def _tf_http_archive(ctx): if ctx.attr.patch_file != None: _apply_patch(ctx, ctx.attr.patch_file) if ctx.attr.build_file != None: - ctx.template("BUILD", ctx.attr.build_file, { + # Use BUILD.bazel to avoid conflict with third party projects with + # BUILD or build (directory) underneath. + ctx.template("BUILD.bazel", ctx.attr.build_file, { "%prefix%": ".." if _repos_are_siblings() else "external", }, False) diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD index 6da795358927f5cb8db7cb0d7ea653b80f8b5226..2876f305f1f74e8bba9a364b1ef582f42c72c313 100644 --- a/third_party/sqlite.BUILD +++ b/third_party/sqlite.BUILD @@ -5,6 +5,7 @@ licenses(["unencumbered"]) # Public Domain SQLITE_COPTS = [ "-Os", + "-DSQLITE_ENABLE_JSON1", "-DHAVE_DECL_STRERROR_R=1", "-DHAVE_STDINT_H=1", "-DHAVE_INTTYPES_H=1", diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..fc3183a754369fc30dbce40c2bf7b6828ea497c3 --- /dev/null +++ b/third_party/toolchains/BUILD @@ -0,0 +1,22 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +# Platform for use with remote execution with +# custom container based off RBE Ubuntu16_04 +# http://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04 +# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cpu +platform( + name = "rbe_ubuntu16_04-tf", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains//constraints:xenial", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/asci-toolchain/nosla-ubuntu16_04-tf@sha256:800a7b68cabef15419695c188ed33ed70adf678c2371b97b236f3ae26c38274d" + }""", +) diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl index 6b7e5a88086f8e5e67fa86a0e9377c3c2afd535d..ffba9850bb80a880d5b95afacbad296ec1f2df54 100644 --- a/third_party/toolchains/clang6/CROSSTOOL.tpl +++ b/third_party/toolchains/clang6/CROSSTOOL.tpl @@ -76,9 +76,6 @@ toolchain { # This adds a little bit more durability to our Clang build. # - # At the moment, this only only be needed for: - # - add_boringssl_s390x.patch: --Wa,--noexecstack - # # Folks who do maintenance work on TF Bazel Clang should consider # commenting out these lines, while doing that work, to gain a better # understanding of what the intersection of support looks like between GCC diff --git a/tools/bazel.rc b/tools/bazel.rc index 03aa52da1f6e9c113d6db6cb9c1d38b5be21927d..1c1e6afb65ab8da5b689d58ecaec6ac6c8a69bb8 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -1,14 +1,8 @@ -# By default, we don't distinct target and host platfroms. -# When doing cross compilation, use --config=cross_compile to distinct them. -build --distinct_host_configuration=false -build:cross_compile --distinct_host_configuration=true - # Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the # target CPU to build transient dependencies correctly. See # https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu build:android --crosstool_top=//external:android/crosstool build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain -build:android --config=cross_compile build:android_arm --config=android build:android_arm --cpu=armeabi-v7a build:android_arm --fat_apk_cpu=armeabi-v7a